mmsegmentation/mmseg/datasets/basesegdataset.py

550 lines
24 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2022-05-26 17:13:40 +08:00
import copy
2020-07-07 20:52:19 +08:00
import os.path as osp
2022-05-26 17:13:40 +08:00
from typing import Callable, Dict, List, Optional, Sequence, Union
2020-07-07 20:52:19 +08:00
2022-08-17 18:35:26 +08:00
import mmengine
import mmengine.fileio as fileio
2020-07-07 20:52:19 +08:00
import numpy as np
2022-05-26 17:13:40 +08:00
from mmengine.dataset import BaseDataset, Compose
2020-07-07 20:52:19 +08:00
2022-05-10 20:15:20 +08:00
from mmseg.registry import DATASETS
2020-07-07 20:52:19 +08:00
@DATASETS.register_module()
class BaseSegDataset(BaseDataset):
"""Custom dataset for semantic segmentation. An example of file structure
is as followed.
2020-07-07 20:52:19 +08:00
.. code-block:: none
data
my_dataset
img_dir
train
xxx{img_suffix}
yyy{img_suffix}
zzz{img_suffix}
val
ann_dir
train
xxx{seg_map_suffix}
yyy{seg_map_suffix}
zzz{seg_map_suffix}
val
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
2020-07-07 20:52:19 +08:00
except suffix. A valid img/gt_semantic_seg filename pair should be like
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
in the suffix). If split is given, then ``xxx`` is specified in txt file.
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
2020-07-07 20:52:19 +08:00
Args:
2022-05-26 17:13:40 +08:00
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as
specify classes to load. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
dict(img_path=None, seg_map_path=None).
2020-07-07 20:52:19 +08:00
img_suffix (str): Suffix of images. Default: '.jpg'
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
2022-05-26 17:13:40 +08:00
filter_cfg (dict, optional): Config for filter data. Defaults to None.
indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller
dataset. Defaults to None which means using all ``data_infos``.
serialize_data (bool, optional): Whether to hold memory using
serialized objects, when enabled, data loader workers can use
shared RAM from master process instead of making a copy. Defaults
to True.
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=True``. Defaults to False.
2022-05-26 17:13:40 +08:00
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
None img. The maximum extra number of cycles to get a valid
image. Defaults to 1000.
2020-07-07 20:52:19 +08:00
ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default to False.
backend_args (dict, Optional): Arguments to instantiate a file backend.
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
for details. Defaults to None.
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
2020-07-07 20:52:19 +08:00
"""
2022-05-26 17:13:40 +08:00
METAINFO: dict = dict()
def __init__(self,
ann_file: str = '',
img_suffix='.jpg',
seg_map_suffix='.png',
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: dict = dict(img_path='', seg_map_path=''),
filter_cfg: Optional[dict] = None,
indices: Optional[Union[int, Sequence[int]]] = None,
serialize_data: bool = True,
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
lazy_init: bool = False,
max_refetch: int = 1000,
ignore_index: int = 255,
reduce_zero_label: bool = False,
backend_args: Optional[dict] = None) -> None:
2020-07-07 20:52:19 +08:00
self.img_suffix = img_suffix
self.seg_map_suffix = seg_map_suffix
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
self.backend_args = backend_args.copy() if backend_args else None
2022-05-26 17:13:40 +08:00
self.data_root = data_root
self.data_prefix = copy.copy(data_prefix)
self.ann_file = ann_file
self.filter_cfg = copy.deepcopy(filter_cfg)
self._indices = indices
self.serialize_data = serialize_data
self.test_mode = test_mode
self.max_refetch = max_refetch
self.data_list: List[dict] = []
self.data_bytes: np.ndarray
2020-07-07 20:52:19 +08:00
2022-05-26 17:13:40 +08:00
# Set meta information.
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
2020-07-07 20:52:19 +08:00
2022-05-26 17:13:40 +08:00
# Get label map for custom classes
new_classes = self._metainfo.get('classes', None)
self.label_map = self.get_label_map(new_classes)
self._metainfo.update(
dict(
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label))
2020-07-07 20:52:19 +08:00
2022-05-26 17:13:40 +08:00
# Update palette based on label map or generate palette
# if it is not defined
updated_palette = self._update_palette()
self._metainfo.update(dict(palette=updated_palette))
2020-07-07 20:52:19 +08:00
2022-05-26 17:13:40 +08:00
# Join paths.
if self.data_root is not None:
self._join_prefix()
2020-07-07 20:52:19 +08:00
2022-05-26 17:13:40 +08:00
# Build pipeline.
self.pipeline = Compose(pipeline)
# Full initialize the dataset.
if not lazy_init:
self.full_init()
if test_mode:
assert self._metainfo.get('classes') is not None, \
'dataset metainfo `classes` should be specified when testing'
2022-05-26 17:13:40 +08:00
@classmethod
def get_label_map(cls,
new_classes: Optional[Sequence] = None
) -> Union[Dict, None]:
"""Require label mapping.
The ``label_map`` is a dictionary, its keys are the old label ids and
its values are the new label ids, and is used for changing pixel
labels in load_annotations. If and only if old classes in cls.METAINFO
is not equal to new classes in self._metainfo and nether of them is not
None, `label_map` is not None.
2020-07-07 20:52:19 +08:00
Args:
2022-05-26 17:13:40 +08:00
new_classes (list, tuple, optional): The new classes name from
metainfo. Default to None.
2020-07-07 20:52:19 +08:00
Returns:
2022-05-26 17:13:40 +08:00
dict, optional: The mapping from old classes in cls.METAINFO to
new classes in self._metainfo
2020-07-07 20:52:19 +08:00
"""
2022-05-26 17:13:40 +08:00
old_classes = cls.METAINFO.get('classes', None)
if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)):
label_map = {}
if not set(new_classes).issubset(cls.METAINFO['classes']):
raise ValueError(
f'new classes {new_classes} is not a '
f'subset of classes {old_classes} in METAINFO.')
for i, c in enumerate(old_classes):
if c not in new_classes:
label_map[i] = 255
2022-05-26 17:13:40 +08:00
else:
label_map[i] = new_classes.index(c)
return label_map
2020-07-07 20:52:19 +08:00
else:
2022-05-26 17:13:40 +08:00
return None
2020-07-07 20:52:19 +08:00
2022-05-26 17:13:40 +08:00
def _update_palette(self) -> list:
"""Update palette after loading metainfo.
2020-07-07 20:52:19 +08:00
2022-05-26 17:13:40 +08:00
If length of palette is equal to classes, just return the palette.
If palette is not defined, it will randomly generate a palette.
If classes is updated by customer, it will return the subset of
palette.
2020-07-07 20:52:19 +08:00
Returns:
2022-05-26 17:13:40 +08:00
Sequence: Palette for current dataset.
2020-07-07 20:52:19 +08:00
"""
2022-05-26 17:13:40 +08:00
palette = self._metainfo.get('palette', [])
classes = self._metainfo.get('classes', [])
# palette does match classes
if len(palette) == len(classes):
return palette
if len(palette) == 0:
# Get random state before set seed, and restore
# random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
new_palette = np.random.randint(
0, 255, size=(len(classes), 3)).tolist()
np.random.set_state(state)
elif len(palette) >= len(classes) and self.label_map is not None:
new_palette = []
# return subset of palette
for old_id, new_id in sorted(
self.label_map.items(), key=lambda x: x[1]):
if new_id != 255:
2022-05-26 17:13:40 +08:00
new_palette.append(palette[old_id])
new_palette = type(palette)(new_palette)
else:
raise ValueError('palette does not match classes '
f'as metainfo is {self._metainfo}.')
return new_palette
2020-07-07 20:52:19 +08:00
2022-05-26 17:13:40 +08:00
def load_data_list(self) -> List[dict]:
"""Load annotation from directory or annotation file.
2020-07-07 20:52:19 +08:00
Returns:
2022-05-26 17:13:40 +08:00
list[dict]: All data info of dataset.
2020-07-07 20:52:19 +08:00
"""
2022-05-26 17:13:40 +08:00
data_list = []
img_dir = self.data_prefix.get('img_path', None)
2022-05-26 17:13:40 +08:00
ann_dir = self.data_prefix.get('seg_map_path', None)
if osp.isfile(self.ann_file):
2022-08-17 18:35:26 +08:00
lines = mmengine.list_from_file(
self.ann_file, backend_args=self.backend_args)
2022-05-26 17:13:40 +08:00
for line in lines:
img_name = line.strip()
data_info = dict(
img_path=osp.join(img_dir, img_name + self.img_suffix))
2022-05-26 17:13:40 +08:00
if ann_dir is not None:
seg_map = img_name + self.seg_map_suffix
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
data_info['reduce_zero_label'] = self.reduce_zero_label
data_info['seg_fields'] = []
2022-05-26 17:13:40 +08:00
data_list.append(data_info)
2020-07-07 20:52:19 +08:00
else:
for img in fileio.list_dir_or_file(
2022-05-26 17:13:40 +08:00
dir_path=img_dir,
list_dir=False,
suffix=self.img_suffix,
recursive=True,
backend_args=self.backend_args):
2022-05-26 17:13:40 +08:00
data_info = dict(img_path=osp.join(img_dir, img))
if ann_dir is not None:
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
data_info['reduce_zero_label'] = self.reduce_zero_label
data_info['seg_fields'] = []
2022-05-26 17:13:40 +08:00
data_list.append(data_info)
data_list = sorted(data_list, key=lambda x: x['img_path'])
return data_list
@DATASETS.register_module()
class BaseCDDataset(BaseDataset):
"""Custom dataset for change detection. An example of file structure is as
followed.
.. code-block:: none
data
my_dataset
img_dir
train
xxx{img_suffix}
yyy{img_suffix}
zzz{img_suffix}
val
img_dir2
train
xxx{img_suffix}
yyy{img_suffix}
zzz{img_suffix}
val
ann_dir
train
xxx{seg_map_suffix}
yyy{seg_map_suffix}
zzz{seg_map_suffix}
val
The image names in img_dir and img_dir2 should be consistent.
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
except suffix. A valid img/gt_semantic_seg filename pair should be like
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
in the suffix). If split is given, then ``xxx`` is specified in txt file.
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
Args:
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as
specify classes to load. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
dict(img_path=None, img_path2=None, seg_map_path=None).
img_suffix (str): Suffix of images. Default: '.jpg'
img_suffix2 (str): Suffix of images. Default: '.jpg'
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
filter_cfg (dict, optional): Config for filter data. Defaults to None.
indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller
dataset. Defaults to None which means using all ``data_infos``.
serialize_data (bool, optional): Whether to hold memory using
serialized objects, when enabled, data loader workers can use
shared RAM from master process instead of making a copy. Defaults
to True.
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=True``. Defaults to False.
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
None img. The maximum extra number of cycles to get a valid
image. Defaults to 1000.
ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored.
Default to False.
backend_args (dict, Optional): Arguments to instantiate a file backend.
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
for details. Defaults to None.
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
"""
METAINFO: dict = dict()
def __init__(self,
ann_file: str = '',
img_suffix='.jpg',
img_suffix2='.jpg',
seg_map_suffix='.png',
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: dict = dict(
img_path='', img_path2='', seg_map_path=''),
filter_cfg: Optional[dict] = None,
indices: Optional[Union[int, Sequence[int]]] = None,
serialize_data: bool = True,
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
lazy_init: bool = False,
max_refetch: int = 1000,
ignore_index: int = 255,
reduce_zero_label: bool = False,
backend_args: Optional[dict] = None) -> None:
self.img_suffix = img_suffix
self.img_suffix2 = img_suffix2
self.seg_map_suffix = seg_map_suffix
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
self.backend_args = backend_args.copy() if backend_args else None
self.data_root = data_root
self.data_prefix = copy.copy(data_prefix)
self.ann_file = ann_file
self.filter_cfg = copy.deepcopy(filter_cfg)
self._indices = indices
self.serialize_data = serialize_data
self.test_mode = test_mode
self.max_refetch = max_refetch
self.data_list: List[dict] = []
self.data_bytes: np.ndarray
# Set meta information.
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
# Get label map for custom classes
new_classes = self._metainfo.get('classes', None)
self.label_map = self.get_label_map(new_classes)
self._metainfo.update(
dict(
label_map=self.label_map,
reduce_zero_label=self.reduce_zero_label))
# Update palette based on label map or generate palette
# if it is not defined
updated_palette = self._update_palette()
self._metainfo.update(dict(palette=updated_palette))
# Join paths.
if self.data_root is not None:
self._join_prefix()
# Build pipeline.
self.pipeline = Compose(pipeline)
# Full initialize the dataset.
if not lazy_init:
self.full_init()
if test_mode:
assert self._metainfo.get('classes') is not None, \
'dataset metainfo `classes` should be specified when testing'
@classmethod
def get_label_map(cls,
new_classes: Optional[Sequence] = None
) -> Union[Dict, None]:
"""Require label mapping.
The ``label_map`` is a dictionary, its keys are the old label ids and
its values are the new label ids, and is used for changing pixel
labels in load_annotations. If and only if old classes in cls.METAINFO
is not equal to new classes in self._metainfo and nether of them is not
None, `label_map` is not None.
Args:
new_classes (list, tuple, optional): The new classes name from
metainfo. Default to None.
Returns:
dict, optional: The mapping from old classes in cls.METAINFO to
new classes in self._metainfo
"""
old_classes = cls.METAINFO.get('classes', None)
if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)):
label_map = {}
if not set(new_classes).issubset(cls.METAINFO['classes']):
raise ValueError(
f'new classes {new_classes} is not a '
f'subset of classes {old_classes} in METAINFO.')
for i, c in enumerate(old_classes):
if c not in new_classes:
label_map[i] = 255
else:
label_map[i] = new_classes.index(c)
return label_map
else:
return None
def _update_palette(self) -> list:
"""Update palette after loading metainfo.
If length of palette is equal to classes, just return the palette.
If palette is not defined, it will randomly generate a palette.
If classes is updated by customer, it will return the subset of
palette.
Returns:
Sequence: Palette for current dataset.
"""
palette = self._metainfo.get('palette', [])
classes = self._metainfo.get('classes', [])
# palette does match classes
if len(palette) == len(classes):
return palette
if len(palette) == 0:
# Get random state before set seed, and restore
# random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
new_palette = np.random.randint(
0, 255, size=(len(classes), 3)).tolist()
np.random.set_state(state)
elif len(palette) >= len(classes) and self.label_map is not None:
new_palette = []
# return subset of palette
for old_id, new_id in sorted(
self.label_map.items(), key=lambda x: x[1]):
if new_id != 255:
new_palette.append(palette[old_id])
new_palette = type(palette)(new_palette)
else:
raise ValueError('palette does not match classes '
f'as metainfo is {self._metainfo}.')
return new_palette
def load_data_list(self) -> List[dict]:
"""Load annotation from directory or annotation file.
Returns:
list[dict]: All data info of dataset.
"""
data_list = []
img_dir = self.data_prefix.get('img_path', None)
img_dir2 = self.data_prefix.get('img_path2', None)
ann_dir = self.data_prefix.get('seg_map_path', None)
if osp.isfile(self.ann_file):
lines = mmengine.list_from_file(
self.ann_file, backend_args=self.backend_args)
for line in lines:
img_name = line.strip()
if '.' in osp.basename(img_name):
img_name, img_ext = osp.splitext(img_name)
self.img_suffix = img_ext
self.img_suffix2 = img_ext
data_info = dict(
img_path=osp.join(img_dir, img_name + self.img_suffix),
img_path2=osp.join(img_dir2, img_name + self.img_suffix2))
if ann_dir is not None:
seg_map = img_name + self.seg_map_suffix
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
data_info['reduce_zero_label'] = self.reduce_zero_label
data_info['seg_fields'] = []
data_list.append(data_info)
else:
for img in fileio.list_dir_or_file(
dir_path=img_dir,
list_dir=False,
suffix=self.img_suffix,
recursive=True,
backend_args=self.backend_args):
if '.' in osp.basename(img):
img, img_ext = osp.splitext(img)
self.img_suffix = img_ext
self.img_suffix2 = img_ext
data_info = dict(
img_path=osp.join(img_dir, img + self.img_suffix),
img_path2=osp.join(img_dir2, img + self.img_suffix2))
if ann_dir is not None:
seg_map = img + self.seg_map_suffix
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
data_info['label_map'] = self.label_map
data_info['reduce_zero_label'] = self.reduce_zero_label
data_info['seg_fields'] = []
data_list.append(data_info)
data_list = sorted(data_list, key=lambda x: x['img_path'])
return data_list