[Feature] nnUNet-style Gaussian Noise and Blur (#2373)

## Motivation

implement nnUNet-style Gaussian Noise and Blur
pull/2460/head
Haoyu Wang 2023-01-02 20:43:15 +08:00 committed by GitHub
parent 6eb1a95a48
commit 26f3df7a45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 294 additions and 2 deletions

View File

@ -19,6 +19,7 @@ from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
@ -42,5 +43,6 @@ __all__ = [
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
]

View File

@ -3,17 +3,20 @@ from .formatting import PackSegInputs
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
LoadBiomedicalData, LoadBiomedicalImageFromFile,
LoadImageFromNDArray)
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
# yapf: enable
__all__ = [
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge'
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
]

View File

@ -10,6 +10,7 @@ from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_tuple_of
from numpy import random
from scipy.ndimage import gaussian_filter
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
from mmseg.registry import TRANSFORMS
@ -1507,3 +1508,181 @@ class BioMedical3DRandomCrop(BaseTransform):
def __repr__(self):
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'
@TRANSFORMS.register_module()
class BioMedicalGaussianNoise(BaseTransform):
"""Add random Gaussian noise to image.
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501
Copyright (c) German Cancer Research Center (DKFZ)
Licensed under the Apache License, Version 2.0
Required Keys:
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
N is the number of modalities, and data type is float32.
Modified Keys:
- img
Args:
prob (float): Probability to add Gaussian noise for
each sample. Default to 0.1.
mean (float): Mean or centre of the distribution. Default to 0.0.
std (float): Standard deviation of distribution. Default to 0.1.
"""
def __init__(self,
prob: float = 0.1,
mean: float = 0.0,
std: float = 0.1) -> None:
super().__init__()
assert 0.0 <= prob <= 1.0 and std >= 0.0
self.prob = prob
self.mean = mean
self.std = std
def transform(self, results: Dict) -> Dict:
"""Call function to add random Gaussian noise to image.
Args:
results (dict): Result dict.
Returns:
dict: Result dict with random Gaussian noise.
"""
if np.random.rand() < self.prob:
rand_std = np.random.uniform(0, self.std)
noise = np.random.normal(
self.mean, rand_std, size=results['img'].shape)
# noise is float64 array, convert to the results['img'].dtype
noise = noise.astype(results['img'].dtype)
results['img'] = results['img'] + noise
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'mean={self.mean}, '
repr_str += f'std={self.std})'
return repr_str
@TRANSFORMS.register_module()
class BioMedicalGaussianBlur(BaseTransform):
"""Add Gaussian blur with random sigma to image.
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501
Copyright (c) German Cancer Research Center (DKFZ)
Licensed under the Apache License, Version 2.0
Required Keys:
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
N is the number of modalities, and data type is float32.
Modified Keys:
- img
Args:
sigma_range (Tuple[float, float]|float): range to randomly
select sigma value. Default to (0.5, 1.0).
prob (float): Probability to apply Gaussian blur
for each sample. Default to 0.2.
prob_per_channel (float): Probability to apply Gaussian blur
for each channel (axis N of the image). Default to 0.5.
different_sigma_per_channel (bool): whether to use different
sigma for each channel (axis N of the image). Default to True.
different_sigma_per_axis (bool): whether to use different
sigma for axis Z, X and Y of the image. Default to True.
"""
def __init__(self,
sigma_range: Tuple[float, float] = (0.5, 1.0),
prob: float = 0.2,
prob_per_channel: float = 0.5,
different_sigma_per_channel: bool = True,
different_sigma_per_axis: bool = True) -> None:
super().__init__()
assert 0.0 <= prob <= 1.0
assert 0.0 <= prob_per_channel <= 1.0
assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2
self.sigma_range = sigma_range
self.prob = prob
self.prob_per_channel = prob_per_channel
self.different_sigma_per_channel = different_sigma_per_channel
self.different_sigma_per_axis = different_sigma_per_axis
def _get_valid_sigma(self, value_range) -> Tuple[float, ...]:
"""Ensure the `value_range` to be either a single value or a sequence
of two values. If the `value_range` is a sequence, generate a random
value with `[value_range[0], value_range[1]]` based on uniform
sampling.
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501
Args:
value_range (tuple|list|float|int): the input value range
"""
if (isinstance(value_range, (list, tuple))):
if (value_range[0] == value_range[1]):
value = value_range[0]
else:
orig_type = type(value_range[0])
value = np.random.uniform(value_range[0], value_range[1])
value = orig_type(value)
return value
def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray:
"""Random generate sigma and apply Gaussian Blur to the data
Args:
data_sample (np.ndarray): data sample with multiple modalities,
the data shape is (N, Z, Y, X)
"""
sigma = None
for c in range(data_sample.shape[0]):
if np.random.rand() < self.prob_per_channel:
# if no `sigma` is generated, generate one
# if `self.different_sigma_per_channel` is True,
# re-generate random sigma for each channel
if (sigma is None or self.different_sigma_per_channel):
if (not self.different_sigma_per_axis):
sigma = self._get_valid_sigma(self.sigma_range)
else:
sigma = [
self._get_valid_sigma(self.sigma_range)
for _ in data_sample.shape[1:]
]
# apply gaussian filter with `sigma`
data_sample[c] = gaussian_filter(
data_sample[c], sigma, order=0)
return data_sample
def transform(self, results: Dict) -> Dict:
"""Call function to add random Gaussian blur to image.
Args:
results (dict): Result dict.
Returns:
dict: Result dict with random Gaussian noise.
"""
if np.random.rand() < self.prob:
results['img'] = self._gaussian_blur(results['img'])
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'prob_per_channel={self.prob_per_channel}, '
repr_str += f'sigma_range={self.sigma_range}, '
repr_str += 'different_sigma_per_channel='\
f'{self.different_sigma_per_channel}, '
repr_str += 'different_sigma_per_axis='\
f'{self.different_sigma_per_axis})'
return repr_str

