[Feature] Add MultiImageMixDataset (#1105)
* Fix typo in usage example * original MultiImageMixDataset code in mmdet * Add MultiImageMixDataset unittests in test_dataset_wrapper * fix lint error * fix value name ann_file to ann_dir * modify retrieve_data_cfg (#1) * remove dynamic_scale & add palette * modify retrieve_data_cfg method * modify retrieve_data_cfg func * fix error * improve the unittests coverage * fix unittests error * Dataset (#2) * add cfg-options * Add unittest in test_build_dataset * add blank line * add blank line * add a blank line Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: Younghoon-Lee <72462227+Younghoon-Lee@users.noreply.github.com> Co-authored-by: MeowZheng <meowzheng@outlook.com> Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>pull/1201/head
parent
f0262fa68e
commit
6c3e63e48b
|
@ -6,7 +6,8 @@ from .cityscapes import CityscapesDataset
|
||||||
from .coco_stuff import COCOStuffDataset
|
from .coco_stuff import COCOStuffDataset
|
||||||
from .custom import CustomDataset
|
from .custom import CustomDataset
|
||||||
from .dark_zurich import DarkZurichDataset
|
from .dark_zurich import DarkZurichDataset
|
||||||
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
|
||||||
|
RepeatDataset)
|
||||||
from .drive import DRIVEDataset
|
from .drive import DRIVEDataset
|
||||||
from .hrf import HRFDataset
|
from .hrf import HRFDataset
|
||||||
from .loveda import LoveDADataset
|
from .loveda import LoveDADataset
|
||||||
|
@ -21,5 +22,5 @@ __all__ = [
|
||||||
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
||||||
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
|
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
|
||||||
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
|
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
|
||||||
'COCOStuffDataset', 'LoveDADataset'
|
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset'
|
||||||
]
|
]
|
||||||
|
|
|
@ -64,12 +64,18 @@ def _concat_dataset(cfg, default_args=None):
|
||||||
|
|
||||||
def build_dataset(cfg, default_args=None):
|
def build_dataset(cfg, default_args=None):
|
||||||
"""Build datasets."""
|
"""Build datasets."""
|
||||||
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
|
||||||
|
MultiImageMixDataset)
|
||||||
if isinstance(cfg, (list, tuple)):
|
if isinstance(cfg, (list, tuple)):
|
||||||
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
||||||
elif cfg['type'] == 'RepeatDataset':
|
elif cfg['type'] == 'RepeatDataset':
|
||||||
dataset = RepeatDataset(
|
dataset = RepeatDataset(
|
||||||
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
||||||
|
elif cfg['type'] == 'MultiImageMixDataset':
|
||||||
|
cp_cfg = copy.deepcopy(cfg)
|
||||||
|
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
|
||||||
|
cp_cfg.pop('type')
|
||||||
|
dataset = MultiImageMixDataset(**cp_cfg)
|
||||||
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
|
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
|
||||||
cfg.get('split', None), (list, tuple)):
|
cfg.get('split', None), (list, tuple)):
|
||||||
dataset = _concat_dataset(cfg, default_args)
|
dataset = _concat_dataset(cfg, default_args)
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import bisect
|
import bisect
|
||||||
|
import collections
|
||||||
|
import copy
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.utils import print_log
|
from mmcv.utils import build_from_cfg, print_log
|
||||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||||
|
|
||||||
from .builder import DATASETS
|
from .builder import DATASETS, PIPELINES
|
||||||
from .cityscapes import CityscapesDataset
|
from .cityscapes import CityscapesDataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,3 +190,88 @@ class RepeatDataset(object):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""The length is multiplied by ``times``"""
|
"""The length is multiplied by ``times``"""
|
||||||
return self.times * self._ori_len
|
return self.times * self._ori_len
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class MultiImageMixDataset:
|
||||||
|
"""A wrapper of multiple images mixed dataset.
|
||||||
|
|
||||||
|
Suitable for training on multiple images mixed data augmentation like
|
||||||
|
mosaic and mixup. For the augmentation pipeline of mixed image data,
|
||||||
|
the `get_indexes` method needs to be provided to obtain the image
|
||||||
|
indexes, and you can set `skip_flags` to change the pipeline running
|
||||||
|
process.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (:obj:`CustomDataset`): The dataset to be mixed.
|
||||||
|
pipeline (Sequence[dict]): Sequence of transform object or
|
||||||
|
config dict to be composed.
|
||||||
|
skip_type_keys (list[str], optional): Sequence of type string to
|
||||||
|
be skip pipeline. Default to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, pipeline, skip_type_keys=None):
|
||||||
|
assert isinstance(pipeline, collections.abc.Sequence)
|
||||||
|
if skip_type_keys is not None:
|
||||||
|
assert all([
|
||||||
|
isinstance(skip_type_key, str)
|
||||||
|
for skip_type_key in skip_type_keys
|
||||||
|
])
|
||||||
|
self._skip_type_keys = skip_type_keys
|
||||||
|
|
||||||
|
self.pipeline = []
|
||||||
|
self.pipeline_types = []
|
||||||
|
for transform in pipeline:
|
||||||
|
if isinstance(transform, dict):
|
||||||
|
self.pipeline_types.append(transform['type'])
|
||||||
|
transform = build_from_cfg(transform, PIPELINES)
|
||||||
|
self.pipeline.append(transform)
|
||||||
|
else:
|
||||||
|
raise TypeError('pipeline must be a dict')
|
||||||
|
|
||||||
|
self.dataset = dataset
|
||||||
|
self.CLASSES = dataset.CLASSES
|
||||||
|
self.PALETTE = dataset.PALETTE
|
||||||
|
self.num_samples = len(dataset)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
results = copy.deepcopy(self.dataset[idx])
|
||||||
|
for (transform, transform_type) in zip(self.pipeline,
|
||||||
|
self.pipeline_types):
|
||||||
|
if self._skip_type_keys is not None and \
|
||||||
|
transform_type in self._skip_type_keys:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hasattr(transform, 'get_indexes'):
|
||||||
|
indexes = transform.get_indexes(self.dataset)
|
||||||
|
if not isinstance(indexes, collections.abc.Sequence):
|
||||||
|
indexes = [indexes]
|
||||||
|
mix_results = [
|
||||||
|
copy.deepcopy(self.dataset[index]) for index in indexes
|
||||||
|
]
|
||||||
|
results['mix_results'] = mix_results
|
||||||
|
|
||||||
|
results = transform(results)
|
||||||
|
|
||||||
|
if 'mix_results' in results:
|
||||||
|
results.pop('mix_results')
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def update_skip_type_keys(self, skip_type_keys):
|
||||||
|
"""Update skip_type_keys.
|
||||||
|
|
||||||
|
It is called by an external hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip_type_keys (list[str], optional): Sequence of type
|
||||||
|
string to be skip pipeline.
|
||||||
|
"""
|
||||||
|
assert all([
|
||||||
|
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
|
||||||
|
])
|
||||||
|
self._skip_type_keys = skip_type_keys
|
||||||
|
|
|
@ -14,7 +14,8 @@ from PIL import Image
|
||||||
from mmseg.core.evaluation import get_classes, get_palette
|
from mmseg.core.evaluation import get_classes, get_palette
|
||||||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
||||||
ConcatDataset, CustomDataset, LoveDADataset,
|
ConcatDataset, CustomDataset, LoveDADataset,
|
||||||
PascalVOCDataset, RepeatDataset, build_dataset)
|
MultiImageMixDataset, PascalVOCDataset,
|
||||||
|
RepeatDataset, build_dataset)
|
||||||
|
|
||||||
|
|
||||||
def test_classes():
|
def test_classes():
|
||||||
|
@ -95,6 +96,66 @@ def test_dataset_wrapper():
|
||||||
assert repeat_dataset[27] == 7
|
assert repeat_dataset[27] == 7
|
||||||
assert len(repeat_dataset) == 10 * len(dataset_a)
|
assert len(repeat_dataset) == 10 * len(dataset_a)
|
||||||
|
|
||||||
|
img_scale = (60, 60)
|
||||||
|
pipeline = [
|
||||||
|
# dict(type='Mosaic', img_scale=img_scale, pad_val=255),
|
||||||
|
# need to merge mosaic
|
||||||
|
dict(type='RandomFlip', prob=0.5),
|
||||||
|
dict(type='Resize', img_scale=img_scale, keep_ratio=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
CustomDataset.load_annotations = MagicMock()
|
||||||
|
results = []
|
||||||
|
for _ in range(2):
|
||||||
|
height = np.random.randint(10, 30)
|
||||||
|
weight = np.random.randint(10, 30)
|
||||||
|
img = np.ones((height, weight, 3))
|
||||||
|
gt_semantic_seg = np.random.randint(5, size=(height, weight))
|
||||||
|
results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img))
|
||||||
|
|
||||||
|
classes = ['0', '1', '2', '3', '4']
|
||||||
|
palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)]
|
||||||
|
CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx])
|
||||||
|
dataset_a = CustomDataset(
|
||||||
|
img_dir=MagicMock(),
|
||||||
|
pipeline=[],
|
||||||
|
test_mode=True,
|
||||||
|
classes=classes,
|
||||||
|
palette=palette)
|
||||||
|
len_a = 2
|
||||||
|
cat_ids_list_a = [
|
||||||
|
np.random.randint(0, 80, num).tolist()
|
||||||
|
for num in np.random.randint(1, 20, len_a)
|
||||||
|
]
|
||||||
|
dataset_a.data_infos = MagicMock()
|
||||||
|
dataset_a.data_infos.__len__.return_value = len_a
|
||||||
|
dataset_a.get_cat_ids = MagicMock(
|
||||||
|
side_effect=lambda idx: cat_ids_list_a[idx])
|
||||||
|
|
||||||
|
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
|
||||||
|
assert len(multi_image_mix_dataset) == len(dataset_a)
|
||||||
|
|
||||||
|
for idx in range(len_a):
|
||||||
|
results_ = multi_image_mix_dataset[idx]
|
||||||
|
|
||||||
|
# test skip_type_keys
|
||||||
|
multi_image_mix_dataset = MultiImageMixDataset(
|
||||||
|
dataset_a, pipeline, skip_type_keys=('RandomFlip'))
|
||||||
|
for idx in range(len_a):
|
||||||
|
results_ = multi_image_mix_dataset[idx]
|
||||||
|
assert results_['img'].shape == (img_scale[0], img_scale[1], 3)
|
||||||
|
|
||||||
|
skip_type_keys = ('RandomFlip', 'Resize')
|
||||||
|
multi_image_mix_dataset.update_skip_type_keys(skip_type_keys)
|
||||||
|
for idx in range(len_a):
|
||||||
|
results_ = multi_image_mix_dataset[idx]
|
||||||
|
assert results_['img'].shape[:2] != img_scale
|
||||||
|
|
||||||
|
# test pipeline
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
pipeline = [['Resize']]
|
||||||
|
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
|
||||||
|
|
||||||
|
|
||||||
def test_custom_dataset():
|
def test_custom_dataset():
|
||||||
img_norm_cfg = dict(
|
img_norm_cfg = dict(
|
||||||
|
|
|
@ -6,8 +6,8 @@ import pytest
|
||||||
from torch.utils.data import (DistributedSampler, RandomSampler,
|
from torch.utils.data import (DistributedSampler, RandomSampler,
|
||||||
SequentialSampler)
|
SequentialSampler)
|
||||||
|
|
||||||
from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader,
|
from mmseg.datasets import (DATASETS, ConcatDataset, MultiImageMixDataset,
|
||||||
build_dataset)
|
build_dataloader, build_dataset)
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
|
@ -48,6 +48,11 @@ def test_build_dataset():
|
||||||
assert isinstance(dataset, ConcatDataset)
|
assert isinstance(dataset, ConcatDataset)
|
||||||
assert len(dataset) == 10
|
assert len(dataset) == 10
|
||||||
|
|
||||||
|
cfg = dict(type='MultiImageMixDataset', dataset=cfg, pipeline=[])
|
||||||
|
dataset = build_dataset(cfg)
|
||||||
|
assert isinstance(dataset, MultiImageMixDataset)
|
||||||
|
assert len(dataset) == 10
|
||||||
|
|
||||||
# with ann_dir, split
|
# with ann_dir, split
|
||||||
cfg = dict(
|
cfg = dict(
|
||||||
type='CustomDataset',
|
type='CustomDataset',
|
||||||
|
|
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv import Config
|
from mmcv import Config, DictAction
|
||||||
|
|
||||||
from mmseg.datasets.builder import build_dataset
|
from mmseg.datasets.builder import build_dataset
|
||||||
|
|
||||||
|
@ -42,6 +42,16 @@ def parse_args():
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help='the opacity of semantic map')
|
help='the opacity of semantic map')
|
||||||
|
parser.add_argument(
|
||||||
|
'--cfg-options',
|
||||||
|
nargs='+',
|
||||||
|
action=DictAction,
|
||||||
|
help='override some settings in the used config, the key-value pair '
|
||||||
|
'in xxx=yyy format will be merged into config file. If the value to '
|
||||||
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||||
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||||
|
'Note that the quotation marks are necessary and that no white space '
|
||||||
|
'is allowed.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@ -122,28 +132,32 @@ def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def retrieve_data_cfg(config_path, skip_type, show_origin=False):
|
def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False):
|
||||||
cfg = Config.fromfile(config_path)
|
cfg = Config.fromfile(config_path)
|
||||||
|
if cfg_options is not None:
|
||||||
|
cfg.merge_from_dict(cfg_options)
|
||||||
train_data_cfg = cfg.data.train
|
train_data_cfg = cfg.data.train
|
||||||
if isinstance(train_data_cfg, list):
|
if isinstance(train_data_cfg, list):
|
||||||
for _data_cfg in train_data_cfg:
|
for _data_cfg in train_data_cfg:
|
||||||
|
while 'dataset' in _data_cfg and _data_cfg[
|
||||||
|
'type'] != 'MultiImageMixDataset':
|
||||||
|
_data_cfg = _data_cfg['dataset']
|
||||||
if 'pipeline' in _data_cfg:
|
if 'pipeline' in _data_cfg:
|
||||||
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
|
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
|
||||||
elif 'dataset' in _data_cfg:
|
|
||||||
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
|
|
||||||
show_origin)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
elif 'dataset' in train_data_cfg:
|
|
||||||
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
|
|
||||||
else:
|
else:
|
||||||
|
while 'dataset' in train_data_cfg and train_data_cfg[
|
||||||
|
'type'] != 'MultiImageMixDataset':
|
||||||
|
train_data_cfg = train_data_cfg['dataset']
|
||||||
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
|
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
|
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
|
||||||
|
args.show_origin)
|
||||||
dataset = build_dataset(cfg.data.train)
|
dataset = build_dataset(cfg.data.train)
|
||||||
progress_bar = mmcv.ProgressBar(len(dataset))
|
progress_bar = mmcv.ProgressBar(len(dataset))
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
|
|
Loading…
Reference in New Issue