[Enhance] Add `hparams` argument in `AutoAugment` and `RandAugment` and some other improvement. (#398)

* Add hparams argument in `AutoAugment` and `RandAugment`.

And `pad_val` supports sequence instead of tuple only.

* Add unit tests for `AutoAugment` and `hparams` in `RandAugment`.

* Use smaller test image to speed up uni tests.

* Use hparams to simplify RandAugment config in swin-transformer.

* Rename augment config name from `pipeline` to `pipelines`.

* Add some commnet ad docstring.
pull/425/head
Ma Zerun 2021-08-24 18:15:54 +08:00 committed by GitHub
parent a9d65271ab
commit 6a0a76af0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 199 additions and 84 deletions

View File

@ -1,64 +1,10 @@
_base_ = ['./pipelines/rand_aug.py']
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
policies = [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Invert'),
dict(
type='Rotate',
interpolation='bicubic',
magnitude_key='angle',
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
magnitude_range=(0, 30)),
dict(type='Posterize', magnitude_key='bits', magnitude_range=(4, 0)),
dict(type='Solarize', magnitude_key='thr', magnitude_range=(256, 0)),
dict(
type='SolarizeAdd',
magnitude_key='magnitude',
magnitude_range=(0, 110)),
dict(
type='ColorTransform',
magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(type='Contrast', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(
type='Brightness', magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(
type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(
type='Shear',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='horizontal'),
dict(
type='Shear',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='vertical'),
dict(
type='Translate',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='horizontal'),
dict(
type='Translate',
interpolation='bicubic',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
pad_val=tuple([round(x) for x in img_norm_cfg['mean'][::-1]]),
direction='vertical')
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
@ -69,11 +15,14 @@ train_pipeline = [
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies=policies,
policies={{_base_.rand_increasing_policies}},
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5),
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,

View File

@ -0,0 +1,43 @@
# Refers to `_RAND_INCREASING_TRANSFORMS` in pytorch-image-models
rand_increasing_policies = [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Invert'),
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
dict(type='Posterize', magnitude_key='bits', magnitude_range=(4, 0)),
dict(type='Solarize', magnitude_key='thr', magnitude_range=(256, 0)),
dict(
type='SolarizeAdd',
magnitude_key='magnitude',
magnitude_range=(0, 110)),
dict(
type='ColorTransform',
magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(type='Contrast', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(
type='Brightness', magnitude_key='magnitude',
magnitude_range=(0, 0.9)),
dict(
type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0.9)),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='horizontal'),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='vertical'),
dict(
type='Translate',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
direction='horizontal'),
dict(
type='Translate',
magnitude_key='magnitude',
magnitude_range=(0, 0.45),
direction='vertical')
]

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import random
from numbers import Number
from typing import Sequence
@ -10,12 +11,37 @@ import numpy as np
from ..builder import PIPELINES
from .compose import Compose
# Default hyperparameters for all Ops
_HPARAMS_DEFAULT = dict(pad_val=128)
def random_negative(value, random_negative_prob):
"""Randomly negate value based on random_negative_prob."""
return -value if np.random.rand() < random_negative_prob else value
def merge_hparams(policy: dict, hparams: dict):
"""Merge hyperparameters into policy config.
Only merge partial hyperparameters required of the policy.
Args:
policy (dict): Original policy config dict.
hparams (dict): Hyperparameters need to be merged.
Returns:
dict: Policy config dict after adding ``hparams``.
"""
op = PIPELINES.get(policy['type'])
assert op is not None, f'Invalid policy type "{policy["type"]}".'
for key, value in hparams.items():
if policy.get(key, None) is not None:
continue
if key in inspect.getfullargspec(op.__init__).args:
policy[key] = value
return policy
@PIPELINES.register_module()
class AutoAugment(object):
"""Auto augmentation.
@ -29,9 +55,12 @@ class AutoAugment(object):
composed by several augmentations (dict). When AutoAugment is
called, a random policy in ``policies`` will be selected to
augment images.
hparams (dict): Configs of hyperparameters. Hyperparameters will be
used in policies that require these arguments if these arguments
are not set in policy dicts. Defaults to use _HPARAMS_DEFAULT.
"""
def __init__(self, policies):
def __init__(self, policies, hparams=_HPARAMS_DEFAULT):
assert isinstance(policies, list) and len(policies) > 0, \
'Policies must be a non-empty list.'
for policy in policies:
@ -42,7 +71,13 @@ class AutoAugment(object):
'Each specific augmentation must be a dict with key' \
' "type".'
self.policies = copy.deepcopy(policies)
self.hparams = hparams
policies = copy.deepcopy(policies)
self.policies = []
for sub in policies:
merged_sub = [merge_hparams(policy, hparams) for policy in sub]
self.policies.append(merged_sub)
self.sub_policy = [Compose(policy) for policy in self.policies]
def __call__(self, results):
@ -86,6 +121,9 @@ class RandAugment(object):
- If 0 or negative number, magnitude remains unchanged.
- If str "inf", magnitude is sampled from uniform distribution
(range=[min, magnitude]).
hparams (dict): Configs of hyperparameters. Hyperparameters will be
used in policies that require these arguments if these arguments
are not set in policy dicts. Defaults to use _HPARAMS_DEFAULT.
Note:
`magnitude_std` will introduce some randomness to policy, modified by
@ -104,7 +142,8 @@ class RandAugment(object):
num_policies,
magnitude_level,
magnitude_std=0.,
total_level=30):
total_level=30,
hparams=_HPARAMS_DEFAULT):
assert isinstance(num_policies, int), 'Number of policies must be ' \
f'of int type, got {type(num_policies)} instead.'
assert isinstance(magnitude_level, (int, float)), \
@ -131,8 +170,10 @@ class RandAugment(object):
self.magnitude_level = magnitude_level
self.magnitude_std = magnitude_std
self.total_level = total_level
self.policies = policies
self._check_policies(self.policies)
self.hparams = hparams
policies = copy.deepcopy(policies)
self._check_policies(policies)
self.policies = [merge_hparams(policy, hparams) for policy in policies]
def _check_policies(self, policies):
for policy in policies:
@ -196,8 +237,8 @@ class Shear(object):
Args:
magnitude (int | float): The magnitude used for shear.
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If a
tuple of length 3, it is used to pad_val R, G, B channels
pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
If a sequence of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128.
prob (float): The probability for performing Shear therefore should be
in range [0, 1]. Defaults to 0.5.
@ -220,7 +261,7 @@ class Shear(object):
f'be int or float, but got {type(magnitude)} instead.'
if isinstance(pad_val, int):
pad_val = tuple([pad_val] * 3)
elif isinstance(pad_val, tuple):
elif isinstance(pad_val, Sequence):
assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \
f'elements, got {len(pad_val)} instead.'
assert all(isinstance(i, int) for i in pad_val), 'pad_val as a '\
@ -235,7 +276,7 @@ class Shear(object):
f'should be in range [0,1], got {random_negative_prob} instead.'
self.magnitude = magnitude
self.pad_val = pad_val
self.pad_val = tuple(pad_val)
self.prob = prob
self.direction = direction
self.random_negative_prob = random_negative_prob
@ -276,8 +317,8 @@ class Translate(object):
the offset is calculated by magnitude * size in the corresponding
direction. With a magnitude of 1, the whole image will be moved out
of the range.
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If a
tuple of length 3, it is used to pad_val R, G, B channels
pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
If a sequence of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128.
prob (float): The probability for performing translate therefore should
be in range [0, 1]. Defaults to 0.5.
@ -300,7 +341,7 @@ class Translate(object):
f'be int or float, but got {type(magnitude)} instead.'
if isinstance(pad_val, int):
pad_val = tuple([pad_val] * 3)
elif isinstance(pad_val, tuple):
elif isinstance(pad_val, Sequence):
assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \
f'elements, got {len(pad_val)} instead.'
assert all(isinstance(i, int) for i in pad_val), 'pad_val as a '\
@ -315,7 +356,7 @@ class Translate(object):
f'should be in range [0,1], got {random_negative_prob} instead.'
self.magnitude = magnitude
self.pad_val = pad_val
self.pad_val = tuple(pad_val)
self.prob = prob
self.direction = direction
self.random_negative_prob = random_negative_prob
@ -363,8 +404,8 @@ class Rotate(object):
the source image. If None, the center of the image will be used.
Defaults to None.
scale (float): Isotropic scale factor. Defaults to 1.0.
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If a
tuple of length 3, it is used to pad_val R, G, B channels
pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
If a sequence of length 3, it is used to pad_val R, G, B channels
respectively. Defaults to 128.
prob (float): The probability for performing Rotate therefore should be
in range [0, 1]. Defaults to 0.5.
@ -394,7 +435,7 @@ class Rotate(object):
f'got {type(scale)} instead.'
if isinstance(pad_val, int):
pad_val = tuple([pad_val] * 3)
elif isinstance(pad_val, tuple):
elif isinstance(pad_val, Sequence):
assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \
f'elements, got {len(pad_val)} instead.'
assert all(isinstance(i, int) for i in pad_val), 'pad_val as a '\
@ -409,7 +450,7 @@ class Rotate(object):
self.angle = angle
self.center = center
self.scale = scale
self.pad_val = pad_val
self.pad_val = tuple(pad_val)
self.prob = prob
self.random_negative_prob = random_negative_prob
self.interpolation = interpolation
@ -833,8 +874,8 @@ class Cutout(object):
shape (int | float | tuple(int | float)): Expected cutout shape (h, w).
If given as a single value, the value will be used for
both h and w.
pad_val (int, tuple[int]): Pixel pad_val value for constant fill. If
it is a tuple, it must have the same length with the image
pad_val (int, Sequence[int]): Pixel pad_val value for constant fill.
If it is a sequence, it must have the same length with the image
channels. Defaults to 128.
prob (float): The probability for performing cutout therefore should
be in range [0, 1]. Defaults to 0.5.
@ -849,11 +890,16 @@ class Cutout(object):
raise TypeError(
'shape must be of '
f'type int, float or tuple, got {type(shape)} instead')
if isinstance(pad_val, int):
pad_val = tuple([pad_val] * 3)
elif isinstance(pad_val, Sequence):
assert len(pad_val) == 3, 'pad_val as a tuple must have 3 ' \
f'elements, got {len(pad_val)} instead.'
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
self.shape = shape
self.pad_val = pad_val
self.pad_val = tuple(pad_val)
self.prob = prob
def __call__(self, results):

View File

@ -40,6 +40,47 @@ def construct_toy_data_photometric():
return results
def test_auto_augment():
policies = [[
dict(type='Posterize', bits=4, prob=0.4),
dict(type='Rotate', angle=30., prob=0.6)
]]
# test assertion for policies
with pytest.raises(AssertionError):
# policies shouldn't be empty
transform = dict(type='AutoAugment', policies=[])
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
# policy should have type
invalid_policies = copy.deepcopy(policies)
invalid_policies[0][0].pop('type')
transform = dict(type='AutoAugment', policies=invalid_policies)
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
# sub policy should be a non-empty list
invalid_policies = copy.deepcopy(policies)
invalid_policies[0] = []
transform = dict(type='AutoAugment', policies=invalid_policies)
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
# policy should be valid in PIPELINES registry.
invalid_policies = copy.deepcopy(policies)
invalid_policies.append([dict(type='Wrong_policy')])
transform = dict(type='AutoAugment', policies=invalid_policies)
build_from_cfg(transform, PIPELINES)
# test hparams
transform = dict(
type='AutoAugment',
policies=policies,
hparams=dict(pad_val=15, interpolation='nearest'))
pipeline = build_from_cfg(transform, PIPELINES)
# use hparams if not set in policies config
assert pipeline.policies[0][1]['pad_val'] == 15
assert pipeline.policies[0][1]['interpolation'] == 'nearest'
def test_rand_augment():
policies = [
dict(
@ -48,12 +89,13 @@ def test_rand_augment():
magnitude_range=(0, 1),
pad_val=128,
prob=1.,
direction='horizontal'),
direction='horizontal',
interpolation='nearest'),
dict(type='Invert', prob=1.),
dict(
type='Rotate',
magnitude_key='angle',
magnitude_range=(0, 30),
magnitude_range=(0, 90),
prob=0.)
]
# test assertion for num_policies
@ -137,6 +179,15 @@ def test_rand_augment():
num_policies=2,
magnitude_level=12)
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
invalid_policies = copy.deepcopy(policies)
invalid_policies.append(dict(type='Wrong_policy'))
transform = dict(
type='RandAugment',
policies=invalid_policies,
num_policies=2,
magnitude_level=12)
build_from_cfg(transform, PIPELINES)
with pytest.raises(AssertionError):
invalid_policies = copy.deepcopy(policies)
invalid_policies[2].pop('type')
@ -327,6 +378,32 @@ def test_rand_augment():
axis=-1)
np.testing.assert_array_equal(results['img'], img_augmented)
# test hparams
random.seed(8)
np.random.seed(0)
results = construct_toy_data()
policies[2]['prob'] = 1.0
transform = dict(
type='RandAugment',
policies=policies,
num_policies=2,
magnitude_level=12,
magnitude_std=-1,
hparams=dict(pad_val=15, interpolation='nearest'))
pipeline = build_from_cfg(transform, PIPELINES)
# apply translate (magnitude=0.4) and rotate (angle=36)
results = pipeline(results)
img_augmented = np.array(
[[128, 128, 128, 15], [128, 128, 5, 2], [15, 9, 9, 6]], dtype=np.uint8)
img_augmented = np.stack([img_augmented, img_augmented, img_augmented],
axis=-1)
np.testing.assert_array_equal(results['img'], img_augmented)
# hparams won't override setting in policies config
assert pipeline.policies[0]['pad_val'] == 128
# use hparams if not set in policies config
assert pipeline.policies[2]['pad_val'] == 15
assert pipeline.policies[2]['interpolation'] == 'nearest'
def test_shear():
# test assertion for invalid type of magnitude
@ -705,7 +782,7 @@ def test_equalize(nb_rand_test=100):
transform = dict(type='Equalize', prob=1.)
pipeline = build_from_cfg(transform, PIPELINES)
for _ in range(nb_rand_test):
img = np.clip(np.random.normal(0, 1, (1000, 1200, 3)) * 260, 0,
img = np.clip(np.random.normal(0, 1, (256, 256, 3)) * 260, 0,
255).astype(np.uint8)
results['img'] = img
results = pipeline(copy.deepcopy(results))
@ -904,7 +981,7 @@ def test_contrast(nb_rand_test=100):
prob=1.,
random_negative_prob=0.)
pipeline = build_from_cfg(transform, PIPELINES)
img = np.clip(np.random.uniform(0, 1, (1200, 1000, 3)) * 260, 0,
img = np.clip(np.random.uniform(0, 1, (256, 256, 3)) * 260, 0,
255).astype(np.uint8)
results['img'] = img
results = pipeline(copy.deepcopy(results))
@ -1035,7 +1112,7 @@ def test_brightness(nb_rand_test=100):
prob=1.,
random_negative_prob=0.)
pipeline = build_from_cfg(transform, PIPELINES)
img = np.clip(np.random.uniform(0, 1, (1200, 1000, 3)) * 260, 0,
img = np.clip(np.random.uniform(0, 1, (256, 256, 3)) * 260, 0,
255).astype(np.uint8)
results['img'] = img
results = pipeline(copy.deepcopy(results))
@ -1097,7 +1174,7 @@ def test_sharpness(nb_rand_test=100):
prob=1.,
random_negative_prob=0.)
pipeline = build_from_cfg(transform, PIPELINES)
img = np.clip(np.random.uniform(0, 1, (1200, 1000, 3)) * 260, 0,
img = np.clip(np.random.uniform(0, 1, (256, 256, 3)) * 260, 0,
255).astype(np.uint8)
results['img'] = img
results = pipeline(copy.deepcopy(results))