[Feature] Support random augmentation (#201)

* support random augmentation

* minor fix on posterize

* minor fix on posterize

* minor fix on cutout

* minor fix on cutout

* fix bug in solarize add

* revised according to comments
pull/207/head
LXXXXR 2021-04-09 14:02:50 +08:00 committed by GitHub
parent 2bd28435cf
commit 5195932952
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 470 additions and 14 deletions

View File

@ -1,7 +1,7 @@
from .auto_augment import (AutoAugment, AutoContrast, Brightness,
ColorTransform, Contrast, Equalize, Invert,
Posterize, Rotate, Sharpness, Shear, Solarize,
Translate)
ColorTransform, Contrast, Cutout, Equalize, Invert,
Posterize, RandAugment, Rotate, Sharpness, Shear,
Solarize, SolarizeAdd, Translate)
from .compose import Compose
from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
Transpose, to_tensor)
@ -16,6 +16,6 @@ __all__ = [
'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'Lighting',
'ColorJitter'
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter'
]

View File

@ -1,4 +1,5 @@
import copy
import random
import mmcv
import numpy as np
@ -41,7 +42,7 @@ class AutoAugment(object):
self.sub_policy = [Compose(policy) for policy in self.policies]
def __call__(self, results):
sub_policy = np.random.choice(self.sub_policy)
sub_policy = random.choice(self.sub_policy)
return sub_policy(results)
def __repr__(self):
@ -50,6 +51,87 @@ class AutoAugment(object):
return repr_str
@PIPELINES.register_module()
class RandAugment(object):
"""Random augmentation.
This data augmentation is proposed in `RandAugment: Practical automated
data augmentation with a reduced search space
<https://arxiv.org/abs/1909.13719>`_.
Args:
policies (list[dict]): The policies of random augmentation. Each
policy in ``policies`` is one specific augmentation policy (dict).
The policy shall at least have key `type`, indicating the type of
augmentation. For those which have magnitude, (given to the fact
they are named differently in different augmentation, )
`magnitude_key` and `magnitude_range` shall be the magnitude
argument (str) and the range of magnitude (tuple in the format or
(minval, maxval)), respectively.
num_policies (int): Number of policies to select from policies each
time.
magnitude_level (int | float): Magnitude level for all the augmentation
selected.
total_level (int | float): Total level for the magnitude. Defaults to
30.
"""
def __init__(self,
policies,
num_policies,
magnitude_level,
total_level=30):
assert isinstance(num_policies, int), 'Number of policies must be ' \
f'of int type, got {type(num_policies)} instead.'
assert isinstance(magnitude_level, (int, float)), \
'Magnitude level must be of int or float type, ' \
f'got {type(magnitude_level)} instead.'
assert isinstance(total_level, (int, float)), 'Total level must be ' \
f'of int or float type, got {type(total_level)} instead.'
assert isinstance(policies, list) and len(policies) > 0, \
'Policies must be a non-empty list.'
for policy in policies:
assert isinstance(policy, dict) and 'type' in policy, \
'Each policy must be a dict with key "type".'
assert num_policies > 0, 'num_policies must be greater than 0.'
assert magnitude_level >= 0, 'magnitude_level must be no less than 0.'
assert total_level > 0, 'total_level must be greater than 0.'
self.num_policies = num_policies
self.magnitude_level = magnitude_level
self.total_level = total_level
self.policies = self._process_policies(policies)
def _process_policies(self, policies):
processed_policies = []
for policy in policies:
processed_policy = copy.deepcopy(policy)
magnitude_key = processed_policy.pop('magnitude_key', None)
if magnitude_key is not None:
minval, maxval = processed_policy.pop('magnitude_range')
magnitude_value = (float(self.magnitude_level) /
self.total_level) * float(maxval -
minval) + minval
processed_policy.update({magnitude_key: magnitude_value})
processed_policies.append(processed_policy)
return processed_policies
def __call__(self, results):
if self.num_policies == 0:
return results
sub_policy = random.choices(self.policies, k=self.num_policies)
sub_policy = Compose(sub_policy)
return sub_policy(results)
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(policies={self.policies}, '
repr_str += f'num_policies={self.num_policies}, '
repr_str += f'magnitude_level={self.magnitude_level}, '
repr_str += f'total_level={self.total_level})'
return repr_str
@PIPELINES.register_module()
class Shear(object):
"""Shear images.
@ -428,25 +510,66 @@ class Solarize(object):
return repr_str
@PIPELINES.register_module()
class SolarizeAdd(object):
"""SolarizeAdd images (add a certain value to pixels below a threshold).
Args:
magnitude (int | float): The value to be added to pixels below the thr.
thr (int | float): The threshold below which the pixels value will be
adjusted.
prob (float): The probability for solarizing therefore should be in
range [0, 1]. Defaults to 0.5.
"""
def __init__(self, magnitude, thr=128, prob=0.5):
assert isinstance(magnitude, (int, float)), 'The thr magnitude must '\
f'be int or float, but got {type(magnitude)} instead.'
assert isinstance(thr, (int, float)), 'The thr type must '\
f'be int or float, but got {type(thr)} instead.'
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
self.magnitude = magnitude
self.thr = thr
self.prob = prob
def __call__(self, results):
if np.random.rand() > self.prob:
return results
for key in results.get('img_fields', ['img']):
img = results[key]
img_solarized = np.where(img < self.thr,
np.minimum(img + self.magnitude, 255),
img)
results[key] = img_solarized.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(magnitude={self.magnitude}, '
repr_str += f'thr={self.thr}, '
repr_str += f'prob={self.prob})'
return repr_str
@PIPELINES.register_module()
class Posterize(object):
"""Posterize images (reduce the number of bits for each color channel).
Args:
bits (int): Number of bits for each pixel in the output img, which
should be less or equal to 8.
bits (int | float): Number of bits for each pixel in the output img,
which should be less or equal to 8.
prob (float): The probability for posterizing therefore should be in
range [0, 1]. Defaults to 0.5.
"""
def __init__(self, bits, prob=0.5):
assert isinstance(bits, int), 'The bits type must be int, '\
f'but got {type(bits)} instead.'
assert bits <= 8, f'The bits must be less than 8, got {bits} instead.'
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
self.bits = bits
self.bits = int(bits)
self.prob = prob
def __call__(self, results):
@ -642,3 +765,51 @@ class Sharpness(object):
repr_str += f'prob={self.prob}, '
repr_str += f'random_negative_prob={self.random_negative_prob})'
return repr_str
@PIPELINES.register_module()
class Cutout(object):
"""Cutout images.
Args:
shape (int | float | tuple(int | float)): Expected cutout shape (h, w).
If given as a single value, the value will be used for
both h and w.
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If
it is a tuple, it must have the same length with the image
channels. Defaults to 128.
prob (float): The probability for performing cutout therefore should
be in range [0, 1]. Defaults to 0.5.
"""
def __init__(self, shape, pad_val=128, prob=0.5):
if isinstance(shape, float):
shape = int(shape)
elif isinstance(shape, tuple):
shape = tuple(int(i) for i in shape)
elif not isinstance(shape, int):
raise TypeError(
'shape must be of '
f'type int, float or tuple, got {type(shape)} instead')
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
self.shape = shape
self.pad_val = pad_val
self.prob = prob
def __call__(self, results):
if np.random.rand() > self.prob:
return results
for key in results.get('img_fields', ['img']):
img = results[key]
img_cutout = mmcv.cutout(img, self.shape, pad_val=self.pad_val)
results[key] = img_cutout.astype(img.dtype)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(shape={self.shape}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'prob={self.prob})'
return repr_str

View File

@ -1,4 +1,5 @@
import copy
import random
import mmcv
import numpy as np
@ -38,6 +39,189 @@ def construct_toy_data_photometric():
return results
def test_rand_augment():
policies = [
dict(
type='Translate',
magnitude_key='magnitude',
magnitude_range=(0, 1),
pad_val=128,
prob=1.,
direction='horizontal'),
dict(type='Invert', prob=1.),
dict(
type='Rotate',
magnitude_key='angle',
magnitude_range=(0, 30),
prob=0.)
]
# test assertion for num_policies
with pytest.raises(AssertionError):
transform = dict(
type='RandAugment',
policies=policies,
num_policies=1.5,
magnitude_level=12)
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
transform = dict(
type='RandAugment',
policies=policies,
num_policies=-1,
magnitude_level=12)
build_from_cfg(transform, PIPELINES)
# test assertion for magnitude_level
with pytest.raises(AssertionError):
transform = dict(
type='RandAugment',
policies=policies,
num_policies=1,
magnitude_level=None)
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
transform = dict(
type='RandAugment',
policies=policies,
num_policies=1,
magnitude_level=-1)
build_from_cfg(transform, PIPELINES)
# test assertion for total_level
with pytest.raises(AssertionError):
transform = dict(
type='RandAugment',
policies=policies,
num_policies=1,
magnitude_level=12,
total_level=None)
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
transform = dict(
type='RandAugment',
policies=policies,
num_policies=1,
magnitude_level=12,
total_level=-30)
build_from_cfg(transform, PIPELINES)
# test assertion for policies
with pytest.raises(AssertionError):
transform = dict(
type='RandAugment',
policies=[],
num_policies=2,
magnitude_level=12)
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
invalid_policies = copy.deepcopy(policies)
invalid_policies.append(('Wrong_policy'))
transform = dict(
type='RandAugment',
policies=invalid_policies,
num_policies=2,
magnitude_level=12)
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
invalid_policies = copy.deepcopy(policies)
invalid_policies[2].pop('type')
transform = dict(
type='RandAugment',
policies=invalid_policies,
num_policies=2,
magnitude_level=12)
build_from_cfg(transform, PIPELINES)
with pytest.raises(KeyError):
invalid_policies = copy.deepcopy(policies)
invalid_policies[2].pop('magnitude_range')
transform = dict(
type='RandAugment',
policies=invalid_policies,
num_policies=2,
magnitude_level=12)
build_from_cfg(transform, PIPELINES)
# test case where num_policies = 1
random.seed(1)
np.random.seed(0)
results = construct_toy_data()
transform = dict(
type='RandAugment',
policies=policies,
num_policies=1,
magnitude_level=12)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
# apply translate
img_augmented = np.array(
[[128, 128, 1, 2], [128, 128, 5, 6], [128, 128, 9, 10]],
dtype=np.uint8)
img_augmented = np.stack([img_augmented, img_augmented, img_augmented],
axis=-1)
assert (results['img'] == img_augmented).all()
results = construct_toy_data()
transform = dict(
type='RandAugment',
policies=policies,
num_policies=1,
magnitude_level=12)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
# apply rotation with prob=0.
assert (results['img'] == results['ori_img']).all()
# test case where num_policies = 2
random.seed(0)
np.random.seed(0)
results = construct_toy_data()
transform = dict(
type='RandAugment',
policies=policies,
num_policies=2,
magnitude_level=12)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
# apply rotate and rotate with prob=0
assert (results['img'] == results['ori_img']).all()
results = construct_toy_data()
transform = dict(
type='RandAugment',
policies=policies,
num_policies=2,
magnitude_level=12)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
# apply invert and translate
img_augmented = np.array(
[[252, 251, 128, 128], [248, 247, 128, 128], [244, 243, 128, 128]],
dtype=np.uint8)
img_augmented = np.stack([img_augmented, img_augmented, img_augmented],
axis=-1)
assert (results['img'] == img_augmented).all()
results = construct_toy_data()
transform = dict(
type='RandAugment',
policies=policies,
num_policies=2,
magnitude_level=0)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
# apply invert and invert
assert (results['img'] == results['ori_img']).all()
# test case where magnitude_level = 0
results = construct_toy_data()
transform = dict(
type='RandAugment',
policies=policies,
num_policies=2,
magnitude_level=0)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
# apply rotate and translate
assert (results['img'] == results['ori_img']).all()
def test_shear():
# test assertion for invalid type of magnitude
with pytest.raises(AssertionError):
@ -467,12 +651,57 @@ def test_solarize():
assert (results['img'] == results['img2']).all()
def test_posterize():
# test assertion for invalid type of bits
def test_solarize_add():
# test assertion for invalid type of magnitude
with pytest.raises(AssertionError):
transform = dict(type='Posterize', bits=4.5)
transform = dict(type='SolarizeAdd', magnitude=(1, 2))
build_from_cfg(transform, PIPELINES)
# test assertion for invalid type of thr
with pytest.raises(AssertionError):
transform = dict(type='SolarizeAdd', magnitude=100, thr=(1, 2))
build_from_cfg(transform, PIPELINES)
# test case when prob=0, therefore no solarize
results = construct_toy_data_photometric()
transform = dict(type='SolarizeAdd', magnitude=100, thr=128, prob=0.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == results['ori_img']).all()
# test case when thr=0, therefore no solarize
results = construct_toy_data_photometric()
transform = dict(type='SolarizeAdd', magnitude=100, thr=0, prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == results['ori_img']).all()
# test case when thr=128, magnitude=100
results = construct_toy_data_photometric()
transform = dict(type='SolarizeAdd', magnitude=100, thr=128, prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_solarized = np.array(
[[100, 128, 255], [101, 227, 254], [102, 129, 253]], dtype=np.uint8)
img_solarized = np.stack([img_solarized, img_solarized, img_solarized],
axis=-1)
assert (results['img'] == img_solarized).all()
assert (results['img'] == results['img2']).all()
# test case when thr=100, magnitude=50
results = construct_toy_data_photometric()
transform = dict(type='SolarizeAdd', magnitude=50, thr=100, prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_solarized = np.array([[50, 128, 255], [51, 127, 254], [52, 129, 253]],
dtype=np.uint8)
img_solarized = np.stack([img_solarized, img_solarized, img_solarized],
axis=-1)
assert (results['img'] == img_solarized).all()
assert (results['img'] == results['img2']).all()
def test_posterize():
# test assertion for invalid value of bits
with pytest.raises(AssertionError):
transform = dict(type='Posterize', bits=10)
@ -771,3 +1000,59 @@ def test_sharpness(nb_rand_test=100):
_adjust_sharpness(img, 1 + magnitude)[1:-1, 1:-1],
rtol=0,
atol=1)
def test_cutout():
# test assertion for invalid type of shape
with pytest.raises(TypeError):
transform = dict(type='Cutout', shape=None)
build_from_cfg(transform, PIPELINES)
# test assertion for invalid value of prob
with pytest.raises(AssertionError):
transform = dict(type='Cutout', shape=1, prob=100)
build_from_cfg(transform, PIPELINES)
# test case when prob=0, therefore no cutout
results = construct_toy_data()
transform = dict(type='Cutout', shape=2, prob=0.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == results['ori_img']).all()
# test case when shape=0, therefore no cutout
results = construct_toy_data()
transform = dict(type='Cutout', shape=0, prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == results['ori_img']).all()
# test case when shape=6, therefore the whole img has been cut
results = construct_toy_data()
transform = dict(type='Cutout', shape=6, prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
assert (results['img'] == np.ones_like(results['ori_img']) * 128).all()
# test case when shape is int
np.random.seed(0)
results = construct_toy_data()
transform = dict(type='Cutout', shape=1, prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_cutout = np.array([[1, 2, 3, 4], [5, 128, 7, 8], [9, 10, 11, 12]],
dtype=np.uint8)
img_cutout = np.stack([img_cutout, img_cutout, img_cutout], axis=-1)
assert (results['img'] == img_cutout).all()
# test case when shape is tuple
np.random.seed(0)
results = construct_toy_data()
transform = dict(type='Cutout', shape=(1, 2), pad_val=0, prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
results = pipeline(results)
img_cutout = np.array([[1, 2, 3, 4], [5, 0, 0, 8], [9, 10, 11, 12]],
dtype=np.uint8)
img_cutout = np.stack([img_cutout, img_cutout, img_cutout], axis=-1)
assert (results['img'] == img_cutout).all()