[Feature] nnUNet-style Gaussian Noise and Blur (#2373)
## Motivation implement nnUNet-style Gaussian Noise and Blurpull/2460/head
parent
6eb1a95a48
commit
26f3df7a45
|
@ -19,6 +19,7 @@ from .pascal_context import PascalContextDataset, PascalContextDataset59
|
||||||
from .potsdam import PotsdamDataset
|
from .potsdam import PotsdamDataset
|
||||||
from .stare import STAREDataset
|
from .stare import STAREDataset
|
||||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||||
|
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||||
GenerateEdge, LoadAnnotations,
|
GenerateEdge, LoadAnnotations,
|
||||||
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||||
|
@ -42,5 +43,6 @@ __all__ = [
|
||||||
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
|
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
|
||||||
|
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,17 +3,20 @@ from .formatting import PackSegInputs
|
||||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||||
LoadImageFromNDArray)
|
LoadImageFromNDArray)
|
||||||
|
# yapf: disable
|
||||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop,
|
||||||
|
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||||
GenerateEdge, PhotoMetricDistortion, RandomCrop,
|
GenerateEdge, PhotoMetricDistortion, RandomCrop,
|
||||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||||
SegRescale)
|
SegRescale)
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
|
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
|
||||||
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
|
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
|
||||||
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||||
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||||
'ResizeShortestEdge'
|
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur'
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,6 +10,7 @@ from mmcv.transforms.base import BaseTransform
|
||||||
from mmcv.transforms.utils import cache_randomness
|
from mmcv.transforms.utils import cache_randomness
|
||||||
from mmengine.utils import is_tuple_of
|
from mmengine.utils import is_tuple_of
|
||||||
from numpy import random
|
from numpy import random
|
||||||
|
from scipy.ndimage import gaussian_filter
|
||||||
|
|
||||||
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
|
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
|
||||||
from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
|
@ -1507,3 +1508,181 @@ class BioMedical3DRandomCrop(BaseTransform):
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'
|
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class BioMedicalGaussianNoise(BaseTransform):
|
||||||
|
"""Add random Gaussian noise to image.
|
||||||
|
|
||||||
|
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501
|
||||||
|
|
||||||
|
Copyright (c) German Cancer Research Center (DKFZ)
|
||||||
|
Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
||||||
|
N is the number of modalities, and data type is float32.
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prob (float): Probability to add Gaussian noise for
|
||||||
|
each sample. Default to 0.1.
|
||||||
|
mean (float): Mean or “centre” of the distribution. Default to 0.0.
|
||||||
|
std (float): Standard deviation of distribution. Default to 0.1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
prob: float = 0.1,
|
||||||
|
mean: float = 0.0,
|
||||||
|
std: float = 0.1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert 0.0 <= prob <= 1.0 and std >= 0.0
|
||||||
|
self.prob = prob
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
"""Call function to add random Gaussian noise to image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Result dict with random Gaussian noise.
|
||||||
|
"""
|
||||||
|
if np.random.rand() < self.prob:
|
||||||
|
rand_std = np.random.uniform(0, self.std)
|
||||||
|
noise = np.random.normal(
|
||||||
|
self.mean, rand_std, size=results['img'].shape)
|
||||||
|
# noise is float64 array, convert to the results['img'].dtype
|
||||||
|
noise = noise.astype(results['img'].dtype)
|
||||||
|
results['img'] = results['img'] + noise
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(prob={self.prob}, '
|
||||||
|
repr_str += f'mean={self.mean}, '
|
||||||
|
repr_str += f'std={self.std})'
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class BioMedicalGaussianBlur(BaseTransform):
|
||||||
|
"""Add Gaussian blur with random sigma to image.
|
||||||
|
|
||||||
|
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501
|
||||||
|
|
||||||
|
Copyright (c) German Cancer Research Center (DKFZ)
|
||||||
|
Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
||||||
|
N is the number of modalities, and data type is float32.
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sigma_range (Tuple[float, float]|float): range to randomly
|
||||||
|
select sigma value. Default to (0.5, 1.0).
|
||||||
|
prob (float): Probability to apply Gaussian blur
|
||||||
|
for each sample. Default to 0.2.
|
||||||
|
prob_per_channel (float): Probability to apply Gaussian blur
|
||||||
|
for each channel (axis N of the image). Default to 0.5.
|
||||||
|
different_sigma_per_channel (bool): whether to use different
|
||||||
|
sigma for each channel (axis N of the image). Default to True.
|
||||||
|
different_sigma_per_axis (bool): whether to use different
|
||||||
|
sigma for axis Z, X and Y of the image. Default to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
sigma_range: Tuple[float, float] = (0.5, 1.0),
|
||||||
|
prob: float = 0.2,
|
||||||
|
prob_per_channel: float = 0.5,
|
||||||
|
different_sigma_per_channel: bool = True,
|
||||||
|
different_sigma_per_axis: bool = True) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert 0.0 <= prob <= 1.0
|
||||||
|
assert 0.0 <= prob_per_channel <= 1.0
|
||||||
|
assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2
|
||||||
|
self.sigma_range = sigma_range
|
||||||
|
self.prob = prob
|
||||||
|
self.prob_per_channel = prob_per_channel
|
||||||
|
self.different_sigma_per_channel = different_sigma_per_channel
|
||||||
|
self.different_sigma_per_axis = different_sigma_per_axis
|
||||||
|
|
||||||
|
def _get_valid_sigma(self, value_range) -> Tuple[float, ...]:
|
||||||
|
"""Ensure the `value_range` to be either a single value or a sequence
|
||||||
|
of two values. If the `value_range` is a sequence, generate a random
|
||||||
|
value with `[value_range[0], value_range[1]]` based on uniform
|
||||||
|
sampling.
|
||||||
|
|
||||||
|
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value_range (tuple|list|float|int): the input value range
|
||||||
|
"""
|
||||||
|
if (isinstance(value_range, (list, tuple))):
|
||||||
|
if (value_range[0] == value_range[1]):
|
||||||
|
value = value_range[0]
|
||||||
|
else:
|
||||||
|
orig_type = type(value_range[0])
|
||||||
|
value = np.random.uniform(value_range[0], value_range[1])
|
||||||
|
value = orig_type(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray:
|
||||||
|
"""Random generate sigma and apply Gaussian Blur to the data
|
||||||
|
Args:
|
||||||
|
data_sample (np.ndarray): data sample with multiple modalities,
|
||||||
|
the data shape is (N, Z, Y, X)
|
||||||
|
"""
|
||||||
|
sigma = None
|
||||||
|
for c in range(data_sample.shape[0]):
|
||||||
|
if np.random.rand() < self.prob_per_channel:
|
||||||
|
# if no `sigma` is generated, generate one
|
||||||
|
# if `self.different_sigma_per_channel` is True,
|
||||||
|
# re-generate random sigma for each channel
|
||||||
|
if (sigma is None or self.different_sigma_per_channel):
|
||||||
|
if (not self.different_sigma_per_axis):
|
||||||
|
sigma = self._get_valid_sigma(self.sigma_range)
|
||||||
|
else:
|
||||||
|
sigma = [
|
||||||
|
self._get_valid_sigma(self.sigma_range)
|
||||||
|
for _ in data_sample.shape[1:]
|
||||||
|
]
|
||||||
|
# apply gaussian filter with `sigma`
|
||||||
|
data_sample[c] = gaussian_filter(
|
||||||
|
data_sample[c], sigma, order=0)
|
||||||
|
return data_sample
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
"""Call function to add random Gaussian blur to image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Result dict with random Gaussian noise.
|
||||||
|
"""
|
||||||
|
if np.random.rand() < self.prob:
|
||||||
|
results['img'] = self._gaussian_blur(results['img'])
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(prob={self.prob}, '
|
||||||
|
repr_str += f'prob_per_channel={self.prob_per_channel}, '
|
||||||
|
repr_str += f'sigma_range={self.sigma_range}, '
|
||||||
|
repr_str += 'different_sigma_per_channel='\
|
||||||
|
f'{self.different_sigma_per_channel}, '
|
||||||
|
repr_str += 'different_sigma_per_axis='\
|
||||||
|
f'{self.different_sigma_per_axis})'
|
||||||
|
return repr_str
|
||||||
|
|
|
@ -778,3 +778,111 @@ def test_biomedical3d_random_crop():
|
||||||
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
|
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
|
||||||
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
|
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
|
||||||
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
|
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
|
||||||
|
|
||||||
|
|
||||||
|
def test_biomedical_gaussian_noise():
|
||||||
|
# test assertion for invalid prob
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='BioMedicalGaussianNoise', prob=1.5)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
# test assertion for invalid std
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalGaussianNoise', prob=0.2, mean=0.5, std=-0.5)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
|
||||||
|
noise_module = TRANSFORMS.build(transform)
|
||||||
|
assert str(noise_module) == 'BioMedicalGaussianNoise'\
|
||||||
|
'(prob=1.0, ' \
|
||||||
|
'mean=0.0, ' \
|
||||||
|
'std=0.1)'
|
||||||
|
|
||||||
|
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
|
||||||
|
noise_module = TRANSFORMS.build(transform)
|
||||||
|
results = dict(
|
||||||
|
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
|
||||||
|
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
|
||||||
|
transform = LoadBiomedicalImageFromFile()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
original_img = copy.deepcopy(results['img'])
|
||||||
|
results = noise_module(results)
|
||||||
|
assert original_img.shape == results['img'].shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_biomedical_gaussian_blur():
|
||||||
|
# test assertion for invalid prob
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='BioMedicalGaussianBlur', prob=-1.5)
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=0.6)
|
||||||
|
smooth_module = TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.6))
|
||||||
|
smooth_module = TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(15, 8, 9))
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalGaussianBlur', prob=1.0, sigma_range='0.16')
|
||||||
|
TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.7, 0.8))
|
||||||
|
smooth_module = TRANSFORMS.build(transform)
|
||||||
|
assert str(
|
||||||
|
smooth_module
|
||||||
|
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
|
||||||
|
'prob_per_channel=0.5, '\
|
||||||
|
'sigma_range=(0.7, 0.8), ' \
|
||||||
|
'different_sigma_per_channel=True, '\
|
||||||
|
'different_sigma_per_axis=True)'
|
||||||
|
|
||||||
|
transform = dict(type='BioMedicalGaussianBlur', prob=1.0)
|
||||||
|
smooth_module = TRANSFORMS.build(transform)
|
||||||
|
assert str(
|
||||||
|
smooth_module
|
||||||
|
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
|
||||||
|
'prob_per_channel=0.5, '\
|
||||||
|
'sigma_range=(0.5, 1.0), ' \
|
||||||
|
'different_sigma_per_channel=True, '\
|
||||||
|
'different_sigma_per_axis=True)'
|
||||||
|
|
||||||
|
results = dict(
|
||||||
|
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
|
||||||
|
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
|
||||||
|
transform = LoadBiomedicalImageFromFile()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
original_img = copy.deepcopy(results['img'])
|
||||||
|
results = smooth_module(results)
|
||||||
|
assert original_img.shape == results['img'].shape
|
||||||
|
# the max value in the smoothed image should be less than the original one
|
||||||
|
assert original_img.max() >= results['img'].max()
|
||||||
|
assert original_img.min() <= results['img'].min()
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='BioMedicalGaussianBlur',
|
||||||
|
prob=1.0,
|
||||||
|
different_sigma_per_axis=False)
|
||||||
|
smooth_module = TRANSFORMS.build(transform)
|
||||||
|
|
||||||
|
results = dict(
|
||||||
|
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
|
||||||
|
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
|
||||||
|
transform = LoadBiomedicalImageFromFile()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
original_img = copy.deepcopy(results['img'])
|
||||||
|
results = smooth_module(results)
|
||||||
|
assert original_img.shape == results['img'].shape
|
||||||
|
# the max value in the smoothed image should be less than the original one
|
||||||
|
assert original_img.max() >= results['img'].max()
|
||||||
|
assert original_img.min() <= results['img'].min()
|
||||||
|
|
Loading…
Reference in New Issue