mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
commit
db04677c94
@ -69,6 +69,7 @@ Several (less common) features that I often utilize in my projects are included.
|
||||
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc)
|
||||
* Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing
|
||||
* An inference script that dumps output to CSV is provided as an example
|
||||
* AutoAugment (https://arxiv.org/abs/1805.09501) and RandAugment (https://arxiv.org/abs/1909.13719) ImageNet configurations modeled after impl for EfficientNet training (https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py)
|
||||
|
||||
## Results
|
||||
|
||||
|
@ -4,3 +4,5 @@ from .dataset import Dataset, DatasetTar
|
||||
from .transforms import *
|
||||
from .loader import create_loader, create_transform
|
||||
from .mixup import mixup_target, FastCollateMixup
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
|
@ -1,17 +1,19 @@
|
||||
""" Auto Augment
|
||||
""" AutoAugment and RandAugment
|
||||
Implementation adapted from:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
||||
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
|
||||
Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import random
|
||||
import math
|
||||
import re
|
||||
from PIL import Image, ImageOps, ImageEnhance
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||
|
||||
_FILL = (128, 128, 128)
|
||||
@ -25,11 +27,11 @@ _HPARAMS_DEFAULT = dict(
|
||||
img_mean=_FILL,
|
||||
)
|
||||
|
||||
_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC)
|
||||
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
||||
|
||||
|
||||
def _interpolation(kwargs):
|
||||
interpolation = kwargs.pop('resample', Image.NEAREST)
|
||||
interpolation = kwargs.pop('resample', Image.BILINEAR)
|
||||
if isinstance(interpolation, (list, tuple)):
|
||||
return random.choice(interpolation)
|
||||
else:
|
||||
@ -140,7 +142,6 @@ def solarize_add(img, add, thresh=128, **__):
|
||||
def posterize(img, bits_to_keep, **__):
|
||||
if bits_to_keep >= 8:
|
||||
return img
|
||||
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
|
||||
return ImageOps.posterize(img, bits_to_keep)
|
||||
|
||||
|
||||
@ -165,61 +166,89 @@ def _randomly_negate(v):
|
||||
return -v if random.random() > 0.5 else v
|
||||
|
||||
|
||||
def _rotate_level_to_arg(level):
|
||||
def _rotate_level_to_arg(level, _hparams):
|
||||
# range [-30, 30]
|
||||
level = (level / _MAX_LEVEL) * 30.
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
return level,
|
||||
|
||||
|
||||
def _enhance_level_to_arg(level):
|
||||
def _enhance_level_to_arg(level, _hparams):
|
||||
# range [0.1, 1.9]
|
||||
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
||||
return (level / _MAX_LEVEL) * 1.8 + 0.1,
|
||||
|
||||
|
||||
def _shear_level_to_arg(level):
|
||||
def _shear_level_to_arg(level, _hparams):
|
||||
# range [-0.3, 0.3]
|
||||
level = (level / _MAX_LEVEL) * 0.3
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
return level,
|
||||
|
||||
|
||||
def _translate_abs_level_to_arg(level, translate_const):
|
||||
def _translate_abs_level_to_arg(level, hparams):
|
||||
translate_const = hparams['translate_const']
|
||||
level = (level / _MAX_LEVEL) * float(translate_const)
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
return level,
|
||||
|
||||
|
||||
def _translate_rel_level_to_arg(level):
|
||||
def _translate_rel_level_to_arg(level, _hparams):
|
||||
# range [-0.45, 0.45]
|
||||
level = (level / _MAX_LEVEL) * 0.45
|
||||
level = _randomly_negate(level)
|
||||
return (level,)
|
||||
return level,
|
||||
|
||||
|
||||
def level_to_arg(hparams):
|
||||
return {
|
||||
'AutoContrast': lambda level: (),
|
||||
'Equalize': lambda level: (),
|
||||
'Invert': lambda level: (),
|
||||
'Rotate': _rotate_level_to_arg,
|
||||
# FIXME these are both different from original impl as I believe there is a bug,
|
||||
# not sure what is the correct alternative, hence 2 options that look better
|
||||
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
|
||||
'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
|
||||
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
|
||||
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
|
||||
'Color': _enhance_level_to_arg,
|
||||
'Contrast': _enhance_level_to_arg,
|
||||
'Brightness': _enhance_level_to_arg,
|
||||
'Sharpness': _enhance_level_to_arg,
|
||||
'ShearX': _shear_level_to_arg,
|
||||
'ShearY': _shear_level_to_arg,
|
||||
'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
|
||||
'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
|
||||
'TranslateXRel': lambda level: _translate_rel_level_to_arg(level),
|
||||
'TranslateYRel': lambda level: _translate_rel_level_to_arg(level),
|
||||
}
|
||||
def _posterize_original_level_to_arg(level, _hparams):
|
||||
# As per original AutoAugment paper description
|
||||
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
||||
return int((level / _MAX_LEVEL) * 4) + 4,
|
||||
|
||||
|
||||
def _posterize_research_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow models research and UDA impl
|
||||
# range [4, 0], 'keep 4 down to 0 MSB of original image'
|
||||
return 4 - int((level / _MAX_LEVEL) * 4),
|
||||
|
||||
|
||||
def _posterize_tpu_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow TPU EfficientNet impl
|
||||
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
||||
return int((level / _MAX_LEVEL) * 4),
|
||||
|
||||
|
||||
def _solarize_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
return int((level / _MAX_LEVEL) * 256),
|
||||
|
||||
|
||||
def _solarize_add_level_to_arg(level, _hparams):
|
||||
# range [0, 110]
|
||||
return int((level / _MAX_LEVEL) * 110),
|
||||
|
||||
|
||||
LEVEL_TO_ARG = {
|
||||
'AutoContrast': None,
|
||||
'Equalize': None,
|
||||
'Invert': None,
|
||||
'Rotate': _rotate_level_to_arg,
|
||||
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
||||
'PosterizeOriginal': _posterize_original_level_to_arg,
|
||||
'PosterizeResearch': _posterize_research_level_to_arg,
|
||||
'PosterizeTpu': _posterize_tpu_level_to_arg,
|
||||
'Solarize': _solarize_level_to_arg,
|
||||
'SolarizeAdd': _solarize_add_level_to_arg,
|
||||
'Color': _enhance_level_to_arg,
|
||||
'Contrast': _enhance_level_to_arg,
|
||||
'Brightness': _enhance_level_to_arg,
|
||||
'Sharpness': _enhance_level_to_arg,
|
||||
'ShearX': _shear_level_to_arg,
|
||||
'ShearY': _shear_level_to_arg,
|
||||
'TranslateX': _translate_abs_level_to_arg,
|
||||
'TranslateY': _translate_abs_level_to_arg,
|
||||
'TranslateXRel': _translate_rel_level_to_arg,
|
||||
'TranslateYRel': _translate_rel_level_to_arg,
|
||||
}
|
||||
|
||||
|
||||
NAME_TO_OP = {
|
||||
@ -227,8 +256,9 @@ NAME_TO_OP = {
|
||||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'Rotate': rotate,
|
||||
'Posterize': posterize,
|
||||
'Posterize2': posterize,
|
||||
'PosterizeOriginal': posterize,
|
||||
'PosterizeResearch': posterize,
|
||||
'PosterizeTpu': posterize,
|
||||
'Solarize': solarize,
|
||||
'SolarizeAdd': solarize_add,
|
||||
'Color': color,
|
||||
@ -246,35 +276,37 @@ NAME_TO_OP = {
|
||||
|
||||
class AutoAugmentOp:
|
||||
|
||||
def __init__(self, name, prob, magnitude, hparams={}):
|
||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
self.aug_fn = NAME_TO_OP[name]
|
||||
self.level_fn = level_to_arg(hparams)[name]
|
||||
self.level_fn = LEVEL_TO_ARG[name]
|
||||
self.prob = prob
|
||||
self.magnitude = magnitude
|
||||
# If std deviation of magnitude is > 0, we introduce some randomness
|
||||
# in the usually fixed policy and sample magnitude from normal dist
|
||||
# with mean magnitude and std-dev of magnitude_std.
|
||||
# NOTE This is being tested as it's not in paper or reference impl.
|
||||
self.magnitude_std = 0.5 # FIXME add arg/hparam
|
||||
self.kwargs = {
|
||||
'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
||||
'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION
|
||||
}
|
||||
self.hparams = hparams.copy()
|
||||
self.kwargs = dict(
|
||||
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
||||
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
|
||||
)
|
||||
|
||||
# If magnitude_std is > 0, we introduce some randomness
|
||||
# in the usually fixed policy and sample magnitude from a normal distribution
|
||||
# with mean `magnitude` and std-dev of `magnitude_std`.
|
||||
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
||||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||
|
||||
def __call__(self, img):
|
||||
if self.prob < random.random():
|
||||
if random.random() > self.prob:
|
||||
return img
|
||||
magnitude = self.magnitude
|
||||
if self.magnitude_std and self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
magnitude = min(_MAX_LEVEL, max(0, magnitude))
|
||||
level_args = self.level_fn(magnitude)
|
||||
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
||||
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||
|
||||
|
||||
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
||||
# ImageNet policy from TPU EfficientNet impl, cannot find
|
||||
# a paper reference.
|
||||
def auto_augment_policy_v0(hparams):
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
|
||||
policy = [
|
||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||
@ -288,7 +320,7 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
@ -298,27 +330,60 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||
[('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
|
||||
def auto_augment_policy_v0r(hparams):
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize
|
||||
policy = [
|
||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
||||
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
||||
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
||||
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
||||
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
||||
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
||||
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
||||
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
||||
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_original(hparams):
|
||||
# ImageNet policy from https://arxiv.org/abs/1805.09501
|
||||
policy = [
|
||||
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
|
||||
[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
|
||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||
@ -335,15 +400,53 @@ def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
|
||||
def auto_augment_policy_originalr(hparams):
|
||||
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
|
||||
policy = [
|
||||
[('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
[('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||
[('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)],
|
||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
||||
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
||||
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
||||
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
||||
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy(name='v0', hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
if name == 'original':
|
||||
return auto_augment_policy_original(hparams)
|
||||
elif name == 'originalr':
|
||||
return auto_augment_policy_originalr(hparams)
|
||||
elif name == 'v0':
|
||||
return auto_augment_policy_v0(hparams)
|
||||
elif name == 'v0r':
|
||||
return auto_augment_policy_v0r(hparams)
|
||||
else:
|
||||
assert False, 'Unknown AA policy (%s)' % name
|
||||
|
||||
@ -358,3 +461,151 @@ class AutoAugment:
|
||||
for op in sub_policy:
|
||||
img = op(img)
|
||||
return img
|
||||
|
||||
|
||||
def auto_augment_transform(config_str, hparams):
|
||||
"""
|
||||
Create a AutoAugment transform
|
||||
|
||||
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
||||
The remaining sections, not order sepecific determine
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
"""
|
||||
config = config_str.split('-')
|
||||
policy_name = config[0]
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
else:
|
||||
assert False, 'Unknown AutoAugment config section'
|
||||
aa_policy = auto_augment_policy(policy_name, hparams=hparams)
|
||||
return AutoAugment(aa_policy)
|
||||
|
||||
|
||||
_RAND_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'Equalize',
|
||||
'Invert',
|
||||
'Rotate',
|
||||
'PosterizeTpu',
|
||||
'Solarize',
|
||||
'SolarizeAdd',
|
||||
'Color',
|
||||
'Contrast',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # FIXME I implement this as random erasing separately
|
||||
]
|
||||
|
||||
|
||||
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
||||
# They may not result in increased performance, but could likely be tuned to so.
|
||||
_RAND_CHOICE_WEIGHTS_0 = {
|
||||
'Rotate': 0.3,
|
||||
'ShearX': 0.2,
|
||||
'ShearY': 0.2,
|
||||
'TranslateXRel': 0.1,
|
||||
'TranslateYRel': 0.1,
|
||||
'Color': .025,
|
||||
'Sharpness': 0.025,
|
||||
'AutoContrast': 0.025,
|
||||
'Solarize': .005,
|
||||
'SolarizeAdd': .005,
|
||||
'Contrast': .005,
|
||||
'Brightness': .005,
|
||||
'Equalize': .005,
|
||||
'PosterizeTpu': 0,
|
||||
'Invert': 0,
|
||||
}
|
||||
|
||||
|
||||
def _select_rand_weights(weight_idx=0, transforms=None):
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
assert weight_idx == 0 # only one set of weights currently
|
||||
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
||||
probs = [rand_weights[k] for k in transforms]
|
||||
probs /= np.sum(probs)
|
||||
return probs
|
||||
|
||||
|
||||
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
return [AutoAugmentOp(
|
||||
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
class RandAugment:
|
||||
def __init__(self, ops, num_layers=2, choice_weights=None):
|
||||
self.ops = ops
|
||||
self.num_layers = num_layers
|
||||
self.choice_weights = choice_weights
|
||||
|
||||
def __call__(self, img):
|
||||
# no replacement when using weighted choice
|
||||
ops = np.random.choice(
|
||||
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
|
||||
for op in ops:
|
||||
img = op(img)
|
||||
return img
|
||||
|
||||
|
||||
def rand_augment_transform(config_str, hparams):
|
||||
"""
|
||||
Create a RandAugment transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude of rand augment
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
"""
|
||||
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
||||
num_layers = 2 # default to 2 ops per image
|
||||
weight_idx = None # default to no probability weights for op choice
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'rand'
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
num_layers = int(val)
|
||||
elif key == 'w':
|
||||
weight_idx = int(val)
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
|
||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||
|
@ -9,7 +9,7 @@ import numpy as np
|
||||
|
||||
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .random_erasing import RandomErasing
|
||||
from .auto_augment import AutoAugment, auto_augment_policy
|
||||
from .auto_augment import auto_augment_transform, rand_augment_transform
|
||||
|
||||
|
||||
class ToNumpy:
|
||||
@ -179,6 +179,7 @@ def transforms_imagenet_train(
|
||||
transforms.RandomHorizontalFlip()
|
||||
]
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
if isinstance(img_size, tuple):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
@ -189,8 +190,10 @@ def transforms_imagenet_train(
|
||||
)
|
||||
if interpolation and interpolation != 'random':
|
||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
||||
aa_policy = auto_augment_policy(auto_augment, aa_params)
|
||||
tfl += [AutoAugment(aa_policy)]
|
||||
if auto_augment.startswith('rand'):
|
||||
tfl += [rand_augment_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
tfl += [auto_augment_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
# color jitter is enabled when not using AA
|
||||
if isinstance(color_jitter, (list, tuple)):
|
||||
|
@ -25,12 +25,13 @@ def create_model(
|
||||
"""
|
||||
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
|
||||
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
|
||||
supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet'])
|
||||
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
||||
# Only gen_efficientnet models have support for batchnorm params or drop_connect_rate passed as args
|
||||
is_efficientnet = is_model_in_modules(model_name, ['gen_efficientnet'])
|
||||
if not is_efficientnet:
|
||||
kwargs.pop('bn_tf', None)
|
||||
kwargs.pop('bn_momentum', None)
|
||||
kwargs.pop('bn_eps', None)
|
||||
kwargs.pop('drop_connect_rate', None)
|
||||
|
||||
if is_model(model_name):
|
||||
create_fn = model_entrypoint(model_name)
|
||||
|
@ -373,25 +373,37 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
|
||||
_USE_SWISH_OPT = True
|
||||
if _USE_SWISH_OPT:
|
||||
class SwishAutoFn(torch.autograd.Function):
|
||||
""" Memory Efficient Swish
|
||||
From: https://blog.ceshine.net/post/pytorch-memory-swish/
|
||||
@torch.jit.script
|
||||
def swish_jit_fwd(x):
|
||||
return x.mul(torch.sigmoid(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||
|
||||
|
||||
class SwishJitAutoFn(torch.autograd.Function):
|
||||
""" torch.jit.script optimised Swish
|
||||
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
result = x.mul(torch.sigmoid(x))
|
||||
ctx.save_for_backward(x)
|
||||
return result
|
||||
return swish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_variables[0]
|
||||
sigmoid_x = torch.sigmoid(x)
|
||||
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
|
||||
x = ctx.saved_tensors[0]
|
||||
return swish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def swish(x, inplace=False):
|
||||
return SwishAutoFn.apply(x)
|
||||
# inplace ignored
|
||||
return SwishJitAutoFn.apply(x)
|
||||
else:
|
||||
def swish(x, inplace=False):
|
||||
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
|
||||
|
7
train.py
7
train.py
@ -65,6 +65,8 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
|
||||
help='Drop connect rate (default: 0.)')
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
@ -87,7 +89,7 @@ parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
||||
help='number of epochs to train (default: 2)')
|
||||
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
|
||||
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
||||
help='epoch interval to decay LR')
|
||||
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
|
||||
help='epochs to warmup LR, if scheduler supports')
|
||||
@ -208,6 +210,7 @@ def main():
|
||||
pretrained=args.pretrained,
|
||||
num_classes=args.num_classes,
|
||||
drop_rate=args.drop,
|
||||
drop_connect_rate=args.drop_connect,
|
||||
global_pool=args.gp,
|
||||
bn_tf=args.bn_tf,
|
||||
bn_momentum=args.bn_momentum,
|
||||
@ -253,7 +256,7 @@ def main():
|
||||
if args.local_rank == 0:
|
||||
logging.info('Restoring NVIDIA AMP state from checkpoint')
|
||||
amp.load_state_dict(resume_state['amp'])
|
||||
resume_state = None
|
||||
resume_state = None # clear it
|
||||
|
||||
model_ema = None
|
||||
if args.model_ema:
|
||||
|
Loading…
x
Reference in New Issue
Block a user