Fix pytorch2onnx failed for interpolate op with PyTorch==1.6.0(mmdet#4646) (#848)

* Fix pytorch2onnx for yolov3 with torch==1.6.0

* update and add test for F.interpolate

* add test_onnx.py with build_cuda
pull/863/head
RunningLeon 2021-02-26 10:49:12 +08:00 committed by GitHub
parent 58a8483352
commit 72e4cc12bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 8 deletions

View File

@ -185,7 +185,7 @@ jobs:
- name: Run unittests and generate coverage report
run: |
pip install -r requirements/test.txt
coverage run --branch --source=mmcv -m pytest tests/ --ignore=tests/test_ops/test_onnx.py
coverage run --branch --source=mmcv -m pytest tests/
coverage xml
coverage report -m
# Only upload coverage report for python3.7 && pytorch1.6

View File

@ -228,18 +228,32 @@ def _interpolate_size_to_scales(g, input, output_size, dim):
def _interpolate_get_scales_if_available(g, scales):
if len(scales) == 0:
return None
available_scales = _maybe_get_const(scales[0],
'fs') != -1 and not _is_none(scales[0])
# scales[0] is ListType in Pytorch == 1.7.0
scale_desc = 'fs' if scales[0].type().kind() == 'ListType' else 'f'
available_scales = _maybe_get_const(
scales[0], scale_desc) != -1 and not _is_none(scales[0])
if not available_scales:
return None
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
scales_list = g.op(
'Constant', value_t=torch.tensor(_maybe_get_const(scales[0], 'fs')))
# modify to support PyTorch==1.7.0
# https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501
scales = g.op('Concat', offsets, scales_list, axis_i=0)
if scale_desc == 'fs':
scales_list = g.op(
'Constant',
value_t=torch.tensor(_maybe_get_const(scales[0], scale_desc)))
# modify to support PyTorch==1.7.0
# https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501
scales = g.op('Concat', offsets, scales_list, axis_i=0)
else:
# for PyTorch < 1.7.0
scales_list = []
for scale in scales:
unsqueezed_scale = _unsqueeze_helper(g, scale, 0)
# ONNX only supports float for the scales. double -> float.
unsqueezed_scale = g.op(
'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float'])
scales_list.append(unsqueezed_scale)
scales = g.op('Concat', offsets, *scales_list, axis_i=0)
return scales

View File

@ -299,3 +299,29 @@ def test_simplify():
numel_after = len(slim_onnx_model.graph.node)
os.remove(onnx_file)
assert numel_before == 18 and numel_after == 1, 'Simplify failed.'
def test_interpolate():
from mmcv.onnx.symbolic import register_extra_symbolics
opset_version = 11
register_extra_symbolics(opset_version)
def func(feat, scale_factor=2):
out = nn.functional.interpolate(feat, scale_factor=scale_factor)
return out
net = WrapFunction(func)
net = net.cpu().eval()
dummy_input = torch.randn(2, 4, 8, 8).cpu()
torch.onnx.export(
net,
dummy_input,
onnx_file,
input_names=['input'],
opset_version=opset_version)
sess = rt.InferenceSession(onnx_file)
onnx_result = sess.run(None, {'input': dummy_input.detach().numpy()})
pytorch_result = func(dummy_input).detach().numpy()
if os.path.exists(onnx_file):
os.remove(onnx_file)
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)