diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index aad4adacd..79a90b0c2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -185,7 +185,7 @@ jobs: - name: Run unittests and generate coverage report run: | pip install -r requirements/test.txt - coverage run --branch --source=mmcv -m pytest tests/ --ignore=tests/test_ops/test_onnx.py + coverage run --branch --source=mmcv -m pytest tests/ coverage xml coverage report -m # Only upload coverage report for python3.7 && pytorch1.6 diff --git a/mmcv/onnx/onnx_utils/symbolic_helper.py b/mmcv/onnx/onnx_utils/symbolic_helper.py index eb57c1b89..1839774ad 100644 --- a/mmcv/onnx/onnx_utils/symbolic_helper.py +++ b/mmcv/onnx/onnx_utils/symbolic_helper.py @@ -228,18 +228,32 @@ def _interpolate_size_to_scales(g, input, output_size, dim): def _interpolate_get_scales_if_available(g, scales): if len(scales) == 0: return None - available_scales = _maybe_get_const(scales[0], - 'fs') != -1 and not _is_none(scales[0]) + # scales[0] is ListType in Pytorch == 1.7.0 + scale_desc = 'fs' if scales[0].type().kind() == 'ListType' else 'f' + available_scales = _maybe_get_const( + scales[0], scale_desc) != -1 and not _is_none(scales[0]) if not available_scales: return None offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32)) - scales_list = g.op( - 'Constant', value_t=torch.tensor(_maybe_get_const(scales[0], 'fs'))) - # modify to support PyTorch==1.7.0 - # https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501 - scales = g.op('Concat', offsets, scales_list, axis_i=0) + if scale_desc == 'fs': + scales_list = g.op( + 'Constant', + value_t=torch.tensor(_maybe_get_const(scales[0], scale_desc))) + # modify to support PyTorch==1.7.0 + # https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501 + scales = g.op('Concat', offsets, scales_list, axis_i=0) + else: + # for PyTorch < 1.7.0 + scales_list = [] + for scale in scales: + unsqueezed_scale = _unsqueeze_helper(g, scale, 0) + # ONNX only supports float for the scales. double -> float. + unsqueezed_scale = g.op( + 'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float']) + scales_list.append(unsqueezed_scale) + scales = g.op('Concat', offsets, *scales_list, axis_i=0) return scales diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index f6f3e68f7..63c9c5628 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -299,3 +299,29 @@ def test_simplify(): numel_after = len(slim_onnx_model.graph.node) os.remove(onnx_file) assert numel_before == 18 and numel_after == 1, 'Simplify failed.' + + +def test_interpolate(): + from mmcv.onnx.symbolic import register_extra_symbolics + opset_version = 11 + register_extra_symbolics(opset_version) + + def func(feat, scale_factor=2): + out = nn.functional.interpolate(feat, scale_factor=scale_factor) + return out + + net = WrapFunction(func) + net = net.cpu().eval() + dummy_input = torch.randn(2, 4, 8, 8).cpu() + torch.onnx.export( + net, + dummy_input, + onnx_file, + input_names=['input'], + opset_version=opset_version) + sess = rt.InferenceSession(onnx_file) + onnx_result = sess.run(None, {'input': dummy_input.detach().numpy()}) + pytorch_result = func(dummy_input).detach().numpy() + if os.path.exists(onnx_file): + os.remove(onnx_file) + assert np.allclose(pytorch_result, onnx_result, atol=1e-3)