# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from mmengine.dataset import ConcatDataset, RepeatDataset from mmengine.registry import init_default_scope from mmseg.datasets import MultiImageMixDataset from mmseg.registry import DATASETS init_default_scope('mmseg') @DATASETS.register_module() class ToyDataset: def __init__(self, cnt=0): self.cnt = cnt def __item__(self, idx): return idx def __len__(self): return 100 def test_build_dataset(): cfg = dict(type='ToyDataset') dataset = DATASETS.build(cfg) assert isinstance(dataset, ToyDataset) assert dataset.cnt == 0 dataset = DATASETS.build(cfg, default_args=dict(cnt=1)) assert isinstance(dataset, ToyDataset) assert dataset.cnt == 1 data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') data_prefix = dict(img_path='imgs/', seg_map_path='gts/') # test RepeatDataset cfg = dict( type='BaseSegDataset', pipeline=[], data_root=data_root, data_prefix=data_prefix, serialize_data=False) dataset = DATASETS.build(cfg) dataset_repeat = RepeatDataset(dataset=dataset, times=5) assert isinstance(dataset_repeat, RepeatDataset) assert len(dataset_repeat) == 25 # test ConcatDataset # We use same dir twice for simplicity # with data_prefix.seg_map_path cfg1 = dict( type='BaseSegDataset', pipeline=[], data_root=data_root, data_prefix=data_prefix, serialize_data=False) cfg2 = dict( type='BaseSegDataset', pipeline=[], data_root=data_root, data_prefix=data_prefix, serialize_data=False) dataset1 = DATASETS.build(cfg1) dataset2 = DATASETS.build(cfg2) dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) assert isinstance(dataset_concat, ConcatDataset) assert len(dataset_concat) == 10 # test MultiImageMixDataset dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[]) assert isinstance(dataset, MultiImageMixDataset) assert len(dataset) == 10 cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2]) dataset = MultiImageMixDataset(dataset=cfg, pipeline=[]) assert isinstance(dataset, MultiImageMixDataset) assert len(dataset) == 10 # with data_prefix.seg_map_path, ann_file cfg1 = dict( type='BaseSegDataset', pipeline=[], data_root=data_root, data_prefix=data_prefix, ann_file='splits/train.txt', serialize_data=False) cfg2 = dict( type='BaseSegDataset', pipeline=[], data_root=data_root, data_prefix=data_prefix, ann_file='splits/val.txt', serialize_data=False) dataset1 = DATASETS.build(cfg1) dataset2 = DATASETS.build(cfg2) dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) assert isinstance(dataset_concat, ConcatDataset) assert len(dataset_concat) == 5 # test mode cfg1 = dict( type='BaseSegDataset', pipeline=[], data_root=data_root, data_prefix=dict(img_path='imgs/'), test_mode=True, metainfo=dict(classes=('pseudo_class', )), serialize_data=False) cfg2 = dict( type='BaseSegDataset', pipeline=[], data_root=data_root, data_prefix=dict(img_path='imgs/'), test_mode=True, metainfo=dict(classes=('pseudo_class', )), serialize_data=False) dataset1 = DATASETS.build(cfg1) dataset2 = DATASETS.build(cfg2) dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) assert isinstance(dataset_concat, ConcatDataset) assert len(dataset_concat) == 10 # test mode with ann_files cfg1 = dict( type='BaseSegDataset', pipeline=[], data_root=data_root, data_prefix=dict(img_path='imgs/'), ann_file='splits/val.txt', test_mode=True, metainfo=dict(classes=('pseudo_class', )), serialize_data=False) cfg2 = dict( type='BaseSegDataset', pipeline=[], data_root=data_root, data_prefix=dict(img_path='imgs/'), ann_file='splits/val.txt', test_mode=True, metainfo=dict(classes=('pseudo_class', )), serialize_data=False) dataset1 = DATASETS.build(cfg1) dataset2 = DATASETS.build(cfg2) dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) assert isinstance(dataset_concat, ConcatDataset) assert len(dataset_concat) == 2