From 0310c168dce12fa79ce83bfb4d673f17587bf89a Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 20 Jul 2022 19:32:46 +0800 Subject: [PATCH] [Enhancement] Better index put ONNX export. (#704) * Add rewriter for tensor setitem * add version check --- mmdeploy/pytorch/functions/__init__.py | 4 +- mmdeploy/pytorch/functions/tensor_setitem.py | 57 ++++++++++++++++++++ tests/test_pytorch/test_pytorch_functions.py | 32 +++++++++++ 3 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 mmdeploy/pytorch/functions/tensor_setitem.py diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index 3e84ee6ed..6a2ac2285 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -10,6 +10,7 @@ from .masked_fill import masked_fill__onnxruntime from .normalize import normalize__ncnn from .repeat import tensor__repeat__tensorrt from .size import tensor__size__ncnn +from .tensor_setitem import tensor__setitem__default from .topk import topk__dynamic, topk__tensorrt from .triu import triu @@ -18,5 +19,6 @@ __all__ = [ 'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn', 'triu', 'atan2__default', 'normalize__ncnn', 'expand__ncnn', - 'chunk__torchscript', 'masked_fill__onnxruntime' + 'chunk__torchscript', 'masked_fill__onnxruntime', + 'tensor__setitem__default' ] diff --git a/mmdeploy/pytorch/functions/tensor_setitem.py b/mmdeploy/pytorch/functions/tensor_setitem.py new file mode 100644 index 000000000..70ebda68d --- /dev/null +++ b/mmdeploy/pytorch/functions/tensor_setitem.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +from packaging.version import parse + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter(func_name='torch.Tensor.__setitem__') +def tensor__setitem__default(ctx, self, key, value): + """Rewrite `setitem` to ease the index put.""" + + # only support torch>=1.9.0 + if parse(torch.__version__) < parse('1.9.0'): + return ctx.origin_func(self, key, value) + + if isinstance(key, slice): + key = (key, ) + + if not isinstance(key, Sequence): + return ctx.origin_func(self, key, value) + + for k in key: + if not isinstance(k, slice) or k.step is not None: + return ctx.origin_func(self, key, value) + + out = value + for i, k in enumerate(key): + if k == slice(None): + continue + + cat_list = [] + + # slice self start + if k.start is not None: + self_slice_start = (slice(None), ) * i + (slice( + 0, k.start), ) + key[i + 1:] + self_start = self[self_slice_start] + cat_list.append(self_start) + + # add value + cat_list.append(out) + + # slice self end + if k.stop is not None: + self_slice_end = (slice(None), ) * i + (slice( + k.stop, None), ) + key[i + 1:] + self_end = self[self_slice_end] + cat_list.append(self_end) + + # concate + out = torch.cat(cat_list, dim=i) + + # self assign + # Note that set item does not return any value + self[...] = out diff --git a/tests/test_pytorch/test_pytorch_functions.py b/tests/test_pytorch/test_pytorch_functions.py index fca863af3..2508556cb 100644 --- a/tests/test_pytorch/test_pytorch_functions.py +++ b/tests/test_pytorch/test_pytorch_functions.py @@ -6,6 +6,7 @@ import numpy as np import pytest import torch import torch.nn.functional as F +from packaging.version import parse from mmdeploy.utils import Backend from mmdeploy.utils.test import (WrapFunction, backend_checker, @@ -309,3 +310,34 @@ def test_masked_fill_onnxruntime(input): deploy_cfg=deploy_cfg_ort, run_with_backend=True) assert rewrite_output is not None + + +@backend_checker(Backend.ONNXRUNTIME) +@pytest.mark.skipif( + parse(torch.__version__) < parse('1.9.0'), reason='requires torch>1.8.0') +@pytest.mark.parametrize('x', [torch.rand(1, 3, 16, 16)]) +@pytest.mark.parametrize('y', [torch.rand(1, 3, 4, 4)]) +def test_tensor_setitem(x, y): + import onnx + + from mmdeploy.utils.test import get_onnx_model + + def setitem_slice(x, y): + H, W = y.shape[2:] + x[:, :, 2:H + 2, 2:W + 2] = y + return x + + wrapped_func = WrapFunction(setitem_slice) + model_inputs = {'x': x, 'y': y} + + deploy_cfg = mmcv.Config( + dict( + onnx_config=dict(input_shape=None), + backend_config=dict(type='onnxruntime'), + codebase_config=dict(type='mmdet', task='ObjectDetection'))) + ir_file_path = get_onnx_model(wrapped_func, model_inputs, deploy_cfg) + + onnx_model = onnx.load(ir_file_path) + nodes = onnx_model.graph.node + for node in nodes: + assert node.op_type != 'ScatterND'