From 79e8578bfcafc9abf83978da1933d260ea333a84 Mon Sep 17 00:00:00 2001 From: Jin Ye <49543070+Yejin0111@users.noreply.github.com> Date: Fri, 30 Dec 2022 11:49:32 +0800 Subject: [PATCH] [Feature] Add Biomedical 3D array random crop transform (#2378) * [Feature] Add Biomedical 3D array random crop transform * fix lint * fix gen crop bbox * fix gen crop bbox * docstring * typo Co-authored-by: MeowZheng --- mmseg/datasets/__init__.py | 23 +-- mmseg/datasets/transforms/__init__.py | 14 +- mmseg/datasets/transforms/transforms.py | 197 ++++++++++++++++++++++++ tests/test_datasets/test_transform.py | 41 +++++ 4 files changed, 259 insertions(+), 16 deletions(-) diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index bd04a5a67..58f71b62a 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +# yapf: disable from .ade import ADE20KDataset from .basesegdataset import BaseSegDataset from .chase_db1 import ChaseDB1Dataset @@ -17,7 +18,8 @@ from .night_driving import NightDrivingDataset from .pascal_context import PascalContextDataset, PascalContextDataset59 from .potsdam import PotsdamDataset from .stare import STAREDataset -from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations, +from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, + GenerateEdge, LoadAnnotations, LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadImageFromNDArray, PackSegInputs, PhotoMetricDistortion, RandomCrop, @@ -26,15 +28,18 @@ from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations, SegRescale) from .voc import PascalVOCDataset +# yapf: enable + __all__ = [ - 'BaseSegDataset', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset', - 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', - 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', - 'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset', - 'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', - 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', - 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', - 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', + 'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset', + 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', + 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', + 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', + 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', + 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'LoadAnnotations', + 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', + 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge' diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py index 656f806e1..7f67acec0 100644 --- a/mmseg/datasets/transforms/__init__.py +++ b/mmseg/datasets/transforms/__init__.py @@ -3,17 +3,17 @@ from .formatting import PackSegInputs from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadImageFromNDArray) -from .transforms import (CLAHE, AdjustGamma, GenerateEdge, - PhotoMetricDistortion, RandomCrop, RandomCutOut, - RandomMosaic, RandomRotate, Rerange, +from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, + GenerateEdge, PhotoMetricDistortion, RandomCrop, + RandomCutOut, RandomMosaic, RandomRotate, Rerange, ResizeShortestEdge, ResizeToMultiple, RGB2Gray, SegRescale) __all__ = [ - 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', - 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', - 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', - 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', + 'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale', + 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', + 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', + 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 'ResizeShortestEdge' ] diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 772eddb4a..5d1173f25 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import warnings from typing import Dict, Sequence, Tuple, Union import cv2 @@ -1310,3 +1311,199 @@ class ResizeShortestEdge(BaseTransform): def transform(self, results: Dict) -> Dict: self.resize.scale = self._get_output_shape(results['img'], self.scale) return self.resize(results) + + +@TRANSFORMS.register_module() +class BioMedical3DRandomCrop(BaseTransform): + """Crop the input patch for medical image & segmentation mask. + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + - gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask + with shape (Z, Y, X). + + Modified Keys: + + - img + - img_shape + - gt_seg_map (optional) + + Args: + crop_shape (Union[int, Tuple[int, int, int]]): Expected size after + cropping with the format of (z, y, x). If set to an integer, + then cropping width and height are equal to this integer. + keep_foreground (bool): If keep_foreground is True, it will sample a + voxel of foreground classes randomly, and will take it as the + center of the crop bounding-box. Default to True. + """ + + def __init__(self, + crop_shape: Union[int, Tuple[int, int, int]], + keep_foreground: bool = True): + super().__init__() + assert isinstance(crop_shape, int) or ( + isinstance(crop_shape, tuple) and len(crop_shape) == 3 + ), 'The expected crop_shape is an integer, or a tuple containing ' + 'three integers' + + if isinstance(crop_shape, int): + crop_shape = (crop_shape, crop_shape, crop_shape) + assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0 + self.crop_shape = crop_shape + self.keep_foreground = keep_foreground + + def random_sample_location(self, seg_map: np.ndarray) -> dict: + """sample foreground voxel when keep_foreground is True. + + Args: + seg_map (np.ndarray): gt seg map. + + Returns: + dict: Coordinates of selected foreground voxel. + """ + num_samples = 10000 + # at least 1% of the class voxels need to be selected, + # otherwise it may be too sparse + min_percent_coverage = 0.01 + class_locs = {} + foreground_classes = [] + all_classes = np.unique(seg_map) + for c in all_classes: + if c == 0: + # to avoid the segmentation mask full of background 0 + # and the class_locs is just void dictionary {} when it return + # there add a void list for background 0. + class_locs[c] = [] + else: + all_locs = np.argwhere(seg_map == c) + target_num_samples = min(num_samples, len(all_locs)) + target_num_samples = max( + target_num_samples, + int(np.ceil(len(all_locs) * min_percent_coverage))) + + selected = all_locs[np.random.choice( + len(all_locs), target_num_samples, replace=False)] + class_locs[c] = selected + foreground_classes.append(c) + + selected_voxel = None + if len(foreground_classes) > 0: + selected_class = np.random.choice(foreground_classes) + voxels_of_that_class = class_locs[selected_class] + selected_voxel = voxels_of_that_class[np.random.choice( + len(voxels_of_that_class))] + + return selected_voxel + + def random_generate_crop_bbox(self, margin_z: int, margin_y: int, + margin_x: int) -> tuple: + """Randomly get a crop bounding box. + + Args: + seg_map (np.ndarray): Ground truth segmentation map. + + Returns: + tuple: Coordinates of the cropped image. + """ + offset_z = np.random.randint(0, margin_z + 1) + offset_y = np.random.randint(0, margin_y + 1) + offset_x = np.random.randint(0, margin_x + 1) + crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0] + crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1] + crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2] + + return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 + + def generate_margin(self, results: dict) -> tuple: + """Generate margin of crop bounding-box. + + If keep_foreground is True, it will sample a voxel of foreground + classes randomly, and will take it as the center of the bounding-box, + and return the margin between of the bounding-box and image. + If keep_foreground is False, it will return the difference from crop + shape and image shape. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + tuple: The margin for 3 dimensions of crop bounding-box and image. + """ + + seg_map = results['gt_seg_map'] + if self.keep_foreground: + selected_voxel = self.random_sample_location(seg_map) + if selected_voxel is None: + # this only happens if some image does not contain + # foreground voxels at all + warnings.warn(f'case does not contain any foreground classes' + f': {results["img_path"]}') + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) + else: + margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2) + margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2) + margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2) + margin_z = max( + 0, min(seg_map.shape[0] - self.crop_shape[0], margin_z)) + margin_y = max( + 0, min(seg_map.shape[1] - self.crop_shape[1], margin_y)) + margin_x = max( + 0, min(seg_map.shape[2] - self.crop_shape[2], margin_x)) + else: + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) + + return margin_z, margin_y, margin_x + + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img`` + + Args: + img (np.ndarray): Original input image. + crop_bbox (tuple): Coordinates of the cropped image. + + Returns: + np.ndarray: The cropped image. + """ + crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + if len(img.shape) == 3: + # crop seg map + img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] + else: + # crop image + assert len(img.shape) == 4 + img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + margin = self.generate_margin(results) + crop_bbox = self.random_generate_crop_bbox(*margin) + + # crop the image + img = results['img'] + results['img'] = self.crop(img, crop_bbox) + results['img_shape'] = results['img'].shape[1:] + + # crop semantic seg + seg_map = results['gt_seg_map'] + results['gt_seg_map'] = self.crop(seg_map, crop_bbox) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_shape={self.crop_shape})' diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py index 0833ac183..2c18b8e02 100644 --- a/tests/test_datasets/test_transform.py +++ b/tests/test_datasets/test_transform.py @@ -737,3 +737,44 @@ def test_generate_edge(): [1, 1, 0, 0, 0], [1, 0, 0, 0, 0], ])) + + +def test_biomedical3d_random_crop(): + # test assertion for invalid random crop + with pytest.raises(AssertionError): + transform = dict(type='BioMedical3DRandomCrop', crop_shape=(-2, -1, 0)) + transform = TRANSFORMS.build(transform) + + from mmseg.datasets.transforms import (LoadBiomedicalAnnotation, + LoadBiomedicalImageFromFile) + results = dict() + results['img_path'] = osp.join( + osp.dirname(__file__), '../data', 'biomedical.nii.gz') + transform = LoadBiomedicalImageFromFile() + results = transform(copy.deepcopy(results)) + + results['seg_map_path'] = osp.join( + osp.dirname(__file__), '../data', 'biomedical_ann.nii.gz') + transform = LoadBiomedicalAnnotation() + results = transform(copy.deepcopy(results)) + + d, h, w = results['img_shape'] + transform = dict( + type='BioMedical3DRandomCrop', + crop_shape=(d - 20, h - 20, w - 20), + keep_foreground=True) + transform = TRANSFORMS.build(transform) + crop_results = transform(results) + assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20) + assert crop_results['img_shape'] == (d - 20, h - 20, w - 20) + assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20) + + transform = dict( + type='BioMedical3DRandomCrop', + crop_shape=(d - 20, h - 20, w - 20), + keep_foreground=False) + transform = TRANSFORMS.build(transform) + crop_results = transform(results) + assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20) + assert crop_results['img_shape'] == (d - 20, h - 20, w - 20) + assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)