2021-08-17 14:16:55 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2022-05-26 17:13:40 +08:00
|
|
|
import copy
|
2020-07-07 20:52:19 +08:00
|
|
|
import os.path as osp
|
2022-05-26 17:13:40 +08:00
|
|
|
from typing import Callable, Dict, List, Optional, Sequence, Union
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-08-17 18:35:26 +08:00
|
|
|
import mmengine
|
2023-02-01 17:53:22 +08:00
|
|
|
import mmengine.fileio as fileio
|
2020-07-07 20:52:19 +08:00
|
|
|
import numpy as np
|
2022-05-26 17:13:40 +08:00
|
|
|
from mmengine.dataset import BaseDataset, Compose
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-05-10 20:15:20 +08:00
|
|
|
from mmseg.registry import DATASETS
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
|
|
|
|
@DATASETS.register_module()
|
2022-07-26 12:01:40 +08:00
|
|
|
class BaseSegDataset(BaseDataset):
|
2020-11-24 11:21:22 +08:00
|
|
|
"""Custom dataset for semantic segmentation. An example of file structure
|
|
|
|
is as followed.
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
.. code-block:: none
|
|
|
|
|
|
|
|
├── data
|
|
|
|
│ ├── my_dataset
|
|
|
|
│ │ ├── img_dir
|
|
|
|
│ │ │ ├── train
|
|
|
|
│ │ │ │ ├── xxx{img_suffix}
|
|
|
|
│ │ │ │ ├── yyy{img_suffix}
|
|
|
|
│ │ │ │ ├── zzz{img_suffix}
|
|
|
|
│ │ │ ├── val
|
|
|
|
│ │ ├── ann_dir
|
|
|
|
│ │ │ ├── train
|
|
|
|
│ │ │ │ ├── xxx{seg_map_suffix}
|
|
|
|
│ │ │ │ ├── yyy{seg_map_suffix}
|
|
|
|
│ │ │ │ ├── zzz{seg_map_suffix}
|
|
|
|
│ │ │ ├── val
|
|
|
|
|
2022-07-26 12:01:40 +08:00
|
|
|
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
2020-07-07 20:52:19 +08:00
|
|
|
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
|
|
|
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
|
|
|
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
|
|
|
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
2021-12-16 18:56:45 +08:00
|
|
|
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2022-05-26 17:13:40 +08:00
|
|
|
ann_file (str): Annotation file path. Defaults to ''.
|
|
|
|
metainfo (dict, optional): Meta information for dataset, such as
|
|
|
|
specify classes to load. Defaults to None.
|
|
|
|
data_root (str, optional): The root directory for ``data_prefix`` and
|
|
|
|
``ann_file``. Defaults to None.
|
|
|
|
data_prefix (dict, optional): Prefix for training data. Defaults to
|
2023-01-17 20:18:46 +08:00
|
|
|
dict(img_path=None, seg_map_path=None).
|
2020-07-07 20:52:19 +08:00
|
|
|
img_suffix (str): Suffix of images. Default: '.jpg'
|
|
|
|
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
2022-05-26 17:13:40 +08:00
|
|
|
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
|
|
|
indices (int or Sequence[int], optional): Support using first few
|
|
|
|
data in annotation file to facilitate training/testing on a smaller
|
|
|
|
dataset. Defaults to None which means using all ``data_infos``.
|
|
|
|
serialize_data (bool, optional): Whether to hold memory using
|
|
|
|
serialized objects, when enabled, data loader workers can use
|
|
|
|
shared RAM from master process instead of making a copy. Defaults
|
|
|
|
to True.
|
|
|
|
pipeline (list, optional): Processing pipeline. Defaults to [].
|
|
|
|
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
|
|
|
Defaults to False.
|
|
|
|
lazy_init (bool, optional): Whether to load annotation during
|
|
|
|
instantiation. In some cases, such as visualization, only the meta
|
|
|
|
information of the dataset is needed, which is not necessary to
|
|
|
|
load annotation file. ``Basedataset`` can skip load annotations to
|
2022-11-18 17:20:03 +08:00
|
|
|
save time by set ``lazy_init=True``. Defaults to False.
|
2022-05-26 17:13:40 +08:00
|
|
|
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
|
|
|
None img. The maximum extra number of cycles to get a valid
|
|
|
|
image. Defaults to 1000.
|
2020-07-07 20:52:19 +08:00
|
|
|
ignore_index (int): The label index to be ignored. Default: 255
|
|
|
|
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
2022-06-19 14:32:09 +08:00
|
|
|
Default to False.
|
2023-02-16 15:33:52 +08:00
|
|
|
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
2023-02-01 17:53:22 +08:00
|
|
|
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
2023-02-16 15:33:52 +08:00
|
|
|
for details. Defaults to None.
|
2023-02-01 17:53:22 +08:00
|
|
|
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
2020-07-07 20:52:19 +08:00
|
|
|
"""
|
2022-05-26 17:13:40 +08:00
|
|
|
METAINFO: dict = dict()
|
|
|
|
|
2023-02-16 15:33:52 +08:00
|
|
|
def __init__(self,
|
|
|
|
ann_file: str = '',
|
|
|
|
img_suffix='.jpg',
|
|
|
|
seg_map_suffix='.png',
|
|
|
|
metainfo: Optional[dict] = None,
|
|
|
|
data_root: Optional[str] = None,
|
|
|
|
data_prefix: dict = dict(img_path='', seg_map_path=''),
|
|
|
|
filter_cfg: Optional[dict] = None,
|
|
|
|
indices: Optional[Union[int, Sequence[int]]] = None,
|
|
|
|
serialize_data: bool = True,
|
|
|
|
pipeline: List[Union[dict, Callable]] = [],
|
|
|
|
test_mode: bool = False,
|
|
|
|
lazy_init: bool = False,
|
|
|
|
max_refetch: int = 1000,
|
|
|
|
ignore_index: int = 255,
|
|
|
|
reduce_zero_label: bool = False,
|
|
|
|
backend_args: Optional[dict] = None) -> None:
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
self.img_suffix = img_suffix
|
|
|
|
self.seg_map_suffix = seg_map_suffix
|
|
|
|
self.ignore_index = ignore_index
|
|
|
|
self.reduce_zero_label = reduce_zero_label
|
2023-02-16 15:33:52 +08:00
|
|
|
self.backend_args = backend_args.copy() if backend_args else None
|
2022-03-28 23:53:23 +08:00
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
self.data_root = data_root
|
|
|
|
self.data_prefix = copy.copy(data_prefix)
|
|
|
|
self.ann_file = ann_file
|
|
|
|
self.filter_cfg = copy.deepcopy(filter_cfg)
|
|
|
|
self._indices = indices
|
|
|
|
self.serialize_data = serialize_data
|
|
|
|
self.test_mode = test_mode
|
|
|
|
self.max_refetch = max_refetch
|
|
|
|
self.data_list: List[dict] = []
|
|
|
|
self.data_bytes: np.ndarray
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
# Set meta information.
|
|
|
|
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
# Get label map for custom classes
|
|
|
|
new_classes = self._metainfo.get('classes', None)
|
|
|
|
self.label_map = self.get_label_map(new_classes)
|
2022-06-09 20:23:36 +08:00
|
|
|
self._metainfo.update(
|
|
|
|
dict(
|
|
|
|
label_map=self.label_map,
|
|
|
|
reduce_zero_label=self.reduce_zero_label))
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
# Update palette based on label map or generate palette
|
|
|
|
# if it is not defined
|
|
|
|
updated_palette = self._update_palette()
|
|
|
|
self._metainfo.update(dict(palette=updated_palette))
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
# Join paths.
|
|
|
|
if self.data_root is not None:
|
|
|
|
self._join_prefix()
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
# Build pipeline.
|
|
|
|
self.pipeline = Compose(pipeline)
|
|
|
|
# Full initialize the dataset.
|
|
|
|
if not lazy_init:
|
|
|
|
self.full_init()
|
|
|
|
|
2022-10-26 20:10:42 +08:00
|
|
|
if test_mode:
|
|
|
|
assert self._metainfo.get('classes') is not None, \
|
|
|
|
'dataset metainfo `classes` should be specified when testing'
|
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
@classmethod
|
|
|
|
def get_label_map(cls,
|
|
|
|
new_classes: Optional[Sequence] = None
|
|
|
|
) -> Union[Dict, None]:
|
|
|
|
"""Require label mapping.
|
|
|
|
|
|
|
|
The ``label_map`` is a dictionary, its keys are the old label ids and
|
|
|
|
its values are the new label ids, and is used for changing pixel
|
|
|
|
labels in load_annotations. If and only if old classes in cls.METAINFO
|
|
|
|
is not equal to new classes in self._metainfo and nether of them is not
|
|
|
|
None, `label_map` is not None.
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
Args:
|
2022-05-26 17:13:40 +08:00
|
|
|
new_classes (list, tuple, optional): The new classes name from
|
|
|
|
metainfo. Default to None.
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
2022-05-26 17:13:40 +08:00
|
|
|
dict, optional: The mapping from old classes in cls.METAINFO to
|
|
|
|
new classes in self._metainfo
|
2020-07-07 20:52:19 +08:00
|
|
|
"""
|
2022-05-26 17:13:40 +08:00
|
|
|
old_classes = cls.METAINFO.get('classes', None)
|
|
|
|
if (new_classes is not None and old_classes is not None
|
|
|
|
and list(new_classes) != list(old_classes)):
|
|
|
|
|
|
|
|
label_map = {}
|
|
|
|
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
|
|
|
raise ValueError(
|
|
|
|
f'new classes {new_classes} is not a '
|
|
|
|
f'subset of classes {old_classes} in METAINFO.')
|
|
|
|
for i, c in enumerate(old_classes):
|
|
|
|
if c not in new_classes:
|
2022-11-22 00:52:36 +08:00
|
|
|
label_map[i] = 255
|
2022-05-26 17:13:40 +08:00
|
|
|
else:
|
|
|
|
label_map[i] = new_classes.index(c)
|
|
|
|
return label_map
|
2020-07-07 20:52:19 +08:00
|
|
|
else:
|
2022-05-26 17:13:40 +08:00
|
|
|
return None
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
def _update_palette(self) -> list:
|
|
|
|
"""Update palette after loading metainfo.
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
If length of palette is equal to classes, just return the palette.
|
|
|
|
If palette is not defined, it will randomly generate a palette.
|
|
|
|
If classes is updated by customer, it will return the subset of
|
|
|
|
palette.
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
Returns:
|
2022-05-26 17:13:40 +08:00
|
|
|
Sequence: Palette for current dataset.
|
2020-07-07 20:52:19 +08:00
|
|
|
"""
|
2022-05-26 17:13:40 +08:00
|
|
|
palette = self._metainfo.get('palette', [])
|
|
|
|
classes = self._metainfo.get('classes', [])
|
|
|
|
# palette does match classes
|
|
|
|
if len(palette) == len(classes):
|
|
|
|
return palette
|
|
|
|
|
|
|
|
if len(palette) == 0:
|
|
|
|
# Get random state before set seed, and restore
|
|
|
|
# random state later.
|
|
|
|
# It will prevent loss of randomness, as the palette
|
|
|
|
# may be different in each iteration if not specified.
|
|
|
|
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
|
|
|
state = np.random.get_state()
|
|
|
|
np.random.seed(42)
|
|
|
|
# random palette
|
|
|
|
new_palette = np.random.randint(
|
|
|
|
0, 255, size=(len(classes), 3)).tolist()
|
|
|
|
np.random.set_state(state)
|
|
|
|
elif len(palette) >= len(classes) and self.label_map is not None:
|
|
|
|
new_palette = []
|
2020-09-16 21:33:01 +08:00
|
|
|
# return subset of palette
|
|
|
|
for old_id, new_id in sorted(
|
|
|
|
self.label_map.items(), key=lambda x: x[1]):
|
2023-01-30 12:35:55 +08:00
|
|
|
if new_id != 255:
|
2022-05-26 17:13:40 +08:00
|
|
|
new_palette.append(palette[old_id])
|
|
|
|
new_palette = type(palette)(new_palette)
|
|
|
|
else:
|
|
|
|
raise ValueError('palette does not match classes '
|
|
|
|
f'as metainfo is {self._metainfo}.')
|
|
|
|
return new_palette
|
2020-07-07 20:52:19 +08:00
|
|
|
|
2022-05-26 17:13:40 +08:00
|
|
|
def load_data_list(self) -> List[dict]:
|
|
|
|
"""Load annotation from directory or annotation file.
|
2020-07-07 20:52:19 +08:00
|
|
|
|
|
|
|
Returns:
|
2022-05-26 17:13:40 +08:00
|
|
|
list[dict]: All data info of dataset.
|
2020-07-07 20:52:19 +08:00
|
|
|
"""
|
2022-05-26 17:13:40 +08:00
|
|
|
data_list = []
|
2022-07-11 19:31:38 +08:00
|
|
|
img_dir = self.data_prefix.get('img_path', None)
|
2022-05-26 17:13:40 +08:00
|
|
|
ann_dir = self.data_prefix.get('seg_map_path', None)
|
|
|
|
if osp.isfile(self.ann_file):
|
2022-08-17 18:35:26 +08:00
|
|
|
lines = mmengine.list_from_file(
|
2023-02-01 17:53:22 +08:00
|
|
|
self.ann_file, backend_args=self.backend_args)
|
2022-05-26 17:13:40 +08:00
|
|
|
for line in lines:
|
|
|
|
img_name = line.strip()
|
2022-07-11 19:31:38 +08:00
|
|
|
data_info = dict(
|
|
|
|
img_path=osp.join(img_dir, img_name + self.img_suffix))
|
2022-05-26 17:13:40 +08:00
|
|
|
if ann_dir is not None:
|
|
|
|
seg_map = img_name + self.seg_map_suffix
|
|
|
|
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
|
|
|
data_info['label_map'] = self.label_map
|
2022-06-09 20:23:36 +08:00
|
|
|
data_info['reduce_zero_label'] = self.reduce_zero_label
|
|
|
|
data_info['seg_fields'] = []
|
2022-05-26 17:13:40 +08:00
|
|
|
data_list.append(data_info)
|
2020-07-07 20:52:19 +08:00
|
|
|
else:
|
2023-02-01 17:53:22 +08:00
|
|
|
for img in fileio.list_dir_or_file(
|
2022-05-26 17:13:40 +08:00
|
|
|
dir_path=img_dir,
|
|
|
|
list_dir=False,
|
|
|
|
suffix=self.img_suffix,
|
2023-02-01 17:53:22 +08:00
|
|
|
recursive=True,
|
|
|
|
backend_args=self.backend_args):
|
2022-05-26 17:13:40 +08:00
|
|
|
data_info = dict(img_path=osp.join(img_dir, img))
|
|
|
|
if ann_dir is not None:
|
|
|
|
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
|
|
|
|
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
|
|
|
data_info['label_map'] = self.label_map
|
2022-06-09 20:23:36 +08:00
|
|
|
data_info['reduce_zero_label'] = self.reduce_zero_label
|
|
|
|
data_info['seg_fields'] = []
|
2022-05-26 17:13:40 +08:00
|
|
|
data_list.append(data_info)
|
|
|
|
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
|
|
|
return data_list
|
[Feature] Add GDAL backend and Support LEVIR-CD Dataset (#2903)
## Motivation
For support with reading multiple remote sensing image formats, please
refer to https://gdal.org/drivers/raster/index.html.
Byte, UInt16, Int16, UInt32, Int32, Float32, Float64, CInt16, CInt32,
CFloat32 and CFloat64 are supported for reading and writing.
Support input of two images for change detection tasks, and support the
LEVIR-CD dataset.
## Modification
Add LoadSingleRSImageFromFile in 'mmseg/datasets/transforms/loading.py'.
Load a single remote sensing image for object segmentation tasks.
Add LoadMultipleRSImageFromFile in
'mmseg/datasets/transforms/loading.py'.
Load two remote sensing images for change detection tasks.
Add ConcatCDInput in 'mmseg/datasets/transforms/transforms.py'.
Combine images that have been separately augmented for data enhancement.
Add BaseCDDataset in 'mmseg/datasets/basesegdataset.py'
Base class for datasets used in change detection tasks.
---------
Co-authored-by: xiexinch <xiexinch@outlook.com>
2023-05-08 20:09:33 +08:00
|
|
|
|
|
|
|
|
|
|
|
@DATASETS.register_module()
|
|
|
|
class BaseCDDataset(BaseDataset):
|
|
|
|
"""Custom dataset for change detection. An example of file structure is as
|
|
|
|
followed.
|
|
|
|
|
|
|
|
.. code-block:: none
|
|
|
|
|
|
|
|
├── data
|
|
|
|
│ ├── my_dataset
|
|
|
|
│ │ ├── img_dir
|
|
|
|
│ │ │ ├── train
|
|
|
|
│ │ │ │ ├── xxx{img_suffix}
|
|
|
|
│ │ │ │ ├── yyy{img_suffix}
|
|
|
|
│ │ │ │ ├── zzz{img_suffix}
|
|
|
|
│ │ │ ├── val
|
|
|
|
│ │ ├── img_dir2
|
|
|
|
│ │ │ ├── train
|
|
|
|
│ │ │ │ ├── xxx{img_suffix}
|
|
|
|
│ │ │ │ ├── yyy{img_suffix}
|
|
|
|
│ │ │ │ ├── zzz{img_suffix}
|
|
|
|
│ │ │ ├── val
|
|
|
|
│ │ ├── ann_dir
|
|
|
|
│ │ │ ├── train
|
|
|
|
│ │ │ │ ├── xxx{seg_map_suffix}
|
|
|
|
│ │ │ │ ├── yyy{seg_map_suffix}
|
|
|
|
│ │ │ │ ├── zzz{seg_map_suffix}
|
|
|
|
│ │ │ ├── val
|
|
|
|
|
|
|
|
The image names in img_dir and img_dir2 should be consistent.
|
|
|
|
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
|
|
|
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
|
|
|
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
|
|
|
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
|
|
|
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
|
|
|
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
ann_file (str): Annotation file path. Defaults to ''.
|
|
|
|
metainfo (dict, optional): Meta information for dataset, such as
|
|
|
|
specify classes to load. Defaults to None.
|
|
|
|
data_root (str, optional): The root directory for ``data_prefix`` and
|
|
|
|
``ann_file``. Defaults to None.
|
|
|
|
data_prefix (dict, optional): Prefix for training data. Defaults to
|
|
|
|
dict(img_path=None, img_path2=None, seg_map_path=None).
|
|
|
|
img_suffix (str): Suffix of images. Default: '.jpg'
|
|
|
|
img_suffix2 (str): Suffix of images. Default: '.jpg'
|
|
|
|
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
|
|
|
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
|
|
|
indices (int or Sequence[int], optional): Support using first few
|
|
|
|
data in annotation file to facilitate training/testing on a smaller
|
|
|
|
dataset. Defaults to None which means using all ``data_infos``.
|
|
|
|
serialize_data (bool, optional): Whether to hold memory using
|
|
|
|
serialized objects, when enabled, data loader workers can use
|
|
|
|
shared RAM from master process instead of making a copy. Defaults
|
|
|
|
to True.
|
|
|
|
pipeline (list, optional): Processing pipeline. Defaults to [].
|
|
|
|
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
|
|
|
Defaults to False.
|
|
|
|
lazy_init (bool, optional): Whether to load annotation during
|
|
|
|
instantiation. In some cases, such as visualization, only the meta
|
|
|
|
information of the dataset is needed, which is not necessary to
|
|
|
|
load annotation file. ``Basedataset`` can skip load annotations to
|
|
|
|
save time by set ``lazy_init=True``. Defaults to False.
|
|
|
|
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
|
|
|
None img. The maximum extra number of cycles to get a valid
|
|
|
|
image. Defaults to 1000.
|
|
|
|
ignore_index (int): The label index to be ignored. Default: 255
|
|
|
|
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
|
|
|
Default to False.
|
|
|
|
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
|
|
|
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
|
|
|
for details. Defaults to None.
|
|
|
|
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
|
|
|
"""
|
|
|
|
METAINFO: dict = dict()
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
ann_file: str = '',
|
|
|
|
img_suffix='.jpg',
|
|
|
|
img_suffix2='.jpg',
|
|
|
|
seg_map_suffix='.png',
|
|
|
|
metainfo: Optional[dict] = None,
|
|
|
|
data_root: Optional[str] = None,
|
|
|
|
data_prefix: dict = dict(
|
|
|
|
img_path='', img_path2='', seg_map_path=''),
|
|
|
|
filter_cfg: Optional[dict] = None,
|
|
|
|
indices: Optional[Union[int, Sequence[int]]] = None,
|
|
|
|
serialize_data: bool = True,
|
|
|
|
pipeline: List[Union[dict, Callable]] = [],
|
|
|
|
test_mode: bool = False,
|
|
|
|
lazy_init: bool = False,
|
|
|
|
max_refetch: int = 1000,
|
|
|
|
ignore_index: int = 255,
|
|
|
|
reduce_zero_label: bool = False,
|
|
|
|
backend_args: Optional[dict] = None) -> None:
|
|
|
|
|
|
|
|
self.img_suffix = img_suffix
|
|
|
|
self.img_suffix2 = img_suffix2
|
|
|
|
self.seg_map_suffix = seg_map_suffix
|
|
|
|
self.ignore_index = ignore_index
|
|
|
|
self.reduce_zero_label = reduce_zero_label
|
|
|
|
self.backend_args = backend_args.copy() if backend_args else None
|
|
|
|
|
|
|
|
self.data_root = data_root
|
|
|
|
self.data_prefix = copy.copy(data_prefix)
|
|
|
|
self.ann_file = ann_file
|
|
|
|
self.filter_cfg = copy.deepcopy(filter_cfg)
|
|
|
|
self._indices = indices
|
|
|
|
self.serialize_data = serialize_data
|
|
|
|
self.test_mode = test_mode
|
|
|
|
self.max_refetch = max_refetch
|
|
|
|
self.data_list: List[dict] = []
|
|
|
|
self.data_bytes: np.ndarray
|
|
|
|
|
|
|
|
# Set meta information.
|
|
|
|
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
|
|
|
|
|
|
|
# Get label map for custom classes
|
|
|
|
new_classes = self._metainfo.get('classes', None)
|
|
|
|
self.label_map = self.get_label_map(new_classes)
|
|
|
|
self._metainfo.update(
|
|
|
|
dict(
|
|
|
|
label_map=self.label_map,
|
|
|
|
reduce_zero_label=self.reduce_zero_label))
|
|
|
|
|
|
|
|
# Update palette based on label map or generate palette
|
|
|
|
# if it is not defined
|
|
|
|
updated_palette = self._update_palette()
|
|
|
|
self._metainfo.update(dict(palette=updated_palette))
|
|
|
|
|
|
|
|
# Join paths.
|
|
|
|
if self.data_root is not None:
|
|
|
|
self._join_prefix()
|
|
|
|
|
|
|
|
# Build pipeline.
|
|
|
|
self.pipeline = Compose(pipeline)
|
|
|
|
# Full initialize the dataset.
|
|
|
|
if not lazy_init:
|
|
|
|
self.full_init()
|
|
|
|
|
|
|
|
if test_mode:
|
|
|
|
assert self._metainfo.get('classes') is not None, \
|
|
|
|
'dataset metainfo `classes` should be specified when testing'
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_label_map(cls,
|
|
|
|
new_classes: Optional[Sequence] = None
|
|
|
|
) -> Union[Dict, None]:
|
|
|
|
"""Require label mapping.
|
|
|
|
|
|
|
|
The ``label_map`` is a dictionary, its keys are the old label ids and
|
|
|
|
its values are the new label ids, and is used for changing pixel
|
|
|
|
labels in load_annotations. If and only if old classes in cls.METAINFO
|
|
|
|
is not equal to new classes in self._metainfo and nether of them is not
|
|
|
|
None, `label_map` is not None.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
new_classes (list, tuple, optional): The new classes name from
|
|
|
|
metainfo. Default to None.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict, optional: The mapping from old classes in cls.METAINFO to
|
|
|
|
new classes in self._metainfo
|
|
|
|
"""
|
|
|
|
old_classes = cls.METAINFO.get('classes', None)
|
|
|
|
if (new_classes is not None and old_classes is not None
|
|
|
|
and list(new_classes) != list(old_classes)):
|
|
|
|
|
|
|
|
label_map = {}
|
|
|
|
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
|
|
|
raise ValueError(
|
|
|
|
f'new classes {new_classes} is not a '
|
|
|
|
f'subset of classes {old_classes} in METAINFO.')
|
|
|
|
for i, c in enumerate(old_classes):
|
|
|
|
if c not in new_classes:
|
|
|
|
label_map[i] = 255
|
|
|
|
else:
|
|
|
|
label_map[i] = new_classes.index(c)
|
|
|
|
return label_map
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
def _update_palette(self) -> list:
|
|
|
|
"""Update palette after loading metainfo.
|
|
|
|
|
|
|
|
If length of palette is equal to classes, just return the palette.
|
|
|
|
If palette is not defined, it will randomly generate a palette.
|
|
|
|
If classes is updated by customer, it will return the subset of
|
|
|
|
palette.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Sequence: Palette for current dataset.
|
|
|
|
"""
|
|
|
|
palette = self._metainfo.get('palette', [])
|
|
|
|
classes = self._metainfo.get('classes', [])
|
|
|
|
# palette does match classes
|
|
|
|
if len(palette) == len(classes):
|
|
|
|
return palette
|
|
|
|
|
|
|
|
if len(palette) == 0:
|
|
|
|
# Get random state before set seed, and restore
|
|
|
|
# random state later.
|
|
|
|
# It will prevent loss of randomness, as the palette
|
|
|
|
# may be different in each iteration if not specified.
|
|
|
|
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
|
|
|
state = np.random.get_state()
|
|
|
|
np.random.seed(42)
|
|
|
|
# random palette
|
|
|
|
new_palette = np.random.randint(
|
|
|
|
0, 255, size=(len(classes), 3)).tolist()
|
|
|
|
np.random.set_state(state)
|
|
|
|
elif len(palette) >= len(classes) and self.label_map is not None:
|
|
|
|
new_palette = []
|
|
|
|
# return subset of palette
|
|
|
|
for old_id, new_id in sorted(
|
|
|
|
self.label_map.items(), key=lambda x: x[1]):
|
|
|
|
if new_id != 255:
|
|
|
|
new_palette.append(palette[old_id])
|
|
|
|
new_palette = type(palette)(new_palette)
|
|
|
|
else:
|
|
|
|
raise ValueError('palette does not match classes '
|
|
|
|
f'as metainfo is {self._metainfo}.')
|
|
|
|
return new_palette
|
|
|
|
|
|
|
|
def load_data_list(self) -> List[dict]:
|
|
|
|
"""Load annotation from directory or annotation file.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list[dict]: All data info of dataset.
|
|
|
|
"""
|
|
|
|
data_list = []
|
|
|
|
img_dir = self.data_prefix.get('img_path', None)
|
|
|
|
img_dir2 = self.data_prefix.get('img_path2', None)
|
|
|
|
ann_dir = self.data_prefix.get('seg_map_path', None)
|
|
|
|
if osp.isfile(self.ann_file):
|
|
|
|
lines = mmengine.list_from_file(
|
|
|
|
self.ann_file, backend_args=self.backend_args)
|
|
|
|
for line in lines:
|
|
|
|
img_name = line.strip()
|
|
|
|
if '.' in osp.basename(img_name):
|
|
|
|
img_name, img_ext = osp.splitext(img_name)
|
|
|
|
self.img_suffix = img_ext
|
|
|
|
self.img_suffix2 = img_ext
|
|
|
|
data_info = dict(
|
|
|
|
img_path=osp.join(img_dir, img_name + self.img_suffix),
|
|
|
|
img_path2=osp.join(img_dir2, img_name + self.img_suffix2))
|
|
|
|
|
|
|
|
if ann_dir is not None:
|
|
|
|
seg_map = img_name + self.seg_map_suffix
|
|
|
|
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
|
|
|
data_info['label_map'] = self.label_map
|
|
|
|
data_info['reduce_zero_label'] = self.reduce_zero_label
|
|
|
|
data_info['seg_fields'] = []
|
|
|
|
data_list.append(data_info)
|
|
|
|
else:
|
|
|
|
for img in fileio.list_dir_or_file(
|
|
|
|
dir_path=img_dir,
|
|
|
|
list_dir=False,
|
|
|
|
suffix=self.img_suffix,
|
|
|
|
recursive=True,
|
|
|
|
backend_args=self.backend_args):
|
|
|
|
if '.' in osp.basename(img):
|
|
|
|
img, img_ext = osp.splitext(img)
|
|
|
|
self.img_suffix = img_ext
|
|
|
|
self.img_suffix2 = img_ext
|
|
|
|
data_info = dict(
|
|
|
|
img_path=osp.join(img_dir, img + self.img_suffix),
|
|
|
|
img_path2=osp.join(img_dir2, img + self.img_suffix2))
|
|
|
|
if ann_dir is not None:
|
|
|
|
seg_map = img + self.seg_map_suffix
|
|
|
|
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
|
|
|
data_info['label_map'] = self.label_map
|
|
|
|
data_info['reduce_zero_label'] = self.reduce_zero_label
|
|
|
|
data_info['seg_fields'] = []
|
|
|
|
data_list.append(data_info)
|
|
|
|
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
|
|
|
return data_list
|