[Feature] nnUNet-style Gaussian Noise and Blur (#2373)
## Motivation implement nnUNet-style Gaussian Noise and Blurpull/2460/head
parent
6eb1a95a48
commit
26f3df7a45
|
@ -19,6 +19,7 @@ from .pascal_context import PascalContextDataset, PascalContextDataset59
|
|||
from .potsdam import PotsdamDataset
|
||||
from .stare import STAREDataset
|
||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
GenerateEdge, LoadAnnotations,
|
||||
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||
|
@ -42,5 +43,6 @@ __all__ = [
|
|||
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
|
||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
|
||||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
|
||||
]
|
||||
|
|
|
@ -3,17 +3,20 @@ from .formatting import PackSegInputs
|
|||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray)
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
GenerateEdge, PhotoMetricDistortion, RandomCrop,
|
||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
|
||||
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
|
||||
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge'
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
|
||||
]
|
||||
|
|
|
@ -10,6 +10,7 @@ from mmcv.transforms.base import BaseTransform
|
|||
from mmcv.transforms.utils import cache_randomness
|
||||
from mmengine.utils import is_tuple_of
|
||||
from numpy import random
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
@ -1507,3 +1508,181 @@ class BioMedical3DRandomCrop(BaseTransform):
|
|||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class BioMedicalGaussianNoise(BaseTransform):
|
||||
"""Add random Gaussian noise to image.
|
||||
|
||||
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501
|
||||
|
||||
Copyright (c) German Cancer Research Center (DKFZ)
|
||||
Licensed under the Apache License, Version 2.0
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
||||
N is the number of modalities, and data type is float32.
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
|
||||
Args:
|
||||
prob (float): Probability to add Gaussian noise for
|
||||
each sample. Default to 0.1.
|
||||
mean (float): Mean or “centre” of the distribution. Default to 0.0.
|
||||
std (float): Standard deviation of distribution. Default to 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prob: float = 0.1,
|
||||
mean: float = 0.0,
|
||||
std: float = 0.1) -> None:
|
||||
super().__init__()
|
||||
assert 0.0 <= prob <= 1.0 and std >= 0.0
|
||||
self.prob = prob
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Call function to add random Gaussian noise to image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict.
|
||||
|
||||
Returns:
|
||||
dict: Result dict with random Gaussian noise.
|
||||
"""
|
||||
if np.random.rand() < self.prob:
|
||||
rand_std = np.random.uniform(0, self.std)
|
||||
noise = np.random.normal(
|
||||
self.mean, rand_std, size=results['img'].shape)
|
||||
# noise is float64 array, convert to the results['img'].dtype
|
||||
noise = noise.astype(results['img'].dtype)
|
||||
results['img'] = results['img'] + noise
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(prob={self.prob}, '
|
||||
repr_str += f'mean={self.mean}, '
|
||||
repr_str += f'std={self.std})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class BioMedicalGaussianBlur(BaseTransform):
|
||||
"""Add Gaussian blur with random sigma to image.
|
||||
|
||||
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501
|
||||
|
||||
Copyright (c) German Cancer Research Center (DKFZ)
|
||||
Licensed under the Apache License, Version 2.0
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
||||
N is the number of modalities, and data type is float32.
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
|
||||
Args:
|
||||
sigma_range (Tuple[float, float]|float): range to randomly
|
||||
select sigma value. Default to (0.5, 1.0).
|
||||
prob (float): Probability to apply Gaussian blur
|
||||
for each sample. Default to 0.2.
|
||||
prob_per_channel (float): Probability to apply Gaussian blur
|
||||
for each channel (axis N of the image). Default to 0.5.
|
||||
different_sigma_per_channel (bool): whether to use different
|
||||
sigma for each channel (axis N of the image). Default to True.
|
||||
different_sigma_per_axis (bool): whether to use different
|
||||
sigma for axis Z, X and Y of the image. Default to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sigma_range: Tuple[float, float] = (0.5, 1.0),
|
||||
prob: float = 0.2,
|
||||
prob_per_channel: float = 0.5,
|
||||
different_sigma_per_channel: bool = True,
|
||||
different_sigma_per_axis: bool = True) -> None:
|
||||
super().__init__()
|
||||
assert 0.0 <= prob <= 1.0
|
||||
assert 0.0 <= prob_per_channel <= 1.0
|
||||
assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2
|
||||
self.sigma_range = sigma_range
|
||||
self.prob = prob
|
||||
self.prob_per_channel = prob_per_channel
|
||||
self.different_sigma_per_channel = different_sigma_per_channel
|
||||
self.different_sigma_per_axis = different_sigma_per_axis
|
||||
|
||||
def _get_valid_sigma(self, value_range) -> Tuple[float, ...]:
|
||||
"""Ensure the `value_range` to be either a single value or a sequence
|
||||
of two values. If the `value_range` is a sequence, generate a random
|
||||
value with `[value_range[0], value_range[1]]` based on uniform
|
||||
sampling.
|
||||
|
||||
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501
|
||||
|
||||
Args:
|
||||
value_range (tuple|list|float|int): the input value range
|
||||
"""
|
||||
if (isinstance(value_range, (list, tuple))):
|
||||
if (value_range[0] == value_range[1]):
|
||||
value = value_range[0]
|
||||
else:
|
||||
orig_type = type(value_range[0])
|
||||
value = np.random.uniform(value_range[0], value_range[1])
|
||||
value = orig_type(value)
|
||||
return value
|
||||
|
||||
def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray:
|
||||
"""Random generate sigma and apply Gaussian Blur to the data
|
||||
Args:
|
||||
data_sample (np.ndarray): data sample with multiple modalities,
|
||||
the data shape is (N, Z, Y, X)
|
||||
"""
|
||||
sigma = None
|
||||
for c in range(data_sample.shape[0]):
|
||||
if np.random.rand() < self.prob_per_channel:
|
||||
# if no `sigma` is generated, generate one
|
||||
# if `self.different_sigma_per_channel` is True,
|
||||
# re-generate random sigma for each channel
|
||||
if (sigma is None or self.different_sigma_per_channel):
|
||||
if (not self.different_sigma_per_axis):
|
||||
sigma = self._get_valid_sigma(self.sigma_range)
|
||||
else:
|
||||
sigma = [
|
||||
self._get_valid_sigma(self.sigma_range)
|
||||
for _ in data_sample.shape[1:]
|
||||
]
|
||||
# apply gaussian filter with `sigma`
|
||||
data_sample[c] = gaussian_filter(
|
||||
data_sample[c], sigma, order=0)
|
||||
return data_sample
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Call function to add random Gaussian blur to image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict.
|
||||
|
||||
Returns:
|
||||
dict: Result dict with random Gaussian noise.
|
||||
"""
|
||||
if np.random.rand() < self.prob:
|
||||
results['img'] = self._gaussian_blur(results['img'])
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(prob={self.prob}, '
|
||||
repr_str += f'prob_per_channel={self.prob_per_channel}, '
|
||||
repr_str += f'sigma_range={self.sigma_range}, '
|
||||
repr_str += 'different_sigma_per_channel='\
|
||||
f'{self.different_sigma_per_channel}, '
|
||||
repr_str += 'different_sigma_per_axis='\
|
||||
f'{self.different_sigma_per_axis})'
|
||||
return repr_str
|
||||
|
|
|
@ -778,3 +778,111 @@ def test_biomedical3d_random_crop():
|
|||
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
|
||||
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
|
||||
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
|
||||
|
||||
|
||||
def test_biomedical_gaussian_noise():
|
||||
# test assertion for invalid prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='BioMedicalGaussianNoise', prob=1.5)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
# test assertion for invalid std
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalGaussianNoise', prob=0.2, mean=0.5, std=-0.5)
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
|
||||
noise_module = TRANSFORMS.build(transform)
|
||||
assert str(noise_module) == 'BioMedicalGaussianNoise'\
|
||||
'(prob=1.0, ' \
|
||||
'mean=0.0, ' \
|
||||
'std=0.1)'
|
||||
|
||||
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
|
||||
noise_module = TRANSFORMS.build(transform)
|
||||
results = dict(
|
||||
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
|
||||
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
|
||||
transform = LoadBiomedicalImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
original_img = copy.deepcopy(results['img'])
|
||||
results = noise_module(results)
|
||||
assert original_img.shape == results['img'].shape
|
||||
|
||||
|
||||
def test_biomedical_gaussian_blur():
|
||||
# test assertion for invalid prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='BioMedicalGaussianBlur', prob=-1.5)
|
||||
TRANSFORMS.build(transform)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=0.6)
|
||||
smooth_module = TRANSFORMS.build(transform)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.6))
|
||||
smooth_module = TRANSFORMS.build(transform)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(15, 8, 9))
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='BioMedicalGaussianBlur', prob=1.0, sigma_range='0.16')
|
||||
TRANSFORMS.build(transform)
|
||||
|
||||
transform = dict(
|
||||
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.7, 0.8))
|
||||
smooth_module = TRANSFORMS.build(transform)
|
||||
assert str(
|
||||
smooth_module
|
||||
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
|
||||
'prob_per_channel=0.5, '\
|
||||
'sigma_range=(0.7, 0.8), ' \
|
||||
'different_sigma_per_channel=True, '\
|
||||
'different_sigma_per_axis=True)'
|
||||
|
||||
transform = dict(type='BioMedicalGaussianBlur', prob=1.0)
|
||||
smooth_module = TRANSFORMS.build(transform)
|
||||
assert str(
|
||||
smooth_module
|
||||
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
|
||||
'prob_per_channel=0.5, '\
|
||||
'sigma_range=(0.5, 1.0), ' \
|
||||
'different_sigma_per_channel=True, '\
|
||||
'different_sigma_per_axis=True)'
|
||||
|
||||
results = dict(
|
||||
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
|
||||
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
|
||||
transform = LoadBiomedicalImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
original_img = copy.deepcopy(results['img'])
|
||||
results = smooth_module(results)
|
||||
assert original_img.shape == results['img'].shape
|
||||
# the max value in the smoothed image should be less than the original one
|
||||
assert original_img.max() >= results['img'].max()
|
||||
assert original_img.min() <= results['img'].min()
|
||||
|
||||
transform = dict(
|
||||
type='BioMedicalGaussianBlur',
|
||||
prob=1.0,
|
||||
different_sigma_per_axis=False)
|
||||
smooth_module = TRANSFORMS.build(transform)
|
||||
|
||||
results = dict(
|
||||
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
|
||||
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
|
||||
transform = LoadBiomedicalImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
original_img = copy.deepcopy(results['img'])
|
||||
results = smooth_module(results)
|
||||
assert original_img.shape == results['img'].shape
|
||||
# the max value in the smoothed image should be less than the original one
|
||||
assert original_img.max() >= results['img'].max()
|
||||
assert original_img.min() <= results['img'].min()
|
||||
|
|
Loading…
Reference in New Issue