From f2bac79f03308a8f18ff322704216c2bd5fd21bb Mon Sep 17 00:00:00 2001 From: "limengzhang.vendor" Date: Mon, 27 Jun 2022 14:36:18 +0000 Subject: [PATCH] [Refactor] Refactor DatasetWrapper --- configs/_base_/datasets/chase_db1.py | 17 +- configs/_base_/datasets/drive.py | 16 +- configs/_base_/datasets/hrf.py | 16 +- configs/_base_/datasets/pascal_voc12_aug.py | 65 ++++- configs/_base_/datasets/stare.py | 16 +- mmseg/datasets/__init__.py | 20 +- mmseg/datasets/builder.py | 191 --------------- mmseg/datasets/dataset_wrappers.py | 253 +++++--------------- tests/test_datasets/test_dataset_builder.py | 151 ++++++++++++ tests/test_datasets/test_loading.py | 2 +- tools/browse_dataset.py | 4 +- 11 files changed, 316 insertions(+), 435 deletions(-) delete mode 100644 mmseg/datasets/builder.py create mode 100644 tests/test_datasets/test_dataset_builder.py diff --git a/configs/_base_/datasets/chase_db1.py b/configs/_base_/datasets/chase_db1.py index 96de3bc1d..2e6093ec5 100644 --- a/configs/_base_/datasets/chase_db1.py +++ b/configs/_base_/datasets/chase_db1.py @@ -26,14 +26,17 @@ train_dataloader = dict( num_workers=4, persistent_workers=True, sampler=dict(type='InfiniteSampler', shuffle=True), - type='RepeatDataset', - times=40000, dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict( - img_path='images/training', seg_map_path='annotations/training'), - pipeline=train_pipeline)) + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='images/training', + seg_map_path='annotations/training'), + pipeline=train_pipeline))) + val_dataloader = dict( batch_size=1, num_workers=4, diff --git a/configs/_base_/datasets/drive.py b/configs/_base_/datasets/drive.py index cc89d8a71..71f9b619b 100644 --- a/configs/_base_/datasets/drive.py +++ b/configs/_base_/datasets/drive.py @@ -25,14 +25,16 @@ train_dataloader = dict( num_workers=4, persistent_workers=True, sampler=dict(type='InfiniteSampler', shuffle=True), - type='RepeatDataset', - times=40000, dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict( - img_path='images/training', seg_map_path='annotations/training'), - pipeline=train_pipeline)) + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='images/training', + seg_map_path='annotations/training'), + pipeline=train_pipeline))) val_dataloader = dict( batch_size=1, num_workers=4, diff --git a/configs/_base_/datasets/hrf.py b/configs/_base_/datasets/hrf.py index 5f1f23326..a423de889 100644 --- a/configs/_base_/datasets/hrf.py +++ b/configs/_base_/datasets/hrf.py @@ -25,14 +25,16 @@ train_dataloader = dict( num_workers=4, persistent_workers=True, sampler=dict(type='InfiniteSampler', shuffle=True), - type='RepeatDataset', - times=40000, dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict( - img_path='images/training', seg_map_path='annotations/training'), - pipeline=train_pipeline)) + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='images/training', + seg_map_path='annotations/training'), + pipeline=train_pipeline))) val_dataloader = dict( batch_size=1, num_workers=4, diff --git a/configs/_base_/datasets/pascal_voc12_aug.py b/configs/_base_/datasets/pascal_voc12_aug.py index 24ebb74d3..b9401b138 100644 --- a/configs/_base_/datasets/pascal_voc12_aug.py +++ b/configs/_base_/datasets/pascal_voc12_aug.py @@ -1,9 +1,62 @@ -_base_ = './pascal_voc12.py' # dataset settings +dataset_type = 'PascalVOCDataset' +data_root = 'data/VOCdevkit/VOC2012' +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Pad', size=crop_size), + dict(type='PackSegInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(2048, 512), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') +] + +dataset_train = dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img_path='JPEGImages', seg_map_path='SegmentationClass'), + ann_file='ImageSets/Segmentation/train.txt', + pipeline=train_pipeline) + +dataset_aug = dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='JPEGImages', seg_map_path='SegmentationClassAug'), + ann_file='ImageSets/Segmentation/aug.txt', + pipeline=train_pipeline) + train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + sampler=dict(type='InfiniteSampler', shuffle=True), + dataset=dict(type='ConcatDataset', datasets=[dataset_train, dataset_aug])) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( - ann_dir=['SegmentationClass', 'SegmentationClassAug'], - ann_file=[ - 'ImageSets/Segmentation/train.txt', - 'ImageSets/Segmentation/aug.txt' - ])) + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='JPEGImages', seg_map_path='SegmentationClass'), + ann_file='ImageSets/Segmentation/val.txt', + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/stare.py b/configs/_base_/datasets/stare.py index 09f731d66..f41ebd835 100644 --- a/configs/_base_/datasets/stare.py +++ b/configs/_base_/datasets/stare.py @@ -25,14 +25,16 @@ train_dataloader = dict( num_workers=4, persistent_workers=True, sampler=dict(type='InfiniteSampler', shuffle=True), - type='RepeatDataset', - times=40000, dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict( - img_path='images/training', seg_map_path='annotations/training'), - pipeline=train_pipeline)) + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='images/training', + seg_map_path='annotations/training'), + pipeline=train_pipeline))) val_dataloader = dict( batch_size=1, num_workers=4, diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index 5d42a11c2..626f0e398 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -1,13 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import ConcatDataset, RepeatDataset + +from mmseg.registry import DATASETS, TRANSFORMS from .ade import ADE20KDataset -from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset from .chase_db1 import ChaseDB1Dataset from .cityscapes import CityscapesDataset from .coco_stuff import COCOStuffDataset from .custom import CustomDataset from .dark_zurich import DarkZurichDataset -from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, - RepeatDataset) +from .dataset_wrappers import MultiImageMixDataset from .drive import DRIVEDataset from .hrf import HRFDataset from .isaid import iSAIDDataset @@ -20,11 +21,10 @@ from .stare import STAREDataset from .voc import PascalVOCDataset __all__ = [ - 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', - 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', - 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', - 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', - 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', - 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', - 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset' + 'CustomDataset', 'ConcatDataset', 'RepeatDataset', 'DATASETS', + 'TRANSFORMS', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset', + 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', + 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', + 'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset', + 'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset' ] diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py deleted file mode 100644 index 0eba0e1dd..000000000 --- a/mmseg/datasets/builder.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import platform -import random -from functools import partial - -import numpy as np -import torch -from mmcv.parallel import collate -from mmcv.runner import get_dist_info -from mmcv.utils import digit_version -from torch.utils.data import DataLoader - -from mmseg.registry import DATASETS, TRANSFORMS -from .samplers import DistributedSampler - -if platform.system() != 'Windows': - # https://github.com/pytorch/pytorch/issues/973 - import resource - rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) - base_soft_limit = rlimit[0] - hard_limit = rlimit[1] - soft_limit = min(max(4096, base_soft_limit), hard_limit) - resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) - -PIPELINES = TRANSFORMS - - -def _concat_dataset(cfg, default_args=None): - """Build :obj:`ConcatDataset by.""" - from .dataset_wrappers import ConcatDataset - img_dir = cfg['img_dir'] - ann_dir = cfg.get('ann_dir', None) - split = cfg.get('split', None) - # pop 'separate_eval' since it is not a valid key for common datasets. - separate_eval = cfg.pop('separate_eval', True) - num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 - if ann_dir is not None: - num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 - else: - num_ann_dir = 0 - if split is not None: - num_split = len(split) if isinstance(split, (list, tuple)) else 1 - else: - num_split = 0 - if num_img_dir > 1: - assert num_img_dir == num_ann_dir or num_ann_dir == 0 - assert num_img_dir == num_split or num_split == 0 - else: - assert num_split == num_ann_dir or num_ann_dir <= 1 - num_dset = max(num_split, num_img_dir) - - datasets = [] - for i in range(num_dset): - data_cfg = copy.deepcopy(cfg) - if isinstance(img_dir, (list, tuple)): - data_cfg['img_dir'] = img_dir[i] - if isinstance(ann_dir, (list, tuple)): - data_cfg['ann_dir'] = ann_dir[i] - if isinstance(split, (list, tuple)): - data_cfg['split'] = split[i] - datasets.append(build_dataset(data_cfg, default_args)) - - return ConcatDataset(datasets, separate_eval) - - -def build_dataset(cfg, default_args=None): - """Build datasets.""" - from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, - RepeatDataset) - if isinstance(cfg, (list, tuple)): - dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) - elif cfg['type'] == 'RepeatDataset': - dataset = RepeatDataset( - build_dataset(cfg['dataset'], default_args), cfg['times']) - elif cfg['type'] == 'MultiImageMixDataset': - cp_cfg = copy.deepcopy(cfg) - cp_cfg['dataset'] = build_dataset(cp_cfg['dataset']) - cp_cfg.pop('type') - dataset = MultiImageMixDataset(**cp_cfg) - elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( - cfg.get('split', None), (list, tuple)): - dataset = _concat_dataset(cfg, default_args) - else: - dataset = DATASETS.build(cfg, default_args=default_args) - - return dataset - - -def build_dataloader(dataset, - samples_per_gpu, - workers_per_gpu, - num_gpus=1, - dist=True, - shuffle=True, - seed=None, - drop_last=False, - pin_memory=True, - persistent_workers=True, - **kwargs): - """Build PyTorch DataLoader. - - In distributed training, each GPU/process has a dataloader. - In non-distributed training, there is only one dataloader for all GPUs. - - Args: - dataset (Dataset): A PyTorch dataset. - samples_per_gpu (int): Number of training samples on each GPU, i.e., - batch size of each GPU. - workers_per_gpu (int): How many subprocesses to use for data loading - for each GPU. - num_gpus (int): Number of GPUs. Only used in non-distributed training. - dist (bool): Distributed training/test or not. Default: True. - shuffle (bool): Whether to shuffle the data at every epoch. - Default: True. - seed (int | None): Seed to be used. Default: None. - drop_last (bool): Whether to drop the last incomplete batch in epoch. - Default: False - pin_memory (bool): Whether to use pin_memory in DataLoader. - Default: True - persistent_workers (bool): If True, the data loader will not shutdown - the worker processes after a dataset has been consumed once. - This allows to maintain the workers Dataset instances alive. - The argument also has effect in PyTorch>=1.7.0. - Default: True - kwargs: any keyword argument to be used to initialize DataLoader - - Returns: - DataLoader: A PyTorch dataloader. - """ - rank, world_size = get_dist_info() - if dist: - sampler = DistributedSampler( - dataset, world_size, rank, shuffle=shuffle, seed=seed) - shuffle = False - batch_size = samples_per_gpu - num_workers = workers_per_gpu - else: - sampler = None - batch_size = num_gpus * samples_per_gpu - num_workers = num_gpus * workers_per_gpu - - init_fn = partial( - worker_init_fn, num_workers=num_workers, rank=rank, - seed=seed) if seed is not None else None - - if digit_version(torch.__version__) >= digit_version('1.8.0'): - data_loader = DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=num_workers, - collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), - pin_memory=pin_memory, - shuffle=shuffle, - worker_init_fn=init_fn, - drop_last=drop_last, - persistent_workers=persistent_workers, - **kwargs) - else: - data_loader = DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=num_workers, - collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), - pin_memory=pin_memory, - shuffle=shuffle, - worker_init_fn=init_fn, - drop_last=drop_last, - **kwargs) - - return data_loader - - -def worker_init_fn(worker_id, num_workers, rank, seed): - """Worker init func for dataloader. - - The seed of each worker equals to num_worker * rank + worker_id + user_seed - - Args: - worker_id (int): Worker id. - num_workers (int): Number of workers. - rank (int): The rank of current process. - seed (int): The random seed to use. - """ - - worker_seed = num_workers * rank + worker_id + seed - np.random.seed(worker_seed) - random.seed(worker_seed) - torch.manual_seed(worker_seed) diff --git a/mmseg/datasets/dataset_wrappers.py b/mmseg/datasets/dataset_wrappers.py index 54b8fe856..57136e33f 100644 --- a/mmseg/datasets/dataset_wrappers.py +++ b/mmseg/datasets/dataset_wrappers.py @@ -1,195 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import bisect import collections import copy -from itertools import chain +from typing import List, Optional, Sequence, Union -import mmcv -import numpy as np -from mmcv.utils import print_log -from torch.utils.data.dataset import ConcatDataset as _ConcatDataset +from mmengine.dataset import ConcatDataset, force_full_init from mmseg.registry import DATASETS, TRANSFORMS -from .cityscapes import CityscapesDataset - - -@DATASETS.register_module() -class ConcatDataset(_ConcatDataset): - """A wrapper of concatenated dataset. - - Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but - support evaluation and formatting results - - Args: - datasets (list[:obj:`Dataset`]): A list of datasets. - separate_eval (bool): Whether to evaluate the concatenated - dataset results separately, Defaults to True. - """ - - def __init__(self, datasets, separate_eval=True): - super(ConcatDataset, self).__init__(datasets) - self.CLASSES = datasets[0].CLASSES - self.PALETTE = datasets[0].PALETTE - self.separate_eval = separate_eval - assert separate_eval in [True, False], \ - f'separate_eval can only be True or False,' \ - f'but get {separate_eval}' - if any([isinstance(ds, CityscapesDataset) for ds in datasets]): - raise NotImplementedError( - 'Evaluating ConcatDataset containing CityscapesDataset' - 'is not supported!') - - def evaluate(self, results, logger=None, **kwargs): - """Evaluate the results. - - Args: - results (list[tuple[torch.Tensor]] | list[str]]): per image - pre_eval results or predict segmentation map for - computing evaluation metric. - logger (logging.Logger | str | None): Logger used for printing - related information during evaluation. Default: None. - - Returns: - dict[str: float]: evaluate results of the total dataset - or each separate - dataset if `self.separate_eval=True`. - """ - assert len(results) == self.cumulative_sizes[-1], \ - ('Dataset and results have different sizes: ' - f'{self.cumulative_sizes[-1]} v.s. {len(results)}') - - # Check whether all the datasets support evaluation - for dataset in self.datasets: - assert hasattr(dataset, 'evaluate'), \ - f'{type(dataset)} does not implement evaluate function' - - if self.separate_eval: - dataset_idx = -1 - total_eval_results = dict() - for size, dataset in zip(self.cumulative_sizes, self.datasets): - start_idx = 0 if dataset_idx == -1 else \ - self.cumulative_sizes[dataset_idx] - end_idx = self.cumulative_sizes[dataset_idx + 1] - - results_per_dataset = results[start_idx:end_idx] - print_log( - f'\nEvaluateing {dataset.img_dir} with ' - f'{len(results_per_dataset)} images now', - logger=logger) - - eval_results_per_dataset = dataset.evaluate( - results_per_dataset, logger=logger, **kwargs) - dataset_idx += 1 - for k, v in eval_results_per_dataset.items(): - total_eval_results.update({f'{dataset_idx}_{k}': v}) - - return total_eval_results - - if len(set([type(ds) for ds in self.datasets])) != 1: - raise NotImplementedError( - 'All the datasets should have same types when ' - 'self.separate_eval=False') - else: - if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( - results, str): - # merge the generators of gt_seg_maps - gt_seg_maps = chain( - *[dataset.get_gt_seg_maps() for dataset in self.datasets]) - else: - # if the results are `pre_eval` results, - # we do not need gt_seg_maps to evaluate - gt_seg_maps = None - eval_results = self.datasets[0].evaluate( - results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs) - return eval_results - - def get_dataset_idx_and_sample_idx(self, indice): - """Return dataset and sample index when given an indice of - ConcatDataset. - - Args: - indice (int): indice of sample in ConcatDataset - - Returns: - int: the index of sub dataset the sample belong to - int: the index of sample in its corresponding subset - """ - if indice < 0: - if -indice > len(self): - raise ValueError( - 'absolute value of index should not exceed dataset length') - indice = len(self) + indice - dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice) - if dataset_idx == 0: - sample_idx = indice - else: - sample_idx = indice - self.cumulative_sizes[dataset_idx - 1] - return dataset_idx, sample_idx - - def format_results(self, results, imgfile_prefix, indices=None, **kwargs): - """format result for every sample of ConcatDataset.""" - if indices is None: - indices = list(range(len(self))) - - assert isinstance(results, list), 'results must be a list.' - assert isinstance(indices, list), 'indices must be a list.' - - ret_res = [] - for i, indice in enumerate(indices): - dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( - indice) - res = self.datasets[dataset_idx].format_results( - [results[i]], - imgfile_prefix + f'/{dataset_idx}', - indices=[sample_idx], - **kwargs) - ret_res.append(res) - return sum(ret_res, []) - - def pre_eval(self, preds, indices): - """do pre eval for every sample of ConcatDataset.""" - # In order to compat with batch inference - if not isinstance(indices, list): - indices = [indices] - if not isinstance(preds, list): - preds = [preds] - ret_res = [] - for i, indice in enumerate(indices): - dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx( - indice) - res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx) - ret_res.append(res) - return sum(ret_res, []) - - -@DATASETS.register_module() -class RepeatDataset(object): - """A wrapper of repeated dataset. - - The length of repeated dataset will be `times` larger than the original - dataset. This is useful when the data loading time is long but the dataset - is small. Using RepeatDataset can reduce the data loading time between - epochs. - - Args: - dataset (:obj:`Dataset`): The dataset to be repeated. - times (int): Repeat times. - """ - - def __init__(self, dataset, times): - self.dataset = dataset - self.times = times - self.CLASSES = dataset.CLASSES - self.PALETTE = dataset.PALETTE - self._ori_len = len(self.dataset) - - def __getitem__(self, idx): - """Get item from original dataset.""" - return self.dataset[idx % self._ori_len] - - def __len__(self): - """The length is multiplied by ``times``""" - return self.times * self._ori_len @DATASETS.register_module() @@ -197,22 +13,32 @@ class MultiImageMixDataset: """A wrapper of multiple images mixed dataset. Suitable for training on multiple images mixed data augmentation like - mosaic and mixup. For the augmentation pipeline of mixed image data, - the `get_indexes` method needs to be provided to obtain the image - indexes, and you can set `skip_flags` to change the pipeline running - process. - + mosaic and mixup. Args: - dataset (:obj:`CustomDataset`): The dataset to be mixed. + dataset (ConcatDataset or dict): The dataset to be mixed. pipeline (Sequence[dict]): Sequence of transform object or config dict to be composed. skip_type_keys (list[str], optional): Sequence of type string to be skip pipeline. Default to None. """ - def __init__(self, dataset, pipeline, skip_type_keys=None): + def __init__(self, + dataset: Union[ConcatDataset, dict], + pipeline: Sequence[dict], + skip_type_keys: Optional[List[str]] = None, + lazy_init: bool = False) -> None: assert isinstance(pipeline, collections.abc.Sequence) + + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, ConcatDataset): + self.dataset = dataset + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`ConcatDataset` instance, but got {type(dataset)}') + if skip_type_keys is not None: assert all([ isinstance(skip_type_key, str) @@ -230,11 +56,44 @@ class MultiImageMixDataset: else: raise TypeError('pipeline must be a dict') - self.dataset = dataset - self.CLASSES = dataset.CLASSES - self.PALETTE = dataset.PALETTE - self.num_samples = len(dataset) + self._metainfo = self.dataset.metainfo + self.num_samples = len(self.dataset) + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def metainfo(self) -> dict: + """Get the meta information of the multi-image-mixed dataset. + + Returns: + dict: The meta information of multi-image-mixed dataset. + """ + return copy.deepcopy(self._metainfo) + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + self._ori_len = len(self.dataset) + self._fully_initialized = True + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + return self.dataset.get_data_info(idx) + + @force_full_init def __len__(self): return self.num_samples diff --git a/tests/test_datasets/test_dataset_builder.py b/tests/test_datasets/test_dataset_builder.py new file mode 100644 index 000000000..7954f3a1a --- /dev/null +++ b/tests/test_datasets/test_dataset_builder.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from mmengine.dataset import ConcatDataset, RepeatDataset + +from mmseg.datasets import DATASETS, MultiImageMixDataset +from mmseg.utils import register_all_modules + +register_all_modules() + + +@DATASETS.register_module() +class ToyDataset(object): + + def __init__(self, cnt=0): + self.cnt = cnt + + def __item__(self, idx): + return idx + + def __len__(self): + return 100 + + +def test_build_dataset(): + cfg = dict(type='ToyDataset') + dataset = DATASETS.build(cfg) + assert isinstance(dataset, ToyDataset) + assert dataset.cnt == 0 + dataset = DATASETS.build(cfg, default_args=dict(cnt=1)) + assert isinstance(dataset, ToyDataset) + assert dataset.cnt == 1 + + data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') + data_prefix = dict(img_path='imgs/', seg_map_path='gts/') + + # test RepeatDataset + cfg = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + data_prefix=data_prefix, + serialize_data=False) + dataset = DATASETS.build(cfg) + dataset_repeat = RepeatDataset(dataset=dataset, times=5) + assert isinstance(dataset_repeat, RepeatDataset) + assert len(dataset_repeat) == 25 + + # test ConcatDataset + # We use same dir twice for simplicity + # with data_prefix.seg_map_path + cfg1 = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + data_prefix=data_prefix, + serialize_data=False) + cfg2 = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + data_prefix=data_prefix, + serialize_data=False) + dataset1 = DATASETS.build(cfg1) + dataset2 = DATASETS.build(cfg2) + dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) + assert isinstance(dataset_concat, ConcatDataset) + assert len(dataset_concat) == 10 + + # test MultiImageMixDataset + dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[]) + assert isinstance(dataset, MultiImageMixDataset) + assert len(dataset) == 10 + + cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2]) + + dataset = MultiImageMixDataset(dataset=cfg, pipeline=[]) + assert isinstance(dataset, MultiImageMixDataset) + assert len(dataset) == 10 + + # with data_prefix.seg_map_path, ann_file + cfg1 = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + data_prefix=data_prefix, + ann_file='splits/train.txt', + serialize_data=False) + cfg2 = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + data_prefix=data_prefix, + ann_file='splits/val.txt', + serialize_data=False) + + dataset1 = DATASETS.build(cfg1) + dataset2 = DATASETS.build(cfg2) + dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) + assert isinstance(dataset_concat, ConcatDataset) + assert len(dataset_concat) == 5 + + # test mode + cfg1 = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + data_prefix=dict(img_path='imgs/'), + test_mode=True, + metainfo=dict(classes=('pseudo_class', )), + serialize_data=False) + cfg2 = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + data_prefix=dict(img_path='imgs/'), + test_mode=True, + metainfo=dict(classes=('pseudo_class', )), + serialize_data=False) + + dataset1 = DATASETS.build(cfg1) + dataset2 = DATASETS.build(cfg2) + dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) + assert isinstance(dataset_concat, ConcatDataset) + assert len(dataset_concat) == 10 + + # test mode with ann_files + cfg1 = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + data_prefix=dict(img_path='imgs/'), + ann_file='splits/val.txt', + test_mode=True, + metainfo=dict(classes=('pseudo_class', )), + serialize_data=False) + cfg2 = dict( + type='CustomDataset', + pipeline=[], + data_root=data_root, + data_prefix=dict(img_path='imgs/'), + ann_file='splits/val.txt', + test_mode=True, + metainfo=dict(classes=('pseudo_class', )), + serialize_data=False) + + dataset1 = DATASETS.build(cfg1) + dataset2 = DATASETS.build(cfg2) + dataset_concat = ConcatDataset(datasets=[dataset1, dataset2]) + assert isinstance(dataset_concat, ConcatDataset) + assert len(dataset_concat) == 2 diff --git a/tests/test_datasets/test_loading.py b/tests/test_datasets/test_loading.py index 937dcf6df..77029bb7f 100644 --- a/tests/test_datasets/test_loading.py +++ b/tests/test_datasets/test_loading.py @@ -25,7 +25,7 @@ class TestLoading(object): assert results['img'].dtype == np.uint8 assert results['ori_shape'] == results['img'].shape[:2] assert repr(transform) == transform.__class__.__name__ + \ - "(to_float32=False, color_type='color'," + \ + "(ignore_empty=False, to_float32=False, color_type='color'," + \ " imdecode_backend='cv2', file_client_args={'backend': 'disk'})" # to_float32 diff --git a/tools/browse_dataset.py b/tools/browse_dataset.py index 0aa9430ea..64fe69585 100644 --- a/tools/browse_dataset.py +++ b/tools/browse_dataset.py @@ -8,7 +8,7 @@ import mmcv import numpy as np from mmcv import Config, DictAction -from mmseg.datasets.builder import build_dataset +from mmseg.datasets import DATASETS def parse_args(): @@ -159,7 +159,7 @@ def main(): args = parse_args() cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options, args.show_origin) - dataset = build_dataset(cfg.data.train) + dataset = DATASETS.build(cfg.data.train) progress_bar = mmcv.ProgressBar(len(dataset)) for item in dataset: filename = os.path.join(args.output_dir,