View File

@ -778,3 +778,111 @@ def test_biomedical3d_random_crop():
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
def test_biomedical_gaussian_noise():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='BioMedicalGaussianNoise', prob=1.5)
TRANSFORMS.build(transform)
# test assertion for invalid std
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianNoise', prob=0.2, mean=0.5, std=-0.5)
TRANSFORMS.build(transform)
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
noise_module = TRANSFORMS.build(transform)
assert str(noise_module) == 'BioMedicalGaussianNoise'\
'(prob=1.0, ' \
'mean=0.0, ' \
'std=0.1)'
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
noise_module = TRANSFORMS.build(transform)
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = noise_module(results)
assert original_img.shape == results['img'].shape
def test_biomedical_gaussian_blur():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='BioMedicalGaussianBlur', prob=-1.5)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=0.6)
smooth_module = TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.6))
smooth_module = TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(15, 8, 9))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range='0.16')
TRANSFORMS.build(transform)
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.7, 0.8))
smooth_module = TRANSFORMS.build(transform)
assert str(
smooth_module
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
'prob_per_channel=0.5, '\
'sigma_range=(0.7, 0.8), ' \
'different_sigma_per_channel=True, '\
'different_sigma_per_axis=True)'
transform = dict(type='BioMedicalGaussianBlur', prob=1.0)
smooth_module = TRANSFORMS.build(transform)
assert str(
smooth_module
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
'prob_per_channel=0.5, '\
'sigma_range=(0.5, 1.0), ' \
'different_sigma_per_channel=True, '\
'different_sigma_per_axis=True)'
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = smooth_module(results)
assert original_img.shape == results['img'].shape
# the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min()
transform = dict(
type='BioMedicalGaussianBlur',
prob=1.0,
different_sigma_per_axis=False)
smooth_module = TRANSFORMS.build(transform)
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = smooth_module(results)
assert original_img.shape == results['img'].shape
# the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min()