[Enhancement] Better index put ONNX export. (#704)

* Add rewriter for tensor setitem

* add version check
pull/690/head
q.yao 2022-07-20 19:32:46 +08:00 committed by GitHub
parent c498cd2d6b
commit 0310c168dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 1 deletions

View File

@ -10,6 +10,7 @@ from .masked_fill import masked_fill__onnxruntime
from .normalize import normalize__ncnn
from .repeat import tensor__repeat__tensorrt
from .size import tensor__size__ncnn
from .tensor_setitem import tensor__setitem__default
from .topk import topk__dynamic, topk__tensorrt
from .triu import triu
@ -18,5 +19,6 @@ __all__ = [
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn',
'triu', 'atan2__default', 'normalize__ncnn', 'expand__ncnn',
'chunk__torchscript', 'masked_fill__onnxruntime'
'chunk__torchscript', 'masked_fill__onnxruntime',
'tensor__setitem__default'
]

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
from packaging.version import parse
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(func_name='torch.Tensor.__setitem__')
def tensor__setitem__default(ctx, self, key, value):
"""Rewrite `setitem` to ease the index put."""
# only support torch>=1.9.0
if parse(torch.__version__) < parse('1.9.0'):
return ctx.origin_func(self, key, value)
if isinstance(key, slice):
key = (key, )
if not isinstance(key, Sequence):
return ctx.origin_func(self, key, value)
for k in key:
if not isinstance(k, slice) or k.step is not None:
return ctx.origin_func(self, key, value)
out = value
for i, k in enumerate(key):
if k == slice(None):
continue
cat_list = []
# slice self start
if k.start is not None:
self_slice_start = (slice(None), ) * i + (slice(
0, k.start), ) + key[i + 1:]
self_start = self[self_slice_start]
cat_list.append(self_start)
# add value
cat_list.append(out)
# slice self end
if k.stop is not None:
self_slice_end = (slice(None), ) * i + (slice(
k.stop, None), ) + key[i + 1:]
self_end = self[self_slice_end]
cat_list.append(self_end)
# concate
out = torch.cat(cat_list, dim=i)
# self assign
# Note that set item does not return any value
self[...] = out

View File

@ -6,6 +6,7 @@ import numpy as np
import pytest
import torch
import torch.nn.functional as F
from packaging.version import parse
from mmdeploy.utils import Backend
from mmdeploy.utils.test import (WrapFunction, backend_checker,
@ -309,3 +310,34 @@ def test_masked_fill_onnxruntime(input):
deploy_cfg=deploy_cfg_ort,
run_with_backend=True)
assert rewrite_output is not None
@backend_checker(Backend.ONNXRUNTIME)
@pytest.mark.skipif(
parse(torch.__version__) < parse('1.9.0'), reason='requires torch>1.8.0')
@pytest.mark.parametrize('x', [torch.rand(1, 3, 16, 16)])
@pytest.mark.parametrize('y', [torch.rand(1, 3, 4, 4)])
def test_tensor_setitem(x, y):
import onnx
from mmdeploy.utils.test import get_onnx_model
def setitem_slice(x, y):
H, W = y.shape[2:]
x[:, :, 2:H + 2, 2:W + 2] = y
return x
wrapped_func = WrapFunction(setitem_slice)
model_inputs = {'x': x, 'y': y}
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(input_shape=None),
backend_config=dict(type='onnxruntime'),
codebase_config=dict(type='mmdet', task='ObjectDetection')))
ir_file_path = get_onnx_model(wrapped_func, model_inputs, deploy_cfg)
onnx_model = onnx.load(ir_file_path)
nodes = onnx_model.graph.node
for node in nodes:
assert node.op_type != 'ScatterND'