fix pad to square (#1436)

* fix pad to square

* fix topk

* remove comment

* recovery topk
This commit is contained in:
q.yao 2022-11-25 17:31:44 +08:00 committed by GitHub
parent 4e1c83ab5b
commit 0d16f6ec30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,7 +7,7 @@ import torch
from mmcv.parallel import DataContainer
from torch.utils.data import Dataset
from mmdeploy.utils import Task, get_input_shape, is_dynamic_shape
from mmdeploy.utils import Task, get_input_shape
from ...base import BaseTask
from .mmdetection import MMDET_TASK
@ -44,7 +44,10 @@ def process_model_config(model_cfg: mmcv.Config,
if trans_type == 'Resize' and len(input_shape) != 1:
trans['keep_ratio'] = False
elif trans_type == 'Pad':
if 'size_divisor' in trans:
if trans.get('pad_to_square', False):
# pad_to_square is mutually exclusive with size and divisor
pass
elif 'size_divisor' in trans:
trans['size_divisor'] = 1
else:
trans['size'] = tuple(input_shape)
@ -122,22 +125,11 @@ class ObjectDetection(BaseTask):
from mmdet.datasets.pipelines import Compose
if isinstance(imgs, (str, np.ndarray)):
imgs = [imgs]
dynamic_flag = is_dynamic_shape(self.deploy_cfg)
model_cfg = self.model_cfg
if pipeline_updater is not None:
model_cfg = pipeline_updater(self.deploy_cfg, model_cfg)
cfg = process_model_config(model_cfg, imgs, input_shape)
# Drop pad_to_square when static shape. Because static shape should
# ensure the shape before input image.
if not dynamic_flag:
transform = cfg.data.test.pipeline[1]
if 'transforms' in transform:
transform_list = transform['transforms']
for i, step in enumerate(transform_list):
if step['type'] == 'Pad' and 'pad_to_square' in step \
and step['pad_to_square']:
transform_list.pop(i)
break
test_pipeline = Compose(cfg.data.test.pipeline)
data_list = []
for img in imgs: