[Feature] Add MultiImageMixDataset (#1105)
* Fix typo in usage example * original MultiImageMixDataset code in mmdet * Add MultiImageMixDataset unittests in test_dataset_wrapper * fix lint error * fix value name ann_file to ann_dir * modify retrieve_data_cfg (#1) * remove dynamic_scale & add palette * modify retrieve_data_cfg method * modify retrieve_data_cfg func * fix error * improve the unittests coverage * fix unittests error * Dataset (#2) * add cfg-options * Add unittest in test_build_dataset * add blank line * add blank line * add a blank line Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: Younghoon-Lee <72462227+Younghoon-Lee@users.noreply.github.com> Co-authored-by: MeowZheng <meowzheng@outlook.com> Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>pull/1201/head
parent
f0262fa68e
commit
6c3e63e48b
|
@ -6,7 +6,8 @@ from .cityscapes import CityscapesDataset
|
|||
from .coco_stuff import COCOStuffDataset
|
||||
from .custom import CustomDataset
|
||||
from .dark_zurich import DarkZurichDataset
|
||||
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
||||
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
|
||||
RepeatDataset)
|
||||
from .drive import DRIVEDataset
|
||||
from .hrf import HRFDataset
|
||||
from .loveda import LoveDADataset
|
||||
|
@ -21,5 +22,5 @@ __all__ = [
|
|||
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
||||
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
|
||||
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
|
||||
'COCOStuffDataset', 'LoveDADataset'
|
||||
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset'
|
||||
]
|
||||
|
|
|
@ -64,12 +64,18 @@ def _concat_dataset(cfg, default_args=None):
|
|||
|
||||
def build_dataset(cfg, default_args=None):
|
||||
"""Build datasets."""
|
||||
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
||||
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
|
||||
MultiImageMixDataset)
|
||||
if isinstance(cfg, (list, tuple)):
|
||||
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
||||
elif cfg['type'] == 'RepeatDataset':
|
||||
dataset = RepeatDataset(
|
||||
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
||||
elif cfg['type'] == 'MultiImageMixDataset':
|
||||
cp_cfg = copy.deepcopy(cfg)
|
||||
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
|
||||
cp_cfg.pop('type')
|
||||
dataset = MultiImageMixDataset(**cp_cfg)
|
||||
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
|
||||
cfg.get('split', None), (list, tuple)):
|
||||
dataset = _concat_dataset(cfg, default_args)
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import bisect
|
||||
import collections
|
||||
import copy
|
||||
from itertools import chain
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from mmcv.utils import build_from_cfg, print_log
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
|
||||
from .builder import DATASETS
|
||||
from .builder import DATASETS, PIPELINES
|
||||
from .cityscapes import CityscapesDataset
|
||||
|
||||
|
||||
|
@ -188,3 +190,88 @@ class RepeatDataset(object):
|
|||
def __len__(self):
|
||||
"""The length is multiplied by ``times``"""
|
||||
return self.times * self._ori_len
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MultiImageMixDataset:
|
||||
"""A wrapper of multiple images mixed dataset.
|
||||
|
||||
Suitable for training on multiple images mixed data augmentation like
|
||||
mosaic and mixup. For the augmentation pipeline of mixed image data,
|
||||
the `get_indexes` method needs to be provided to obtain the image
|
||||
indexes, and you can set `skip_flags` to change the pipeline running
|
||||
process.
|
||||
|
||||
|
||||
Args:
|
||||
dataset (:obj:`CustomDataset`): The dataset to be mixed.
|
||||
pipeline (Sequence[dict]): Sequence of transform object or
|
||||
config dict to be composed.
|
||||
skip_type_keys (list[str], optional): Sequence of type string to
|
||||
be skip pipeline. Default to None.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, pipeline, skip_type_keys=None):
|
||||
assert isinstance(pipeline, collections.abc.Sequence)
|
||||
if skip_type_keys is not None:
|
||||
assert all([
|
||||
isinstance(skip_type_key, str)
|
||||
for skip_type_key in skip_type_keys
|
||||
])
|
||||
self._skip_type_keys = skip_type_keys
|
||||
|
||||
self.pipeline = []
|
||||
self.pipeline_types = []
|
||||
for transform in pipeline:
|
||||
if isinstance(transform, dict):
|
||||
self.pipeline_types.append(transform['type'])
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
self.pipeline.append(transform)
|
||||
else:
|
||||
raise TypeError('pipeline must be a dict')
|
||||
|
||||
self.dataset = dataset
|
||||
self.CLASSES = dataset.CLASSES
|
||||
self.PALETTE = dataset.PALETTE
|
||||
self.num_samples = len(dataset)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = copy.deepcopy(self.dataset[idx])
|
||||
for (transform, transform_type) in zip(self.pipeline,
|
||||
self.pipeline_types):
|
||||
if self._skip_type_keys is not None and \
|
||||
transform_type in self._skip_type_keys:
|
||||
continue
|
||||
|
||||
if hasattr(transform, 'get_indexes'):
|
||||
indexes = transform.get_indexes(self.dataset)
|
||||
if not isinstance(indexes, collections.abc.Sequence):
|
||||
indexes = [indexes]
|
||||
mix_results = [
|
||||
copy.deepcopy(self.dataset[index]) for index in indexes
|
||||
]
|
||||
results['mix_results'] = mix_results
|
||||
|
||||
results = transform(results)
|
||||
|
||||
if 'mix_results' in results:
|
||||
results.pop('mix_results')
|
||||
|
||||
return results
|
||||
|
||||
def update_skip_type_keys(self, skip_type_keys):
|
||||
"""Update skip_type_keys.
|
||||
|
||||
It is called by an external hook.
|
||||
|
||||
Args:
|
||||
skip_type_keys (list[str], optional): Sequence of type
|
||||
string to be skip pipeline.
|
||||
"""
|
||||
assert all([
|
||||
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
|
||||
])
|
||||
self._skip_type_keys = skip_type_keys
|
||||
|
|
|
@ -14,7 +14,8 @@ from PIL import Image
|
|||
from mmseg.core.evaluation import get_classes, get_palette
|
||||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
||||
ConcatDataset, CustomDataset, LoveDADataset,
|
||||
PascalVOCDataset, RepeatDataset, build_dataset)
|
||||
MultiImageMixDataset, PascalVOCDataset,
|
||||
RepeatDataset, build_dataset)
|
||||
|
||||
|
||||
def test_classes():
|
||||
|
@ -95,6 +96,66 @@ def test_dataset_wrapper():
|
|||
assert repeat_dataset[27] == 7
|
||||
assert len(repeat_dataset) == 10 * len(dataset_a)
|
||||
|
||||
img_scale = (60, 60)
|
||||
pipeline = [
|
||||
# dict(type='Mosaic', img_scale=img_scale, pad_val=255),
|
||||
# need to merge mosaic
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='Resize', img_scale=img_scale, keep_ratio=False),
|
||||
]
|
||||
|
||||
CustomDataset.load_annotations = MagicMock()
|
||||
results = []
|
||||
for _ in range(2):
|
||||
height = np.random.randint(10, 30)
|
||||
weight = np.random.randint(10, 30)
|
||||
img = np.ones((height, weight, 3))
|
||||
gt_semantic_seg = np.random.randint(5, size=(height, weight))
|
||||
results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img))
|
||||
|
||||
classes = ['0', '1', '2', '3', '4']
|
||||
palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)]
|
||||
CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx])
|
||||
dataset_a = CustomDataset(
|
||||
img_dir=MagicMock(),
|
||||
pipeline=[],
|
||||
test_mode=True,
|
||||
classes=classes,
|
||||
palette=palette)
|
||||
len_a = 2
|
||||
cat_ids_list_a = [
|
||||
np.random.randint(0, 80, num).tolist()
|
||||
for num in np.random.randint(1, 20, len_a)
|
||||
]
|
||||
dataset_a.data_infos = MagicMock()
|
||||
dataset_a.data_infos.__len__.return_value = len_a
|
||||
dataset_a.get_cat_ids = MagicMock(
|
||||
side_effect=lambda idx: cat_ids_list_a[idx])
|
||||
|
||||
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
|
||||
assert len(multi_image_mix_dataset) == len(dataset_a)
|
||||
|
||||
for idx in range(len_a):
|
||||
results_ = multi_image_mix_dataset[idx]
|
||||
|
||||
# test skip_type_keys
|
||||
multi_image_mix_dataset = MultiImageMixDataset(
|
||||
dataset_a, pipeline, skip_type_keys=('RandomFlip'))
|
||||
for idx in range(len_a):
|
||||
results_ = multi_image_mix_dataset[idx]
|
||||
assert results_['img'].shape == (img_scale[0], img_scale[1], 3)
|
||||
|
||||
skip_type_keys = ('RandomFlip', 'Resize')
|
||||
multi_image_mix_dataset.update_skip_type_keys(skip_type_keys)
|
||||
for idx in range(len_a):
|
||||
results_ = multi_image_mix_dataset[idx]
|
||||
assert results_['img'].shape[:2] != img_scale
|
||||
|
||||
# test pipeline
|
||||
with pytest.raises(TypeError):
|
||||
pipeline = [['Resize']]
|
||||
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
|
||||
|
||||
|
||||
def test_custom_dataset():
|
||||
img_norm_cfg = dict(
|
||||
|
|
|
@ -6,8 +6,8 @@ import pytest
|
|||
from torch.utils.data import (DistributedSampler, RandomSampler,
|
||||
SequentialSampler)
|
||||
|
||||
from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader,
|
||||
build_dataset)
|
||||
from mmseg.datasets import (DATASETS, ConcatDataset, MultiImageMixDataset,
|
||||
build_dataloader, build_dataset)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -48,6 +48,11 @@ def test_build_dataset():
|
|||
assert isinstance(dataset, ConcatDataset)
|
||||
assert len(dataset) == 10
|
||||
|
||||
cfg = dict(type='MultiImageMixDataset', dataset=cfg, pipeline=[])
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, MultiImageMixDataset)
|
||||
assert len(dataset) == 10
|
||||
|
||||
# with ann_dir, split
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
|
|
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv import Config
|
||||
from mmcv import Config, DictAction
|
||||
|
||||
from mmseg.datasets.builder import build_dataset
|
||||
|
||||
|
@ -42,6 +42,16 @@ def parse_args():
|
|||
type=float,
|
||||
default=0.5,
|
||||
help='the opacity of semantic map')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
@ -122,28 +132,32 @@ def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
|
|||
]
|
||||
|
||||
|
||||
def retrieve_data_cfg(config_path, skip_type, show_origin=False):
|
||||
def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False):
|
||||
cfg = Config.fromfile(config_path)
|
||||
if cfg_options is not None:
|
||||
cfg.merge_from_dict(cfg_options)
|
||||
train_data_cfg = cfg.data.train
|
||||
if isinstance(train_data_cfg, list):
|
||||
for _data_cfg in train_data_cfg:
|
||||
while 'dataset' in _data_cfg and _data_cfg[
|
||||
'type'] != 'MultiImageMixDataset':
|
||||
_data_cfg = _data_cfg['dataset']
|
||||
if 'pipeline' in _data_cfg:
|
||||
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
|
||||
elif 'dataset' in _data_cfg:
|
||||
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
|
||||
show_origin)
|
||||
else:
|
||||
raise ValueError
|
||||
elif 'dataset' in train_data_cfg:
|
||||
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
|
||||
else:
|
||||
while 'dataset' in train_data_cfg and train_data_cfg[
|
||||
'type'] != 'MultiImageMixDataset':
|
||||
train_data_cfg = train_data_cfg['dataset']
|
||||
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
|
||||
return cfg
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
|
||||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
|
||||
args.show_origin)
|
||||
dataset = build_dataset(cfg.data.train)
|
||||
progress_bar = mmcv.ProgressBar(len(dataset))
|
||||
for item in dataset:
|
||||
|
|
Loading…
Reference in New Issue