[Refactor] Refactor DatasetWrapper
parent
eef12a064b
commit
f2bac79f03
|
@ -26,14 +26,17 @@ train_dataloader = dict(
|
|||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training', seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline))
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training',
|
||||
seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline)))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
|
|
|
@ -25,14 +25,16 @@ train_dataloader = dict(
|
|||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training', seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline))
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training',
|
||||
seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline)))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
|
|
|
@ -25,14 +25,16 @@ train_dataloader = dict(
|
|||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training', seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline))
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training',
|
||||
seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline)))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
|
|
|
@ -1,9 +1,62 @@
|
|||
_base_ = './pascal_voc12.py'
|
||||
# dataset settings
|
||||
dataset_type = 'PascalVOCDataset'
|
||||
data_root = 'data/VOCdevkit/VOC2012'
|
||||
crop_size = (512, 512)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Pad', size=crop_size),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
|
||||
dataset_train = dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='JPEGImages', seg_map_path='SegmentationClass'),
|
||||
ann_file='ImageSets/Segmentation/train.txt',
|
||||
pipeline=train_pipeline)
|
||||
|
||||
dataset_aug = dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='JPEGImages', seg_map_path='SegmentationClassAug'),
|
||||
ann_file='ImageSets/Segmentation/aug.txt',
|
||||
pipeline=train_pipeline)
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(type='ConcatDataset', datasets=[dataset_train, dataset_aug]))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
ann_dir=['SegmentationClass', 'SegmentationClassAug'],
|
||||
ann_file=[
|
||||
'ImageSets/Segmentation/train.txt',
|
||||
'ImageSets/Segmentation/aug.txt'
|
||||
]))
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='JPEGImages', seg_map_path='SegmentationClass'),
|
||||
ann_file='ImageSets/Segmentation/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
|
|
|
@ -25,14 +25,16 @@ train_dataloader = dict(
|
|||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training', seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline))
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='images/training',
|
||||
seg_map_path='annotations/training'),
|
||||
pipeline=train_pipeline)))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.dataset import ConcatDataset, RepeatDataset
|
||||
|
||||
from mmseg.registry import DATASETS, TRANSFORMS
|
||||
from .ade import ADE20KDataset
|
||||
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
|
||||
from .chase_db1 import ChaseDB1Dataset
|
||||
from .cityscapes import CityscapesDataset
|
||||
from .coco_stuff import COCOStuffDataset
|
||||
from .custom import CustomDataset
|
||||
from .dark_zurich import DarkZurichDataset
|
||||
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
|
||||
RepeatDataset)
|
||||
from .dataset_wrappers import MultiImageMixDataset
|
||||
from .drive import DRIVEDataset
|
||||
from .hrf import HRFDataset
|
||||
from .isaid import iSAIDDataset
|
||||
|
@ -20,11 +21,10 @@ from .stare import STAREDataset
|
|||
from .voc import PascalVOCDataset
|
||||
|
||||
__all__ = [
|
||||
'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
|
||||
'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
|
||||
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
||||
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
|
||||
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
|
||||
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
|
||||
'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset'
|
||||
'CustomDataset', 'ConcatDataset', 'RepeatDataset', 'DATASETS',
|
||||
'TRANSFORMS', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
|
||||
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
|
||||
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
|
||||
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
|
||||
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset'
|
||||
]
|
||||
|
|
|
@ -1,191 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import platform
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import collate
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmcv.utils import digit_version
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmseg.registry import DATASETS, TRANSFORMS
|
||||
from .samplers import DistributedSampler
|
||||
|
||||
if platform.system() != 'Windows':
|
||||
# https://github.com/pytorch/pytorch/issues/973
|
||||
import resource
|
||||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
base_soft_limit = rlimit[0]
|
||||
hard_limit = rlimit[1]
|
||||
soft_limit = min(max(4096, base_soft_limit), hard_limit)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
|
||||
|
||||
PIPELINES = TRANSFORMS
|
||||
|
||||
|
||||
def _concat_dataset(cfg, default_args=None):
|
||||
"""Build :obj:`ConcatDataset by."""
|
||||
from .dataset_wrappers import ConcatDataset
|
||||
img_dir = cfg['img_dir']
|
||||
ann_dir = cfg.get('ann_dir', None)
|
||||
split = cfg.get('split', None)
|
||||
# pop 'separate_eval' since it is not a valid key for common datasets.
|
||||
separate_eval = cfg.pop('separate_eval', True)
|
||||
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
|
||||
if ann_dir is not None:
|
||||
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
|
||||
else:
|
||||
num_ann_dir = 0
|
||||
if split is not None:
|
||||
num_split = len(split) if isinstance(split, (list, tuple)) else 1
|
||||
else:
|
||||
num_split = 0
|
||||
if num_img_dir > 1:
|
||||
assert num_img_dir == num_ann_dir or num_ann_dir == 0
|
||||
assert num_img_dir == num_split or num_split == 0
|
||||
else:
|
||||
assert num_split == num_ann_dir or num_ann_dir <= 1
|
||||
num_dset = max(num_split, num_img_dir)
|
||||
|
||||
datasets = []
|
||||
for i in range(num_dset):
|
||||
data_cfg = copy.deepcopy(cfg)
|
||||
if isinstance(img_dir, (list, tuple)):
|
||||
data_cfg['img_dir'] = img_dir[i]
|
||||
if isinstance(ann_dir, (list, tuple)):
|
||||
data_cfg['ann_dir'] = ann_dir[i]
|
||||
if isinstance(split, (list, tuple)):
|
||||
data_cfg['split'] = split[i]
|
||||
datasets.append(build_dataset(data_cfg, default_args))
|
||||
|
||||
return ConcatDataset(datasets, separate_eval)
|
||||
|
||||
|
||||
def build_dataset(cfg, default_args=None):
|
||||
"""Build datasets."""
|
||||
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
|
||||
RepeatDataset)
|
||||
if isinstance(cfg, (list, tuple)):
|
||||
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
||||
elif cfg['type'] == 'RepeatDataset':
|
||||
dataset = RepeatDataset(
|
||||
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
||||
elif cfg['type'] == 'MultiImageMixDataset':
|
||||
cp_cfg = copy.deepcopy(cfg)
|
||||
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
|
||||
cp_cfg.pop('type')
|
||||
dataset = MultiImageMixDataset(**cp_cfg)
|
||||
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
|
||||
cfg.get('split', None), (list, tuple)):
|
||||
dataset = _concat_dataset(cfg, default_args)
|
||||
else:
|
||||
dataset = DATASETS.build(cfg, default_args=default_args)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def build_dataloader(dataset,
|
||||
samples_per_gpu,
|
||||
workers_per_gpu,
|
||||
num_gpus=1,
|
||||
dist=True,
|
||||
shuffle=True,
|
||||
seed=None,
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
**kwargs):
|
||||
"""Build PyTorch DataLoader.
|
||||
|
||||
In distributed training, each GPU/process has a dataloader.
|
||||
In non-distributed training, there is only one dataloader for all GPUs.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): A PyTorch dataset.
|
||||
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
|
||||
batch size of each GPU.
|
||||
workers_per_gpu (int): How many subprocesses to use for data loading
|
||||
for each GPU.
|
||||
num_gpus (int): Number of GPUs. Only used in non-distributed training.
|
||||
dist (bool): Distributed training/test or not. Default: True.
|
||||
shuffle (bool): Whether to shuffle the data at every epoch.
|
||||
Default: True.
|
||||
seed (int | None): Seed to be used. Default: None.
|
||||
drop_last (bool): Whether to drop the last incomplete batch in epoch.
|
||||
Default: False
|
||||
pin_memory (bool): Whether to use pin_memory in DataLoader.
|
||||
Default: True
|
||||
persistent_workers (bool): If True, the data loader will not shutdown
|
||||
the worker processes after a dataset has been consumed once.
|
||||
This allows to maintain the workers Dataset instances alive.
|
||||
The argument also has effect in PyTorch>=1.7.0.
|
||||
Default: True
|
||||
kwargs: any keyword argument to be used to initialize DataLoader
|
||||
|
||||
Returns:
|
||||
DataLoader: A PyTorch dataloader.
|
||||
"""
|
||||
rank, world_size = get_dist_info()
|
||||
if dist:
|
||||
sampler = DistributedSampler(
|
||||
dataset, world_size, rank, shuffle=shuffle, seed=seed)
|
||||
shuffle = False
|
||||
batch_size = samples_per_gpu
|
||||
num_workers = workers_per_gpu
|
||||
else:
|
||||
sampler = None
|
||||
batch_size = num_gpus * samples_per_gpu
|
||||
num_workers = num_gpus * workers_per_gpu
|
||||
|
||||
init_fn = partial(
|
||||
worker_init_fn, num_workers=num_workers, rank=rank,
|
||||
seed=seed) if seed is not None else None
|
||||
|
||||
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
||||
pin_memory=pin_memory,
|
||||
shuffle=shuffle,
|
||||
worker_init_fn=init_fn,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=persistent_workers,
|
||||
**kwargs)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
||||
pin_memory=pin_memory,
|
||||
shuffle=shuffle,
|
||||
worker_init_fn=init_fn,
|
||||
drop_last=drop_last,
|
||||
**kwargs)
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||
"""Worker init func for dataloader.
|
||||
|
||||
The seed of each worker equals to num_worker * rank + worker_id + user_seed
|
||||
|
||||
Args:
|
||||
worker_id (int): Worker id.
|
||||
num_workers (int): Number of workers.
|
||||
rank (int): The rank of current process.
|
||||
seed (int): The random seed to use.
|
||||
"""
|
||||
|
||||
worker_seed = num_workers * rank + worker_id + seed
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
|
@ -1,195 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import bisect
|
||||
import collections
|
||||
import copy
|
||||
from itertools import chain
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
from mmengine.dataset import ConcatDataset, force_full_init
|
||||
|
||||
from mmseg.registry import DATASETS, TRANSFORMS
|
||||
from .cityscapes import CityscapesDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ConcatDataset(_ConcatDataset):
|
||||
"""A wrapper of concatenated dataset.
|
||||
|
||||
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
|
||||
support evaluation and formatting results
|
||||
|
||||
Args:
|
||||
datasets (list[:obj:`Dataset`]): A list of datasets.
|
||||
separate_eval (bool): Whether to evaluate the concatenated
|
||||
dataset results separately, Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, separate_eval=True):
|
||||
super(ConcatDataset, self).__init__(datasets)
|
||||
self.CLASSES = datasets[0].CLASSES
|
||||
self.PALETTE = datasets[0].PALETTE
|
||||
self.separate_eval = separate_eval
|
||||
assert separate_eval in [True, False], \
|
||||
f'separate_eval can only be True or False,' \
|
||||
f'but get {separate_eval}'
|
||||
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
|
||||
raise NotImplementedError(
|
||||
'Evaluating ConcatDataset containing CityscapesDataset'
|
||||
'is not supported!')
|
||||
|
||||
def evaluate(self, results, logger=None, **kwargs):
|
||||
"""Evaluate the results.
|
||||
|
||||
Args:
|
||||
results (list[tuple[torch.Tensor]] | list[str]]): per image
|
||||
pre_eval results or predict segmentation map for
|
||||
computing evaluation metric.
|
||||
logger (logging.Logger | str | None): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
|
||||
Returns:
|
||||
dict[str: float]: evaluate results of the total dataset
|
||||
or each separate
|
||||
dataset if `self.separate_eval=True`.
|
||||
"""
|
||||
assert len(results) == self.cumulative_sizes[-1], \
|
||||
('Dataset and results have different sizes: '
|
||||
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
|
||||
|
||||
# Check whether all the datasets support evaluation
|
||||
for dataset in self.datasets:
|
||||
assert hasattr(dataset, 'evaluate'), \
|
||||
f'{type(dataset)} does not implement evaluate function'
|
||||
|
||||
if self.separate_eval:
|
||||
dataset_idx = -1
|
||||
total_eval_results = dict()
|
||||
for size, dataset in zip(self.cumulative_sizes, self.datasets):
|
||||
start_idx = 0 if dataset_idx == -1 else \
|
||||
self.cumulative_sizes[dataset_idx]
|
||||
end_idx = self.cumulative_sizes[dataset_idx + 1]
|
||||
|
||||
results_per_dataset = results[start_idx:end_idx]
|
||||
print_log(
|
||||
f'\nEvaluateing {dataset.img_dir} with '
|
||||
f'{len(results_per_dataset)} images now',
|
||||
logger=logger)
|
||||
|
||||
eval_results_per_dataset = dataset.evaluate(
|
||||
results_per_dataset, logger=logger, **kwargs)
|
||||
dataset_idx += 1
|
||||
for k, v in eval_results_per_dataset.items():
|
||||
total_eval_results.update({f'{dataset_idx}_{k}': v})
|
||||
|
||||
return total_eval_results
|
||||
|
||||
if len(set([type(ds) for ds in self.datasets])) != 1:
|
||||
raise NotImplementedError(
|
||||
'All the datasets should have same types when '
|
||||
'self.separate_eval=False')
|
||||
else:
|
||||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
|
||||
results, str):
|
||||
# merge the generators of gt_seg_maps
|
||||
gt_seg_maps = chain(
|
||||
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
|
||||
else:
|
||||
# if the results are `pre_eval` results,
|
||||
# we do not need gt_seg_maps to evaluate
|
||||
gt_seg_maps = None
|
||||
eval_results = self.datasets[0].evaluate(
|
||||
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
|
||||
return eval_results
|
||||
|
||||
def get_dataset_idx_and_sample_idx(self, indice):
|
||||
"""Return dataset and sample index when given an indice of
|
||||
ConcatDataset.
|
||||
|
||||
Args:
|
||||
indice (int): indice of sample in ConcatDataset
|
||||
|
||||
Returns:
|
||||
int: the index of sub dataset the sample belong to
|
||||
int: the index of sample in its corresponding subset
|
||||
"""
|
||||
if indice < 0:
|
||||
if -indice > len(self):
|
||||
raise ValueError(
|
||||
'absolute value of index should not exceed dataset length')
|
||||
indice = len(self) + indice
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
|
||||
if dataset_idx == 0:
|
||||
sample_idx = indice
|
||||
else:
|
||||
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
|
||||
return dataset_idx, sample_idx
|
||||
|
||||
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
|
||||
"""format result for every sample of ConcatDataset."""
|
||||
if indices is None:
|
||||
indices = list(range(len(self)))
|
||||
|
||||
assert isinstance(results, list), 'results must be a list.'
|
||||
assert isinstance(indices, list), 'indices must be a list.'
|
||||
|
||||
ret_res = []
|
||||
for i, indice in enumerate(indices):
|
||||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
|
||||
indice)
|
||||
res = self.datasets[dataset_idx].format_results(
|
||||
[results[i]],
|
||||
imgfile_prefix + f'/{dataset_idx}',
|
||||
indices=[sample_idx],
|
||||
**kwargs)
|
||||
ret_res.append(res)
|
||||
return sum(ret_res, [])
|
||||
|
||||
def pre_eval(self, preds, indices):
|
||||
"""do pre eval for every sample of ConcatDataset."""
|
||||
# In order to compat with batch inference
|
||||
if not isinstance(indices, list):
|
||||
indices = [indices]
|
||||
if not isinstance(preds, list):
|
||||
preds = [preds]
|
||||
ret_res = []
|
||||
for i, indice in enumerate(indices):
|
||||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
|
||||
indice)
|
||||
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
|
||||
ret_res.append(res)
|
||||
return sum(ret_res, [])
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class RepeatDataset(object):
|
||||
"""A wrapper of repeated dataset.
|
||||
|
||||
The length of repeated dataset will be `times` larger than the original
|
||||
dataset. This is useful when the data loading time is long but the dataset
|
||||
is small. Using RepeatDataset can reduce the data loading time between
|
||||
epochs.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`Dataset`): The dataset to be repeated.
|
||||
times (int): Repeat times.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, times):
|
||||
self.dataset = dataset
|
||||
self.times = times
|
||||
self.CLASSES = dataset.CLASSES
|
||||
self.PALETTE = dataset.PALETTE
|
||||
self._ori_len = len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get item from original dataset."""
|
||||
return self.dataset[idx % self._ori_len]
|
||||
|
||||
def __len__(self):
|
||||
"""The length is multiplied by ``times``"""
|
||||
return self.times * self._ori_len
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -197,22 +13,32 @@ class MultiImageMixDataset:
|
|||
"""A wrapper of multiple images mixed dataset.
|
||||
|
||||
Suitable for training on multiple images mixed data augmentation like
|
||||
mosaic and mixup. For the augmentation pipeline of mixed image data,
|
||||
the `get_indexes` method needs to be provided to obtain the image
|
||||
indexes, and you can set `skip_flags` to change the pipeline running
|
||||
process.
|
||||
|
||||
mosaic and mixup.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`CustomDataset`): The dataset to be mixed.
|
||||
dataset (ConcatDataset or dict): The dataset to be mixed.
|
||||
pipeline (Sequence[dict]): Sequence of transform object or
|
||||
config dict to be composed.
|
||||
skip_type_keys (list[str], optional): Sequence of type string to
|
||||
be skip pipeline. Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, pipeline, skip_type_keys=None):
|
||||
def __init__(self,
|
||||
dataset: Union[ConcatDataset, dict],
|
||||
pipeline: Sequence[dict],
|
||||
skip_type_keys: Optional[List[str]] = None,
|
||||
lazy_init: bool = False) -> None:
|
||||
assert isinstance(pipeline, collections.abc.Sequence)
|
||||
|
||||
if isinstance(dataset, dict):
|
||||
self.dataset = DATASETS.build(dataset)
|
||||
elif isinstance(dataset, ConcatDataset):
|
||||
self.dataset = dataset
|
||||
else:
|
||||
raise TypeError(
|
||||
'elements in datasets sequence should be config or '
|
||||
f'`ConcatDataset` instance, but got {type(dataset)}')
|
||||
|
||||
if skip_type_keys is not None:
|
||||
assert all([
|
||||
isinstance(skip_type_key, str)
|
||||
|
@ -230,11 +56,44 @@ class MultiImageMixDataset:
|
|||
else:
|
||||
raise TypeError('pipeline must be a dict')
|
||||
|
||||
self.dataset = dataset
|
||||
self.CLASSES = dataset.CLASSES
|
||||
self.PALETTE = dataset.PALETTE
|
||||
self.num_samples = len(dataset)
|
||||
self._metainfo = self.dataset.metainfo
|
||||
self.num_samples = len(self.dataset)
|
||||
|
||||
self._fully_initialized = False
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
@property
|
||||
def metainfo(self) -> dict:
|
||||
"""Get the meta information of the multi-image-mixed dataset.
|
||||
|
||||
Returns:
|
||||
dict: The meta information of multi-image-mixed dataset.
|
||||
"""
|
||||
return copy.deepcopy(self._metainfo)
|
||||
|
||||
def full_init(self):
|
||||
"""Loop to ``full_init`` each dataset."""
|
||||
if self._fully_initialized:
|
||||
return
|
||||
|
||||
self.dataset.full_init()
|
||||
self._ori_len = len(self.dataset)
|
||||
self._fully_initialized = True
|
||||
|
||||
@force_full_init
|
||||
def get_data_info(self, idx: int) -> dict:
|
||||
"""Get annotation by index.
|
||||
|
||||
Args:
|
||||
idx (int): Global index of ``ConcatDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The idx-th annotation of the datasets.
|
||||
"""
|
||||
return self.dataset.get_data_info(idx)
|
||||
|
||||
@force_full_init
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
from mmengine.dataset import ConcatDataset, RepeatDataset
|
||||
|
||||
from mmseg.datasets import DATASETS, MultiImageMixDataset
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ToyDataset(object):
|
||||
|
||||
def __init__(self, cnt=0):
|
||||
self.cnt = cnt
|
||||
|
||||
def __item__(self, idx):
|
||||
return idx
|
||||
|
||||
def __len__(self):
|
||||
return 100
|
||||
|
||||
|
||||
def test_build_dataset():
|
||||
cfg = dict(type='ToyDataset')
|
||||
dataset = DATASETS.build(cfg)
|
||||
assert isinstance(dataset, ToyDataset)
|
||||
assert dataset.cnt == 0
|
||||
dataset = DATASETS.build(cfg, default_args=dict(cnt=1))
|
||||
assert isinstance(dataset, ToyDataset)
|
||||
assert dataset.cnt == 1
|
||||
|
||||
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
|
||||
data_prefix = dict(img_path='imgs/', seg_map_path='gts/')
|
||||
|
||||
# test RepeatDataset
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
serialize_data=False)
|
||||
dataset = DATASETS.build(cfg)
|
||||
dataset_repeat = RepeatDataset(dataset=dataset, times=5)
|
||||
assert isinstance(dataset_repeat, RepeatDataset)
|
||||
assert len(dataset_repeat) == 25
|
||||
|
||||
# test ConcatDataset
|
||||
# We use same dir twice for simplicity
|
||||
# with data_prefix.seg_map_path
|
||||
cfg1 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
serialize_data=False)
|
||||
cfg2 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
serialize_data=False)
|
||||
dataset1 = DATASETS.build(cfg1)
|
||||
dataset2 = DATASETS.build(cfg2)
|
||||
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||
assert isinstance(dataset_concat, ConcatDataset)
|
||||
assert len(dataset_concat) == 10
|
||||
|
||||
# test MultiImageMixDataset
|
||||
dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[])
|
||||
assert isinstance(dataset, MultiImageMixDataset)
|
||||
assert len(dataset) == 10
|
||||
|
||||
cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2])
|
||||
|
||||
dataset = MultiImageMixDataset(dataset=cfg, pipeline=[])
|
||||
assert isinstance(dataset, MultiImageMixDataset)
|
||||
assert len(dataset) == 10
|
||||
|
||||
# with data_prefix.seg_map_path, ann_file
|
||||
cfg1 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
ann_file='splits/train.txt',
|
||||
serialize_data=False)
|
||||
cfg2 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
ann_file='splits/val.txt',
|
||||
serialize_data=False)
|
||||
|
||||
dataset1 = DATASETS.build(cfg1)
|
||||
dataset2 = DATASETS.build(cfg2)
|
||||
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||
assert isinstance(dataset_concat, ConcatDataset)
|
||||
assert len(dataset_concat) == 5
|
||||
|
||||
# test mode
|
||||
cfg1 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='imgs/'),
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )),
|
||||
serialize_data=False)
|
||||
cfg2 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='imgs/'),
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )),
|
||||
serialize_data=False)
|
||||
|
||||
dataset1 = DATASETS.build(cfg1)
|
||||
dataset2 = DATASETS.build(cfg2)
|
||||
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||
assert isinstance(dataset_concat, ConcatDataset)
|
||||
assert len(dataset_concat) == 10
|
||||
|
||||
# test mode with ann_files
|
||||
cfg1 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='imgs/'),
|
||||
ann_file='splits/val.txt',
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )),
|
||||
serialize_data=False)
|
||||
cfg2 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
data_prefix=dict(img_path='imgs/'),
|
||||
ann_file='splits/val.txt',
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )),
|
||||
serialize_data=False)
|
||||
|
||||
dataset1 = DATASETS.build(cfg1)
|
||||
dataset2 = DATASETS.build(cfg2)
|
||||
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||
assert isinstance(dataset_concat, ConcatDataset)
|
||||
assert len(dataset_concat) == 2
|
|
@ -25,7 +25,7 @@ class TestLoading(object):
|
|||
assert results['img'].dtype == np.uint8
|
||||
assert results['ori_shape'] == results['img'].shape[:2]
|
||||
assert repr(transform) == transform.__class__.__name__ + \
|
||||
"(to_float32=False, color_type='color'," + \
|
||||
"(ignore_empty=False, to_float32=False, color_type='color'," + \
|
||||
" imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
|
||||
|
||||
# to_float32
|
||||
|
|
|
@ -8,7 +8,7 @@ import mmcv
|
|||
import numpy as np
|
||||
from mmcv import Config, DictAction
|
||||
|
||||
from mmseg.datasets.builder import build_dataset
|
||||
from mmseg.datasets import DATASETS
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -159,7 +159,7 @@ def main():
|
|||
args = parse_args()
|
||||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
|
||||
args.show_origin)
|
||||
dataset = build_dataset(cfg.data.train)
|
||||
dataset = DATASETS.build(cfg.data.train)
|
||||
progress_bar = mmcv.ProgressBar(len(dataset))
|
||||
for item in dataset:
|
||||
filename = os.path.join(args.output_dir,
|
||||
|
|
Loading…
Reference in New Issue