[Feature] Add Biomedical 3D array random crop transform (#2378)

* [Feature] Add Biomedical 3D array random crop transform

* fix lint

* fix gen crop bbox

* fix gen crop bbox

* docstring

* typo

Co-authored-by: MeowZheng <meowzheng@outlook.com>
This commit is contained in:
Jin Ye 2022-12-30 11:49:32 +08:00 committed by MeowZheng
parent ad99ad1444
commit 79e8578bfc
4 changed files with 259 additions and 16 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
@ -17,7 +18,8 @@ from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
@ -26,15 +28,18 @@ from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
SegRescale)
from .voc import PascalVOCDataset
# yapf: enable
__all__ = [
'BaseSegDataset', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'LoadAnnotations',
'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'

View File

@ -3,17 +3,17 @@ from .formatting import PackSegInputs
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
LoadBiomedicalData, LoadBiomedicalImageFromFile,
LoadImageFromNDArray)
from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, Rerange,
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
GenerateEdge, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
__all__ = [
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge'
]

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import Dict, Sequence, Tuple, Union
import cv2
@ -1310,3 +1311,199 @@ class ResizeShortestEdge(BaseTransform):
def transform(self, results: Dict) -> Dict:
self.resize.scale = self._get_output_shape(results['img'], self.scale)
return self.resize(results)
@TRANSFORMS.register_module()
class BioMedical3DRandomCrop(BaseTransform):
"""Crop the input patch for medical image & segmentation mask.
Required Keys:
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
N is the number of modalities, and data type is float32.
- gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask
with shape (Z, Y, X).
Modified Keys:
- img
- img_shape
- gt_seg_map (optional)
Args:
crop_shape (Union[int, Tuple[int, int, int]]): Expected size after
cropping with the format of (z, y, x). If set to an integer,
then cropping width and height are equal to this integer.
keep_foreground (bool): If keep_foreground is True, it will sample a
voxel of foreground classes randomly, and will take it as the
center of the crop bounding-box. Default to True.
"""
def __init__(self,
crop_shape: Union[int, Tuple[int, int, int]],
keep_foreground: bool = True):
super().__init__()
assert isinstance(crop_shape, int) or (
isinstance(crop_shape, tuple) and len(crop_shape) == 3
), 'The expected crop_shape is an integer, or a tuple containing '
'three integers'
if isinstance(crop_shape, int):
crop_shape = (crop_shape, crop_shape, crop_shape)
assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0
self.crop_shape = crop_shape
self.keep_foreground = keep_foreground
def random_sample_location(self, seg_map: np.ndarray) -> dict:
"""sample foreground voxel when keep_foreground is True.
Args:
seg_map (np.ndarray): gt seg map.
Returns:
dict: Coordinates of selected foreground voxel.
"""
num_samples = 10000
# at least 1% of the class voxels need to be selected,
# otherwise it may be too sparse
min_percent_coverage = 0.01
class_locs = {}
foreground_classes = []
all_classes = np.unique(seg_map)
for c in all_classes:
if c == 0:
# to avoid the segmentation mask full of background 0
# and the class_locs is just void dictionary {} when it return
# there add a void list for background 0.
class_locs[c] = []
else:
all_locs = np.argwhere(seg_map == c)
target_num_samples = min(num_samples, len(all_locs))
target_num_samples = max(
target_num_samples,
int(np.ceil(len(all_locs) * min_percent_coverage)))
selected = all_locs[np.random.choice(
len(all_locs), target_num_samples, replace=False)]
class_locs[c] = selected
foreground_classes.append(c)
selected_voxel = None
if len(foreground_classes) > 0:
selected_class = np.random.choice(foreground_classes)
voxels_of_that_class = class_locs[selected_class]
selected_voxel = voxels_of_that_class[np.random.choice(
len(voxels_of_that_class))]
return selected_voxel
def random_generate_crop_bbox(self, margin_z: int, margin_y: int,
margin_x: int) -> tuple:
"""Randomly get a crop bounding box.
Args:
seg_map (np.ndarray): Ground truth segmentation map.
Returns:
tuple: Coordinates of the cropped image.
"""
offset_z = np.random.randint(0, margin_z + 1)
offset_y = np.random.randint(0, margin_y + 1)
offset_x = np.random.randint(0, margin_x + 1)
crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0]
crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1]
crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2]
return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2
def generate_margin(self, results: dict) -> tuple:
"""Generate margin of crop bounding-box.
If keep_foreground is True, it will sample a voxel of foreground
classes randomly, and will take it as the center of the bounding-box,
and return the margin between of the bounding-box and image.
If keep_foreground is False, it will return the difference from crop
shape and image shape.
Args:
results (dict): Result dict from loading pipeline.
Returns:
tuple: The margin for 3 dimensions of crop bounding-box and image.
"""
seg_map = results['gt_seg_map']
if self.keep_foreground:
selected_voxel = self.random_sample_location(seg_map)
if selected_voxel is None:
# this only happens if some image does not contain
# foreground voxels at all
warnings.warn(f'case does not contain any foreground classes'
f': {results["img_path"]}')
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
else:
margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2)
margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2)
margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2)
margin_z = max(
0, min(seg_map.shape[0] - self.crop_shape[0], margin_z))
margin_y = max(
0, min(seg_map.shape[1] - self.crop_shape[1], margin_y))
margin_x = max(
0, min(seg_map.shape[2] - self.crop_shape[2], margin_x))
else:
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
return margin_z, margin_y, margin_x
def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
"""Crop from ``img``
Args:
img (np.ndarray): Original input image.
crop_bbox (tuple): Coordinates of the cropped image.
Returns:
np.ndarray: The cropped image.
"""
crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
if len(img.shape) == 3:
# crop seg map
img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
else:
# crop image
assert len(img.shape) == 4
img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
return img
def transform(self, results: dict) -> dict:
"""Transform function to randomly crop images, semantic segmentation
maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
margin = self.generate_margin(results)
crop_bbox = self.random_generate_crop_bbox(*margin)
# crop the image
img = results['img']
results['img'] = self.crop(img, crop_bbox)
results['img_shape'] = results['img'].shape[1:]
# crop semantic seg
seg_map = results['gt_seg_map']
results['gt_seg_map'] = self.crop(seg_map, crop_bbox)
return results
def __repr__(self):
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'

View File

@ -737,3 +737,44 @@ def test_generate_edge():
[1, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
]))
def test_biomedical3d_random_crop():
# test assertion for invalid random crop
with pytest.raises(AssertionError):
transform = dict(type='BioMedical3DRandomCrop', crop_shape=(-2, -1, 0))
transform = TRANSFORMS.build(transform)
from mmseg.datasets.transforms import (LoadBiomedicalAnnotation,
LoadBiomedicalImageFromFile)
results = dict()
results['img_path'] = osp.join(
osp.dirname(__file__), '../data', 'biomedical.nii.gz')
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
results['seg_map_path'] = osp.join(
osp.dirname(__file__), '../data', 'biomedical_ann.nii.gz')
transform = LoadBiomedicalAnnotation()
results = transform(copy.deepcopy(results))
d, h, w = results['img_shape']
transform = dict(
type='BioMedical3DRandomCrop',
crop_shape=(d - 20, h - 20, w - 20),
keep_foreground=True)
transform = TRANSFORMS.build(transform)
crop_results = transform(results)
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
transform = dict(
type='BioMedical3DRandomCrop',
crop_shape=(d - 20, h - 20, w - 20),
keep_foreground=False)
transform = TRANSFORMS.build(transform)
crop_results = transform(results)
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)