fix: fix augmentation
Fix RandomErasing, RandAugment to be consistent with Timm and compatible with earlier PaddleClas. Add ColorJitter implemented by PaddleVision and TimmAutoAugment borrowed from Timm Lib.pull/1169/head
parent
ef2fd19bb0
commit
1f8cfbd69d
|
@ -14,6 +14,7 @@
|
|||
|
||||
from ppcls.data.preprocess.ops.autoaugment import ImageNetPolicy as RawImageNetPolicy
|
||||
from ppcls.data.preprocess.ops.randaugment import RandAugment as RawRandAugment
|
||||
from ppcls.data.preprocess.ops.timm_autoaugment import RawTimmAutoAugment
|
||||
from ppcls.data.preprocess.ops.cutout import Cutout
|
||||
|
||||
from ppcls.data.preprocess.ops.hide_and_seek import HideAndSeek
|
||||
|
@ -31,7 +32,6 @@ from ppcls.data.preprocess.ops.operators import AugMix
|
|||
|
||||
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
|
||||
|
||||
import six
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
@ -47,20 +47,14 @@ class AutoAugment(RawImageNetPolicy):
|
|||
""" ImageNetPolicy wrapper to auto fit different img types """
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if six.PY2:
|
||||
super(AutoAugment, self).__init__(*args, **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, img):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if six.PY2:
|
||||
img = super(AutoAugment, self).__call__(img)
|
||||
else:
|
||||
img = super().__call__(img)
|
||||
img = super().__call__(img)
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
|
@ -72,20 +66,33 @@ class RandAugment(RawRandAugment):
|
|||
""" RandAugment wrapper to auto fit different img types """
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if six.PY2:
|
||||
super(RandAugment, self).__init__(*args, **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, img):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if six.PY2:
|
||||
img = super(RandAugment, self).__call__(img)
|
||||
else:
|
||||
img = super().__call__(img)
|
||||
img = super().__call__(img)
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class TimmAutoAugment(RawTimmAutoAugment):
|
||||
""" TimmAutoAugment wrapper to auto fit different img tyeps. """
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, img):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
|
||||
img = super().__call__(img)
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
|
|
|
@ -26,6 +26,7 @@ import random
|
|||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from paddle.vision.transforms import ColorJitter as RawColorJitter
|
||||
|
||||
from .autoaugment import ImageNetPolicy
|
||||
from .functional import augmentations
|
||||
|
@ -363,3 +364,20 @@ class AugMix(object):
|
|||
|
||||
mixed = (1 - m) * image + m * mix
|
||||
return mixed.astype(np.uint8)
|
||||
|
||||
|
||||
class ColorJitter(RawColorJitter):
|
||||
"""ColorJitter.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, img):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
img = super()._apply_image(img)
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
return img
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#This code is based on https://github.com/zhunzhong07/Random-Erasing
|
||||
#This code is adapted from https://github.com/zhunzhong07/Random-Erasing, and refer to Timm.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import math
|
||||
import random
|
||||
|
@ -20,36 +22,69 @@ import random
|
|||
import numpy as np
|
||||
|
||||
|
||||
class Pixels(object):
|
||||
def __init__(self, mode="const", mean=[0., 0., 0.]):
|
||||
self._mode = mode
|
||||
self._mean = mean
|
||||
|
||||
def __call__(self, h=224, w=224, c=3):
|
||||
if self._mode == "rand":
|
||||
return np.random.normal(size=(1, 1, 3))
|
||||
elif self._mode == "pixel":
|
||||
return np.random.normal(size=(h, w, c))
|
||||
elif self._mode == "const":
|
||||
return self._mean
|
||||
else:
|
||||
raise Exception(
|
||||
"Invalid mode in RandomErasing, only support \"const\", \"rand\", \"pixel\""
|
||||
)
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
def __init__(self, EPSILON=0.5, sl=0.02, sh=0.4, r1=0.3,
|
||||
mean=[0., 0., 0.]):
|
||||
self.EPSILON = EPSILON
|
||||
self.mean = mean
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.r1 = r1
|
||||
"""RandomErasing.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
EPSILON=0.5,
|
||||
sl=0.02,
|
||||
sh=0.4,
|
||||
r1=0.3,
|
||||
mean=[0., 0., 0.],
|
||||
attempt=100,
|
||||
use_log_aspect=False,
|
||||
mode='const'):
|
||||
self.EPSILON = eval(EPSILON) if isinstance(EPSILON, str) else EPSILON
|
||||
self.sl = eval(sl) if isinstance(sl, str) else sl
|
||||
self.sh = eval(sh) if isinstance(sh, str) else sh
|
||||
r1 = eval(r1) if isinstance(r1, str) else r1
|
||||
self.r1 = (math.log(r1), math.log(1 / r1)) if use_log_aspect else (
|
||||
r1, 1 / r1)
|
||||
self.use_log_aspect = use_log_aspect
|
||||
self.attempt = attempt
|
||||
self.get_pixels = Pixels(mode, mean)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.EPSILON:
|
||||
if random.random() > self.EPSILON:
|
||||
return img
|
||||
|
||||
for _ in range(100):
|
||||
for _ in range(self.attempt):
|
||||
area = img.shape[0] * img.shape[1]
|
||||
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
|
||||
aspect_ratio = random.uniform(*self.r1)
|
||||
if self.use_log_aspect:
|
||||
aspect_ratio = math.exp(aspect_ratio)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w < img.shape[1] and h < img.shape[0]:
|
||||
pixels = self.get_pixels(h, w, img.shape[2])
|
||||
x1 = random.randint(0, img.shape[0] - h)
|
||||
y1 = random.randint(0, img.shape[1] - w)
|
||||
if img.shape[0] == 3:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
|
||||
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
|
||||
if img.shape[2] == 3:
|
||||
img[x1:x1 + h, y1:y1 + w, :] = pixels
|
||||
else:
|
||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[1]
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = pixels[0]
|
||||
return img
|
||||
return img
|
||||
|
|
|
@ -0,0 +1,879 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code implements is borrowed from Timm: https://github.com/rwightman/pytorch-image-models.
|
||||
hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import random
|
||||
import math
|
||||
import re
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
|
||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||
|
||||
_FILL = (128, 128, 128)
|
||||
|
||||
# This signifies the max integer that the controller RNN could predict for the
|
||||
# augmentation scheme.
|
||||
_MAX_LEVEL = 10.
|
||||
|
||||
_HPARAMS_DEFAULT = dict(
|
||||
translate_const=250,
|
||||
img_mean=_FILL, )
|
||||
|
||||
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
||||
|
||||
|
||||
def _pil_interp(method):
|
||||
if method == 'bicubic':
|
||||
return Image.BICUBIC
|
||||
elif method == 'lanczos':
|
||||
return Image.LANCZOS
|
||||
elif method == 'hamming':
|
||||
return Image.HAMMING
|
||||
else:
|
||||
# default bilinear, do we want to allow nearest?
|
||||
return Image.BILINEAR
|
||||
|
||||
|
||||
def _interpolation(kwargs):
|
||||
interpolation = kwargs.pop('resample', Image.BILINEAR)
|
||||
if isinstance(interpolation, (list, tuple)):
|
||||
return random.choice(interpolation)
|
||||
else:
|
||||
return interpolation
|
||||
|
||||
|
||||
def _check_args_tf(kwargs):
|
||||
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
||||
kwargs.pop('fillcolor')
|
||||
kwargs['resample'] = _interpolation(kwargs)
|
||||
|
||||
|
||||
def shear_x(img, factor, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def shear_y(img, factor, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def translate_x_rel(img, pct, **kwargs):
|
||||
pixels = pct * img.size[0]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def translate_y_rel(img, pct, **kwargs):
|
||||
pixels = pct * img.size[1]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def translate_x_abs(img, pixels, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def translate_y_abs(img, pixels, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def rotate(img, degrees, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
if _PIL_VER >= (5, 2):
|
||||
return img.rotate(degrees, **kwargs)
|
||||
elif _PIL_VER >= (5, 0):
|
||||
w, h = img.size
|
||||
post_trans = (0, 0)
|
||||
rotn_center = (w / 2.0, h / 2.0)
|
||||
angle = -math.radians(degrees)
|
||||
matrix = [
|
||||
round(math.cos(angle), 15),
|
||||
round(math.sin(angle), 15),
|
||||
0.0,
|
||||
round(-math.sin(angle), 15),
|
||||
round(math.cos(angle), 15),
|
||||
0.0,
|
||||
]
|
||||
|
||||
def transform(x, y, matrix):
|
||||
(a, b, c, d, e, f) = matrix
|
||||
return a * x + b * y + c, d * x + e * y + f
|
||||
|
||||
matrix[2], matrix[5] = transform(-rotn_center[0] - post_trans[0],
|
||||
-rotn_center[1] - post_trans[1],
|
||||
matrix)
|
||||
matrix[2] += rotn_center[0]
|
||||
matrix[5] += rotn_center[1]
|
||||
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
||||
else:
|
||||
return img.rotate(degrees, resample=kwargs['resample'])
|
||||
|
||||
|
||||
def auto_contrast(img, **__):
|
||||
return ImageOps.autocontrast(img)
|
||||
|
||||
|
||||
def invert(img, **__):
|
||||
return ImageOps.invert(img)
|
||||
|
||||
|
||||
def equalize(img, **__):
|
||||
return ImageOps.equalize(img)
|
||||
|
||||
|
||||
def solarize(img, thresh, **__):
|
||||
return ImageOps.solarize(img, thresh)
|
||||
|
||||
|
||||
def solarize_add(img, add, thresh=128, **__):
|
||||
lut = []
|
||||
for i in range(256):
|
||||
if i < thresh:
|
||||
lut.append(min(255, i + add))
|
||||
else:
|
||||
lut.append(i)
|
||||
if img.mode in ("L", "RGB"):
|
||||
if img.mode == "RGB" and len(lut) == 256:
|
||||
lut = lut + lut + lut
|
||||
return img.point(lut)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def posterize(img, bits_to_keep, **__):
|
||||
if bits_to_keep >= 8:
|
||||
return img
|
||||
return ImageOps.posterize(img, bits_to_keep)
|
||||
|
||||
|
||||
def contrast(img, factor, **__):
|
||||
return ImageEnhance.Contrast(img).enhance(factor)
|
||||
|
||||
|
||||
def color(img, factor, **__):
|
||||
return ImageEnhance.Color(img).enhance(factor)
|
||||
|
||||
|
||||
def brightness(img, factor, **__):
|
||||
return ImageEnhance.Brightness(img).enhance(factor)
|
||||
|
||||
|
||||
def sharpness(img, factor, **__):
|
||||
return ImageEnhance.Sharpness(img).enhance(factor)
|
||||
|
||||
|
||||
def _randomly_negate(v):
|
||||
"""With 50% prob, negate the value"""
|
||||
return -v if random.random() > 0.5 else v
|
||||
|
||||
|
||||
def _rotate_level_to_arg(level, _hparams):
|
||||
# range [-30, 30]
|
||||
level = (level / _MAX_LEVEL) * 30.
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _enhance_level_to_arg(level, _hparams):
|
||||
# range [0.1, 1.9]
|
||||
return (level / _MAX_LEVEL) * 1.8 + 0.1,
|
||||
|
||||
|
||||
def _enhance_increasing_level_to_arg(level, _hparams):
|
||||
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
|
||||
# range [0.1, 1.9]
|
||||
level = (level / _MAX_LEVEL) * .9
|
||||
level = 1.0 + _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _shear_level_to_arg(level, _hparams):
|
||||
# range [-0.3, 0.3]
|
||||
level = (level / _MAX_LEVEL) * 0.3
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _translate_abs_level_to_arg(level, hparams):
|
||||
translate_const = hparams['translate_const']
|
||||
level = (level / _MAX_LEVEL) * float(translate_const)
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _translate_rel_level_to_arg(level, hparams):
|
||||
# default range [-0.45, 0.45]
|
||||
translate_pct = hparams.get('translate_pct', 0.45)
|
||||
level = (level / _MAX_LEVEL) * translate_pct
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _posterize_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow TPU EfficientNet impl
|
||||
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _MAX_LEVEL) * 4),
|
||||
|
||||
|
||||
def _posterize_increasing_level_to_arg(level, hparams):
|
||||
# As per Tensorflow models research and UDA impl
|
||||
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
||||
# intensity/severity of augmentation increases with level
|
||||
return 4 - _posterize_level_to_arg(level, hparams)[0],
|
||||
|
||||
|
||||
def _posterize_original_level_to_arg(level, _hparams):
|
||||
# As per original AutoAugment paper description
|
||||
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _MAX_LEVEL) * 4) + 4,
|
||||
|
||||
|
||||
def _solarize_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _MAX_LEVEL) * 256),
|
||||
|
||||
|
||||
def _solarize_increasing_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation increases with level
|
||||
return 256 - _solarize_level_to_arg(level, _hparams)[0],
|
||||
|
||||
|
||||
def _solarize_add_level_to_arg(level, _hparams):
|
||||
# range [0, 110]
|
||||
return int((level / _MAX_LEVEL) * 110),
|
||||
|
||||
|
||||
LEVEL_TO_ARG = {
|
||||
'AutoContrast': None,
|
||||
'Equalize': None,
|
||||
'Invert': None,
|
||||
'Rotate': _rotate_level_to_arg,
|
||||
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
||||
'Posterize': _posterize_level_to_arg,
|
||||
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
|
||||
'PosterizeOriginal': _posterize_original_level_to_arg,
|
||||
'Solarize': _solarize_level_to_arg,
|
||||
'SolarizeIncreasing': _solarize_increasing_level_to_arg,
|
||||
'SolarizeAdd': _solarize_add_level_to_arg,
|
||||
'Color': _enhance_level_to_arg,
|
||||
'ColorIncreasing': _enhance_increasing_level_to_arg,
|
||||
'Contrast': _enhance_level_to_arg,
|
||||
'ContrastIncreasing': _enhance_increasing_level_to_arg,
|
||||
'Brightness': _enhance_level_to_arg,
|
||||
'BrightnessIncreasing': _enhance_increasing_level_to_arg,
|
||||
'Sharpness': _enhance_level_to_arg,
|
||||
'SharpnessIncreasing': _enhance_increasing_level_to_arg,
|
||||
'ShearX': _shear_level_to_arg,
|
||||
'ShearY': _shear_level_to_arg,
|
||||
'TranslateX': _translate_abs_level_to_arg,
|
||||
'TranslateY': _translate_abs_level_to_arg,
|
||||
'TranslateXRel': _translate_rel_level_to_arg,
|
||||
'TranslateYRel': _translate_rel_level_to_arg,
|
||||
}
|
||||
|
||||
NAME_TO_OP = {
|
||||
'AutoContrast': auto_contrast,
|
||||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'Rotate': rotate,
|
||||
'Posterize': posterize,
|
||||
'PosterizeIncreasing': posterize,
|
||||
'PosterizeOriginal': posterize,
|
||||
'Solarize': solarize,
|
||||
'SolarizeIncreasing': solarize,
|
||||
'SolarizeAdd': solarize_add,
|
||||
'Color': color,
|
||||
'ColorIncreasing': color,
|
||||
'Contrast': contrast,
|
||||
'ContrastIncreasing': contrast,
|
||||
'Brightness': brightness,
|
||||
'BrightnessIncreasing': brightness,
|
||||
'Sharpness': sharpness,
|
||||
'SharpnessIncreasing': sharpness,
|
||||
'ShearX': shear_x,
|
||||
'ShearY': shear_y,
|
||||
'TranslateX': translate_x_abs,
|
||||
'TranslateY': translate_y_abs,
|
||||
'TranslateXRel': translate_x_rel,
|
||||
'TranslateYRel': translate_y_rel,
|
||||
}
|
||||
|
||||
|
||||
class AugmentOp(object):
|
||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
self.aug_fn = NAME_TO_OP[name]
|
||||
self.level_fn = LEVEL_TO_ARG[name]
|
||||
self.prob = prob
|
||||
self.magnitude = magnitude
|
||||
self.hparams = hparams.copy()
|
||||
self.kwargs = dict(
|
||||
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
||||
resample=hparams['interpolation']
|
||||
if 'interpolation' in hparams else _RANDOM_INTERPOLATION, )
|
||||
|
||||
# If magnitude_std is > 0, we introduce some randomness
|
||||
# in the usually fixed policy and sample magnitude from a normal distribution
|
||||
# with mean `magnitude` and std-dev of `magnitude_std`.
|
||||
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
||||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||
|
||||
def __call__(self, img):
|
||||
if self.prob < 1.0 and random.random() > self.prob:
|
||||
return img
|
||||
magnitude = self.magnitude
|
||||
if self.magnitude_std and self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
||||
level_args = self.level_fn(
|
||||
magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||
|
||||
|
||||
def auto_augment_policy_v0(hparams):
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
|
||||
policy = [
|
||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
||||
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
||||
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
||||
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
||||
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
||||
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
||||
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
||||
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
||||
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)
|
||||
], # This results in black image with Tpu posterize
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_v0r(hparams):
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
|
||||
# in Google research implementation (number of bits discarded increases with magnitude)
|
||||
policy = [
|
||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
||||
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
||||
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
||||
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
||||
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
||||
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
||||
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
||||
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
||||
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_original(hparams):
|
||||
# ImageNet policy from https://arxiv.org/abs/1805.09501
|
||||
policy = [
|
||||
[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||
[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
|
||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
||||
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
||||
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
||||
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
||||
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_originalr(hparams):
|
||||
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
|
||||
policy = [
|
||||
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
|
||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
||||
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
||||
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
||||
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
||||
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy(name='v0', hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
if name == 'original':
|
||||
return auto_augment_policy_original(hparams)
|
||||
elif name == 'originalr':
|
||||
return auto_augment_policy_originalr(hparams)
|
||||
elif name == 'v0':
|
||||
return auto_augment_policy_v0(hparams)
|
||||
elif name == 'v0r':
|
||||
return auto_augment_policy_v0r(hparams)
|
||||
else:
|
||||
assert False, 'Unknown AA policy (%s)' % name
|
||||
|
||||
|
||||
class AutoAugment(object):
|
||||
def __init__(self, policy):
|
||||
self.policy = policy
|
||||
|
||||
def __call__(self, img):
|
||||
sub_policy = random.choice(self.policy)
|
||||
for op in sub_policy:
|
||||
img = op(img)
|
||||
return img
|
||||
|
||||
|
||||
def auto_augment_transform(config_str, hparams):
|
||||
"""
|
||||
Create a AutoAugment transform
|
||||
|
||||
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
||||
The remaining sections, not order sepecific determine
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
||||
|
||||
:return: A callable Transform Op
|
||||
"""
|
||||
config = config_str.split('-')
|
||||
policy_name = config[0]
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
else:
|
||||
assert False, 'Unknown AutoAugment config section'
|
||||
aa_policy = auto_augment_policy(policy_name, hparams=hparams)
|
||||
return AutoAugment(aa_policy)
|
||||
|
||||
|
||||
_RAND_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'Equalize',
|
||||
'Invert',
|
||||
'Rotate',
|
||||
'Posterize',
|
||||
'Solarize',
|
||||
'SolarizeAdd',
|
||||
'Color',
|
||||
'Contrast',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
_RAND_INCREASING_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'Equalize',
|
||||
'Invert',
|
||||
'Rotate',
|
||||
'PosterizeIncreasing',
|
||||
'SolarizeIncreasing',
|
||||
'SolarizeAdd',
|
||||
'ColorIncreasing',
|
||||
'ContrastIncreasing',
|
||||
'BrightnessIncreasing',
|
||||
'SharpnessIncreasing',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
||||
# They may not result in increased performance, but could likely be tuned to so.
|
||||
_RAND_CHOICE_WEIGHTS_0 = {
|
||||
'Rotate': 0.3,
|
||||
'ShearX': 0.2,
|
||||
'ShearY': 0.2,
|
||||
'TranslateXRel': 0.1,
|
||||
'TranslateYRel': 0.1,
|
||||
'Color': .025,
|
||||
'Sharpness': 0.025,
|
||||
'AutoContrast': 0.025,
|
||||
'Solarize': .005,
|
||||
'SolarizeAdd': .005,
|
||||
'Contrast': .005,
|
||||
'Brightness': .005,
|
||||
'Equalize': .005,
|
||||
'Posterize': 0,
|
||||
'Invert': 0,
|
||||
}
|
||||
|
||||
|
||||
def _select_rand_weights(weight_idx=0, transforms=None):
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
assert weight_idx == 0 # only one set of weights currently
|
||||
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
||||
probs = [rand_weights[k] for k in transforms]
|
||||
probs /= np.sum(probs)
|
||||
return probs
|
||||
|
||||
|
||||
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
return [
|
||||
AugmentOp(
|
||||
name, prob=0.5, magnitude=magnitude, hparams=hparams)
|
||||
for name in transforms
|
||||
]
|
||||
|
||||
|
||||
class RandAugment(object):
|
||||
def __init__(self, ops, num_layers=2, choice_weights=None):
|
||||
self.ops = ops
|
||||
self.num_layers = num_layers
|
||||
self.choice_weights = choice_weights
|
||||
|
||||
def __call__(self, img):
|
||||
# no replacement when using weighted choice
|
||||
ops = np.random.choice(
|
||||
self.ops,
|
||||
self.num_layers,
|
||||
replace=self.choice_weights is None,
|
||||
p=self.choice_weights)
|
||||
for op in ops:
|
||||
img = op(img)
|
||||
return img
|
||||
|
||||
|
||||
def rand_augment_transform(config_str, hparams):
|
||||
"""
|
||||
Create a RandAugment transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude of rand augment
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
||||
|
||||
:return: A callable Transform Op
|
||||
"""
|
||||
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
||||
num_layers = 2 # default to 2 ops per image
|
||||
weight_idx = None # default to no probability weights for op choice
|
||||
transforms = _RAND_TRANSFORMS
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'rand'
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
elif key == 'inc':
|
||||
if bool(val):
|
||||
transforms = _RAND_INCREASING_TRANSFORMS
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
num_layers = int(val)
|
||||
elif key == 'w':
|
||||
weight_idx = int(val)
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
ra_ops = rand_augment_ops(
|
||||
magnitude=magnitude, hparams=hparams, transforms=transforms)
|
||||
choice_weights = None if weight_idx is None else _select_rand_weights(
|
||||
weight_idx)
|
||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||
|
||||
|
||||
_AUGMIX_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'ColorIncreasing', # not in paper
|
||||
'ContrastIncreasing', # not in paper
|
||||
'BrightnessIncreasing', # not in paper
|
||||
'SharpnessIncreasing', # not in paper
|
||||
'Equalize',
|
||||
'Rotate',
|
||||
'PosterizeIncreasing',
|
||||
'SolarizeIncreasing',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
]
|
||||
|
||||
|
||||
def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _AUGMIX_TRANSFORMS
|
||||
return [
|
||||
AugmentOp(
|
||||
name, prob=1.0, magnitude=magnitude, hparams=hparams)
|
||||
for name in transforms
|
||||
]
|
||||
|
||||
|
||||
class AugMixAugment(object):
|
||||
""" AugMix Transform
|
||||
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
||||
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
||||
https://arxiv.org/abs/1912.02781
|
||||
"""
|
||||
|
||||
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
||||
self.ops = ops
|
||||
self.alpha = alpha
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.blended = blended # blended mode is faster but not well tested
|
||||
|
||||
def _calc_blended_weights(self, ws, m):
|
||||
ws = ws * m
|
||||
cump = 1.
|
||||
rws = []
|
||||
for w in ws[::-1]:
|
||||
alpha = w / cump
|
||||
cump *= (1 - alpha)
|
||||
rws.append(alpha)
|
||||
return np.array(rws[::-1], dtype=np.float32)
|
||||
|
||||
def _apply_blended(self, img, mixing_weights, m):
|
||||
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
|
||||
# of accumulating the mix for each chain in a Numpy array and then blending with original,
|
||||
# it recomputes the blending coefficients and applies one PIL image blend per chain.
|
||||
# TODO the results appear in the right ballpark but they differ by more than rounding.
|
||||
img_orig = img.copy()
|
||||
ws = self._calc_blended_weights(mixing_weights, m)
|
||||
for w in ws:
|
||||
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||
ops = np.random.choice(self.ops, depth, replace=True)
|
||||
img_aug = img_orig # no ops are in-place, deep copy not necessary
|
||||
for op in ops:
|
||||
img_aug = op(img_aug)
|
||||
img = Image.blend(img, img_aug, w)
|
||||
return img
|
||||
|
||||
def _apply_basic(self, img, mixing_weights, m):
|
||||
# This is a literal adaptation of the paper/official implementation without normalizations and
|
||||
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
|
||||
# typical augmentation transforms, could use a GPU / Kornia implementation.
|
||||
img_shape = img.size[0], img.size[1], len(img.getbands())
|
||||
mixed = np.zeros(img_shape, dtype=np.float32)
|
||||
for mw in mixing_weights:
|
||||
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||
ops = np.random.choice(self.ops, depth, replace=True)
|
||||
img_aug = img # no ops are in-place, deep copy not necessary
|
||||
for op in ops:
|
||||
img_aug = op(img_aug)
|
||||
mixed += mw * np.asarray(img_aug, dtype=np.float32)
|
||||
np.clip(mixed, 0, 255., out=mixed)
|
||||
mixed = Image.fromarray(mixed.astype(np.uint8))
|
||||
return Image.blend(img, mixed, m)
|
||||
|
||||
def __call__(self, img):
|
||||
mixing_weights = np.float32(
|
||||
np.random.dirichlet([self.alpha] * self.width))
|
||||
m = np.float32(np.random.beta(self.alpha, self.alpha))
|
||||
if self.blended:
|
||||
mixed = self._apply_blended(img, mixing_weights, m)
|
||||
else:
|
||||
mixed = self._apply_basic(img, mixing_weights, m)
|
||||
return mixed
|
||||
|
||||
|
||||
def augment_and_mix_transform(config_str, hparams):
|
||||
""" Create AugMix transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
||||
'w' - integer width of augmentation chain (default: 3)
|
||||
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
||||
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
||||
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
||||
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the Augmentation transforms
|
||||
|
||||
:return: A callable Transform Op
|
||||
"""
|
||||
magnitude = 3
|
||||
width = 3
|
||||
depth = -1
|
||||
alpha = 1.
|
||||
blended = False
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'augmix'
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'w':
|
||||
width = int(val)
|
||||
elif key == 'd':
|
||||
depth = int(val)
|
||||
elif key == 'a':
|
||||
alpha = float(val)
|
||||
elif key == 'b':
|
||||
blended = bool(val)
|
||||
else:
|
||||
assert False, 'Unknown AugMix config section'
|
||||
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
|
||||
return AugMixAugment(
|
||||
ops, alpha=alpha, width=width, depth=depth, blended=blended)
|
||||
|
||||
|
||||
class RawTimmAutoAugment(object):
|
||||
"""TimmAutoAugment API for PaddleClas."""
|
||||
|
||||
def __init__(self,
|
||||
config_str="rand-m9-mstd0.5-inc1",
|
||||
interpolation="bicubic",
|
||||
img_size=224,
|
||||
mean=IMAGENET_DEFAULT_MEAN):
|
||||
if isinstance(img_size, (tuple, list)):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
img_size_min = img_size
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]), )
|
||||
if interpolation and interpolation != 'random':
|
||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
||||
if config_str.startswith('rand'):
|
||||
self.augment_func = rand_augment_transform(config_str, aa_params)
|
||||
elif config_str.startswith('augmix'):
|
||||
aa_params['translate_pct'] = 0.3
|
||||
self.augment_func = augment_and_mix_transform(config_str,
|
||||
aa_params)
|
||||
elif config_str.startswith('auto'):
|
||||
self.augment_func = auto_augment_transform(config_str, aa_params)
|
||||
else:
|
||||
raise Exception(
|
||||
"ConfigError: The TimmAutoAugment Op only support RandAugment, AutoAugment, AugMix, and the config_str only starts with \"rand\", \"augmix\", \"auto\"."
|
||||
)
|
||||
|
||||
def __call__(self, img):
|
||||
return self.augment_func(img)
|
Loading…
Reference in New Issue