mirror of https://github.com/open-mmlab/mmcv.git
Fix pytorch2onnx failed for interpolate op with PyTorch==1.6.0(mmdet#4646) (#848)
* Fix pytorch2onnx for yolov3 with torch==1.6.0 * update and add test for F.interpolate * add test_onnx.py with build_cudapull/863/head
parent
58a8483352
commit
72e4cc12bd
|
@ -185,7 +185,7 @@ jobs:
|
|||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
pip install -r requirements/test.txt
|
||||
coverage run --branch --source=mmcv -m pytest tests/ --ignore=tests/test_ops/test_onnx.py
|
||||
coverage run --branch --source=mmcv -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
# Only upload coverage report for python3.7 && pytorch1.6
|
||||
|
|
|
@ -228,18 +228,32 @@ def _interpolate_size_to_scales(g, input, output_size, dim):
|
|||
def _interpolate_get_scales_if_available(g, scales):
|
||||
if len(scales) == 0:
|
||||
return None
|
||||
available_scales = _maybe_get_const(scales[0],
|
||||
'fs') != -1 and not _is_none(scales[0])
|
||||
# scales[0] is ListType in Pytorch == 1.7.0
|
||||
scale_desc = 'fs' if scales[0].type().kind() == 'ListType' else 'f'
|
||||
available_scales = _maybe_get_const(
|
||||
scales[0], scale_desc) != -1 and not _is_none(scales[0])
|
||||
|
||||
if not available_scales:
|
||||
return None
|
||||
|
||||
offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
|
||||
scales_list = g.op(
|
||||
'Constant', value_t=torch.tensor(_maybe_get_const(scales[0], 'fs')))
|
||||
# modify to support PyTorch==1.7.0
|
||||
# https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501
|
||||
scales = g.op('Concat', offsets, scales_list, axis_i=0)
|
||||
if scale_desc == 'fs':
|
||||
scales_list = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor(_maybe_get_const(scales[0], scale_desc)))
|
||||
# modify to support PyTorch==1.7.0
|
||||
# https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501
|
||||
scales = g.op('Concat', offsets, scales_list, axis_i=0)
|
||||
else:
|
||||
# for PyTorch < 1.7.0
|
||||
scales_list = []
|
||||
for scale in scales:
|
||||
unsqueezed_scale = _unsqueeze_helper(g, scale, 0)
|
||||
# ONNX only supports float for the scales. double -> float.
|
||||
unsqueezed_scale = g.op(
|
||||
'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float'])
|
||||
scales_list.append(unsqueezed_scale)
|
||||
scales = g.op('Concat', offsets, *scales_list, axis_i=0)
|
||||
return scales
|
||||
|
||||
|
||||
|
|
|
@ -299,3 +299,29 @@ def test_simplify():
|
|||
numel_after = len(slim_onnx_model.graph.node)
|
||||
os.remove(onnx_file)
|
||||
assert numel_before == 18 and numel_after == 1, 'Simplify failed.'
|
||||
|
||||
|
||||
def test_interpolate():
|
||||
from mmcv.onnx.symbolic import register_extra_symbolics
|
||||
opset_version = 11
|
||||
register_extra_symbolics(opset_version)
|
||||
|
||||
def func(feat, scale_factor=2):
|
||||
out = nn.functional.interpolate(feat, scale_factor=scale_factor)
|
||||
return out
|
||||
|
||||
net = WrapFunction(func)
|
||||
net = net.cpu().eval()
|
||||
dummy_input = torch.randn(2, 4, 8, 8).cpu()
|
||||
torch.onnx.export(
|
||||
net,
|
||||
dummy_input,
|
||||
onnx_file,
|
||||
input_names=['input'],
|
||||
opset_version=opset_version)
|
||||
sess = rt.InferenceSession(onnx_file)
|
||||
onnx_result = sess.run(None, {'input': dummy_input.detach().numpy()})
|
||||
pytorch_result = func(dummy_input).detach().numpy()
|
||||
if os.path.exists(onnx_file):
|
||||
os.remove(onnx_file)
|
||||
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)
|
||||
|
|
Loading…
Reference in New Issue