add unittest for onnx convert (#608)

* add unittest for onnx convert

* build onnx and onnxruntime in CI

* skip onnx op unit test while using CUDA

* fix offset==0 case in NMS

* remove tmp file used in test

* delete tmp file before assert so that we can remove the tmp file anyway
pull/629/head
robin Han 2020-10-26 11:33:35 +08:00 committed by GitHub
parent 65a60a3d7d
commit 23b2bdbf52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 3 deletions

View File

@ -111,7 +111,7 @@ jobs:
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install unittest dependencies
run: pip install pytest coverage lmdb PyTurboJPEG
run: pip install pytest coverage lmdb PyTurboJPEG onnx==1.6.0 onnxruntime==1.2.0
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report
@ -181,7 +181,7 @@ jobs:
run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source=mmcv -m pytest tests/
coverage run --branch --source=mmcv -m pytest tests/ --ignore=tests/test_ops/test_onnx.py
coverage xml
coverage report -m
# Only upload coverage report for python3.7 && pytorch1.5
@ -220,6 +220,8 @@ jobs:
- name: Install Pillow
run: pip install Pillow==6.2.2
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install ONNX
run: pip install onnx==1.6.0 onnxruntime==1.2.0
- name: Install PyTorch
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} --no-cache-dir
- name: Build and install

View File

@ -111,6 +111,9 @@ def nms(boxes, scores, iou_threshold, offset=0):
# ONNX only support offset == 1
boxes[:, -2:] -= 1
inds = NMSop.apply(boxes, scores, iou_threshold, offset)
if torch.onnx.is_in_onnx_export() and offset == 0:
# ONNX only support offset == 1
boxes[:, -2:] += 1
dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
if is_numpy:
dets = dets.cpu().numpy()

View File

@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcv
known_third_party = Cython,addict,cv2,m2r,numpy,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf
known_third_party = Cython,addict,cv2,m2r,numpy,onnx,onnxruntime,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -0,0 +1,62 @@
import os
from functools import partial
import numpy as np
import onnx
import onnxruntime as rt
import torch
import torch.nn as nn
onnx_file = 'tmp.onnx'
class WrapFunction(nn.Module):
def __init__(self, wrapped_function):
super(WrapFunction, self).__init__()
self.wrapped_function = wrapped_function
def forward(self, *args, **kwargs):
return self.wrapped_function(*args, **kwargs)
class Testonnx(object):
def test_nms(self):
from mmcv.ops import nms
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
pytorch_dets, _ = nms(boxes, scores, iou_threshold=0.3, offset=0)
pytorch_score = pytorch_dets[:, 4]
nms = partial(nms, iou_threshold=0.3, offset=0)
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (boxes, scores),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file)
onnx_dets, _ = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
})
onnx_score = onnx_dets[:, 4]
os.remove(onnx_file)
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)