fix efficientnet from mmcls (#1260)

* fix efficientnet from mmcls

* update
pull/1266/head
RunningLeon 2022-10-27 16:20:39 +08:00 committed by GitHub
parent f2be2abeb5
commit 197a7ad425
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 16 deletions

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import context_block # noqa: F401,F403
from . import conv2d_adaptive_padding # noqa: F401,F403
from . import hsigmoid # noqa: F401,F403
from . import hswish # noqa: F401,F403
from .conv2d_adaptive_padding import AdaptivePadOp
from .transformer import MultiHeadAttentionop
__all__ = ['AdaptivePadOp', 'MultiHeadAttentionop']
__all__ = ['conv2d_adaptive_padding', 'MultiHeadAttentionop']

View File

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, is_dynamic_shape
from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape
def compute_padding(input_size, kernel_size, stride, dilation):
@ -33,25 +33,22 @@ def compute_padding(input_size, kernel_size, stride, dilation):
class AdaptivePadOp(torch.autograd.Function):
"""AdaptivePadOp."""
"""Dummy adaptive pad op."""
@staticmethod
def forward(ctx, x, kernel, stride, dilation):
padded = compute_padding(x.shape[2:], kernel, stride, dilation)
def forward(ctx, x, padded):
if padded is not None:
x = F.pad(x, padded)
return x
@staticmethod
def symbolic(g, x, kernel, stride, dilation):
padded = compute_padding(x.type().sizes()[2:], kernel, stride,
dilation)
def symbolic(g, x, padded):
if padded is None:
return g.op('Identity', x)
padded = g.op(
'Constant', value_t=torch.tensor(padded, dtype=torch.int64))
constant_value = g.op(
'Constant', value_t=torch.tensor(0, dtype=torch.float32))
'Constant', value_t=torch.tensor(0, dtype=torch.int64))
return g.op(
'Pad', x, padded, constant_value, mode_s='constant', outputs=1)
@ -76,9 +73,12 @@ def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x):
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
if not is_dynamic_flag:
x = AdaptivePadOp.apply(x, self.weight.shape[2:], self.stride,
self.dilation)
if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg):
padded = compute_padding(x.shape[2:], self.weight.shape[2:],
self.stride, self.dilation)
if padded is not None:
padded = [int(_) for _ in padded]
x = AdaptivePadOp.apply(x, padded)
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
else:

View File

@ -227,6 +227,6 @@ models:
model_configs:
- configs/efficientnet/efficientnet-b0_8xb32_in1k.py
pipelines:
- *pipeline_trt_static_fp32
- *pipeline_trt_static_fp16
- *pipeline_trt_static_int8
- *pipeline_ort_static_fp32
- convert_image: *convert_image
deploy_config: configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py