mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add BioMedicalRandomGamma (#2406)
Add the random gamma correction transform for biomedical images, which follows the design of the nnUNet.
This commit is contained in:
parent
26f3df7a45
commit
3ca690bad3
@ -1,4 +1,4 @@
|
||||
# 数据集
|
||||
# 数据集
|
||||
|
||||
在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系.
|
||||
|
||||
|
@ -18,9 +18,10 @@ from .night_driving import NightDrivingDataset
|
||||
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
||||
from .potsdam import PotsdamDataset
|
||||
from .stare import STAREDataset
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
GenerateEdge, LoadAnnotations,
|
||||
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
|
||||
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
||||
@ -30,7 +31,6 @@ from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||
from .voc import PascalVOCDataset
|
||||
|
||||
# yapf: enable
|
||||
|
||||
__all__ = [
|
||||
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
|
||||
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
||||
@ -44,5 +44,6 @@ __all__ = [
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
|
||||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
|
||||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma'
|
||||
]
|
||||
|
@ -6,8 +6,9 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
GenerateEdge, PhotoMetricDistortion, RandomCrop,
|
||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||
BioMedicalRandomGamma, GenerateEdge,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
|
||||
@ -18,5 +19,6 @@ __all__ = [
|
||||
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma'
|
||||
]
|
||||
|
@ -1686,3 +1686,122 @@ class BioMedicalGaussianBlur(BaseTransform):
|
||||
repr_str += 'different_sigma_per_axis='\
|
||||
f'{self.different_sigma_per_axis})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class BioMedicalRandomGamma(BaseTransform):
|
||||
"""Using random gamma correction to process the biomedical image.
|
||||
|
||||
Modified from
|
||||
https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501
|
||||
With licence: Apache 2.0
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
||||
N is the number of modalities, and data type is float32.
|
||||
|
||||
Modified Keys:
|
||||
- img
|
||||
|
||||
Args:
|
||||
prob (float): The probability to perform this transform. Default: 0.5.
|
||||
gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2).
|
||||
invert_image (bool): Whether invert the image before applying gamma
|
||||
augmentation. Default: False.
|
||||
per_channel (bool): Whether perform the transform each channel
|
||||
individually. Default: False
|
||||
retain_stats (bool): Gamma transformation will alter the mean and std
|
||||
of the data in the patch. If retain_stats=True, the data will be
|
||||
transformed to match the mean and standard deviation before gamma
|
||||
augmentation. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prob: float = 0.5,
|
||||
gamma_range: Tuple[float] = (0.5, 2),
|
||||
invert_image: bool = False,
|
||||
per_channel: bool = False,
|
||||
retain_stats: bool = False):
|
||||
assert 0 <= prob and prob <= 1
|
||||
assert isinstance(gamma_range, tuple) and len(gamma_range) == 2
|
||||
assert isinstance(invert_image, bool)
|
||||
assert isinstance(per_channel, bool)
|
||||
assert isinstance(retain_stats, bool)
|
||||
self.prob = prob
|
||||
self.gamma_range = gamma_range
|
||||
self.invert_image = invert_image
|
||||
self.per_channel = per_channel
|
||||
self.retain_stats = retain_stats
|
||||
|
||||
@cache_randomness
|
||||
def _do_gamma(self):
|
||||
"""Whether do adjust gamma for image."""
|
||||
return np.random.rand() < self.prob
|
||||
|
||||
def _adjust_gamma(self, img: np.array):
|
||||
"""Gamma adjustment for image.
|
||||
|
||||
Args:
|
||||
img (np.array): Input image before gamma adjust.
|
||||
|
||||
Returns:
|
||||
np.arrays: Image after gamma adjust.
|
||||
"""
|
||||
|
||||
if self.invert_image:
|
||||
img = -img
|
||||
|
||||
def _do_adjust(img):
|
||||
if retain_stats_here:
|
||||
img_mean = img.mean()
|
||||
img_std = img.std()
|
||||
if np.random.random() < 0.5 and self.gamma_range[0] < 1:
|
||||
gamma = np.random.uniform(self.gamma_range[0], 1)
|
||||
else:
|
||||
gamma = np.random.uniform(
|
||||
max(self.gamma_range[0], 1), self.gamma_range[1])
|
||||
img_min = img.min()
|
||||
img_range = img.max() - img_min # range
|
||||
img = np.power(((img - img_min) / float(img_range + 1e-7)),
|
||||
gamma) * img_range + img_min
|
||||
if retain_stats_here:
|
||||
img = img - img.mean()
|
||||
img = img / (img.std() + 1e-8) * img_std
|
||||
img = img + img_mean
|
||||
return img
|
||||
|
||||
if not self.per_channel:
|
||||
retain_stats_here = self.retain_stats
|
||||
img = _do_adjust(img)
|
||||
else:
|
||||
for c in range(img.shape[0]):
|
||||
img[c] = _do_adjust(img[c])
|
||||
if self.invert_image:
|
||||
img = -img
|
||||
return img
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Call function to perform random gamma correction
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
|
||||
Returns:
|
||||
dict: Result dict with random gamma correction performed.
|
||||
"""
|
||||
do_gamma = self._do_gamma()
|
||||
|
||||
if do_gamma:
|
||||
results['img'] = self._adjust_gamma(results['img'])
|
||||
else:
|
||||
pass
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(prob={self.prob}, '
|
||||
repr_str += f'gamma_range={self.gamma_range},'
|
||||
repr_str += f'invert_image={self.invert_image},'
|
||||
repr_str += f'per_channel={self.per_channel},'
|
||||
repr_str += f'retain_stats={self.retain_stats}'
|
||||
return repr_str
|
||||
|
@ -8,7 +8,8 @@ import pytest
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.datasets.transforms import * # noqa
|
||||
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
|
||||
from mmseg.datasets.transforms import (LoadBiomedicalImageFromFile,
|
||||
PhotoMetricDistortion, RandomCrop)
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
@ -886,3 +887,67 @@ def test_biomedical_gaussian_blur():
|
||||
# the max value in the smoothed image should be less than the original one
|
||||
assert original_img.max() >= results['img'].max()
|
||||
assert original_img.min() <= results['img'].min()
|
||||
|
||||
|
||||
def test_BioMedicalRandomGamma():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalRandomGamma', prob=-1, gamma_range=(0.7, 1.2))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalRandomGamma', prob=1.2, gamma_range=(0.7, 1.2))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalRandomGamma', prob=1.0, gamma_range=(0.7))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalRandomGamma',
|
||||
prob=1.0,
|
||||
gamma_range=(0.7, 0.2, 0.3))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalRandomGamma',
|
||||
prob=1.0,
|
||||
gamma_range=(0.7, 2),
|
||||
invert_image=1)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalRandomGamma',
|
||||
prob=1.0,
|
||||
gamma_range=(0.7, 2),
|
||||
per_channel=1)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalRandomGamma',
|
||||
prob=1.0,
|
||||
gamma_range=(0.7, 2),
|
||||
retain_stats=1)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
test_img = 'tests/data/biomedical.nii.gz'
|
||||
results = dict(img_path=test_img)
|
||||
transform = LoadBiomedicalImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
origin_img = results['img']
|
||||
transform2 = dict(
|
||||
type='BioMedicalRandomGamma',
|
||||
prob=1.0,
|
||||
gamma_range=(0.7, 2),
|
||||
)
|
||||
transform2 = TRANSFORMS.build(transform2)
|
||||
results = transform2(results)
|
||||
transformed_img = results['img']
|
||||
assert origin_img.shape == transformed_img.shape
|
||||
|
Loading…
x
Reference in New Issue
Block a user