diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index d4d017fed..19eda98b1 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -15,7 +15,7 @@ from .night_driving import NightDrivingDataset from .pascal_context import PascalContextDataset, PascalContextDataset59 from .potsdam import PotsdamDataset from .stare import STAREDataset -from .transforms import (CLAHE, AdjustGamma, LoadAnnotations, +from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations, LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadImageFromNDArray, PackSegInputs, PhotoMetricDistortion, RandomCrop, @@ -33,5 +33,5 @@ __all__ = [ 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', - 'LoadBiomedicalAnnotation', 'LoadBiomedicalData' + 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge' ] diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py index f46aa8fba..09f6c655a 100644 --- a/mmseg/datasets/transforms/__init__.py +++ b/mmseg/datasets/transforms/__init__.py @@ -3,14 +3,15 @@ from .formatting import PackSegInputs from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadImageFromNDArray) -from .transforms import (CLAHE, AdjustGamma, PhotoMetricDistortion, RandomCrop, - RandomCutOut, RandomMosaic, RandomRotate, Rerange, - ResizeToMultiple, RGB2Gray, SegRescale) +from .transforms import (CLAHE, AdjustGamma, GenerateEdge, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomMosaic, RandomRotate, Rerange, ResizeToMultiple, + RGB2Gray, SegRescale) __all__ = [ 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', - 'LoadBiomedicalAnnotation', 'LoadBiomedicalData' + 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge' ] diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 3cb173539..46d3a66e0 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Sequence, Tuple, Union +from typing import Dict, Sequence, Tuple, Union +import cv2 import mmcv import numpy as np from mmcv.transforms.base import BaseTransform @@ -1147,3 +1148,81 @@ class RandomMosaic(BaseTransform): repr_str += f'pad_val={self.pad_val}, ' repr_str += f'seg_pad_val={self.pad_val})' return repr_str + + +@TRANSFORMS.register_module() +class GenerateEdge(BaseTransform): + """Generate Edge for CE2P approach. + + Edge will be used to calculate loss of + `CE2P `_. + + Modified from https://github.com/liutinglt/CE2P/blob/master/dataset/target_generation.py # noqa:E501 + + Required Keys: + + - img_shape + - gt_seg_map + + Added Keys: + - gt_edge (np.ndarray, uint8): The edge annotation generated from the + seg map by extracting border between different semantics. + + Args: + edge_width (int): The width of edge. Default to 3. + ignore_index (int): Index that will be ignored. Default to 255. + """ + + def __init__(self, edge_width: int = 3, ignore_index: int = 255) -> None: + super().__init__() + self.edge_width = edge_width + self.ignore_index = ignore_index + + def transform(self, results: Dict) -> Dict: + """Call function to generate edge from segmentation map. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with edge mask. + """ + h, w = results['img_shape'] + edge = np.zeros((h, w), dtype=np.uint8) + seg_map = results['gt_seg_map'] + + # down + edge_down = edge[1:h, :] + edge_down[(seg_map[1:h, :] != seg_map[:h - 1, :]) + & (seg_map[1:h, :] != self.ignore_index) & + (seg_map[:h - 1, :] != self.ignore_index)] = 1 + # left + edge_left = edge[:, :w - 1] + edge_left[(seg_map[:, :w - 1] != seg_map[:, 1:w]) + & (seg_map[:, :w - 1] != self.ignore_index) & + (seg_map[:, 1:w] != self.ignore_index)] = 1 + # up_left + edge_upleft = edge[:h - 1, :w - 1] + edge_upleft[(seg_map[:h - 1, :w - 1] != seg_map[1:h, 1:w]) + & (seg_map[:h - 1, :w - 1] != self.ignore_index) & + (seg_map[1:h, 1:w] != self.ignore_index)] = 1 + # up_right + edge_upright = edge[:h - 1, 1:w] + edge_upright[(seg_map[:h - 1, 1:w] != seg_map[1:h, :w - 1]) + & (seg_map[:h - 1, 1:w] != self.ignore_index) & + (seg_map[1:h, :w - 1] != self.ignore_index)] = 1 + + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, + (self.edge_width, self.edge_width)) + edge = cv2.dilate(edge, kernel) + + results['gt_edge'] = edge + results['edge_width'] = self.edge_width + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'edge_width={self.edge_width}, ' + repr_str += f'ignore_index={self.ignore_index})' + return repr_str diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py index f314050b7..bd9c05ac4 100644 --- a/tests/test_datasets/test_transform.py +++ b/tests/test_datasets/test_transform.py @@ -679,3 +679,30 @@ def test_resize_to_multiple(): assert results['img'].shape == (224, 256, 3) assert results['gt_semantic_seg'].shape == (224, 256) assert results['img_shape'] == (224, 256) + + +def test_generate_edge(): + transform = dict(type='GenerateEdge', edge_width=1) + transform = TRANSFORMS.build(transform) + + seg_map = np.array([ + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 2], + [1, 1, 1, 2, 2], + [1, 1, 2, 2, 2], + [1, 2, 2, 2, 2], + [2, 2, 2, 2, 2], + ]) + results = dict() + results['gt_seg_map'] = seg_map + results['img_shape'] = seg_map.shape + + results = transform(results) + assert np.all(results['gt_edge'] == np.array([ + [0, 0, 0, 1, 0], + [0, 0, 1, 1, 1], + [0, 1, 1, 1, 0], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + ]))