mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix pad to square (#1436)
* fix pad to square * fix topk * remove comment * recovery topk
This commit is contained in:
parent
4e1c83ab5b
commit
0d16f6ec30
@ -7,7 +7,7 @@ import torch
|
||||
from mmcv.parallel import DataContainer
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmdeploy.utils import Task, get_input_shape, is_dynamic_shape
|
||||
from mmdeploy.utils import Task, get_input_shape
|
||||
from ...base import BaseTask
|
||||
from .mmdetection import MMDET_TASK
|
||||
|
||||
@ -44,7 +44,10 @@ def process_model_config(model_cfg: mmcv.Config,
|
||||
if trans_type == 'Resize' and len(input_shape) != 1:
|
||||
trans['keep_ratio'] = False
|
||||
elif trans_type == 'Pad':
|
||||
if 'size_divisor' in trans:
|
||||
if trans.get('pad_to_square', False):
|
||||
# pad_to_square is mutually exclusive with size and divisor
|
||||
pass
|
||||
elif 'size_divisor' in trans:
|
||||
trans['size_divisor'] = 1
|
||||
else:
|
||||
trans['size'] = tuple(input_shape)
|
||||
@ -122,22 +125,11 @@ class ObjectDetection(BaseTask):
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
if isinstance(imgs, (str, np.ndarray)):
|
||||
imgs = [imgs]
|
||||
dynamic_flag = is_dynamic_shape(self.deploy_cfg)
|
||||
model_cfg = self.model_cfg
|
||||
if pipeline_updater is not None:
|
||||
model_cfg = pipeline_updater(self.deploy_cfg, model_cfg)
|
||||
cfg = process_model_config(model_cfg, imgs, input_shape)
|
||||
# Drop pad_to_square when static shape. Because static shape should
|
||||
# ensure the shape before input image.
|
||||
if not dynamic_flag:
|
||||
transform = cfg.data.test.pipeline[1]
|
||||
if 'transforms' in transform:
|
||||
transform_list = transform['transforms']
|
||||
for i, step in enumerate(transform_list):
|
||||
if step['type'] == 'Pad' and 'pad_to_square' in step \
|
||||
and step['pad_to_square']:
|
||||
transform_list.pop(i)
|
||||
break
|
||||
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
data_list = []
|
||||
for img in imgs:
|
||||
|
Loading…
x
Reference in New Issue
Block a user