rewrite torch.cat for TensorRT when input is dynamic (#1851)
parent
847a906e6f
commit
8e2f6556be
mmdeploy/pytorch/functions
tests/test_pytorch
|
@ -2,6 +2,7 @@
|
|||
from . import adaptive_pool # noqa: F401,F403
|
||||
from . import any # noqa: F401,F403
|
||||
from . import atan2 # noqa: F401,F403
|
||||
from . import cat # noqa: F401,F403
|
||||
from . import chunk # noqa: F401,F403
|
||||
from . import clip # noqa: F401,F403
|
||||
from . import expand # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import get_dynamic_axes
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='torch.cat', backend='tensorrt')
|
||||
def cat__tensorrt(tensors: Sequence[Tensor], *args, **kwargs) -> torch.Tensor:
|
||||
"""Rewrite `cat` for TensorRT backend.
|
||||
|
||||
cat in TensorRT does not support bool or uint8 type when input is dynamic.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if get_dynamic_axes(ctx.cfg) is None:
|
||||
return ctx.origin_func(tensors, *args, **kwargs)
|
||||
if len(tensors) > 0 and (tensors[0].dtype in [torch.bool, torch.uint8]):
|
||||
original_dtype = tensors[0].dtype
|
||||
tensors = [i.to(torch.int32) for i in tensors]
|
||||
return ctx.origin_func(tensors, *args, **kwargs).to(original_dtype)
|
||||
return ctx.origin_func(tensors, *args, **kwargs)
|
|
@ -19,10 +19,13 @@ deploy_cfg_ncnn = Config(
|
|||
codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
||||
|
||||
|
||||
def get_trt_config(output_names, shape):
|
||||
def get_trt_config(output_names, shape, dynamic_axes=None):
|
||||
deploy_cfg_tensorrt = Config(
|
||||
dict(
|
||||
onnx_config=dict(input_shape=None, output_names=output_names),
|
||||
onnx_config=dict(
|
||||
input_shape=None,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes),
|
||||
backend_config=dict(
|
||||
type='tensorrt',
|
||||
common_config=dict(
|
||||
|
@ -615,3 +618,30 @@ def test_linspace__default():
|
|||
|
||||
assert np.allclose(
|
||||
model_output, rewrite_outputs, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@backend_checker(Backend.TENSORRT)
|
||||
@pytest.mark.parametrize('dtype', [torch.bool, torch.float32])
|
||||
@pytest.mark.parametrize('dynamic_axes',
|
||||
[None, dict(input=dict({
|
||||
0: 'dim0',
|
||||
1: 'dim1'
|
||||
}))])
|
||||
def test_cat__tensorrt(dtype, dynamic_axes):
|
||||
input = torch.rand(2, 4)
|
||||
model = WrapFunction(lambda input: torch.cat(
|
||||
[input.to(dtype), input.to(dtype)], -1))
|
||||
pytorch_output = model(input)
|
||||
rewrite_output, _ = get_rewrite_outputs(
|
||||
model,
|
||||
model_inputs={'input': input},
|
||||
deploy_cfg=get_trt_config(['output'],
|
||||
shape=[2, 4],
|
||||
dynamic_axes=dynamic_axes),
|
||||
run_with_backend=True)
|
||||
assert pytorch_output.dtype == rewrite_output[0].dtype
|
||||
assert torch.allclose(
|
||||
pytorch_output.cpu().float(),
|
||||
rewrite_output[0].cpu().float(),
|
||||
rtol=1e-3,
|
||||
atol=1e-5)
|
||||
|
|
Loading…
Reference in New Issue