mirror of https://github.com/open-mmlab/mmcv.git
add unittest for onnx convert (#608)
* add unittest for onnx convert * build onnx and onnxruntime in CI * skip onnx op unit test while using CUDA * fix offset==0 case in NMS * remove tmp file used in test * delete tmp file before assert so that we can remove the tmp file anywaypull/629/head
parent
65a60a3d7d
commit
23b2bdbf52
|
@ -111,7 +111,7 @@ jobs:
|
|||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install unittest dependencies
|
||||
run: pip install pytest coverage lmdb PyTurboJPEG
|
||||
run: pip install pytest coverage lmdb PyTurboJPEG onnx==1.6.0 onnxruntime==1.2.0
|
||||
- name: Build and install
|
||||
run: rm -rf .eggs && pip install -e .
|
||||
- name: Run unittests and generate coverage report
|
||||
|
@ -181,7 +181,7 @@ jobs:
|
|||
run: rm -rf .eggs && pip install -e .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source=mmcv -m pytest tests/
|
||||
coverage run --branch --source=mmcv -m pytest tests/ --ignore=tests/test_ops/test_onnx.py
|
||||
coverage xml
|
||||
coverage report -m
|
||||
# Only upload coverage report for python3.7 && pytorch1.5
|
||||
|
@ -220,6 +220,8 @@ jobs:
|
|||
- name: Install Pillow
|
||||
run: pip install Pillow==6.2.2
|
||||
if: ${{matrix.torchvision == '0.4.2'}}
|
||||
- name: Install ONNX
|
||||
run: pip install onnx==1.6.0 onnxruntime==1.2.0
|
||||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} --no-cache-dir
|
||||
- name: Build and install
|
||||
|
|
|
@ -111,6 +111,9 @@ def nms(boxes, scores, iou_threshold, offset=0):
|
|||
# ONNX only support offset == 1
|
||||
boxes[:, -2:] -= 1
|
||||
inds = NMSop.apply(boxes, scores, iou_threshold, offset)
|
||||
if torch.onnx.is_in_onnx_export() and offset == 0:
|
||||
# ONNX only support offset == 1
|
||||
boxes[:, -2:] += 1
|
||||
dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
|
||||
if is_numpy:
|
||||
dets = dets.cpu().numpy()
|
||||
|
|
|
@ -14,6 +14,6 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = mmcv
|
||||
known_third_party = Cython,addict,cv2,m2r,numpy,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf
|
||||
known_third_party = Cython,addict,cv2,m2r,numpy,onnx,onnxruntime,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as rt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
onnx_file = 'tmp.onnx'
|
||||
|
||||
|
||||
class WrapFunction(nn.Module):
|
||||
|
||||
def __init__(self, wrapped_function):
|
||||
super(WrapFunction, self).__init__()
|
||||
self.wrapped_function = wrapped_function
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.wrapped_function(*args, **kwargs)
|
||||
|
||||
|
||||
class Testonnx(object):
|
||||
|
||||
def test_nms(self):
|
||||
from mmcv.ops import nms
|
||||
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
|
||||
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
|
||||
dtype=np.float32)
|
||||
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
|
||||
boxes = torch.from_numpy(np_boxes)
|
||||
scores = torch.from_numpy(np_scores)
|
||||
pytorch_dets, _ = nms(boxes, scores, iou_threshold=0.3, offset=0)
|
||||
pytorch_score = pytorch_dets[:, 4]
|
||||
nms = partial(nms, iou_threshold=0.3, offset=0)
|
||||
wrapped_model = WrapFunction(nms)
|
||||
wrapped_model.cpu().eval()
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
wrapped_model, (boxes, scores),
|
||||
onnx_file,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
input_names=['boxes', 'scores'],
|
||||
opset_version=11)
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
|
||||
# get onnx output
|
||||
input_all = [node.name for node in onnx_model.graph.input]
|
||||
input_initializer = [
|
||||
node.name for node in onnx_model.graph.initializer
|
||||
]
|
||||
net_feed_input = list(set(input_all) - set(input_initializer))
|
||||
assert (len(net_feed_input) == 2)
|
||||
sess = rt.InferenceSession(onnx_file)
|
||||
onnx_dets, _ = sess.run(None, {
|
||||
'scores': scores.detach().numpy(),
|
||||
'boxes': boxes.detach().numpy()
|
||||
})
|
||||
onnx_score = onnx_dets[:, 4]
|
||||
os.remove(onnx_file)
|
||||
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)
|
Loading…
Reference in New Issue