[Enhancement] Better index put ONNX export. (#704)
* Add rewriter for tensor setitem * add version checkpull/690/head
parent
c498cd2d6b
commit
0310c168dc
|
@ -10,6 +10,7 @@ from .masked_fill import masked_fill__onnxruntime
|
|||
from .normalize import normalize__ncnn
|
||||
from .repeat import tensor__repeat__tensorrt
|
||||
from .size import tensor__size__ncnn
|
||||
from .tensor_setitem import tensor__setitem__default
|
||||
from .topk import topk__dynamic, topk__tensorrt
|
||||
from .triu import triu
|
||||
|
||||
|
@ -18,5 +19,6 @@ __all__ = [
|
|||
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
|
||||
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn',
|
||||
'triu', 'atan2__default', 'normalize__ncnn', 'expand__ncnn',
|
||||
'chunk__torchscript', 'masked_fill__onnxruntime'
|
||||
'chunk__torchscript', 'masked_fill__onnxruntime',
|
||||
'tensor__setitem__default'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from packaging.version import parse
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='torch.Tensor.__setitem__')
|
||||
def tensor__setitem__default(ctx, self, key, value):
|
||||
"""Rewrite `setitem` to ease the index put."""
|
||||
|
||||
# only support torch>=1.9.0
|
||||
if parse(torch.__version__) < parse('1.9.0'):
|
||||
return ctx.origin_func(self, key, value)
|
||||
|
||||
if isinstance(key, slice):
|
||||
key = (key, )
|
||||
|
||||
if not isinstance(key, Sequence):
|
||||
return ctx.origin_func(self, key, value)
|
||||
|
||||
for k in key:
|
||||
if not isinstance(k, slice) or k.step is not None:
|
||||
return ctx.origin_func(self, key, value)
|
||||
|
||||
out = value
|
||||
for i, k in enumerate(key):
|
||||
if k == slice(None):
|
||||
continue
|
||||
|
||||
cat_list = []
|
||||
|
||||
# slice self start
|
||||
if k.start is not None:
|
||||
self_slice_start = (slice(None), ) * i + (slice(
|
||||
0, k.start), ) + key[i + 1:]
|
||||
self_start = self[self_slice_start]
|
||||
cat_list.append(self_start)
|
||||
|
||||
# add value
|
||||
cat_list.append(out)
|
||||
|
||||
# slice self end
|
||||
if k.stop is not None:
|
||||
self_slice_end = (slice(None), ) * i + (slice(
|
||||
k.stop, None), ) + key[i + 1:]
|
||||
self_end = self[self_slice_end]
|
||||
cat_list.append(self_end)
|
||||
|
||||
# concate
|
||||
out = torch.cat(cat_list, dim=i)
|
||||
|
||||
# self assign
|
||||
# Note that set item does not return any value
|
||||
self[...] = out
|
|
@ -6,6 +6,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from packaging.version import parse
|
||||
|
||||
from mmdeploy.utils import Backend
|
||||
from mmdeploy.utils.test import (WrapFunction, backend_checker,
|
||||
|
@ -309,3 +310,34 @@ def test_masked_fill_onnxruntime(input):
|
|||
deploy_cfg=deploy_cfg_ort,
|
||||
run_with_backend=True)
|
||||
assert rewrite_output is not None
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
@pytest.mark.skipif(
|
||||
parse(torch.__version__) < parse('1.9.0'), reason='requires torch>1.8.0')
|
||||
@pytest.mark.parametrize('x', [torch.rand(1, 3, 16, 16)])
|
||||
@pytest.mark.parametrize('y', [torch.rand(1, 3, 4, 4)])
|
||||
def test_tensor_setitem(x, y):
|
||||
import onnx
|
||||
|
||||
from mmdeploy.utils.test import get_onnx_model
|
||||
|
||||
def setitem_slice(x, y):
|
||||
H, W = y.shape[2:]
|
||||
x[:, :, 2:H + 2, 2:W + 2] = y
|
||||
return x
|
||||
|
||||
wrapped_func = WrapFunction(setitem_slice)
|
||||
model_inputs = {'x': x, 'y': y}
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(input_shape=None),
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
||||
ir_file_path = get_onnx_model(wrapped_func, model_inputs, deploy_cfg)
|
||||
|
||||
onnx_model = onnx.load(ir_file_path)
|
||||
nodes = onnx_model.graph.node
|
||||
for node in nodes:
|
||||
assert node.op_type != 'ScatterND'
|
||||
|
|
Loading…
Reference in New Issue