[Feature] Add BioMedicalRandomGamma (#2406)

Add the random gamma correction transform for biomedical images, which
follows the design of the nnUNet.
This commit is contained in:
Fivethousand 2023-01-02 21:29:03 +08:00 committed by GitHub
parent 26f3df7a45
commit 3ca690bad3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 195 additions and 8 deletions

View File

@ -1,4 +1,4 @@
# 数据集
# 数据集
在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系.

View File

@ -18,9 +18,10 @@ from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, LoadAnnotations,
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
@ -30,7 +31,6 @@ from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
from .voc import PascalVOCDataset
# yapf: enable
__all__ = [
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
@ -44,5 +44,6 @@ __all__ = [
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma'
]

View File

@ -6,8 +6,9 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
BioMedicalRandomGamma, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
@ -18,5 +19,6 @@ __all__ = [
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma'
]

View File

@ -1686,3 +1686,122 @@ class BioMedicalGaussianBlur(BaseTransform):
repr_str += 'different_sigma_per_axis='\
f'{self.different_sigma_per_axis})'
return repr_str
@TRANSFORMS.register_module()
class BioMedicalRandomGamma(BaseTransform):
"""Using random gamma correction to process the biomedical image.
Modified from
https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501
With licence: Apache 2.0
Required Keys:
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
N is the number of modalities, and data type is float32.
Modified Keys:
- img
Args:
prob (float): The probability to perform this transform. Default: 0.5.
gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2).
invert_image (bool): Whether invert the image before applying gamma
augmentation. Default: False.
per_channel (bool): Whether perform the transform each channel
individually. Default: False
retain_stats (bool): Gamma transformation will alter the mean and std
of the data in the patch. If retain_stats=True, the data will be
transformed to match the mean and standard deviation before gamma
augmentation. Default: False.
"""
def __init__(self,
prob: float = 0.5,
gamma_range: Tuple[float] = (0.5, 2),
invert_image: bool = False,
per_channel: bool = False,
retain_stats: bool = False):
assert 0 <= prob and prob <= 1
assert isinstance(gamma_range, tuple) and len(gamma_range) == 2
assert isinstance(invert_image, bool)
assert isinstance(per_channel, bool)
assert isinstance(retain_stats, bool)
self.prob = prob
self.gamma_range = gamma_range
self.invert_image = invert_image
self.per_channel = per_channel
self.retain_stats = retain_stats
@cache_randomness
def _do_gamma(self):
"""Whether do adjust gamma for image."""
return np.random.rand() < self.prob
def _adjust_gamma(self, img: np.array):
"""Gamma adjustment for image.
Args:
img (np.array): Input image before gamma adjust.
Returns:
np.arrays: Image after gamma adjust.
"""
if self.invert_image:
img = -img
def _do_adjust(img):
if retain_stats_here:
img_mean = img.mean()
img_std = img.std()
if np.random.random() < 0.5 and self.gamma_range[0] < 1:
gamma = np.random.uniform(self.gamma_range[0], 1)
else:
gamma = np.random.uniform(
max(self.gamma_range[0], 1), self.gamma_range[1])
img_min = img.min()
img_range = img.max() - img_min # range
img = np.power(((img - img_min) / float(img_range + 1e-7)),
gamma) * img_range + img_min
if retain_stats_here:
img = img - img.mean()
img = img / (img.std() + 1e-8) * img_std
img = img + img_mean
return img
if not self.per_channel:
retain_stats_here = self.retain_stats
img = _do_adjust(img)
else:
for c in range(img.shape[0]):
img[c] = _do_adjust(img[c])
if self.invert_image:
img = -img
return img
def transform(self, results: dict) -> dict:
"""Call function to perform random gamma correction
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with random gamma correction performed.
"""
do_gamma = self._do_gamma()
if do_gamma:
results['img'] = self._adjust_gamma(results['img'])
else:
pass
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'gamma_range={self.gamma_range},'
repr_str += f'invert_image={self.invert_image},'
repr_str += f'per_channel={self.per_channel},'
repr_str += f'retain_stats={self.retain_stats}'
return repr_str

View File

@ -8,7 +8,8 @@ import pytest
from PIL import Image
from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
from mmseg.datasets.transforms import (LoadBiomedicalImageFromFile,
PhotoMetricDistortion, RandomCrop)
from mmseg.registry import TRANSFORMS
from mmseg.utils import register_all_modules
@ -886,3 +887,67 @@ def test_biomedical_gaussian_blur():
# the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min()
def test_BioMedicalRandomGamma():
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma', prob=-1, gamma_range=(0.7, 1.2))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma', prob=1.2, gamma_range=(0.7, 1.2))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma', prob=1.0, gamma_range=(0.7))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 0.2, 0.3))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
invert_image=1)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
per_channel=1)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
retain_stats=1)
TRANSFORMS.build(transform)
test_img = 'tests/data/biomedical.nii.gz'
results = dict(img_path=test_img)
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
origin_img = results['img']
transform2 = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
)
transform2 = TRANSFORMS.build(transform2)
results = transform2(results)
transformed_img = results['img']
assert origin_img.shape == transformed_img.shape