mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature]Generate Edge for dataset (#2210)
* [WIP]Generate Edge for dataset * add ut * add repr * add inti
This commit is contained in:
parent
25604a151b
commit
ac9ee8c355
@ -15,7 +15,7 @@ from .night_driving import NightDrivingDataset
|
||||
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
||||
from .potsdam import PotsdamDataset
|
||||
from .stare import STAREDataset
|
||||
from .transforms import (CLAHE, AdjustGamma, LoadAnnotations,
|
||||
from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
|
||||
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
||||
@ -33,5 +33,5 @@ __all__ = [
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData'
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge'
|
||||
]
|
||||
|
@ -3,14 +3,15 @@ from .formatting import PackSegInputs
|
||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray)
|
||||
from .transforms import (CLAHE, AdjustGamma, PhotoMetricDistortion, RandomCrop,
|
||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||
ResizeToMultiple, RGB2Gray, SegRescale)
|
||||
from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, Rerange, ResizeToMultiple,
|
||||
RGB2Gray, SegRescale)
|
||||
|
||||
__all__ = [
|
||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData'
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge'
|
||||
]
|
||||
|
@ -1,7 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Sequence, Tuple, Union
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
@ -1147,3 +1148,81 @@ class RandomMosaic(BaseTransform):
|
||||
repr_str += f'pad_val={self.pad_val}, '
|
||||
repr_str += f'seg_pad_val={self.pad_val})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class GenerateEdge(BaseTransform):
|
||||
"""Generate Edge for CE2P approach.
|
||||
|
||||
Edge will be used to calculate loss of
|
||||
`CE2P <https://arxiv.org/abs/1809.05996>`_.
|
||||
|
||||
Modified from https://github.com/liutinglt/CE2P/blob/master/dataset/target_generation.py # noqa:E501
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_shape
|
||||
- gt_seg_map
|
||||
|
||||
Added Keys:
|
||||
- gt_edge (np.ndarray, uint8): The edge annotation generated from the
|
||||
seg map by extracting border between different semantics.
|
||||
|
||||
Args:
|
||||
edge_width (int): The width of edge. Default to 3.
|
||||
ignore_index (int): Index that will be ignored. Default to 255.
|
||||
"""
|
||||
|
||||
def __init__(self, edge_width: int = 3, ignore_index: int = 255) -> None:
|
||||
super().__init__()
|
||||
self.edge_width = edge_width
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Call function to generate edge from segmentation map.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict.
|
||||
|
||||
Returns:
|
||||
dict: Result dict with edge mask.
|
||||
"""
|
||||
h, w = results['img_shape']
|
||||
edge = np.zeros((h, w), dtype=np.uint8)
|
||||
seg_map = results['gt_seg_map']
|
||||
|
||||
# down
|
||||
edge_down = edge[1:h, :]
|
||||
edge_down[(seg_map[1:h, :] != seg_map[:h - 1, :])
|
||||
& (seg_map[1:h, :] != self.ignore_index) &
|
||||
(seg_map[:h - 1, :] != self.ignore_index)] = 1
|
||||
# left
|
||||
edge_left = edge[:, :w - 1]
|
||||
edge_left[(seg_map[:, :w - 1] != seg_map[:, 1:w])
|
||||
& (seg_map[:, :w - 1] != self.ignore_index) &
|
||||
(seg_map[:, 1:w] != self.ignore_index)] = 1
|
||||
# up_left
|
||||
edge_upleft = edge[:h - 1, :w - 1]
|
||||
edge_upleft[(seg_map[:h - 1, :w - 1] != seg_map[1:h, 1:w])
|
||||
& (seg_map[:h - 1, :w - 1] != self.ignore_index) &
|
||||
(seg_map[1:h, 1:w] != self.ignore_index)] = 1
|
||||
# up_right
|
||||
edge_upright = edge[:h - 1, 1:w]
|
||||
edge_upright[(seg_map[:h - 1, 1:w] != seg_map[1:h, :w - 1])
|
||||
& (seg_map[:h - 1, 1:w] != self.ignore_index) &
|
||||
(seg_map[1:h, :w - 1] != self.ignore_index)] = 1
|
||||
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,
|
||||
(self.edge_width, self.edge_width))
|
||||
edge = cv2.dilate(edge, kernel)
|
||||
|
||||
results['gt_edge'] = edge
|
||||
results['edge_width'] = self.edge_width
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'edge_width={self.edge_width}, '
|
||||
repr_str += f'ignore_index={self.ignore_index})'
|
||||
return repr_str
|
||||
|
@ -679,3 +679,30 @@ def test_resize_to_multiple():
|
||||
assert results['img'].shape == (224, 256, 3)
|
||||
assert results['gt_semantic_seg'].shape == (224, 256)
|
||||
assert results['img_shape'] == (224, 256)
|
||||
|
||||
|
||||
def test_generate_edge():
|
||||
transform = dict(type='GenerateEdge', edge_width=1)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
|
||||
seg_map = np.array([
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 2],
|
||||
[1, 1, 1, 2, 2],
|
||||
[1, 1, 2, 2, 2],
|
||||
[1, 2, 2, 2, 2],
|
||||
[2, 2, 2, 2, 2],
|
||||
])
|
||||
results = dict()
|
||||
results['gt_seg_map'] = seg_map
|
||||
results['img_shape'] = seg_map.shape
|
||||
|
||||
results = transform(results)
|
||||
assert np.all(results['gt_edge'] == np.array([
|
||||
[0, 0, 0, 1, 0],
|
||||
[0, 0, 1, 1, 1],
|
||||
[0, 1, 1, 1, 0],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0],
|
||||
]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user