# Copyright (c) OpenMMLab. All rights reserved. import mmcv import numpy as np from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations from mmseg.registry import TRANSFORMS @TRANSFORMS.register_module() class LoadAnnotations(MMCV_LoadAnnotations): """Load annotations for semantic segmentation provided by dataset. The annotation format is as the following: .. code-block:: python { # Filename of semantic segmentation ground truth file. 'seg_map_path': 'a/b/c' } After this module, the annotation has been changed to the format below: .. code-block:: python { # in str 'seg_fields': List # In uint8 type. 'gt_seg_map': np.ndarray (H, W) } Required Keys: - seg_map_path (str): Path of semantic segmentation ground truth file. Added Keys: - seg_fields (List) - gt_seg_map (np.uint8) Args: reduce_zero_label (bool): Whether reduce all label value by 1. Usually used for datasets where 0 is background label. Defaults to False. imdecode_backend (str): The image decoding backend type. The backend argument for :func:``mmcv.imfrombytes``. See :fun:``mmcv.imfrombytes`` for details. Defaults to 'pillow'. file_client_args (dict): Arguments to instantiate a FileClient. See :class:``mmcv.fileio.FileClient`` for details. Defaults to ``dict(backend='disk')``. """ def __init__( self, reduce_zero_label=False, file_client_args=dict(backend='disk'), imdecode_backend='pillow', ) -> None: super().__init__( with_bbox=False, with_label=False, with_seg=True, with_keypoints=False, imdecode_backend=imdecode_backend, file_client_args=file_client_args) self.reduce_zero_label = reduce_zero_label self.file_client_args = file_client_args.copy() self.imdecode_backend = imdecode_backend def _load_seg_map(self, results: dict) -> None: """Private function to load semantic segmentation annotations. Args: results (dict): Result dict from :obj:``mmcv.BaseDataset``. Returns: dict: The dict contains loaded semantic segmentation annotations. """ img_bytes = self.file_client.get(results['seg_map_path']) gt_semantic_seg = mmcv.imfrombytes( img_bytes, flag='unchanged', backend=self.imdecode_backend).squeeze().astype(np.uint8) # modify if custom classes if results.get('label_map', None) is not None: # Add deep copy to solve bug of repeatedly # replace `gt_semantic_seg`, which is reported in # https://github.com/open-mmlab/mmsegmentation/pull/1445/ gt_semantic_seg_copy = gt_semantic_seg.copy() for old_id, new_id in results['label_map'].items(): gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id # reduce zero_label if self.reduce_zero_label: # avoid using underflow conversion gt_semantic_seg[gt_semantic_seg == 0] = 255 gt_semantic_seg = gt_semantic_seg - 1 gt_semantic_seg[gt_semantic_seg == 254] = 255 results['gt_seg_map'] = gt_semantic_seg results['seg_fields'].append('gt_seg_map') def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f'(reduce_zero_label={self.reduce_zero_label},' repr_str += f"imdecode_backend='{self.imdecode_backend}')" repr_str += f'file_client_args={self.file_client_args})' return repr_str