diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index 5ee8ef348..19515408a 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -2,6 +2,7 @@ from . import adaptive_pool # noqa: F401,F403 from . import any # noqa: F401,F403 from . import atan2 # noqa: F401,F403 +from . import cat # noqa: F401,F403 from . import chunk # noqa: F401,F403 from . import clip # noqa: F401,F403 from . import expand # noqa: F401,F403 diff --git a/mmdeploy/pytorch/functions/cat.py b/mmdeploy/pytorch/functions/cat.py new file mode 100644 index 000000000..ea1f623bd --- /dev/null +++ b/mmdeploy/pytorch/functions/cat.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import get_dynamic_axes + + +@FUNCTION_REWRITER.register_rewriter(func_name='torch.cat', backend='tensorrt') +def cat__tensorrt(tensors: Sequence[Tensor], *args, **kwargs) -> torch.Tensor: + """Rewrite `cat` for TensorRT backend. + + cat in TensorRT does not support bool or uint8 type when input is dynamic. + """ + ctx = FUNCTION_REWRITER.get_context() + if get_dynamic_axes(ctx.cfg) is None: + return ctx.origin_func(tensors, *args, **kwargs) + if len(tensors) > 0 and (tensors[0].dtype in [torch.bool, torch.uint8]): + original_dtype = tensors[0].dtype + tensors = [i.to(torch.int32) for i in tensors] + return ctx.origin_func(tensors, *args, **kwargs).to(original_dtype) + return ctx.origin_func(tensors, *args, **kwargs) diff --git a/tests/test_pytorch/test_pytorch_functions.py b/tests/test_pytorch/test_pytorch_functions.py index 245bfba9d..638e8b62c 100644 --- a/tests/test_pytorch/test_pytorch_functions.py +++ b/tests/test_pytorch/test_pytorch_functions.py @@ -19,10 +19,13 @@ deploy_cfg_ncnn = Config( codebase_config=dict(type='mmdet', task='ObjectDetection'))) -def get_trt_config(output_names, shape): +def get_trt_config(output_names, shape, dynamic_axes=None): deploy_cfg_tensorrt = Config( dict( - onnx_config=dict(input_shape=None, output_names=output_names), + onnx_config=dict( + input_shape=None, + output_names=output_names, + dynamic_axes=dynamic_axes), backend_config=dict( type='tensorrt', common_config=dict( @@ -615,3 +618,30 @@ def test_linspace__default(): assert np.allclose( model_output, rewrite_outputs, rtol=1e-03, atol=1e-05) + + +@backend_checker(Backend.TENSORRT) +@pytest.mark.parametrize('dtype', [torch.bool, torch.float32]) +@pytest.mark.parametrize('dynamic_axes', + [None, dict(input=dict({ + 0: 'dim0', + 1: 'dim1' + }))]) +def test_cat__tensorrt(dtype, dynamic_axes): + input = torch.rand(2, 4) + model = WrapFunction(lambda input: torch.cat( + [input.to(dtype), input.to(dtype)], -1)) + pytorch_output = model(input) + rewrite_output, _ = get_rewrite_outputs( + model, + model_inputs={'input': input}, + deploy_cfg=get_trt_config(['output'], + shape=[2, 4], + dynamic_axes=dynamic_axes), + run_with_backend=True) + assert pytorch_output.dtype == rewrite_output[0].dtype + assert torch.allclose( + pytorch_output.cpu().float(), + rewrite_output[0].cpu().float(), + rtol=1e-3, + atol=1e-5)