[Refactor] Refactor DatasetWrapper

pull/1801/head
limengzhang.vendor 2022-06-27 14:36:18 +00:00 committed by zhengmiao
parent eef12a064b
commit f2bac79f03
11 changed files with 316 additions and 435 deletions

View File

@ -26,14 +26,17 @@ train_dataloader = dict(
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training',
seg_map_path='annotations/training'),
pipeline=train_pipeline)))
val_dataloader = dict(
batch_size=1,
num_workers=4,

View File

@ -25,14 +25,16 @@ train_dataloader = dict(
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training',
seg_map_path='annotations/training'),
pipeline=train_pipeline)))
val_dataloader = dict(
batch_size=1,
num_workers=4,

View File

@ -25,14 +25,16 @@ train_dataloader = dict(
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training',
seg_map_path='annotations/training'),
pipeline=train_pipeline)))
val_dataloader = dict(
batch_size=1,
num_workers=4,

View File

@ -1,9 +1,62 @@
_base_ = './pascal_voc12.py'
# dataset settings
dataset_type = 'PascalVOCDataset'
data_root = 'data/VOCdevkit/VOC2012'
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Pad', size=crop_size),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
dataset_train = dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img_path='JPEGImages', seg_map_path='SegmentationClass'),
ann_file='ImageSets/Segmentation/train.txt',
pipeline=train_pipeline)
dataset_aug = dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='JPEGImages', seg_map_path='SegmentationClassAug'),
ann_file='ImageSets/Segmentation/aug.txt',
pipeline=train_pipeline)
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(type='ConcatDataset', datasets=[dataset_train, dataset_aug]))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
ann_dir=['SegmentationClass', 'SegmentationClassAug'],
ann_file=[
'ImageSets/Segmentation/train.txt',
'ImageSets/Segmentation/aug.txt'
]))
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='JPEGImages', seg_map_path='SegmentationClass'),
ann_file='ImageSets/Segmentation/val.txt',
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -25,14 +25,16 @@ train_dataloader = dict(
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))
type='RepeatDataset',
times=40000,
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training',
seg_map_path='annotations/training'),
pipeline=train_pipeline)))
val_dataloader = dict(
batch_size=1,
num_workers=4,

View File

@ -1,13 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import ConcatDataset, RepeatDataset
from mmseg.registry import DATASETS, TRANSFORMS
from .ade import ADE20KDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .custom import CustomDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
RepeatDataset)
from .dataset_wrappers import MultiImageMixDataset
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
@ -20,11 +21,10 @@ from .stare import STAREDataset
from .voc import PascalVOCDataset
__all__ = [
'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset'
'CustomDataset', 'ConcatDataset', 'RepeatDataset', 'DATASETS',
'TRANSFORMS', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset'
]

View File

@ -1,191 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import random
from functools import partial
import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import digit_version
from torch.utils.data import DataLoader
from mmseg.registry import DATASETS, TRANSFORMS
from .samplers import DistributedSampler
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
base_soft_limit = rlimit[0]
hard_limit = rlimit[1]
soft_limit = min(max(4096, base_soft_limit), hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
PIPELINES = TRANSFORMS
def _concat_dataset(cfg, default_args=None):
"""Build :obj:`ConcatDataset by."""
from .dataset_wrappers import ConcatDataset
img_dir = cfg['img_dir']
ann_dir = cfg.get('ann_dir', None)
split = cfg.get('split', None)
# pop 'separate_eval' since it is not a valid key for common datasets.
separate_eval = cfg.pop('separate_eval', True)
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
if ann_dir is not None:
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
else:
num_ann_dir = 0
if split is not None:
num_split = len(split) if isinstance(split, (list, tuple)) else 1
else:
num_split = 0
if num_img_dir > 1:
assert num_img_dir == num_ann_dir or num_ann_dir == 0
assert num_img_dir == num_split or num_split == 0
else:
assert num_split == num_ann_dir or num_ann_dir <= 1
num_dset = max(num_split, num_img_dir)
datasets = []
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
if isinstance(img_dir, (list, tuple)):
data_cfg['img_dir'] = img_dir[i]
if isinstance(ann_dir, (list, tuple)):
data_cfg['ann_dir'] = ann_dir[i]
if isinstance(split, (list, tuple)):
data_cfg['split'] = split[i]
datasets.append(build_dataset(data_cfg, default_args))
return ConcatDataset(datasets, separate_eval)
def build_dataset(cfg, default_args=None):
"""Build datasets."""
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
RepeatDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'MultiImageMixDataset':
cp_cfg = copy.deepcopy(cfg)
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
cp_cfg.pop('type')
dataset = MultiImageMixDataset(**cp_cfg)
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
cfg.get('split', None), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = DATASETS.build(cfg, default_args=default_args)
return dataset
def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
drop_last=False,
pin_memory=True,
persistent_workers=True,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
seed (int | None): Seed to be used. Default: None.
drop_last (bool): Whether to drop the last incomplete batch in epoch.
Default: False
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default: True
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Default: True
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
if digit_version(torch.__version__) >= digit_version('1.8.0'):
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
drop_last=drop_last,
persistent_workers=persistent_workers,
**kwargs)
else:
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=pin_memory,
shuffle=shuffle,
worker_init_fn=init_fn,
drop_last=drop_last,
**kwargs)
return data_loader
def worker_init_fn(worker_id, num_workers, rank, seed):
"""Worker init func for dataloader.
The seed of each worker equals to num_worker * rank + worker_id + user_seed
Args:
worker_id (int): Worker id.
num_workers (int): Number of workers.
rank (int): The rank of current process.
seed (int): The random seed to use.
"""
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)

