mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add BioMedicalRandomGamma (#2406)
Add the random gamma correction transform for biomedical images, which follows the design of the nnUNet.
This commit is contained in:
parent
26f3df7a45
commit
3ca690bad3
@ -1,4 +1,4 @@
|
|||||||
# 数据集
|
# 数据集
|
||||||
|
|
||||||
在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系.
|
在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系.
|
||||||
|
|
||||||
|
@ -18,9 +18,10 @@ from .night_driving import NightDrivingDataset
|
|||||||
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
||||||
from .potsdam import PotsdamDataset
|
from .potsdam import PotsdamDataset
|
||||||
from .stare import STAREDataset
|
from .stare import STAREDataset
|
||||||
|
# yapf: disable
|
||||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||||
GenerateEdge, LoadAnnotations,
|
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
|
||||||
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||||
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
||||||
@ -30,7 +31,6 @@ from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
|||||||
from .voc import PascalVOCDataset
|
from .voc import PascalVOCDataset
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
|
'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset',
|
||||||
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
||||||
@ -44,5 +44,6 @@ __all__ = [
|
|||||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
|
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
|
||||||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
|
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||||
|
'BioMedicalRandomGamma'
|
||||||
]
|
]
|
||||||
|
@ -6,8 +6,9 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
|||||||
# yapf: disable
|
# yapf: disable
|
||||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||||
GenerateEdge, PhotoMetricDistortion, RandomCrop,
|
BioMedicalRandomGamma, GenerateEdge,
|
||||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||||
|
RandomMosaic, RandomRotate, Rerange,
|
||||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||||
SegRescale)
|
SegRescale)
|
||||||
|
|
||||||
@ -18,5 +19,6 @@ __all__ = [
|
|||||||
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||||
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
|
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||||
|
'BioMedicalRandomGamma'
|
||||||
]
|
]
|
||||||
|
@ -1686,3 +1686,122 @@ class BioMedicalGaussianBlur(BaseTransform):
|
|||||||
repr_str += 'different_sigma_per_axis='\
|
repr_str += 'different_sigma_per_axis='\
|
||||||
f'{self.different_sigma_per_axis})'
|
f'{self.different_sigma_per_axis})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class BioMedicalRandomGamma(BaseTransform):
|
||||||
|
"""Using random gamma correction to process the biomedical image.
|
||||||
|
|
||||||
|
Modified from
|
||||||
|
https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501
|
||||||
|
With licence: Apache 2.0
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
||||||
|
N is the number of modalities, and data type is float32.
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
- img
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prob (float): The probability to perform this transform. Default: 0.5.
|
||||||
|
gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2).
|
||||||
|
invert_image (bool): Whether invert the image before applying gamma
|
||||||
|
augmentation. Default: False.
|
||||||
|
per_channel (bool): Whether perform the transform each channel
|
||||||
|
individually. Default: False
|
||||||
|
retain_stats (bool): Gamma transformation will alter the mean and std
|
||||||
|
of the data in the patch. If retain_stats=True, the data will be
|
||||||
|
transformed to match the mean and standard deviation before gamma
|
||||||
|
augmentation. Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
prob: float = 0.5,
|
||||||
|
gamma_range: Tuple[float] = (0.5, 2),
|
||||||
|
invert_image: bool = False,
|
||||||
|
per_channel: bool = False,
|
||||||
|
retain_stats: bool = False):
|
||||||
|
assert 0 <= prob and prob <= 1
|
||||||
|
assert isinstance(gamma_range, tuple) and len(gamma_range) == 2
|
||||||
|
assert isinstance(invert_image, bool)
|
||||||
|
assert isinstance(per_channel, bool)
|
||||||
|
assert isinstance(retain_stats, bool)
|
||||||
|
self.prob = prob
|
||||||
|
self.gamma_range = gamma_range
|
||||||
|
self.invert_image = invert_image
|
||||||
|
self.per_channel = per_channel
|
||||||
|
self.retain_stats = retain_stats
|
||||||
|
|
||||||
|
@cache_randomness
|
||||||
|
def _do_gamma(self):
|
||||||
|
"""Whether do adjust gamma for image."""
|
||||||
|
return np.random.rand() < self.prob
|
||||||
|
|
||||||
|
def _adjust_gamma(self, img: np.array):
|
||||||
|
"""Gamma adjustment for image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (np.array): Input image before gamma adjust.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.arrays: Image after gamma adjust.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.invert_image:
|
||||||
|
img = -img
|
||||||
|
|
||||||
|
def _do_adjust(img):
|
||||||
|
if retain_stats_here:
|
||||||
|
img_mean = img.mean()
|
||||||
|
img_std = img.std()
|
||||||
|
if np.random.random() < 0.5 and self.gamma_range[0] < 1:
|
||||||
|
gamma = np.random.uniform(self.gamma_range[0], 1)
|
||||||
|
else:
|
||||||
|
gamma = np.random.uniform(
|
||||||
|
max(self.gamma_range[0], 1), self.gamma_range[1])
|
||||||
|
img_min = img.min()
|
||||||
|
img_range = img.max() - img_min # range
|
||||||
|
img = np.power(((img - img_min) / float(img_range + 1e-7)),
|
||||||
|
gamma) * img_range + img_min
|
||||||
|
if retain_stats_here:
|
||||||
|
img = img - img.mean()
|
||||||
|
img = img / (img.std() + 1e-8) * img_std
|
||||||
|
img = img + img_mean
|
||||||
|
return img
|
||||||
|
|
||||||
|
if not self.per_channel:
|
||||||
|
retain_stats_here = self.retain_stats
|
||||||
|
img = _do_adjust(img)
|
||||||
|
else:
|
||||||
|
for c in range(img.shape[0]):
|
||||||
|
img[c] = _do_adjust(img[c])
|
||||||
|
if self.invert_image:
|
||||||
|
img = -img
|
||||||
|
return img
|
||||||
|
|
||||||
|
def transform(self, results: dict) -> dict:
|
||||||
|
"""Call function to perform random gamma correction
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from loading pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Result dict with random gamma correction performed.
|
||||||
|
"""
|
||||||
|
do_gamma = self._do_gamma()
|
||||||
|
|
||||||
|
if do_gamma:
|
||||||
|
results['img'] = self._adjust_gamma(results['img'])
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(prob={self.prob}, '
|
||||||
|
repr_str += f'gamma_range={self.gamma_range},'
|
||||||
|
repr_str += f'invert_image={self.invert_image},'
|
||||||
|
repr_str += f'per_channel={self.per_channel},'
|
||||||
|
repr_str += f'retain_stats={self.retain_stats}'
|
||||||
|
return repr_str
|
||||||
|
@ -8,7 +8,8 @@ import pytest
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from mmseg.datasets.transforms import * # noqa
|
from mmseg.datasets.transforms import * # noqa
|
||||||
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
|
from mmseg.datasets.transforms import (LoadBiomedicalImageFromFile,
|
||||||
|
PhotoMetricDistortion, RandomCrop)
|
||||||
from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
from mmseg.utils import register_all_modules
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
@ -886,3 +887,67 @@ def test_biomedical_gaussian_blur():
|
|||||||
# the max value in the smoothed image should be less than the original one
|
# the max value in the smoothed image should be less than the original one
|
||||||
assert original_img.max() >= results['img'].max()
|
assert original_img.max() >= results['img'].max()
|
||||||
assert original_img.min() <= results['img'].min()
|
assert original_img.min() <= results['img'].min()
|
||||||
|
|
||||||
|
|
||||||
|
def test_BioMedicalRandomGamma():
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalRandomGamma', prob=-1, gamma_range=(0.7, 1.2))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalRandomGamma', prob=1.2, gamma_range=(0.7, 1.2))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalRandomGamma', prob=1.0, gamma_range=(0.7))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalRandomGamma',
|
||||||
|
prob=1.0,
|
||||||
|
gamma_range=(0.7, 0.2, 0.3))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalRandomGamma',
|
||||||
|
prob=1.0,
|
||||||
|
gamma_range=(0.7, 2),
|
||||||
|
invert_image=1)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalRandomGamma',
|
||||||
|
prob=1.0,
|
||||||
|
gamma_range=(0.7, 2),
|
||||||
|
per_channel=1)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalRandomGamma',
|
||||||
|
prob=1.0,
|
||||||
|
gamma_range=(0.7, 2),
|
||||||
|
retain_stats=1)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
test_img = 'tests/data/biomedical.nii.gz'
|
||||||
|
results = dict(img_path=test_img)
|
||||||
|
transform = LoadBiomedicalImageFromFile()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
origin_img = results['img']
|
||||||
|
transform2 = dict(
|
||||||
|
type='BioMedicalRandomGamma',
|
||||||
|
prob=1.0,
|
||||||
|
gamma_range=(0.7, 2),
|
||||||
|
)
|
||||||
|
transform2 = TRANSFORMS.build(transform2)
|
||||||
|
results = transform2(results)
|
||||||
|
transformed_img = results['img']
|
||||||
|
assert origin_img.shape == transformed_img.shape
|
||||||
|
Loading…
x
Reference in New Issue
Block a user