[Feature] nnUNet-style Gaussian Noise and Blur (#2373)

## Motivation

implement nnUNet-style Gaussian Noise and Blur
pull/2460/head
Haoyu Wang 2023-01-02 20:43:15 +08:00 committed by GitHub
parent 6eb1a95a48
commit 26f3df7a45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 294 additions and 2 deletions

View File

@ -19,6 +19,7 @@ from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset from .potsdam import PotsdamDataset
from .stare import STAREDataset from .stare import STAREDataset
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, LoadAnnotations, GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray, LoadBiomedicalImageFromFile, LoadImageFromNDArray,
@ -42,5 +43,6 @@ __all__ = [
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge' 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
] ]

View File

@ -3,17 +3,20 @@ from .formatting import PackSegInputs
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadBiomedicalData, LoadBiomedicalImageFromFile,
LoadImageFromNDArray) LoadImageFromNDArray)
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
GenerateEdge, PhotoMetricDistortion, RandomCrop, GenerateEdge, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange, RandomCutOut, RandomMosaic, RandomRotate, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray, ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale) SegRescale)
# yapf: enable
__all__ = [ __all__ = [
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale', 'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge' 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
] ]

View File

@ -10,6 +10,7 @@ from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_tuple_of from mmengine.utils import is_tuple_of
from numpy import random from numpy import random
from scipy.ndimage import gaussian_filter
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
from mmseg.registry import TRANSFORMS from mmseg.registry import TRANSFORMS
@ -1507,3 +1508,181 @@ class BioMedical3DRandomCrop(BaseTransform):
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})' return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'
@TRANSFORMS.register_module()
class BioMedicalGaussianNoise(BaseTransform):
"""Add random Gaussian noise to image.
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501
Copyright (c) German Cancer Research Center (DKFZ)
Licensed under the Apache License, Version 2.0
Required Keys:
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
N is the number of modalities, and data type is float32.
Modified Keys:
- img
Args:
prob (float): Probability to add Gaussian noise for
each sample. Default to 0.1.
mean (float): Mean or centre of the distribution. Default to 0.0.
std (float): Standard deviation of distribution. Default to 0.1.
"""
def __init__(self,
prob: float = 0.1,
mean: float = 0.0,
std: float = 0.1) -> None:
super().__init__()
assert 0.0 <= prob <= 1.0 and std >= 0.0
self.prob = prob
self.mean = mean
self.std = std
def transform(self, results: Dict) -> Dict:
"""Call function to add random Gaussian noise to image.
Args:
results (dict): Result dict.
Returns:
dict: Result dict with random Gaussian noise.
"""
if np.random.rand() < self.prob:
rand_std = np.random.uniform(0, self.std)
noise = np.random.normal(
self.mean, rand_std, size=results['img'].shape)
# noise is float64 array, convert to the results['img'].dtype
noise = noise.astype(results['img'].dtype)
results['img'] = results['img'] + noise
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'mean={self.mean}, '
repr_str += f'std={self.std})'
return repr_str
@TRANSFORMS.register_module()
class BioMedicalGaussianBlur(BaseTransform):
"""Add Gaussian blur with random sigma to image.
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501
Copyright (c) German Cancer Research Center (DKFZ)
Licensed under the Apache License, Version 2.0
Required Keys:
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
N is the number of modalities, and data type is float32.
Modified Keys:
- img
Args:
sigma_range (Tuple[float, float]|float): range to randomly
select sigma value. Default to (0.5, 1.0).
prob (float): Probability to apply Gaussian blur
for each sample. Default to 0.2.
prob_per_channel (float): Probability to apply Gaussian blur
for each channel (axis N of the image). Default to 0.5.
different_sigma_per_channel (bool): whether to use different
sigma for each channel (axis N of the image). Default to True.
different_sigma_per_axis (bool): whether to use different
sigma for axis Z, X and Y of the image. Default to True.
"""
def __init__(self,
sigma_range: Tuple[float, float] = (0.5, 1.0),
prob: float = 0.2,
prob_per_channel: float = 0.5,
different_sigma_per_channel: bool = True,
different_sigma_per_axis: bool = True) -> None:
super().__init__()
assert 0.0 <= prob <= 1.0
assert 0.0 <= prob_per_channel <= 1.0
assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2
self.sigma_range = sigma_range
self.prob = prob
self.prob_per_channel = prob_per_channel
self.different_sigma_per_channel = different_sigma_per_channel
self.different_sigma_per_axis = different_sigma_per_axis
def _get_valid_sigma(self, value_range) -> Tuple[float, ...]:
"""Ensure the `value_range` to be either a single value or a sequence
of two values. If the `value_range` is a sequence, generate a random
value with `[value_range[0], value_range[1]]` based on uniform
sampling.
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501
Args:
value_range (tuple|list|float|int): the input value range
"""
if (isinstance(value_range, (list, tuple))):
if (value_range[0] == value_range[1]):
value = value_range[0]
else:
orig_type = type(value_range[0])
value = np.random.uniform(value_range[0], value_range[1])
value = orig_type(value)
return value
def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray:
"""Random generate sigma and apply Gaussian Blur to the data
Args:
data_sample (np.ndarray): data sample with multiple modalities,
the data shape is (N, Z, Y, X)
"""
sigma = None
for c in range(data_sample.shape[0]):
if np.random.rand() < self.prob_per_channel:
# if no `sigma` is generated, generate one
# if `self.different_sigma_per_channel` is True,
# re-generate random sigma for each channel
if (sigma is None or self.different_sigma_per_channel):
if (not self.different_sigma_per_axis):
sigma = self._get_valid_sigma(self.sigma_range)
else:
sigma = [
self._get_valid_sigma(self.sigma_range)
for _ in data_sample.shape[1:]
]
# apply gaussian filter with `sigma`
data_sample[c] = gaussian_filter(
data_sample[c], sigma, order=0)
return data_sample
def transform(self, results: Dict) -> Dict:
"""Call function to add random Gaussian blur to image.
Args:
results (dict): Result dict.
Returns:
dict: Result dict with random Gaussian noise.
"""
if np.random.rand() < self.prob:
results['img'] = self._gaussian_blur(results['img'])
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'prob_per_channel={self.prob_per_channel}, '
repr_str += f'sigma_range={self.sigma_range}, '
repr_str += 'different_sigma_per_channel='\
f'{self.different_sigma_per_channel}, '
repr_str += 'different_sigma_per_axis='\
f'{self.different_sigma_per_axis})'
return repr_str

