mmsegmentation/tests/test_datasets/test_dataset_builder.py

153 lines
4.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from mmengine.dataset import ConcatDataset, RepeatDataset
from mmengine.registry import init_default_scope
from mmseg.datasets import MultiImageMixDataset
from mmseg.registry import DATASETS
init_default_scope('mmseg')
@DATASETS.register_module()
class ToyDataset:
def __init__(self, cnt=0):
self.cnt = cnt
def __item__(self, idx):
return idx
def __len__(self):
return 100
def test_build_dataset():
cfg = dict(type='ToyDataset')
dataset = DATASETS.build(cfg)
assert isinstance(dataset, ToyDataset)
assert dataset.cnt == 0
dataset = DATASETS.build(cfg, default_args=dict(cnt=1))
assert isinstance(dataset, ToyDataset)
assert dataset.cnt == 1
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
data_prefix = dict(img_path='imgs/', seg_map_path='gts/')
# test RepeatDataset
cfg = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
serialize_data=False)
dataset = DATASETS.build(cfg)
dataset_repeat = RepeatDataset(dataset=dataset, times=5)
assert isinstance(dataset_repeat, RepeatDataset)
assert len(dataset_repeat) == 25
# test ConcatDataset
# We use same dir twice for simplicity
# with data_prefix.seg_map_path
cfg1 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
serialize_data=False)
cfg2 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 10
# test MultiImageMixDataset
dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[])
assert isinstance(dataset, MultiImageMixDataset)
assert len(dataset) == 10
cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2])
dataset = MultiImageMixDataset(dataset=cfg, pipeline=[])
assert isinstance(dataset, MultiImageMixDataset)
assert len(dataset) == 10
# with data_prefix.seg_map_path, ann_file
cfg1 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
ann_file='splits/train.txt',
serialize_data=False)
cfg2 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
ann_file='splits/val.txt',
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 5
# test mode
cfg1 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
cfg2 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 10
# test mode with ann_files
cfg1 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
ann_file='splits/val.txt',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
cfg2 = dict(
type='BaseSegDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
ann_file='splits/val.txt',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 2