[Feature] Add Cutout transform (#1022)
* Fix typo in usage example * [Feature] Add CutOut transform * CutOut repr covered by unittests * Cutout ignore index, test * ignore_index -> seg_fill_in, defualt is None * seg_fill_in is added to repr * test is modified for seg_fill_in is None * seg_fill_in (int), 0-255 * add seg_fill_in test * doc string for seg_fill_in * rename CutOut to RandomCutOut, add prob * Add unittest when cutout is Falsepull/1095/head
parent
08272b6208
commit
78a6ff689d
|
@ -5,13 +5,14 @@ from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor,
|
||||||
from .loading import LoadAnnotations, LoadImageFromFile
|
from .loading import LoadAnnotations, LoadImageFromFile
|
||||||
from .test_time_aug import MultiScaleFlipAug
|
from .test_time_aug import MultiScaleFlipAug
|
||||||
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
|
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
|
||||||
PhotoMetricDistortion, RandomCrop, RandomFlip,
|
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||||
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
|
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
|
||||||
|
SegRescale)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
|
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
|
||||||
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
|
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
|
||||||
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
|
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
|
||||||
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
||||||
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
|
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut'
|
||||||
]
|
]
|
||||||
|
|
|
@ -948,3 +948,95 @@ class PhotoMetricDistortion(object):
|
||||||
f'{self.saturation_upper}), '
|
f'{self.saturation_upper}), '
|
||||||
f'hue_delta={self.hue_delta})')
|
f'hue_delta={self.hue_delta})')
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class RandomCutOut(object):
|
||||||
|
"""CutOut operation.
|
||||||
|
|
||||||
|
Randomly drop some regions of image used in
|
||||||
|
`Cutout <https://arxiv.org/abs/1708.04552>`_.
|
||||||
|
Args:
|
||||||
|
prob (float): cutout probability.
|
||||||
|
n_holes (int | tuple[int, int]): Number of regions to be dropped.
|
||||||
|
If it is given as a list, number of holes will be randomly
|
||||||
|
selected from the closed interval [`n_holes[0]`, `n_holes[1]`].
|
||||||
|
cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate
|
||||||
|
shape of dropped regions. It can be `tuple[int, int]` to use a
|
||||||
|
fixed cutout shape, or `list[tuple[int, int]]` to randomly choose
|
||||||
|
shape from the list.
|
||||||
|
cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The
|
||||||
|
candidate ratio of dropped regions. It can be `tuple[float, float]`
|
||||||
|
to use a fixed ratio or `list[tuple[float, float]]` to randomly
|
||||||
|
choose ratio from the list. Please note that `cutout_shape`
|
||||||
|
and `cutout_ratio` cannot be both given at the same time.
|
||||||
|
fill_in (tuple[float, float, float] | tuple[int, int, int]): The value
|
||||||
|
of pixel to fill in the dropped regions. Default: (0, 0, 0).
|
||||||
|
seg_fill_in (int): The labels of pixel to fill in the dropped regions.
|
||||||
|
If seg_fill_in is None, skip. Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
prob,
|
||||||
|
n_holes,
|
||||||
|
cutout_shape=None,
|
||||||
|
cutout_ratio=None,
|
||||||
|
fill_in=(0, 0, 0),
|
||||||
|
seg_fill_in=None):
|
||||||
|
|
||||||
|
assert 0 <= prob and prob <= 1
|
||||||
|
assert (cutout_shape is None) ^ (cutout_ratio is None), \
|
||||||
|
'Either cutout_shape or cutout_ratio should be specified.'
|
||||||
|
assert (isinstance(cutout_shape, (list, tuple))
|
||||||
|
or isinstance(cutout_ratio, (list, tuple)))
|
||||||
|
if isinstance(n_holes, tuple):
|
||||||
|
assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
|
||||||
|
else:
|
||||||
|
n_holes = (n_holes, n_holes)
|
||||||
|
if seg_fill_in is not None:
|
||||||
|
assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in
|
||||||
|
and seg_fill_in <= 255)
|
||||||
|
self.prob = prob
|
||||||
|
self.n_holes = n_holes
|
||||||
|
self.fill_in = fill_in
|
||||||
|
self.seg_fill_in = seg_fill_in
|
||||||
|
self.with_ratio = cutout_ratio is not None
|
||||||
|
self.candidates = cutout_ratio if self.with_ratio else cutout_shape
|
||||||
|
if not isinstance(self.candidates, list):
|
||||||
|
self.candidates = [self.candidates]
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
"""Call function to drop some regions of image."""
|
||||||
|
cutout = True if np.random.rand() < self.prob else False
|
||||||
|
if cutout:
|
||||||
|
h, w, c = results['img'].shape
|
||||||
|
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
|
||||||
|
for _ in range(n_holes):
|
||||||
|
x1 = np.random.randint(0, w)
|
||||||
|
y1 = np.random.randint(0, h)
|
||||||
|
index = np.random.randint(0, len(self.candidates))
|
||||||
|
if not self.with_ratio:
|
||||||
|
cutout_w, cutout_h = self.candidates[index]
|
||||||
|
else:
|
||||||
|
cutout_w = int(self.candidates[index][0] * w)
|
||||||
|
cutout_h = int(self.candidates[index][1] * h)
|
||||||
|
|
||||||
|
x2 = np.clip(x1 + cutout_w, 0, w)
|
||||||
|
y2 = np.clip(y1 + cutout_h, 0, h)
|
||||||
|
results['img'][y1:y2, x1:x2, :] = self.fill_in
|
||||||
|
|
||||||
|
if self.seg_fill_in is not None:
|
||||||
|
for key in results.get('seg_fields', []):
|
||||||
|
results[key][y1:y2, x1:x2] = self.seg_fill_in
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
repr_str += f'(prob={self.prob}, '
|
||||||
|
repr_str += f'n_holes={self.n_holes}, '
|
||||||
|
repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
|
||||||
|
else f'cutout_shape={self.candidates}, ')
|
||||||
|
repr_str += f'fill_in={self.fill_in}, '
|
||||||
|
repr_str += f'seg_fill_in={self.seg_fill_in})'
|
||||||
|
return repr_str
|
||||||
|
|
|
@ -497,3 +497,120 @@ def test_seg_rescale():
|
||||||
rescale_module = build_from_cfg(transform, PIPELINES)
|
rescale_module = build_from_cfg(transform, PIPELINES)
|
||||||
rescale_results = rescale_module(results.copy())
|
rescale_results = rescale_module(results.copy())
|
||||||
assert rescale_results['gt_semantic_seg'].shape == (h, w)
|
assert rescale_results['gt_semantic_seg'].shape == (h, w)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cutout():
|
||||||
|
# test prob
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomCutOut', prob=1.5, n_holes=1)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
# test n_holes
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8))
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=0.5,
|
||||||
|
n_holes=(3, 4, 5),
|
||||||
|
cutout_shape=(8, 8))
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
# test cutout_shape and cutout_ratio
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
# either of cutout_shape and cutout_ratio should be given
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(type='RandomCutOut', prob=0.5, n_holes=1)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=0.5,
|
||||||
|
n_holes=1,
|
||||||
|
cutout_shape=(2, 2),
|
||||||
|
cutout_ratio=(0.4, 0.4))
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
# test seg_fill_in
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=0.5,
|
||||||
|
n_holes=1,
|
||||||
|
cutout_shape=(8, 8),
|
||||||
|
seg_fill_in='a')
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=0.5,
|
||||||
|
n_holes=1,
|
||||||
|
cutout_shape=(8, 8),
|
||||||
|
seg_fill_in=256)
|
||||||
|
build_from_cfg(transform, PIPELINES)
|
||||||
|
|
||||||
|
results = dict()
|
||||||
|
img = mmcv.imread(
|
||||||
|
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||||
|
|
||||||
|
seg = np.array(
|
||||||
|
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||||
|
|
||||||
|
results['img'] = img
|
||||||
|
results['gt_semantic_seg'] = seg
|
||||||
|
results['seg_fields'] = ['gt_semantic_seg']
|
||||||
|
results['img_shape'] = img.shape
|
||||||
|
results['ori_shape'] = img.shape
|
||||||
|
results['pad_shape'] = img.shape
|
||||||
|
results['img_fields'] = ['img']
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10))
|
||||||
|
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||||
|
assert 'cutout_shape' in repr(cutout_module)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() < img.sum()
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8))
|
||||||
|
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||||
|
assert 'cutout_ratio' in repr(cutout_module)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() < img.sum()
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8))
|
||||||
|
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() == img.sum()
|
||||||
|
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=1,
|
||||||
|
n_holes=(2, 4),
|
||||||
|
cutout_shape=[(10, 10), (15, 15)],
|
||||||
|
fill_in=(255, 255, 255),
|
||||||
|
seg_fill_in=None)
|
||||||
|
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() > img.sum()
|
||||||
|
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
|
||||||
|
|
||||||
|
transform = dict(
|
||||||
|
type='RandomCutOut',
|
||||||
|
prob=1,
|
||||||
|
n_holes=1,
|
||||||
|
cutout_ratio=(0.8, 0.8),
|
||||||
|
fill_in=(255, 255, 255),
|
||||||
|
seg_fill_in=255)
|
||||||
|
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||||
|
cutout_result = cutout_module(copy.deepcopy(results))
|
||||||
|
assert cutout_result['img'].sum() > img.sum()
|
||||||
|
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()
|
||||||
|
|
Loading…
Reference in New Issue