rewrite torch.cat for TensorRT when input is dynamic ()

pull/1888/head
AllentDan 2023-03-17 15:49:52 +08:00 committed by GitHub
parent 847a906e6f
commit 8e2f6556be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 2 deletions
mmdeploy/pytorch/functions
tests/test_pytorch

View File

@ -2,6 +2,7 @@
from . import adaptive_pool # noqa: F401,F403
from . import any # noqa: F401,F403
from . import atan2 # noqa: F401,F403
from . import cat # noqa: F401,F403
from . import chunk # noqa: F401,F403
from . import clip # noqa: F401,F403
from . import expand # noqa: F401,F403

View File

@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
from torch import Tensor
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import get_dynamic_axes
@FUNCTION_REWRITER.register_rewriter(func_name='torch.cat', backend='tensorrt')
def cat__tensorrt(tensors: Sequence[Tensor], *args, **kwargs) -> torch.Tensor:
"""Rewrite `cat` for TensorRT backend.
cat in TensorRT does not support bool or uint8 type when input is dynamic.
"""
ctx = FUNCTION_REWRITER.get_context()
if get_dynamic_axes(ctx.cfg) is None:
return ctx.origin_func(tensors, *args, **kwargs)
if len(tensors) > 0 and (tensors[0].dtype in [torch.bool, torch.uint8]):
original_dtype = tensors[0].dtype
tensors = [i.to(torch.int32) for i in tensors]
return ctx.origin_func(tensors, *args, **kwargs).to(original_dtype)
return ctx.origin_func(tensors, *args, **kwargs)

View File

@ -19,10 +19,13 @@ deploy_cfg_ncnn = Config(
codebase_config=dict(type='mmdet', task='ObjectDetection')))
def get_trt_config(output_names, shape):
def get_trt_config(output_names, shape, dynamic_axes=None):
deploy_cfg_tensorrt = Config(
dict(
onnx_config=dict(input_shape=None, output_names=output_names),
onnx_config=dict(
input_shape=None,
output_names=output_names,
dynamic_axes=dynamic_axes),
backend_config=dict(
type='tensorrt',
common_config=dict(
@ -615,3 +618,30 @@ def test_linspace__default():
assert np.allclose(
model_output, rewrite_outputs, rtol=1e-03, atol=1e-05)
@backend_checker(Backend.TENSORRT)
@pytest.mark.parametrize('dtype', [torch.bool, torch.float32])
@pytest.mark.parametrize('dynamic_axes',
[None, dict(input=dict({
0: 'dim0',
1: 'dim1'
}))])
def test_cat__tensorrt(dtype, dynamic_axes):
input = torch.rand(2, 4)
model = WrapFunction(lambda input: torch.cat(
[input.to(dtype), input.to(dtype)], -1))
pytorch_output = model(input)
rewrite_output, _ = get_rewrite_outputs(
model,
model_inputs={'input': input},
deploy_cfg=get_trt_config(['output'],
shape=[2, 4],
dynamic_axes=dynamic_axes),
run_with_backend=True)
assert pytorch_output.dtype == rewrite_output[0].dtype
assert torch.allclose(
pytorch_output.cpu().float(),
rewrite_output[0].cpu().float(),
rtol=1e-3,
atol=1e-5)