110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import mmcv
|
|
import numpy as np
|
|
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
|
|
|
from mmseg.registry import TRANSFORMS
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class LoadAnnotations(MMCV_LoadAnnotations):
|
|
"""Load annotations for semantic segmentation provided by dataset.
|
|
|
|
The annotation format is as the following:
|
|
|
|
.. code-block:: python
|
|
|
|
{
|
|
# Filename of semantic segmentation ground truth file.
|
|
'seg_map_path': 'a/b/c'
|
|
}
|
|
|
|
After this module, the annotation has been changed to the format below:
|
|
|
|
.. code-block:: python
|
|
|
|
{
|
|
# in str
|
|
'seg_fields': List
|
|
# In uint8 type.
|
|
'gt_seg_map': np.ndarray (H, W)
|
|
}
|
|
|
|
Required Keys:
|
|
|
|
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
|
|
|
Added Keys:
|
|
|
|
- seg_fields (List)
|
|
- gt_seg_map (np.uint8)
|
|
|
|
Args:
|
|
reduce_zero_label (bool): Whether reduce all label value by 1.
|
|
Usually used for datasets where 0 is background label.
|
|
Defaults to False.
|
|
imdecode_backend (str): The image decoding backend type. The backend
|
|
argument for :func:``mmcv.imfrombytes``.
|
|
See :fun:``mmcv.imfrombytes`` for details.
|
|
Defaults to 'pillow'.
|
|
file_client_args (dict): Arguments to instantiate a FileClient.
|
|
See :class:``mmcv.fileio.FileClient`` for details.
|
|
Defaults to ``dict(backend='disk')``.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
reduce_zero_label=False,
|
|
file_client_args=dict(backend='disk'),
|
|
imdecode_backend='pillow',
|
|
) -> None:
|
|
super().__init__(
|
|
with_bbox=False,
|
|
with_label=False,
|
|
with_seg=True,
|
|
with_keypoints=False,
|
|
imdecode_backend=imdecode_backend,
|
|
file_client_args=file_client_args)
|
|
self.reduce_zero_label = reduce_zero_label
|
|
self.file_client_args = file_client_args.copy()
|
|
self.imdecode_backend = imdecode_backend
|
|
|
|
def _load_seg_map(self, results: dict) -> None:
|
|
"""Private function to load semantic segmentation annotations.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded semantic segmentation annotations.
|
|
"""
|
|
|
|
img_bytes = self.file_client.get(results['seg_map_path'])
|
|
gt_semantic_seg = mmcv.imfrombytes(
|
|
img_bytes, flag='unchanged',
|
|
backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
|
|
|
# modify if custom classes
|
|
if results.get('label_map', None) is not None:
|
|
# Add deep copy to solve bug of repeatedly
|
|
# replace `gt_semantic_seg`, which is reported in
|
|
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
|
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
|
for old_id, new_id in results['label_map'].items():
|
|
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
|
# reduce zero_label
|
|
if self.reduce_zero_label:
|
|
# avoid using underflow conversion
|
|
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
|
gt_semantic_seg = gt_semantic_seg - 1
|
|
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
|
results['gt_seg_map'] = gt_semantic_seg
|
|
results['seg_fields'].append('gt_seg_map')
|
|
|
|
def __repr__(self) -> str:
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
|
|
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
|
|
repr_str += f'file_client_args={self.file_client_args})'
|
|
return repr_str
|