mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add Biomedical 3D array random crop transform (#2378)
* [Feature] Add Biomedical 3D array random crop transform * fix lint * fix gen crop bbox * fix gen crop bbox * docstring * typo Co-authored-by: MeowZheng <meowzheng@outlook.com>
This commit is contained in:
parent
ad99ad1444
commit
79e8578bfc
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# yapf: disable
|
||||
from .ade import ADE20KDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
from .chase_db1 import ChaseDB1Dataset
|
||||
@ -17,7 +18,8 @@ from .night_driving import NightDrivingDataset
|
||||
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
||||
from .potsdam import PotsdamDataset
|
||||
from .stare import STAREDataset
|
||||
from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
|
||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||
GenerateEdge, LoadAnnotations,
|
||||
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
||||
@ -26,15 +28,18 @@ from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
|
||||
SegRescale)
|
||||
from .voc import PascalVOCDataset
|
||||
|
||||
# yapf: enable
|
||||
|
||||
__all__ = [
|
||||
'BaseSegDataset', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
|
||||
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
|
||||
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
|
||||
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
|
||||
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
|
||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
|
||||
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
||||
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
|
||||
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
|
||||
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
|
||||
'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'LoadAnnotations',
|
||||
'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
||||
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
|
||||
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
|
||||
|
@ -3,17 +3,17 @@ from .formatting import PackSegInputs
|
||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray)
|
||||
from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, Rerange,
|
||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||
GenerateEdge, PhotoMetricDistortion, RandomCrop,
|
||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
|
||||
__all__ = [
|
||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
|
||||
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
|
||||
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge'
|
||||
]
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import warnings
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
import cv2
|
||||
@ -1310,3 +1311,199 @@ class ResizeShortestEdge(BaseTransform):
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
self.resize.scale = self._get_output_shape(results['img'], self.scale)
|
||||
return self.resize(results)
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class BioMedical3DRandomCrop(BaseTransform):
|
||||
"""Crop the input patch for medical image & segmentation mask.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
||||
N is the number of modalities, and data type is float32.
|
||||
- gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask
|
||||
with shape (Z, Y, X).
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- gt_seg_map (optional)
|
||||
|
||||
Args:
|
||||
crop_shape (Union[int, Tuple[int, int, int]]): Expected size after
|
||||
cropping with the format of (z, y, x). If set to an integer,
|
||||
then cropping width and height are equal to this integer.
|
||||
keep_foreground (bool): If keep_foreground is True, it will sample a
|
||||
voxel of foreground classes randomly, and will take it as the
|
||||
center of the crop bounding-box. Default to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
crop_shape: Union[int, Tuple[int, int, int]],
|
||||
keep_foreground: bool = True):
|
||||
super().__init__()
|
||||
assert isinstance(crop_shape, int) or (
|
||||
isinstance(crop_shape, tuple) and len(crop_shape) == 3
|
||||
), 'The expected crop_shape is an integer, or a tuple containing '
|
||||
'three integers'
|
||||
|
||||
if isinstance(crop_shape, int):
|
||||
crop_shape = (crop_shape, crop_shape, crop_shape)
|
||||
assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0
|
||||
self.crop_shape = crop_shape
|
||||
self.keep_foreground = keep_foreground
|
||||
|
||||
def random_sample_location(self, seg_map: np.ndarray) -> dict:
|
||||
"""sample foreground voxel when keep_foreground is True.
|
||||
|
||||
Args:
|
||||
seg_map (np.ndarray): gt seg map.
|
||||
|
||||
Returns:
|
||||
dict: Coordinates of selected foreground voxel.
|
||||
"""
|
||||
num_samples = 10000
|
||||
# at least 1% of the class voxels need to be selected,
|
||||
# otherwise it may be too sparse
|
||||
min_percent_coverage = 0.01
|
||||
class_locs = {}
|
||||
foreground_classes = []
|
||||
all_classes = np.unique(seg_map)
|
||||
for c in all_classes:
|
||||
if c == 0:
|
||||
# to avoid the segmentation mask full of background 0
|
||||
# and the class_locs is just void dictionary {} when it return
|
||||
# there add a void list for background 0.
|
||||
class_locs[c] = []
|
||||
else:
|
||||
all_locs = np.argwhere(seg_map == c)
|
||||
target_num_samples = min(num_samples, len(all_locs))
|
||||
target_num_samples = max(
|
||||
target_num_samples,
|
||||
int(np.ceil(len(all_locs) * min_percent_coverage)))
|
||||
|
||||
selected = all_locs[np.random.choice(
|
||||
len(all_locs), target_num_samples, replace=False)]
|
||||
class_locs[c] = selected
|
||||
foreground_classes.append(c)
|
||||
|
||||
selected_voxel = None
|
||||
if len(foreground_classes) > 0:
|
||||
selected_class = np.random.choice(foreground_classes)
|
||||
voxels_of_that_class = class_locs[selected_class]
|
||||
selected_voxel = voxels_of_that_class[np.random.choice(
|
||||
len(voxels_of_that_class))]
|
||||
|
||||
return selected_voxel
|
||||
|
||||
def random_generate_crop_bbox(self, margin_z: int, margin_y: int,
|
||||
margin_x: int) -> tuple:
|
||||
"""Randomly get a crop bounding box.
|
||||
|
||||
Args:
|
||||
seg_map (np.ndarray): Ground truth segmentation map.
|
||||
|
||||
Returns:
|
||||
tuple: Coordinates of the cropped image.
|
||||
"""
|
||||
offset_z = np.random.randint(0, margin_z + 1)
|
||||
offset_y = np.random.randint(0, margin_y + 1)
|
||||
offset_x = np.random.randint(0, margin_x + 1)
|
||||
crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0]
|
||||
crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1]
|
||||
crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2]
|
||||
|
||||
return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2
|
||||
|
||||
def generate_margin(self, results: dict) -> tuple:
|
||||
"""Generate margin of crop bounding-box.
|
||||
|
||||
If keep_foreground is True, it will sample a voxel of foreground
|
||||
classes randomly, and will take it as the center of the bounding-box,
|
||||
and return the margin between of the bounding-box and image.
|
||||
If keep_foreground is False, it will return the difference from crop
|
||||
shape and image shape.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: The margin for 3 dimensions of crop bounding-box and image.
|
||||
"""
|
||||
|
||||
seg_map = results['gt_seg_map']
|
||||
if self.keep_foreground:
|
||||
selected_voxel = self.random_sample_location(seg_map)
|
||||
if selected_voxel is None:
|
||||
# this only happens if some image does not contain
|
||||
# foreground voxels at all
|
||||
warnings.warn(f'case does not contain any foreground classes'
|
||||
f': {results["img_path"]}')
|
||||
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
|
||||
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
|
||||
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
|
||||
else:
|
||||
margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2)
|
||||
margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2)
|
||||
margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2)
|
||||
margin_z = max(
|
||||
0, min(seg_map.shape[0] - self.crop_shape[0], margin_z))
|
||||
margin_y = max(
|
||||
0, min(seg_map.shape[1] - self.crop_shape[1], margin_y))
|
||||
margin_x = max(
|
||||
0, min(seg_map.shape[2] - self.crop_shape[2], margin_x))
|
||||
else:
|
||||
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
|
||||
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
|
||||
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
|
||||
|
||||
return margin_z, margin_y, margin_x
|
||||
|
||||
def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
|
||||
"""Crop from ``img``
|
||||
|
||||
Args:
|
||||
img (np.ndarray): Original input image.
|
||||
crop_bbox (tuple): Coordinates of the cropped image.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The cropped image.
|
||||
"""
|
||||
crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
|
||||
if len(img.shape) == 3:
|
||||
# crop seg map
|
||||
img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
|
||||
else:
|
||||
# crop image
|
||||
assert len(img.shape) == 4
|
||||
img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
|
||||
return img
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to randomly crop images, semantic segmentation
|
||||
maps.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Randomly cropped results, 'img_shape' key in result dict is
|
||||
updated according to crop size.
|
||||
"""
|
||||
margin = self.generate_margin(results)
|
||||
crop_bbox = self.random_generate_crop_bbox(*margin)
|
||||
|
||||
# crop the image
|
||||
img = results['img']
|
||||
results['img'] = self.crop(img, crop_bbox)
|
||||
results['img_shape'] = results['img'].shape[1:]
|
||||
|
||||
# crop semantic seg
|
||||
seg_map = results['gt_seg_map']
|
||||
results['gt_seg_map'] = self.crop(seg_map, crop_bbox)
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'
|
||||
|
@ -737,3 +737,44 @@ def test_generate_edge():
|
||||
[1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0],
|
||||
]))
|
||||
|
||||
|
||||
def test_biomedical3d_random_crop():
|
||||
# test assertion for invalid random crop
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='BioMedical3DRandomCrop', crop_shape=(-2, -1, 0))
|
||||
transform = TRANSFORMS.build(transform)
|
||||
|
||||
from mmseg.datasets.transforms import (LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalImageFromFile)
|
||||
results = dict()
|
||||
results['img_path'] = osp.join(
|
||||
osp.dirname(__file__), '../data', 'biomedical.nii.gz')
|
||||
transform = LoadBiomedicalImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
|
||||
results['seg_map_path'] = osp.join(
|
||||
osp.dirname(__file__), '../data', 'biomedical_ann.nii.gz')
|
||||
transform = LoadBiomedicalAnnotation()
|
||||
results = transform(copy.deepcopy(results))
|
||||
|
||||
d, h, w = results['img_shape']
|
||||
transform = dict(
|
||||
type='BioMedical3DRandomCrop',
|
||||
crop_shape=(d - 20, h - 20, w - 20),
|
||||
keep_foreground=True)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
crop_results = transform(results)
|
||||
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
|
||||
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
|
||||
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
|
||||
|
||||
transform = dict(
|
||||
type='BioMedical3DRandomCrop',
|
||||
crop_shape=(d - 20, h - 20, w - 20),
|
||||
keep_foreground=False)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
crop_results = transform(results)
|
||||
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
|
||||
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
|
||||
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
|
||||
|
Loading…
x
Reference in New Issue
Block a user