mmsegmentation/mmseg/datasets/pipelines/loading.py

110 lines
3.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
from mmseg.registry import TRANSFORMS
@TRANSFORMS.register_module()
class LoadAnnotations(MMCV_LoadAnnotations):
"""Load annotations for semantic segmentation provided by dataset.
The annotation format is as the following:
.. code-block:: python
{
# Filename of semantic segmentation ground truth file.
'seg_map_path': 'a/b/c'
}
After this module, the annotation has been changed to the format below:
.. code-block:: python
{
# in str
'seg_fields': List
# In uint8 type.
'gt_seg_map': np.ndarray (H, W)
}
Required Keys:
- seg_map_path (str): Path of semantic segmentation ground truth file.
Added Keys:
- seg_fields (List)
- gt_seg_map (np.uint8)
Args:
reduce_zero_label (bool): Whether reduce all label value by 1.
Usually used for datasets where 0 is background label.
Defaults to False.
imdecode_backend (str): The image decoding backend type. The backend
argument for :func:``mmcv.imfrombytes``.
See :fun:``mmcv.imfrombytes`` for details.
Defaults to 'pillow'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:``mmcv.fileio.FileClient`` for details.
Defaults to ``dict(backend='disk')``.
"""
def __init__(
self,
reduce_zero_label=False,
file_client_args=dict(backend='disk'),
imdecode_backend='pillow',
) -> None:
super().__init__(
with_bbox=False,
with_label=False,
with_seg=True,
with_keypoints=False,
imdecode_backend=imdecode_backend,
file_client_args=file_client_args)
self.reduce_zero_label = reduce_zero_label
self.file_client_args = file_client_args.copy()
self.imdecode_backend = imdecode_backend
def _load_seg_map(self, results: dict) -> None:
"""Private function to load semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
img_bytes = self.file_client.get(results['seg_map_path'])
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze().astype(np.uint8)
# modify if custom classes
if results.get('label_map', None) is not None:
# Add deep copy to solve bug of repeatedly
# replace `gt_semantic_seg`, which is reported in
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
gt_semantic_seg_copy = gt_semantic_seg.copy()
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
# reduce zero_label
if self.reduce_zero_label:
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = 255
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == 254] = 255
results['gt_seg_map'] = gt_semantic_seg
results['seg_fields'].append('gt_seg_map')
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
repr_str += f'file_client_args={self.file_client_args})'
return repr_str