From 23b2bdbf52c8c4960dc696ec35901146f839fd6d Mon Sep 17 00:00:00 2001 From: robin Han Date: Mon, 26 Oct 2020 11:33:35 +0800 Subject: [PATCH] add unittest for onnx convert (#608) * add unittest for onnx convert * build onnx and onnxruntime in CI * skip onnx op unit test while using CUDA * fix offset==0 case in NMS * remove tmp file used in test * delete tmp file before assert so that we can remove the tmp file anyway --- .github/workflows/build.yml | 6 ++-- mmcv/ops/nms.py | 3 ++ setup.cfg | 2 +- tests/test_ops/test_onnx.py | 62 +++++++++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 tests/test_ops/test_onnx.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ef9807231..db180c3c2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -111,7 +111,7 @@ jobs: - name: Install PyTorch run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install unittest dependencies - run: pip install pytest coverage lmdb PyTurboJPEG + run: pip install pytest coverage lmdb PyTurboJPEG onnx==1.6.0 onnxruntime==1.2.0 - name: Build and install run: rm -rf .eggs && pip install -e . - name: Run unittests and generate coverage report @@ -181,7 +181,7 @@ jobs: run: rm -rf .eggs && pip install -e . - name: Run unittests and generate coverage report run: | - coverage run --branch --source=mmcv -m pytest tests/ + coverage run --branch --source=mmcv -m pytest tests/ --ignore=tests/test_ops/test_onnx.py coverage xml coverage report -m # Only upload coverage report for python3.7 && pytorch1.5 @@ -220,6 +220,8 @@ jobs: - name: Install Pillow run: pip install Pillow==6.2.2 if: ${{matrix.torchvision == '0.4.2'}} + - name: Install ONNX + run: pip install onnx==1.6.0 onnxruntime==1.2.0 - name: Install PyTorch run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} --no-cache-dir - name: Build and install diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index e982b1472..e225ca11d 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -111,6 +111,9 @@ def nms(boxes, scores, iou_threshold, offset=0): # ONNX only support offset == 1 boxes[:, -2:] -= 1 inds = NMSop.apply(boxes, scores, iou_threshold, offset) + if torch.onnx.is_in_onnx_export() and offset == 0: + # ONNX only support offset == 1 + boxes[:, -2:] += 1 dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1) if is_numpy: dets = dets.cpu().numpy() diff --git a/setup.cfg b/setup.cfg index a46dce37c..234d24c0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmcv -known_third_party = Cython,addict,cv2,m2r,numpy,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf +known_third_party = Cython,addict,cv2,m2r,numpy,onnx,onnxruntime,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py new file mode 100644 index 000000000..e0bc7a4ac --- /dev/null +++ b/tests/test_ops/test_onnx.py @@ -0,0 +1,62 @@ +import os +from functools import partial + +import numpy as np +import onnx +import onnxruntime as rt +import torch +import torch.nn as nn + +onnx_file = 'tmp.onnx' + + +class WrapFunction(nn.Module): + + def __init__(self, wrapped_function): + super(WrapFunction, self).__init__() + self.wrapped_function = wrapped_function + + def forward(self, *args, **kwargs): + return self.wrapped_function(*args, **kwargs) + + +class Testonnx(object): + + def test_nms(self): + from mmcv.ops import nms + np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], + [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], + dtype=np.float32) + np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) + boxes = torch.from_numpy(np_boxes) + scores = torch.from_numpy(np_scores) + pytorch_dets, _ = nms(boxes, scores, iou_threshold=0.3, offset=0) + pytorch_score = pytorch_dets[:, 4] + nms = partial(nms, iou_threshold=0.3, offset=0) + wrapped_model = WrapFunction(nms) + wrapped_model.cpu().eval() + with torch.no_grad(): + torch.onnx.export( + wrapped_model, (boxes, scores), + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=['boxes', 'scores'], + opset_version=11) + onnx_model = onnx.load(onnx_file) + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 2) + sess = rt.InferenceSession(onnx_file) + onnx_dets, _ = sess.run(None, { + 'scores': scores.detach().numpy(), + 'boxes': boxes.detach().numpy() + }) + onnx_score = onnx_dets[:, 4] + os.remove(onnx_file) + assert np.allclose(pytorch_score, onnx_score, atol=1e-3)