mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add Biomedical 3D array random crop transform (#2378)
* [Feature] Add Biomedical 3D array random crop transform * fix lint * fix gen crop bbox * fix gen crop bbox * docstring * typo Co-authored-by: MeowZheng <meowzheng@outlook.com>
This commit is contained in:
parent
ad99ad1444
commit
79e8578bfc
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
# yapf: disable
|
||||||
from .ade import ADE20KDataset
|
from .ade import ADE20KDataset
|
||||||
from .basesegdataset import BaseSegDataset
|
from .basesegdataset import BaseSegDataset
|
||||||
from .chase_db1 import ChaseDB1Dataset
|
from .chase_db1 import ChaseDB1Dataset
|
||||||
@ -17,7 +18,8 @@ from .night_driving import NightDrivingDataset
|
|||||||
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
||||||
from .potsdam import PotsdamDataset
|
from .potsdam import PotsdamDataset
|
||||||
from .stare import STAREDataset
|
from .stare import STAREDataset
|
||||||
from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
|
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||||
|
GenerateEdge, LoadAnnotations,
|
||||||
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||||
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
||||||
@ -26,15 +28,18 @@ from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
|
|||||||
SegRescale)
|
SegRescale)
|
||||||
from .voc import PascalVOCDataset
|
from .voc import PascalVOCDataset
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseSegDataset', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
|
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
|
||||||
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
|
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
||||||
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
|
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
|
||||||
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
|
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
|
||||||
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
|
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
|
||||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'LoadAnnotations',
|
||||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
||||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
|
||||||
|
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
|
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
|
||||||
|
@ -3,17 +3,17 @@ from .formatting import PackSegInputs
|
|||||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||||
LoadImageFromNDArray)
|
LoadImageFromNDArray)
|
||||||
from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
|
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
GenerateEdge, PhotoMetricDistortion, RandomCrop,
|
||||||
RandomMosaic, RandomRotate, Rerange,
|
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||||
SegRescale)
|
SegRescale)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
|
||||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
|
||||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||||
'ResizeShortestEdge'
|
'ResizeShortestEdge'
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
|
import warnings
|
||||||
from typing import Dict, Sequence, Tuple, Union
|
from typing import Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -1310,3 +1311,199 @@ class ResizeShortestEdge(BaseTransform):
|
|||||||
def transform(self, results: Dict) -> Dict:
|
def transform(self, results: Dict) -> Dict:
|
||||||
self.resize.scale = self._get_output_shape(results['img'], self.scale)
|
self.resize.scale = self._get_output_shape(results['img'], self.scale)
|
||||||
return self.resize(results)
|
return self.resize(results)
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class BioMedical3DRandomCrop(BaseTransform):
|
||||||
|
"""Crop the input patch for medical image & segmentation mask.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
||||||
|
N is the number of modalities, and data type is float32.
|
||||||
|
- gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask
|
||||||
|
with shape (Z, Y, X).
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
- img_shape
|
||||||
|
- gt_seg_map (optional)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
crop_shape (Union[int, Tuple[int, int, int]]): Expected size after
|
||||||
|
cropping with the format of (z, y, x). If set to an integer,
|
||||||
|
then cropping width and height are equal to this integer.
|
||||||
|
keep_foreground (bool): If keep_foreground is True, it will sample a
|
||||||
|
voxel of foreground classes randomly, and will take it as the
|
||||||
|
center of the crop bounding-box. Default to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
crop_shape: Union[int, Tuple[int, int, int]],
|
||||||
|
keep_foreground: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(crop_shape, int) or (
|
||||||
|
isinstance(crop_shape, tuple) and len(crop_shape) == 3
|
||||||
|
), 'The expected crop_shape is an integer, or a tuple containing '
|
||||||
|
'three integers'
|
||||||
|
|
||||||
|
if isinstance(crop_shape, int):
|
||||||
|
crop_shape = (crop_shape, crop_shape, crop_shape)
|
||||||
|
assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0
|
||||||
|
self.crop_shape = crop_shape
|
||||||
|
self.keep_foreground = keep_foreground
|
||||||
|
|
||||||
|
def random_sample_location(self, seg_map: np.ndarray) -> dict:
|
||||||
|
"""sample foreground voxel when keep_foreground is True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seg_map (np.ndarray): gt seg map.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Coordinates of selected foreground voxel.
|
||||||
|
"""
|
||||||
|
num_samples = 10000
|
||||||
|
# at least 1% of the class voxels need to be selected,
|
||||||
|
# otherwise it may be too sparse
|
||||||
|
min_percent_coverage = 0.01
|
||||||
|
class_locs = {}
|
||||||
|
foreground_classes = []
|
||||||
|
all_classes = np.unique(seg_map)
|
||||||
|
for c in all_classes:
|
||||||
|
if c == 0:
|
||||||
|
# to avoid the segmentation mask full of background 0
|
||||||
|
# and the class_locs is just void dictionary {} when it return
|
||||||
|
# there add a void list for background 0.
|
||||||
|
class_locs[c] = []
|
||||||
|
else:
|
||||||
|
all_locs = np.argwhere(seg_map == c)
|
||||||
|
target_num_samples = min(num_samples, len(all_locs))
|
||||||
|
target_num_samples = max(
|
||||||
|
target_num_samples,
|
||||||
|
int(np.ceil(len(all_locs) * min_percent_coverage)))
|
||||||
|
|
||||||
|
selected = all_locs[np.random.choice(
|
||||||
|
len(all_locs), target_num_samples, replace=False)]
|
||||||
|
class_locs[c] = selected
|
||||||
|
foreground_classes.append(c)
|
||||||
|
|
||||||
|
selected_voxel = None
|
||||||
|
if len(foreground_classes) > 0:
|
||||||
|
selected_class = np.random.choice(foreground_classes)
|
||||||
|
voxels_of_that_class = class_locs[selected_class]
|
||||||
|
selected_voxel = voxels_of_that_class[np.random.choice(
|
||||||
|
len(voxels_of_that_class))]
|
||||||
|
|
||||||
|
return selected_voxel
|
||||||
|
|
||||||
|
def random_generate_crop_bbox(self, margin_z: int, margin_y: int,
|
||||||
|
margin_x: int) -> tuple:
|
||||||
|
"""Randomly get a crop bounding box.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seg_map (np.ndarray): Ground truth segmentation map.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Coordinates of the cropped image.
|
||||||
|
"""
|
||||||
|
offset_z = np.random.randint(0, margin_z + 1)
|
||||||
|
offset_y = np.random.randint(0, margin_y + 1)
|
||||||
|
offset_x = np.random.randint(0, margin_x + 1)
|
||||||
|
crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0]
|
||||||
|
crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1]
|
||||||
|
crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2]
|
||||||
|
|
||||||
|
return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2
|
||||||
|
|
||||||
|
def generate_margin(self, results: dict) -> tuple:
|
||||||
|
"""Generate margin of crop bounding-box.
|
||||||
|
|
||||||
|
If keep_foreground is True, it will sample a voxel of foreground
|
||||||
|
classes randomly, and will take it as the center of the bounding-box,
|
||||||
|
and return the margin between of the bounding-box and image.
|
||||||
|
If keep_foreground is False, it will return the difference from crop
|
||||||
|
shape and image shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The margin for 3 dimensions of crop bounding-box and image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
seg_map = results['gt_seg_map']
|
||||||
|
if self.keep_foreground:
|
||||||
|
selected_voxel = self.random_sample_location(seg_map)
|
||||||
|
if selected_voxel is None:
|
||||||
|
# this only happens if some image does not contain
|
||||||
|
# foreground voxels at all
|
||||||
|
warnings.warn(f'case does not contain any foreground classes'
|
||||||
|
f': {results["img_path"]}')
|
||||||
|
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
|
||||||
|
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
|
||||||
|
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
|
||||||
|
else:
|
||||||
|
margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2)
|
||||||
|
margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2)
|
||||||
|
margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2)
|
||||||
|
margin_z = max(
|
||||||
|
0, min(seg_map.shape[0] - self.crop_shape[0], margin_z))
|
||||||
|
margin_y = max(
|
||||||
|
0, min(seg_map.shape[1] - self.crop_shape[1], margin_y))
|
||||||
|
margin_x = max(
|
||||||
|
0, min(seg_map.shape[2] - self.crop_shape[2], margin_x))
|
||||||
|
else:
|
||||||
|
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
|
||||||
|
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
|
||||||
|
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
|
||||||
|
|
||||||
|
return margin_z, margin_y, margin_x
|
||||||
|
|
||||||
|
def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
|
||||||
|
"""Crop from ``img``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): Original input image.
|
||||||
|
crop_bbox (tuple): Coordinates of the cropped image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: The cropped image.
|
||||||
|
"""
|
||||||
|
crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
|
||||||
|
if len(img.shape) == 3:
|
||||||
|
# crop seg map
|
||||||
|
img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
|
||||||
|
else:
|
||||||
|
# crop image
|
||||||
|
assert len(img.shape) == 4
|
||||||
|
img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
|
||||||
|
return img
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
|
"""Transform function to randomly crop images, semantic segmentation
|
||||||
|
maps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Randomly cropped results, 'img_shape' key in result dict is
|
||||||
|
updated according to crop size.
|
||||||
|
"""
|
||||||
|
margin = self.generate_margin(results)
|
||||||
|
crop_bbox = self.random_generate_crop_bbox(*margin)
|
||||||
|
|
||||||
|
# crop the image
|
||||||
|
img = results['img']
|
||||||
|
results['img'] = self.crop(img, crop_bbox)
|
||||||
|
results['img_shape'] = results['img'].shape[1:]
|
||||||
|
|
||||||
|
# crop semantic seg
|
||||||
|
seg_map = results['gt_seg_map']
|
||||||
|
results['gt_seg_map'] = self.crop(seg_map, crop_bbox)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'
|
||||||
|
@ -737,3 +737,44 @@ def test_generate_edge():
|
|||||||
[1, 1, 0, 0, 0],
|
[1, 1, 0, 0, 0],
|
||||||
[1, 0, 0, 0, 0],
|
[1, 0, 0, 0, 0],
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_biomedical3d_random_crop():
|
||||||
|
# test assertion for invalid random crop
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='BioMedical3DRandomCrop', crop_shape=(-2, -1, 0))
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
from mmseg.datasets.transforms import (LoadBiomedicalAnnotation,
|
||||||
|
LoadBiomedicalImageFromFile)
|
||||||
|
results = dict()
|
||||||
|
results['img_path'] = osp.join(
|
||||||
|
osp.dirname(__file__), '../data', 'biomedical.nii.gz')
|
||||||
|
transform = LoadBiomedicalImageFromFile()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
|
||||||
|
results['seg_map_path'] = osp.join(
|
||||||
|
osp.dirname(__file__), '../data', 'biomedical_ann.nii.gz')
|
||||||
|
transform = LoadBiomedicalAnnotation()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
|
||||||
|
d, h, w = results['img_shape']
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedical3DRandomCrop',
|
||||||
|
crop_shape=(d - 20, h - 20, w - 20),
|
||||||
|
keep_foreground=True)
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
crop_results = transform(results)
|
||||||
|
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
|
||||||
|
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
|
||||||
|
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedical3DRandomCrop',
|
||||||
|
crop_shape=(d - 20, h - 20, w - 20),
|
||||||
|
keep_foreground=False)
|
||||||
|
transform = TRANSFORMS.build(transform)
|
||||||
|
crop_results = transform(results)
|
||||||
|
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
|
||||||
|
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
|
||||||
|
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user