View File

@ -1,195 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import collections
import copy
from itertools import chain
from typing import List, Optional, Sequence, Union
import mmcv
import numpy as np
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from mmengine.dataset import ConcatDataset, force_full_init
from mmseg.registry import DATASETS, TRANSFORMS
from .cityscapes import CityscapesDataset
@DATASETS.register_module()
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
support evaluation and formatting results
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the concatenated
dataset results separately, Defaults to True.
"""
def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE
self.separate_eval = separate_eval
assert separate_eval in [True, False], \
f'separate_eval can only be True or False,' \
f'but get {separate_eval}'
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
raise NotImplementedError(
'Evaluating ConcatDataset containing CityscapesDataset'
'is not supported!')
def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.
Args:
results (list[tuple[torch.Tensor]] | list[str]]): per image
pre_eval results or predict segmentation map for
computing evaluation metric.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]: evaluate results of the total dataset
or each separate
dataset if `self.separate_eval=True`.
"""
assert len(results) == self.cumulative_sizes[-1], \
('Dataset and results have different sizes: '
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f'{type(dataset)} does not implement evaluate function'
if self.separate_eval:
dataset_idx = -1
total_eval_results = dict()
for size, dataset in zip(self.cumulative_sizes, self.datasets):
start_idx = 0 if dataset_idx == -1 else \
self.cumulative_sizes[dataset_idx]
end_idx = self.cumulative_sizes[dataset_idx + 1]
results_per_dataset = results[start_idx:end_idx]
print_log(
f'\nEvaluateing {dataset.img_dir} with '
f'{len(results_per_dataset)} images now',
logger=logger)
eval_results_per_dataset = dataset.evaluate(
results_per_dataset, logger=logger, **kwargs)
dataset_idx += 1
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})
return total_eval_results
if len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types when '
'self.separate_eval=False')
else:
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
# merge the generators of gt_seg_maps
gt_seg_maps = chain(
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
else:
# if the results are `pre_eval` results,
# we do not need gt_seg_maps to evaluate
gt_seg_maps = None
eval_results = self.datasets[0].evaluate(
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
return eval_results
def get_dataset_idx_and_sample_idx(self, indice):
"""Return dataset and sample index when given an indice of
ConcatDataset.
Args:
indice (int): indice of sample in ConcatDataset
Returns:
int: the index of sub dataset the sample belong to
int: the index of sample in its corresponding subset
"""
if indice < 0:
if -indice > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
indice = len(self) + indice
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
if dataset_idx == 0:
sample_idx = indice
else:
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""format result for every sample of ConcatDataset."""
if indices is None:
indices = list(range(len(self)))
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].format_results(
[results[i]],
imgfile_prefix + f'/{dataset_idx}',
indices=[sample_idx],
**kwargs)
ret_res.append(res)
return sum(ret_res, [])
def pre_eval(self, preds, indices):
"""do pre eval for every sample of ConcatDataset."""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
indice)
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
ret_res.append(res)
return sum(ret_res, [])
@DATASETS.register_module()
class RepeatDataset(object):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
"""
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
self.PALETTE = dataset.PALETTE
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
"""Get item from original dataset."""
return self.dataset[idx % self._ori_len]
def __len__(self):
"""The length is multiplied by ``times``"""
return self.times * self._ori_len
@DATASETS.register_module()
@ -197,22 +13,32 @@ class MultiImageMixDataset:
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
mosaic and mixup. For the augmentation pipeline of mixed image data,
the `get_indexes` method needs to be provided to obtain the image
indexes, and you can set `skip_flags` to change the pipeline running
process.
mosaic and mixup.
Args:
dataset (:obj:`CustomDataset`): The dataset to be mixed.
dataset (ConcatDataset or dict): The dataset to be mixed.
pipeline (Sequence[dict]): Sequence of transform object or
config dict to be composed.
skip_type_keys (list[str], optional): Sequence of type string to
be skip pipeline. Default to None.
"""
def __init__(self, dataset, pipeline, skip_type_keys=None):
def __init__(self,
dataset: Union[ConcatDataset, dict],
pipeline: Sequence[dict],
skip_type_keys: Optional[List[str]] = None,
lazy_init: bool = False) -> None:
assert isinstance(pipeline, collections.abc.Sequence)
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, ConcatDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`ConcatDataset` instance, but got {type(dataset)}')
if skip_type_keys is not None:
assert all([
isinstance(skip_type_key, str)
@ -230,11 +56,44 @@ class MultiImageMixDataset:
else:
raise TypeError('pipeline must be a dict')
self.dataset = dataset
self.CLASSES = dataset.CLASSES
self.PALETTE = dataset.PALETTE
self.num_samples = len(dataset)
self._metainfo = self.dataset.metainfo
self.num_samples = len(self.dataset)
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def metainfo(self) -> dict:
"""Get the meta information of the multi-image-mixed dataset.
Returns:
dict: The meta information of multi-image-mixed dataset.
"""
return copy.deepcopy(self._metainfo)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
self._ori_len = len(self.dataset)
self._fully_initialized = True
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
return self.dataset.get_data_info(idx)
@force_full_init
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,151 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from mmengine.dataset import ConcatDataset, RepeatDataset
from mmseg.datasets import DATASETS, MultiImageMixDataset
from mmseg.utils import register_all_modules
register_all_modules()
@DATASETS.register_module()
class ToyDataset(object):
def __init__(self, cnt=0):
self.cnt = cnt
def __item__(self, idx):
return idx
def __len__(self):
return 100
def test_build_dataset():
cfg = dict(type='ToyDataset')
dataset = DATASETS.build(cfg)
assert isinstance(dataset, ToyDataset)
assert dataset.cnt == 0
dataset = DATASETS.build(cfg, default_args=dict(cnt=1))
assert isinstance(dataset, ToyDataset)
assert dataset.cnt == 1
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
data_prefix = dict(img_path='imgs/', seg_map_path='gts/')
# test RepeatDataset
cfg = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
serialize_data=False)
dataset = DATASETS.build(cfg)
dataset_repeat = RepeatDataset(dataset=dataset, times=5)
assert isinstance(dataset_repeat, RepeatDataset)
assert len(dataset_repeat) == 25
# test ConcatDataset
# We use same dir twice for simplicity
# with data_prefix.seg_map_path
cfg1 = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
serialize_data=False)
cfg2 = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 10
# test MultiImageMixDataset
dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[])
assert isinstance(dataset, MultiImageMixDataset)
assert len(dataset) == 10
cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2])
dataset = MultiImageMixDataset(dataset=cfg, pipeline=[])
assert isinstance(dataset, MultiImageMixDataset)
assert len(dataset) == 10
# with data_prefix.seg_map_path, ann_file
cfg1 = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
ann_file='splits/train.txt',
serialize_data=False)
cfg2 = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
data_prefix=data_prefix,
ann_file='splits/val.txt',
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 5
# test mode
cfg1 = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
cfg2 = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 10
# test mode with ann_files
cfg1 = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
ann_file='splits/val.txt',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
cfg2 = dict(
type='CustomDataset',
pipeline=[],
data_root=data_root,
data_prefix=dict(img_path='imgs/'),
ann_file='splits/val.txt',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )),
serialize_data=False)
dataset1 = DATASETS.build(cfg1)
dataset2 = DATASETS.build(cfg2)
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
assert isinstance(dataset_concat, ConcatDataset)
assert len(dataset_concat) == 2

View File

@ -25,7 +25,7 @@ class TestLoading(object):
assert results['img'].dtype == np.uint8
assert results['ori_shape'] == results['img'].shape[:2]
assert repr(transform) == transform.__class__.__name__ + \
"(to_float32=False, color_type='color'," + \
"(ignore_empty=False, to_float32=False, color_type='color'," + \
" imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
# to_float32

View File

@ -8,7 +8,7 @@ import mmcv
import numpy as np
from mmcv import Config, DictAction
from mmseg.datasets.builder import build_dataset
from mmseg.datasets import DATASETS
def parse_args():
@ -159,7 +159,7 @@ def main():
args = parse_args()
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
args.show_origin)
dataset = build_dataset(cfg.data.train)
dataset = DATASETS.build(cfg.data.train)
progress_bar = mmcv.ProgressBar(len(dataset))
for item in dataset:
filename = os.path.join(args.output_dir,