set test_mode for mmdet (#920)

* fix

* update
pull/941/head
RunningLeon 2022-08-19 10:55:41 +08:00 committed by GitHub
parent a6e07dac2f
commit 1f8d889b36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 4 deletions

View File

@ -68,7 +68,7 @@ class BaseTask(metaclass=ABCMeta):
def build_dataset(self,
dataset_cfg: Union[str, mmcv.Config],
dataset_type: str = 'val',
is_sort_dataset: bool = True,
is_sort_dataset: bool = False,
**kwargs) -> Dataset:
"""Build dataset for different codebase.
@ -80,6 +80,7 @@ class BaseTask(metaclass=ABCMeta):
is_sort_dataset (bool): When 'True', the dataset will be sorted
by image shape in ascending order if 'dataset_cfg'
contains information about height and width.
Default is `False`.
Returns:
Dataset: The built dataset.

View File

@ -62,9 +62,19 @@ class MMDetection(MMCodebase):
data_cfg = dataset_cfg.data[dataset_type]
samples_per_gpu = dataset_cfg.data.get('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline)
if isinstance(data_cfg, dict):
data_cfg.test_mode = True
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline)
elif isinstance(data_cfg, list):
for ds_cfg in data_cfg:
ds_cfg.test_mode = True
if samples_per_gpu > 1:
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
else:
raise TypeError(f'Unsupported type {type(data_cfg)}')
dataset = build_dataset_mmdet(data_cfg)
return dataset