View File

@ -778,3 +778,111 @@ def test_biomedical3d_random_crop():
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20) assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20) assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20) assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
def test_biomedical_gaussian_noise():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='BioMedicalGaussianNoise', prob=1.5)
TRANSFORMS.build(transform)
# test assertion for invalid std
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianNoise', prob=0.2, mean=0.5, std=-0.5)
TRANSFORMS.build(transform)
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
noise_module = TRANSFORMS.build(transform)
assert str(noise_module) == 'BioMedicalGaussianNoise'\
'(prob=1.0, ' \
'mean=0.0, ' \
'std=0.1)'
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
noise_module = TRANSFORMS.build(transform)
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = noise_module(results)
assert original_img.shape == results['img'].shape
def test_biomedical_gaussian_blur():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='BioMedicalGaussianBlur', prob=-1.5)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=0.6)
smooth_module = TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.6))
smooth_module = TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(15, 8, 9))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range='0.16')
TRANSFORMS.build(transform)
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.7, 0.8))
smooth_module = TRANSFORMS.build(transform)
assert str(
smooth_module
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
'prob_per_channel=0.5, '\
'sigma_range=(0.7, 0.8), ' \
'different_sigma_per_channel=True, '\
'different_sigma_per_axis=True)'
transform = dict(type='BioMedicalGaussianBlur', prob=1.0)
smooth_module = TRANSFORMS.build(transform)
assert str(
smooth_module
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
'prob_per_channel=0.5, '\
'sigma_range=(0.5, 1.0), ' \
'different_sigma_per_channel=True, '\
'different_sigma_per_axis=True)'
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = smooth_module(results)
assert original_img.shape == results['img'].shape
# the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min()
transform = dict(
type='BioMedicalGaussianBlur',
prob=1.0,
different_sigma_per_axis=False)
smooth_module = TRANSFORMS.build(transform)
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = smooth_module(results)
assert original_img.shape == results['img'].shape
# the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min()