mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
parent
a6e07dac2f
commit
1f8d889b36
@ -68,7 +68,7 @@ class BaseTask(metaclass=ABCMeta):
|
|||||||
def build_dataset(self,
|
def build_dataset(self,
|
||||||
dataset_cfg: Union[str, mmcv.Config],
|
dataset_cfg: Union[str, mmcv.Config],
|
||||||
dataset_type: str = 'val',
|
dataset_type: str = 'val',
|
||||||
is_sort_dataset: bool = True,
|
is_sort_dataset: bool = False,
|
||||||
**kwargs) -> Dataset:
|
**kwargs) -> Dataset:
|
||||||
"""Build dataset for different codebase.
|
"""Build dataset for different codebase.
|
||||||
|
|
||||||
@ -80,6 +80,7 @@ class BaseTask(metaclass=ABCMeta):
|
|||||||
is_sort_dataset (bool): When 'True', the dataset will be sorted
|
is_sort_dataset (bool): When 'True', the dataset will be sorted
|
||||||
by image shape in ascending order if 'dataset_cfg'
|
by image shape in ascending order if 'dataset_cfg'
|
||||||
contains information about height and width.
|
contains information about height and width.
|
||||||
|
Default is `False`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset: The built dataset.
|
Dataset: The built dataset.
|
||||||
|
@ -62,9 +62,19 @@ class MMDetection(MMCodebase):
|
|||||||
|
|
||||||
data_cfg = dataset_cfg.data[dataset_type]
|
data_cfg = dataset_cfg.data[dataset_type]
|
||||||
samples_per_gpu = dataset_cfg.data.get('samples_per_gpu', 1)
|
samples_per_gpu = dataset_cfg.data.get('samples_per_gpu', 1)
|
||||||
if samples_per_gpu > 1:
|
if isinstance(data_cfg, dict):
|
||||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
data_cfg.test_mode = True
|
||||||
data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline)
|
if samples_per_gpu > 1:
|
||||||
|
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||||
|
data_cfg.pipeline = replace_ImageToTensor(data_cfg.pipeline)
|
||||||
|
elif isinstance(data_cfg, list):
|
||||||
|
for ds_cfg in data_cfg:
|
||||||
|
ds_cfg.test_mode = True
|
||||||
|
if samples_per_gpu > 1:
|
||||||
|
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
|
||||||
|
else:
|
||||||
|
raise TypeError(f'Unsupported type {type(data_cfg)}')
|
||||||
|
|
||||||
dataset = build_dataset_mmdet(data_cfg)
|
dataset = build_dataset_mmdet(data_cfg)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
Loading…
x
Reference in New Issue
Block a user