rewrite torch.cat for TensorRT when input is dynamic (#1851)
parent
847a906e6f
commit
8e2f6556be
|
@ -2,6 +2,7 @@
|
||||||
from . import adaptive_pool # noqa: F401,F403
|
from . import adaptive_pool # noqa: F401,F403
|
||||||
from . import any # noqa: F401,F403
|
from . import any # noqa: F401,F403
|
||||||
from . import atan2 # noqa: F401,F403
|
from . import atan2 # noqa: F401,F403
|
||||||
|
from . import cat # noqa: F401,F403
|
||||||
from . import chunk # noqa: F401,F403
|
from . import chunk # noqa: F401,F403
|
||||||
from . import clip # noqa: F401,F403
|
from . import clip # noqa: F401,F403
|
||||||
from . import expand # noqa: F401,F403
|
from . import expand # noqa: F401,F403
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
from mmdeploy.utils import get_dynamic_axes
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(func_name='torch.cat', backend='tensorrt')
|
||||||
|
def cat__tensorrt(tensors: Sequence[Tensor], *args, **kwargs) -> torch.Tensor:
|
||||||
|
"""Rewrite `cat` for TensorRT backend.
|
||||||
|
|
||||||
|
cat in TensorRT does not support bool or uint8 type when input is dynamic.
|
||||||
|
"""
|
||||||
|
ctx = FUNCTION_REWRITER.get_context()
|
||||||
|
if get_dynamic_axes(ctx.cfg) is None:
|
||||||
|
return ctx.origin_func(tensors, *args, **kwargs)
|
||||||
|
if len(tensors) > 0 and (tensors[0].dtype in [torch.bool, torch.uint8]):
|
||||||
|
original_dtype = tensors[0].dtype
|
||||||
|
tensors = [i.to(torch.int32) for i in tensors]
|
||||||
|
return ctx.origin_func(tensors, *args, **kwargs).to(original_dtype)
|
||||||
|
return ctx.origin_func(tensors, *args, **kwargs)
|
|
@ -19,10 +19,13 @@ deploy_cfg_ncnn = Config(
|
||||||
codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
||||||
|
|
||||||
|
|
||||||
def get_trt_config(output_names, shape):
|
def get_trt_config(output_names, shape, dynamic_axes=None):
|
||||||
deploy_cfg_tensorrt = Config(
|
deploy_cfg_tensorrt = Config(
|
||||||
dict(
|
dict(
|
||||||
onnx_config=dict(input_shape=None, output_names=output_names),
|
onnx_config=dict(
|
||||||
|
input_shape=None,
|
||||||
|
output_names=output_names,
|
||||||
|
dynamic_axes=dynamic_axes),
|
||||||
backend_config=dict(
|
backend_config=dict(
|
||||||
type='tensorrt',
|
type='tensorrt',
|
||||||
common_config=dict(
|
common_config=dict(
|
||||||
|
@ -615,3 +618,30 @@ def test_linspace__default():
|
||||||
|
|
||||||
assert np.allclose(
|
assert np.allclose(
|
||||||
model_output, rewrite_outputs, rtol=1e-03, atol=1e-05)
|
model_output, rewrite_outputs, rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
@backend_checker(Backend.TENSORRT)
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.bool, torch.float32])
|
||||||
|
@pytest.mark.parametrize('dynamic_axes',
|
||||||
|
[None, dict(input=dict({
|
||||||
|
0: 'dim0',
|
||||||
|
1: 'dim1'
|
||||||
|
}))])
|
||||||
|
def test_cat__tensorrt(dtype, dynamic_axes):
|
||||||
|
input = torch.rand(2, 4)
|
||||||
|
model = WrapFunction(lambda input: torch.cat(
|
||||||
|
[input.to(dtype), input.to(dtype)], -1))
|
||||||
|
pytorch_output = model(input)
|
||||||
|
rewrite_output, _ = get_rewrite_outputs(
|
||||||
|
model,
|
||||||
|
model_inputs={'input': input},
|
||||||
|
deploy_cfg=get_trt_config(['output'],
|
||||||
|
shape=[2, 4],
|
||||||
|
dynamic_axes=dynamic_axes),
|
||||||
|
run_with_backend=True)
|
||||||
|
assert pytorch_output.dtype == rewrite_output[0].dtype
|
||||||
|
assert torch.allclose(
|
||||||
|
pytorch_output.cpu().float(),
|
||||||
|
rewrite_output[0].cpu().float(),
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-5)
|
||||||
|
|
Loading…
Reference in New Issue