[Feature] Add BioMedicalRandomGamma (#2406)

Add the random gamma correction transform for biomedical images, which
follows the design of the nnUNet.
This commit is contained in:
Fivethousand 2023-01-02 21:29:03 +08:00 committed by GitHub
parent 26f3df7a45
commit 3ca690bad3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 195 additions and 8 deletions

View File

@ -1,4 +1,4 @@
# 数据集 # 数据集
在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系. 在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系.

View File

@ -18,9 +18,10 @@ from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59 from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset from .potsdam import PotsdamDataset
from .stare import STAREDataset from .stare import STAREDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise, BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, LoadAnnotations, BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray, LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop, PackSegInputs, PhotoMetricDistortion, RandomCrop,
@ -30,7 +31,6 @@ from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
from .voc import PascalVOCDataset from .voc import PascalVOCDataset
# yapf: enable # yapf: enable
__all__ = [ __all__ = [
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset', 'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
@ -44,5 +44,6 @@ __all__ = [
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge', 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur' 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma'
] ]

View File

@ -6,8 +6,9 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
# yapf: disable # yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise, BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, PhotoMetricDistortion, RandomCrop, BioMedicalRandomGamma, GenerateEdge,
RandomCutOut, RandomMosaic, RandomRotate, Rerange, PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray, ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale) SegRescale)
@ -18,5 +19,6 @@ __all__ = [
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur' 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma'
] ]

View File

@ -1686,3 +1686,122 @@ class BioMedicalGaussianBlur(BaseTransform):
repr_str += 'different_sigma_per_axis='\ repr_str += 'different_sigma_per_axis='\
f'{self.different_sigma_per_axis})' f'{self.different_sigma_per_axis})'
return repr_str return repr_str
@TRANSFORMS.register_module()
class BioMedicalRandomGamma(BaseTransform):
"""Using random gamma correction to process the biomedical image.
Modified from
https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501
With licence: Apache 2.0
Required Keys:
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
N is the number of modalities, and data type is float32.
Modified Keys:
- img
Args:
prob (float): The probability to perform this transform. Default: 0.5.
gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2).
invert_image (bool): Whether invert the image before applying gamma
augmentation. Default: False.
per_channel (bool): Whether perform the transform each channel
individually. Default: False
retain_stats (bool): Gamma transformation will alter the mean and std
of the data in the patch. If retain_stats=True, the data will be
transformed to match the mean and standard deviation before gamma
augmentation. Default: False.
"""
def __init__(self,
prob: float = 0.5,
gamma_range: Tuple[float] = (0.5, 2),
invert_image: bool = False,
per_channel: bool = False,
retain_stats: bool = False):
assert 0 <= prob and prob <= 1
assert isinstance(gamma_range, tuple) and len(gamma_range) == 2
assert isinstance(invert_image, bool)
assert isinstance(per_channel, bool)
assert isinstance(retain_stats, bool)
self.prob = prob
self.gamma_range = gamma_range
self.invert_image = invert_image
self.per_channel = per_channel
self.retain_stats = retain_stats
@cache_randomness
def _do_gamma(self):
"""Whether do adjust gamma for image."""
return np.random.rand() < self.prob
def _adjust_gamma(self, img: np.array):
"""Gamma adjustment for image.
Args:
img (np.array): Input image before gamma adjust.
Returns:
np.arrays: Image after gamma adjust.
"""
if self.invert_image:
img = -img
def _do_adjust(img):
if retain_stats_here:
img_mean = img.mean()
img_std = img.std()
if np.random.random() < 0.5 and self.gamma_range[0] < 1:
gamma = np.random.uniform(self.gamma_range[0], 1)
else:
gamma = np.random.uniform(
max(self.gamma_range[0], 1), self.gamma_range[1])
img_min = img.min()
img_range = img.max() - img_min # range
img = np.power(((img - img_min) / float(img_range + 1e-7)),
gamma) * img_range + img_min
if retain_stats_here:
img = img - img.mean()
img = img / (img.std() + 1e-8) * img_std
img = img + img_mean
return img
if not self.per_channel:
retain_stats_here = self.retain_stats
img = _do_adjust(img)
else:
for c in range(img.shape[0]):
img[c] = _do_adjust(img[c])
if self.invert_image:
img = -img
return img
def transform(self, results: dict) -> dict:
"""Call function to perform random gamma correction
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with random gamma correction performed.
"""
do_gamma = self._do_gamma()
if do_gamma:
results['img'] = self._adjust_gamma(results['img'])
else:
pass
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'gamma_range={self.gamma_range},'
repr_str += f'invert_image={self.invert_image},'
repr_str += f'per_channel={self.per_channel},'
repr_str += f'retain_stats={self.retain_stats}'
return repr_str

View File

@ -8,7 +8,8 @@ import pytest
from PIL import Image from PIL import Image
from mmseg.datasets.transforms import * # noqa from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop from mmseg.datasets.transforms import (LoadBiomedicalImageFromFile,
PhotoMetricDistortion, RandomCrop)
from mmseg.registry import TRANSFORMS from mmseg.registry import TRANSFORMS
from mmseg.utils import register_all_modules from mmseg.utils import register_all_modules
@ -886,3 +887,67 @@ def test_biomedical_gaussian_blur():
# the max value in the smoothed image should be less than the original one # the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max() assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min() assert original_img.min() <= results['img'].min()
def test_BioMedicalRandomGamma():
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma', prob=-1, gamma_range=(0.7, 1.2))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma', prob=1.2, gamma_range=(0.7, 1.2))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma', prob=1.0, gamma_range=(0.7))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 0.2, 0.3))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
invert_image=1)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
per_channel=1)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
retain_stats=1)
TRANSFORMS.build(transform)
test_img = 'tests/data/biomedical.nii.gz'
results = dict(img_path=test_img)
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
origin_img = results['img']
transform2 = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
)
transform2 = TRANSFORMS.build(transform2)
results = transform2(results)
transformed_img = results['img']
assert origin_img.shape == transformed_img.shape