mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More AutoAugment work. Ready to roll...
This commit is contained in:
parent
25d2088d9e
commit
b750b76f67
@ -1,7 +1,13 @@
|
|||||||
|
""" Auto Augment
|
||||||
|
Implementation adapted from:
|
||||||
|
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
||||||
|
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
from torchvision import transforms
|
from PIL import Image, ImageOps, ImageEnhance
|
||||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw
|
|
||||||
import PIL
|
import PIL
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -131,8 +137,11 @@ def solarize_add(img, add, thresh=128, **__):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
def posterize(img, bits, **__):
|
def posterize(img, bits_to_keep, **__):
|
||||||
return ImageOps.posterize(img, 4 - bits)
|
if bits_to_keep >= 8:
|
||||||
|
return img
|
||||||
|
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
|
||||||
|
return ImageOps.posterize(img, bits_to_keep)
|
||||||
|
|
||||||
|
|
||||||
def contrast(img, factor, **__):
|
def contrast(img, factor, **__):
|
||||||
@ -157,16 +166,19 @@ def _randomly_negate(v):
|
|||||||
|
|
||||||
|
|
||||||
def _rotate_level_to_arg(level):
|
def _rotate_level_to_arg(level):
|
||||||
|
# range [-30, 30]
|
||||||
level = (level / _MAX_LEVEL) * 30.
|
level = (level / _MAX_LEVEL) * 30.
|
||||||
level = _randomly_negate(level)
|
level = _randomly_negate(level)
|
||||||
return (level,)
|
return (level,)
|
||||||
|
|
||||||
|
|
||||||
def _enhance_level_to_arg(level):
|
def _enhance_level_to_arg(level):
|
||||||
|
# range [0.1, 1.9]
|
||||||
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
||||||
|
|
||||||
|
|
||||||
def _shear_level_to_arg(level):
|
def _shear_level_to_arg(level):
|
||||||
|
# range [-0.3, 0.3]
|
||||||
level = (level / _MAX_LEVEL) * 0.3
|
level = (level / _MAX_LEVEL) * 0.3
|
||||||
level = _randomly_negate(level)
|
level = _randomly_negate(level)
|
||||||
return (level,)
|
return (level,)
|
||||||
@ -179,6 +191,7 @@ def _translate_abs_level_to_arg(level, translate_const):
|
|||||||
|
|
||||||
|
|
||||||
def _translate_rel_level_to_arg(level):
|
def _translate_rel_level_to_arg(level):
|
||||||
|
# range [-0.45, 0.45]
|
||||||
level = (level / _MAX_LEVEL) * 0.45
|
level = (level / _MAX_LEVEL) * 0.45
|
||||||
level = _randomly_negate(level)
|
level = _randomly_negate(level)
|
||||||
return (level,)
|
return (level,)
|
||||||
@ -190,9 +203,12 @@ def level_to_arg(hparams):
|
|||||||
'Equalize': lambda level: (),
|
'Equalize': lambda level: (),
|
||||||
'Invert': lambda level: (),
|
'Invert': lambda level: (),
|
||||||
'Rotate': _rotate_level_to_arg,
|
'Rotate': _rotate_level_to_arg,
|
||||||
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),),
|
# FIXME these are both different from original impl as I believe there is a bug,
|
||||||
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),),
|
# not sure what is the correct alternative, hence 2 options that look better
|
||||||
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),),
|
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
|
||||||
|
'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
|
||||||
|
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
|
||||||
|
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
|
||||||
'Color': _enhance_level_to_arg,
|
'Color': _enhance_level_to_arg,
|
||||||
'Contrast': _enhance_level_to_arg,
|
'Contrast': _enhance_level_to_arg,
|
||||||
'Brightness': _enhance_level_to_arg,
|
'Brightness': _enhance_level_to_arg,
|
||||||
@ -212,6 +228,7 @@ NAME_TO_OP = {
|
|||||||
'Invert': invert,
|
'Invert': invert,
|
||||||
'Rotate': rotate,
|
'Rotate': rotate,
|
||||||
'Posterize': posterize,
|
'Posterize': posterize,
|
||||||
|
'Posterize2': posterize,
|
||||||
'Solarize': solarize,
|
'Solarize': solarize,
|
||||||
'SolarizeAdd': solarize_add,
|
'SolarizeAdd': solarize_add,
|
||||||
'Color': color,
|
'Color': color,
|
||||||
@ -252,10 +269,8 @@ class AutoAugmentOp:
|
|||||||
|
|
||||||
|
|
||||||
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
||||||
"""Autoaugment policy that was used in AutoAugment Paper."""
|
# ImageNet policy from TPU EfficientNet impl, cannot find
|
||||||
# Each tuple is an augmentation operation of the form
|
# a paper reference.
|
||||||
# (operation, probability, magnitude). Each element in policy is a
|
|
||||||
# sub-policy that will be applied sequentially on the image.
|
|
||||||
policy = [
|
policy = [
|
||||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||||
@ -287,6 +302,48 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
|||||||
return pc
|
return pc
|
||||||
|
|
||||||
|
|
||||||
|
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
|
||||||
|
# ImageNet policy from https://arxiv.org/abs/1805.09501
|
||||||
|
policy = [
|
||||||
|
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||||
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||||
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||||
|
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
|
||||||
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||||
|
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||||
|
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||||
|
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||||
|
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||||
|
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
|
||||||
|
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||||
|
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||||
|
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||||
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||||
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||||
|
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
||||||
|
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
||||||
|
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
||||||
|
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
||||||
|
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
||||||
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||||
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||||
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||||
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||||
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||||
|
]
|
||||||
|
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
|
||||||
|
return pc
|
||||||
|
|
||||||
|
|
||||||
|
def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
|
||||||
|
if name == 'original':
|
||||||
|
return auto_augment_policy_original(hparams)
|
||||||
|
elif name == 'v0':
|
||||||
|
return auto_augment_policy_v0(hparams)
|
||||||
|
else:
|
||||||
|
assert False, 'Unknown AA policy (%s)' % name
|
||||||
|
|
||||||
|
|
||||||
class AutoAugment:
|
class AutoAugment:
|
||||||
|
|
||||||
def __init__(self, policy):
|
def __init__(self, policy):
|
||||||
|
@ -92,6 +92,7 @@ def create_transform(
|
|||||||
is_training=False,
|
is_training=False,
|
||||||
use_prefetcher=False,
|
use_prefetcher=False,
|
||||||
color_jitter=0.4,
|
color_jitter=0.4,
|
||||||
|
auto_augment=None,
|
||||||
interpolation='bilinear',
|
interpolation='bilinear',
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std=IMAGENET_DEFAULT_STD,
|
||||||
@ -109,21 +110,14 @@ def create_transform(
|
|||||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
is_training=is_training, size=img_size, interpolation=interpolation)
|
||||||
else:
|
else:
|
||||||
if is_training:
|
if is_training:
|
||||||
if True:
|
transform = transforms_imagenet_train(
|
||||||
transform = transforms_imagenet_aa(
|
img_size,
|
||||||
img_size,
|
color_jitter=color_jitter,
|
||||||
interpolation=interpolation,
|
auto_augment=auto_augment,
|
||||||
use_prefetcher=use_prefetcher,
|
interpolation=interpolation,
|
||||||
mean=mean,
|
use_prefetcher=use_prefetcher,
|
||||||
std=std)
|
mean=mean,
|
||||||
else:
|
std=std)
|
||||||
transform = transforms_imagenet_train(
|
|
||||||
img_size,
|
|
||||||
color_jitter=color_jitter,
|
|
||||||
interpolation=interpolation,
|
|
||||||
use_prefetcher=use_prefetcher,
|
|
||||||
mean=mean,
|
|
||||||
std=std)
|
|
||||||
else:
|
else:
|
||||||
transform = transforms_imagenet_eval(
|
transform = transforms_imagenet_eval(
|
||||||
img_size,
|
img_size,
|
||||||
@ -146,6 +140,7 @@ def create_loader(
|
|||||||
rand_erase_mode='const',
|
rand_erase_mode='const',
|
||||||
rand_erase_count=1,
|
rand_erase_count=1,
|
||||||
color_jitter=0.4,
|
color_jitter=0.4,
|
||||||
|
auto_augment=None,
|
||||||
interpolation='bilinear',
|
interpolation='bilinear',
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std=IMAGENET_DEFAULT_STD,
|
||||||
@ -161,6 +156,7 @@ def create_loader(
|
|||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
use_prefetcher=use_prefetcher,
|
use_prefetcher=use_prefetcher,
|
||||||
color_jitter=color_jitter,
|
color_jitter=color_jitter,
|
||||||
|
auto_augment=auto_augment,
|
||||||
interpolation=interpolation,
|
interpolation=interpolation,
|
||||||
mean=mean,
|
mean=mean,
|
||||||
std=std,
|
std=std,
|
||||||
|
@ -9,7 +9,7 @@ import numpy as np
|
|||||||
|
|
||||||
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .random_erasing import RandomErasing
|
from .random_erasing import RandomErasing
|
||||||
from .auto_augment import AutoAugment, auto_augment_policy_v0
|
from .auto_augment import AutoAugment, auto_augment_policy
|
||||||
|
|
||||||
|
|
||||||
class ToNumpy:
|
class ToNumpy:
|
||||||
@ -57,10 +57,10 @@ def _pil_interp(method):
|
|||||||
return Image.BILINEAR
|
return Image.BILINEAR
|
||||||
|
|
||||||
|
|
||||||
RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
||||||
|
|
||||||
|
|
||||||
class RandomResizedCropAndInterpolation(object):
|
class RandomResizedCropAndInterpolation:
|
||||||
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
|
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
|
||||||
|
|
||||||
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
||||||
@ -85,7 +85,7 @@ class RandomResizedCropAndInterpolation(object):
|
|||||||
warnings.warn("range should be of kind (min, max)")
|
warnings.warn("range should be of kind (min, max)")
|
||||||
|
|
||||||
if interpolation == 'random':
|
if interpolation == 'random':
|
||||||
self.interpolation = RANDOM_INTERPOLATION
|
self.interpolation = _RANDOM_INTERPOLATION
|
||||||
else:
|
else:
|
||||||
self.interpolation = _pil_interp(interpolation)
|
self.interpolation = _pil_interp(interpolation)
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -161,52 +161,11 @@ class RandomResizedCropAndInterpolation(object):
|
|||||||
return format_string
|
return format_string
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_aa(
|
|
||||||
img_size=224,
|
|
||||||
scale=(0.08, 1.0),
|
|
||||||
interpolation='random',
|
|
||||||
random_erasing=0.4,
|
|
||||||
random_erasing_mode='const',
|
|
||||||
use_prefetcher=False,
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
|
||||||
std=IMAGENET_DEFAULT_STD
|
|
||||||
):
|
|
||||||
aa_params = dict(
|
|
||||||
cutout_max_pad_fraction=0.75,
|
|
||||||
cutout_const=100,
|
|
||||||
translate_const=img_size[-1] // 2 - 1,
|
|
||||||
img_mean=tuple([min(255, round(255*x)) for x in mean]),
|
|
||||||
)
|
|
||||||
if interpolation and interpolation != 'random':
|
|
||||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
|
||||||
aa_policy = auto_augment_policy_v0(aa_params)
|
|
||||||
|
|
||||||
tfl = [
|
|
||||||
RandomResizedCropAndInterpolation(
|
|
||||||
img_size, scale=scale, interpolation=interpolation),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
AutoAugment(aa_policy)
|
|
||||||
]
|
|
||||||
|
|
||||||
if use_prefetcher:
|
|
||||||
# prefetcher and collate will handle tensor conversion and norm
|
|
||||||
tfl += [ToNumpy()]
|
|
||||||
else:
|
|
||||||
tfl += [
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(
|
|
||||||
mean=torch.tensor(mean),
|
|
||||||
std=torch.tensor(std))
|
|
||||||
]
|
|
||||||
if random_erasing > 0.:
|
|
||||||
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
|
||||||
return transforms.Compose(tfl)
|
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_train(
|
def transforms_imagenet_train(
|
||||||
img_size=224,
|
img_size=224,
|
||||||
scale=(0.08, 1.0),
|
scale=(0.08, 1.0),
|
||||||
color_jitter=0.4,
|
color_jitter=0.4,
|
||||||
|
auto_augment=None,
|
||||||
interpolation='random',
|
interpolation='random',
|
||||||
random_erasing=0.4,
|
random_erasing=0.4,
|
||||||
random_erasing_mode='const',
|
random_erasing_mode='const',
|
||||||
@ -214,20 +173,30 @@ def transforms_imagenet_train(
|
|||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD
|
std=IMAGENET_DEFAULT_STD
|
||||||
):
|
):
|
||||||
if isinstance(color_jitter, (list, tuple)):
|
|
||||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
|
||||||
# or 4 if also augmenting hue
|
|
||||||
assert len(color_jitter) in (3, 4)
|
|
||||||
else:
|
|
||||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
|
||||||
color_jitter = (float(color_jitter),) * 3
|
|
||||||
|
|
||||||
tfl = [
|
tfl = [
|
||||||
RandomResizedCropAndInterpolation(
|
RandomResizedCropAndInterpolation(
|
||||||
img_size, scale=scale, interpolation=interpolation),
|
img_size, scale=scale, interpolation=interpolation),
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.RandomHorizontalFlip()
|
||||||
transforms.ColorJitter(*color_jitter),
|
|
||||||
]
|
]
|
||||||
|
if auto_augment:
|
||||||
|
aa_params = dict(
|
||||||
|
translate_const=img_size[-1] // 2 - 1,
|
||||||
|
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
||||||
|
)
|
||||||
|
if interpolation and interpolation != 'random':
|
||||||
|
aa_params['interpolation'] = _pil_interp(interpolation)
|
||||||
|
aa_policy = auto_augment_policy(auto_augment, aa_params)
|
||||||
|
tfl += [AutoAugment(aa_policy)]
|
||||||
|
else:
|
||||||
|
# color jitter is enabled when not using AA
|
||||||
|
if isinstance(color_jitter, (list, tuple)):
|
||||||
|
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||||
|
# or 4 if also augmenting hue
|
||||||
|
assert len(color_jitter) in (3, 4)
|
||||||
|
else:
|
||||||
|
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||||
|
color_jitter = (float(color_jitter),) * 3
|
||||||
|
tfl += [transforms.ColorJitter(*color_jitter)]
|
||||||
|
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
# prefetcher and collate will handle tensor conversion and norm
|
# prefetcher and collate will handle tensor conversion and norm
|
||||||
|
3
train.py
3
train.py
@ -89,6 +89,8 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RA
|
|||||||
# Augmentation parameters
|
# Augmentation parameters
|
||||||
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
||||||
help='Color jitter factor (default: 0.4)')
|
help='Color jitter factor (default: 0.4)')
|
||||||
|
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
||||||
|
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
||||||
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||||
help='Random erase prob (default: 0.)')
|
help='Random erase prob (default: 0.)')
|
||||||
parser.add_argument('--remode', type=str, default='const',
|
parser.add_argument('--remode', type=str, default='const',
|
||||||
@ -287,6 +289,7 @@ def main():
|
|||||||
rand_erase_mode=args.remode,
|
rand_erase_mode=args.remode,
|
||||||
rand_erase_count=args.recount,
|
rand_erase_count=args.recount,
|
||||||
color_jitter=args.color_jitter,
|
color_jitter=args.color_jitter,
|
||||||
|
auto_augment=args.aa,
|
||||||
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
||||||
mean=data_config['mean'],
|
mean=data_config['mean'],
|
||||||
std=data_config['std'],
|
std=data_config['std'],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user