parent
f2be2abeb5
commit
197a7ad425
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import context_block # noqa: F401,F403
|
||||
from . import conv2d_adaptive_padding # noqa: F401,F403
|
||||
from . import hsigmoid # noqa: F401,F403
|
||||
from . import hswish # noqa: F401,F403
|
||||
from .conv2d_adaptive_padding import AdaptivePadOp
|
||||
from .transformer import MultiHeadAttentionop
|
||||
|
||||
__all__ = ['AdaptivePadOp', 'MultiHeadAttentionop']
|
||||
__all__ = ['conv2d_adaptive_padding', 'MultiHeadAttentionop']
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import Backend, is_dynamic_shape
|
||||
from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape
|
||||
|
||||
|
||||
def compute_padding(input_size, kernel_size, stride, dilation):
|
||||
|
@ -33,25 +33,22 @@ def compute_padding(input_size, kernel_size, stride, dilation):
|
|||
|
||||
|
||||
class AdaptivePadOp(torch.autograd.Function):
|
||||
"""AdaptivePadOp."""
|
||||
"""Dummy adaptive pad op."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, kernel, stride, dilation):
|
||||
padded = compute_padding(x.shape[2:], kernel, stride, dilation)
|
||||
def forward(ctx, x, padded):
|
||||
if padded is not None:
|
||||
x = F.pad(x, padded)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, x, kernel, stride, dilation):
|
||||
padded = compute_padding(x.type().sizes()[2:], kernel, stride,
|
||||
dilation)
|
||||
def symbolic(g, x, padded):
|
||||
if padded is None:
|
||||
return g.op('Identity', x)
|
||||
padded = g.op(
|
||||
'Constant', value_t=torch.tensor(padded, dtype=torch.int64))
|
||||
constant_value = g.op(
|
||||
'Constant', value_t=torch.tensor(0, dtype=torch.float32))
|
||||
'Constant', value_t=torch.tensor(0, dtype=torch.int64))
|
||||
return g.op(
|
||||
'Pad', x, padded, constant_value, mode_s='constant', outputs=1)
|
||||
|
||||
|
@ -76,9 +73,12 @@ def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x):
|
|||
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
if not is_dynamic_flag:
|
||||
x = AdaptivePadOp.apply(x, self.weight.shape[2:], self.stride,
|
||||
self.dilation)
|
||||
if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg):
|
||||
padded = compute_padding(x.shape[2:], self.weight.shape[2:],
|
||||
self.stride, self.dilation)
|
||||
if padded is not None:
|
||||
padded = [int(_) for _ in padded]
|
||||
x = AdaptivePadOp.apply(x, padded)
|
||||
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups)
|
||||
else:
|
||||
|
|
|
@ -227,6 +227,6 @@ models:
|
|||
model_configs:
|
||||
- configs/efficientnet/efficientnet-b0_8xb32_in1k.py
|
||||
pipelines:
|
||||
- *pipeline_trt_static_fp32
|
||||
- *pipeline_trt_static_fp16
|
||||
- *pipeline_trt_static_int8
|
||||
- *pipeline_ort_static_fp32
|
||||
- convert_image: *convert_image
|
||||
deploy_config: configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py
|
||||
|
|
Loading…
Reference in New Issue