137 lines
4.3 KiB
Python
137 lines
4.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import collections
|
|
import copy
|
|
from typing import List, Optional, Sequence, Union
|
|
|
|
from mmengine.dataset import ConcatDataset, force_full_init
|
|
|
|
from mmseg.registry import DATASETS, TRANSFORMS
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class MultiImageMixDataset:
|
|
"""A wrapper of multiple images mixed dataset.
|
|
|
|
Suitable for training on multiple images mixed data augmentation like
|
|
mosaic and mixup.
|
|
|
|
Args:
|
|
dataset (ConcatDataset or dict): The dataset to be mixed.
|
|
pipeline (Sequence[dict]): Sequence of transform object or
|
|
config dict to be composed.
|
|
skip_type_keys (list[str], optional): Sequence of type string to
|
|
be skip pipeline. Default to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
dataset: Union[ConcatDataset, dict],
|
|
pipeline: Sequence[dict],
|
|
skip_type_keys: Optional[List[str]] = None,
|
|
lazy_init: bool = False) -> None:
|
|
assert isinstance(pipeline, collections.abc.Sequence)
|
|
|
|
if isinstance(dataset, dict):
|
|
self.dataset = DATASETS.build(dataset)
|
|
elif isinstance(dataset, ConcatDataset):
|
|
self.dataset = dataset
|
|
else:
|
|
raise TypeError(
|
|
'elements in datasets sequence should be config or '
|
|
f'`ConcatDataset` instance, but got {type(dataset)}')
|
|
|
|
if skip_type_keys is not None:
|
|
assert all([
|
|
isinstance(skip_type_key, str)
|
|
for skip_type_key in skip_type_keys
|
|
])
|
|
self._skip_type_keys = skip_type_keys
|
|
|
|
self.pipeline = []
|
|
self.pipeline_types = []
|
|
for transform in pipeline:
|
|
if isinstance(transform, dict):
|
|
self.pipeline_types.append(transform['type'])
|
|
transform = TRANSFORMS.build(transform)
|
|
self.pipeline.append(transform)
|
|
else:
|
|
raise TypeError('pipeline must be a dict')
|
|
|
|
self._metainfo = self.dataset.metainfo
|
|
self.num_samples = len(self.dataset)
|
|
|
|
self._fully_initialized = False
|
|
if not lazy_init:
|
|
self.full_init()
|
|
|
|
@property
|
|
def metainfo(self) -> dict:
|
|
"""Get the meta information of the multi-image-mixed dataset.
|
|
|
|
Returns:
|
|
dict: The meta information of multi-image-mixed dataset.
|
|
"""
|
|
return copy.deepcopy(self._metainfo)
|
|
|
|
def full_init(self):
|
|
"""Loop to ``full_init`` each dataset."""
|
|
if self._fully_initialized:
|
|
return
|
|
|
|
self.dataset.full_init()
|
|
self._ori_len = len(self.dataset)
|
|
self._fully_initialized = True
|
|
|
|
@force_full_init
|
|
def get_data_info(self, idx: int) -> dict:
|
|
"""Get annotation by index.
|
|
|
|
Args:
|
|
idx (int): Global index of ``ConcatDataset``.
|
|
|
|
Returns:
|
|
dict: The idx-th annotation of the datasets.
|
|
"""
|
|
return self.dataset.get_data_info(idx)
|
|
|
|
@force_full_init
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
results = copy.deepcopy(self.dataset[idx])
|
|
for (transform, transform_type) in zip(self.pipeline,
|
|
self.pipeline_types):
|
|
if self._skip_type_keys is not None and \
|
|
transform_type in self._skip_type_keys:
|
|
continue
|
|
|
|
if hasattr(transform, 'get_indexes'):
|
|
indexes = transform.get_indexes(self.dataset)
|
|
if not isinstance(indexes, collections.abc.Sequence):
|
|
indexes = [indexes]
|
|
mix_results = [
|
|
copy.deepcopy(self.dataset[index]) for index in indexes
|
|
]
|
|
results['mix_results'] = mix_results
|
|
|
|
results = transform(results)
|
|
|
|
if 'mix_results' in results:
|
|
results.pop('mix_results')
|
|
|
|
return results
|
|
|
|
def update_skip_type_keys(self, skip_type_keys):
|
|
"""Update skip_type_keys.
|
|
|
|
It is called by an external hook.
|
|
|
|
Args:
|
|
skip_type_keys (list[str], optional): Sequence of type
|
|
string to be skip pipeline.
|
|
"""
|
|
assert all([
|
|
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
|
|
])
|
|
self._skip_type_keys = skip_type_keys
|