153 lines
4.5 KiB
Python
153 lines
4.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
|
|
from mmengine.dataset import ConcatDataset, RepeatDataset
|
|
from mmengine.registry import init_default_scope
|
|
|
|
from mmseg.datasets import MultiImageMixDataset
|
|
from mmseg.registry import DATASETS
|
|
|
|
init_default_scope('mmseg')
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class ToyDataset:
|
|
|
|
def __init__(self, cnt=0):
|
|
self.cnt = cnt
|
|
|
|
def __item__(self, idx):
|
|
return idx
|
|
|
|
def __len__(self):
|
|
return 100
|
|
|
|
|
|
def test_build_dataset():
|
|
cfg = dict(type='ToyDataset')
|
|
dataset = DATASETS.build(cfg)
|
|
assert isinstance(dataset, ToyDataset)
|
|
assert dataset.cnt == 0
|
|
dataset = DATASETS.build(cfg, default_args=dict(cnt=1))
|
|
assert isinstance(dataset, ToyDataset)
|
|
assert dataset.cnt == 1
|
|
|
|
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
|
|
data_prefix = dict(img_path='imgs/', seg_map_path='gts/')
|
|
|
|
# test RepeatDataset
|
|
cfg = dict(
|
|
type='BaseSegDataset',
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=data_prefix,
|
|
serialize_data=False)
|
|
dataset = DATASETS.build(cfg)
|
|
dataset_repeat = RepeatDataset(dataset=dataset, times=5)
|
|
assert isinstance(dataset_repeat, RepeatDataset)
|
|
assert len(dataset_repeat) == 25
|
|
|
|
# test ConcatDataset
|
|
# We use same dir twice for simplicity
|
|
# with data_prefix.seg_map_path
|
|
cfg1 = dict(
|
|
type='BaseSegDataset',
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=data_prefix,
|
|
serialize_data=False)
|
|
cfg2 = dict(
|
|
type='BaseSegDataset',
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=data_prefix,
|
|
serialize_data=False)
|
|
dataset1 = DATASETS.build(cfg1)
|
|
dataset2 = DATASETS.build(cfg2)
|
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
|
assert isinstance(dataset_concat, ConcatDataset)
|
|
assert len(dataset_concat) == 10
|
|
|
|
# test MultiImageMixDataset
|
|
dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[])
|
|
assert isinstance(dataset, MultiImageMixDataset)
|
|
assert len(dataset) == 10
|
|
|
|
cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2])
|
|
|
|
dataset = MultiImageMixDataset(dataset=cfg, pipeline=[])
|
|
assert isinstance(dataset, MultiImageMixDataset)
|
|
assert len(dataset) == 10
|
|
|
|
# with data_prefix.seg_map_path, ann_file
|
|
cfg1 = dict(
|
|
type='BaseSegDataset',
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=data_prefix,
|
|
ann_file='splits/train.txt',
|
|
serialize_data=False)
|
|
cfg2 = dict(
|
|
type='BaseSegDataset',
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=data_prefix,
|
|
ann_file='splits/val.txt',
|
|
serialize_data=False)
|
|
|
|
dataset1 = DATASETS.build(cfg1)
|
|
dataset2 = DATASETS.build(cfg2)
|
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
|
assert isinstance(dataset_concat, ConcatDataset)
|
|
assert len(dataset_concat) == 5
|
|
|
|
# test mode
|
|
cfg1 = dict(
|
|
type='BaseSegDataset',
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=dict(img_path='imgs/'),
|
|
test_mode=True,
|
|
metainfo=dict(classes=('pseudo_class', )),
|
|
serialize_data=False)
|
|
cfg2 = dict(
|
|
type='BaseSegDataset',
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=dict(img_path='imgs/'),
|
|
test_mode=True,
|
|
metainfo=dict(classes=('pseudo_class', )),
|
|
serialize_data=False)
|
|
|
|
dataset1 = DATASETS.build(cfg1)
|
|
dataset2 = DATASETS.build(cfg2)
|
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
|
assert isinstance(dataset_concat, ConcatDataset)
|
|
assert len(dataset_concat) == 10
|
|
|
|
# test mode with ann_files
|
|
cfg1 = dict(
|
|
type='BaseSegDataset',
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=dict(img_path='imgs/'),
|
|
ann_file='splits/val.txt',
|
|
test_mode=True,
|
|
metainfo=dict(classes=('pseudo_class', )),
|
|
serialize_data=False)
|
|
cfg2 = dict(
|
|
type='BaseSegDataset',
|
|
pipeline=[],
|
|
data_root=data_root,
|
|
data_prefix=dict(img_path='imgs/'),
|
|
ann_file='splits/val.txt',
|
|
test_mode=True,
|
|
metainfo=dict(classes=('pseudo_class', )),
|
|
serialize_data=False)
|
|
|
|
dataset1 = DATASETS.build(cfg1)
|
|
dataset2 = DATASETS.build(cfg2)
|
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
|
assert isinstance(dataset_concat, ConcatDataset)
|
|
assert len(dataset_concat) == 2
|