parent
a6e07dac2f
commit
1f8d889b36
|
@ -68,7 +68,7 @@ class BaseTask(metaclass=ABCMeta):
|
|||
def build_dataset(self,
|
||||
dataset_cfg: Union[str, mmcv.Config],
|
||||
dataset_type: str = 'val',
|
||||
is_sort_dataset: bool = True,
|
||||
is_sort_dataset: bool = False,
|
||||
**kwargs) -> Dataset:
|
||||
"""Build dataset for different codebase.
|
||||
|
||||
|
@ -80,6 +80,7 @@ class BaseTask(metaclass=ABCMeta):
|
|||
is_sort_dataset (bool): When 'True', the dataset will be sorted
|
||||
by image shape in ascending order if 'dataset_cfg'
|
||||
contains information about height and width.
|
||||
Default is `False`.
|
||||
|
||||
Returns:
|
||||
Dataset: The built dataset.
|
||||
|
|
|
@ -62,9 +62,19 @@ class MMDetection(MMCodebase):
|
|||
|
||||
data_cfg = dataset_cfg.data[dataset_type]
|
||||
samples_per_gpu = dataset_cfg.data.get('samples_per_gpu', 1)
|
||||
if samples_per_gpu > 1:
|
||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||
data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline)
|
||||
if isinstance(data_cfg, dict):
|
||||
data_cfg.test_mode = True
|
||||
if samples_per_gpu > 1:
|
||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||
data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline)
|
||||
elif isinstance(data_cfg, list):
|
||||
for ds_cfg in data_cfg:
|
||||
ds_cfg.test_mode = True
|
||||
if samples_per_gpu > 1:
|
||||
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
|
||||
else:
|
||||
raise TypeError(f'Unsupported type {type(data_cfg)}')
|
||||
|
||||
dataset = build_dataset_mmdet(data_cfg)
|
||||
|
||||
return dataset
|
||||
|
|
Loading…
Reference in New Issue