cherry-pick from commit 197a7ad

pull/1276/head
RunningLeon 2022-10-27 16:20:39 +08:00 committed by lvhan028
parent 53d4668a2f
commit 2fa5095154
3 changed files with 20 additions and 22 deletions

View File

@ -1,10 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .conv2d_adaptive_padding import ( from . import context_block # noqa: F401,F403
AdaptivePadOp, conv2d_adaptive_padding__forward__tensorrt) from . import conv2d_adaptive_padding # noqa: F401,F403
from .transformer import (MultiHeadAttentionop, from . import hsigmoid # noqa: F401,F403
multiheadattention__forward__ncnn) from . import hswish # noqa: F401,F403
from .transformer import MultiHeadAttentionop
__all__ = [ __all__ = ['conv2d_adaptive_padding', 'MultiHeadAttentionop']
'multiheadattention__forward__ncnn', 'MultiHeadAttentionop',
'conv2d_adaptive_padding__forward__tensorrt', 'AdaptivePadOp'
]

View File

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, is_dynamic_shape from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape
def compute_padding(input_size, kernel_size, stride, dilation): def compute_padding(input_size, kernel_size, stride, dilation):
@ -33,25 +33,22 @@ def compute_padding(input_size, kernel_size, stride, dilation):
class AdaptivePadOp(torch.autograd.Function): class AdaptivePadOp(torch.autograd.Function):
"""AdaptivePadOp.""" """Dummy adaptive pad op."""
@staticmethod @staticmethod
def forward(ctx, x, kernel, stride, dilation): def forward(ctx, x, padded):
padded = compute_padding(x.shape[2:], kernel, stride, dilation)
if padded is not None: if padded is not None:
x = F.pad(x, padded) x = F.pad(x, padded)
return x return x
@staticmethod @staticmethod
def symbolic(g, x, kernel, stride, dilation): def symbolic(g, x, padded):
padded = compute_padding(x.type().sizes()[2:], kernel, stride,
dilation)
if padded is None: if padded is None:
return g.op('Identity', x) return g.op('Identity', x)
padded = g.op( padded = g.op(
'Constant', value_t=torch.tensor(padded, dtype=torch.int64)) 'Constant', value_t=torch.tensor(padded, dtype=torch.int64))
constant_value = g.op( constant_value = g.op(
'Constant', value_t=torch.tensor(0, dtype=torch.float32)) 'Constant', value_t=torch.tensor(0, dtype=torch.int64))
return g.op( return g.op(
'Pad', x, padded, constant_value, mode_s='constant', outputs=1) 'Pad', x, padded, constant_value, mode_s='constant', outputs=1)
@ -76,9 +73,12 @@ def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x):
deploy_cfg = ctx.cfg deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg) is_dynamic_flag = is_dynamic_shape(deploy_cfg)
if not is_dynamic_flag: if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg):
x = AdaptivePadOp.apply(x, self.weight.shape[2:], self.stride, padded = compute_padding(x.shape[2:], self.weight.shape[2:],
self.dilation) self.stride, self.dilation)
if padded is not None:
padded = [int(_) for _ in padded]
x = AdaptivePadOp.apply(x, padded)
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups) self.dilation, self.groups)
else: else:

View File

@ -217,6 +217,6 @@ models:
model_configs: model_configs:
- configs/efficientnet/efficientnet-b0_8xb32_in1k.py - configs/efficientnet/efficientnet-b0_8xb32_in1k.py
pipelines: pipelines:
- *pipeline_trt_static_fp32 - *pipeline_ort_static_fp32
- *pipeline_trt_static_fp16 - convert_image: *convert_image
- *pipeline_trt_static_int8 deploy_config: configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py