mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #74 from rwightman/augmix-jsd
AugMix, JSD loss, SplitBatchNorm (Auxiliary BN), and more
This commit is contained in:
commit
d9a6a9d0af
@ -1,8 +1,9 @@
|
||||
from .constants import *
|
||||
from .config import resolve_data_config
|
||||
from .dataset import Dataset, DatasetTar
|
||||
from .dataset import Dataset, DatasetTar, AugMixDataset
|
||||
from .transforms import *
|
||||
from .loader import create_loader, create_transform
|
||||
from .mixup import mixup_target, FastCollateMixup
|
||||
from .loader import create_loader
|
||||
from .transforms_factory import create_transform
|
||||
from .mixup import mixup_batch, FastCollateMixup
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
|
@ -1,19 +1,30 @@
|
||||
""" AutoAugment and RandAugment
|
||||
Implementation adapted from:
|
||||
""" AutoAugment, RandAugment, and AugMix for PyTorch
|
||||
|
||||
This code implements the searched ImageNet policies with various tweaks and improvements and
|
||||
does not include any of the search code.
|
||||
|
||||
AA and RA Implementation adapted from:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
||||
Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
|
||||
|
||||
AugMix adapted from:
|
||||
https://github.com/google-research/augmix
|
||||
|
||||
Papers:
|
||||
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
|
||||
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
|
||||
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
||||
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import random
|
||||
import math
|
||||
import re
|
||||
from PIL import Image, ImageOps, ImageEnhance
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||
|
||||
_FILL = (128, 128, 128)
|
||||
@ -178,6 +189,14 @@ def _enhance_level_to_arg(level, _hparams):
|
||||
return (level / _MAX_LEVEL) * 1.8 + 0.1,
|
||||
|
||||
|
||||
def _enhance_increasing_level_to_arg(level, _hparams):
|
||||
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
|
||||
# range [0.1, 1.9]
|
||||
level = (level / _MAX_LEVEL) * .9
|
||||
level = 1.0 + _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _shear_level_to_arg(level, _hparams):
|
||||
# range [-0.3, 0.3]
|
||||
level = (level / _MAX_LEVEL) * 0.3
|
||||
@ -192,36 +211,47 @@ def _translate_abs_level_to_arg(level, hparams):
|
||||
return level,
|
||||
|
||||
|
||||
def _translate_rel_level_to_arg(level, _hparams):
|
||||
# range [-0.45, 0.45]
|
||||
level = (level / _MAX_LEVEL) * 0.45
|
||||
def _translate_rel_level_to_arg(level, hparams):
|
||||
# default range [-0.45, 0.45]
|
||||
translate_pct = hparams.get('translate_pct', 0.45)
|
||||
level = (level / _MAX_LEVEL) * translate_pct
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _posterize_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow TPU EfficientNet impl
|
||||
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _MAX_LEVEL) * 4),
|
||||
|
||||
|
||||
def _posterize_increasing_level_to_arg(level, hparams):
|
||||
# As per Tensorflow models research and UDA impl
|
||||
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
||||
# intensity/severity of augmentation increases with level
|
||||
return 4 - _posterize_level_to_arg(level, hparams)[0],
|
||||
|
||||
|
||||
def _posterize_original_level_to_arg(level, _hparams):
|
||||
# As per original AutoAugment paper description
|
||||
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _MAX_LEVEL) * 4) + 4,
|
||||
|
||||
|
||||
def _posterize_research_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow models research and UDA impl
|
||||
# range [4, 0], 'keep 4 down to 0 MSB of original image'
|
||||
return 4 - int((level / _MAX_LEVEL) * 4),
|
||||
|
||||
|
||||
def _posterize_tpu_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow TPU EfficientNet impl
|
||||
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
||||
return int((level / _MAX_LEVEL) * 4),
|
||||
|
||||
|
||||
def _solarize_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _MAX_LEVEL) * 256),
|
||||
|
||||
|
||||
def _solarize_increasing_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation increases with level
|
||||
return 256 - _solarize_level_to_arg(level, _hparams)[0],
|
||||
|
||||
|
||||
def _solarize_add_level_to_arg(level, _hparams):
|
||||
# range [0, 110]
|
||||
return int((level / _MAX_LEVEL) * 110),
|
||||
@ -233,15 +263,20 @@ LEVEL_TO_ARG = {
|
||||
'Invert': None,
|
||||
'Rotate': _rotate_level_to_arg,
|
||||
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
||||
'Posterize': _posterize_level_to_arg,
|
||||
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
|
||||
'PosterizeOriginal': _posterize_original_level_to_arg,
|
||||
'PosterizeResearch': _posterize_research_level_to_arg,
|
||||
'PosterizeTpu': _posterize_tpu_level_to_arg,
|
||||
'Solarize': _solarize_level_to_arg,
|
||||
'SolarizeIncreasing': _solarize_increasing_level_to_arg,
|
||||
'SolarizeAdd': _solarize_add_level_to_arg,
|
||||
'Color': _enhance_level_to_arg,
|
||||
'ColorIncreasing': _enhance_increasing_level_to_arg,
|
||||
'Contrast': _enhance_level_to_arg,
|
||||
'ContrastIncreasing': _enhance_increasing_level_to_arg,
|
||||
'Brightness': _enhance_level_to_arg,
|
||||
'BrightnessIncreasing': _enhance_increasing_level_to_arg,
|
||||
'Sharpness': _enhance_level_to_arg,
|
||||
'SharpnessIncreasing': _enhance_increasing_level_to_arg,
|
||||
'ShearX': _shear_level_to_arg,
|
||||
'ShearY': _shear_level_to_arg,
|
||||
'TranslateX': _translate_abs_level_to_arg,
|
||||
@ -256,15 +291,20 @@ NAME_TO_OP = {
|
||||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'Rotate': rotate,
|
||||
'Posterize': posterize,
|
||||
'PosterizeIncreasing': posterize,
|
||||
'PosterizeOriginal': posterize,
|
||||
'PosterizeResearch': posterize,
|
||||
'PosterizeTpu': posterize,
|
||||
'Solarize': solarize,
|
||||
'SolarizeIncreasing': solarize,
|
||||
'SolarizeAdd': solarize_add,
|
||||
'Color': color,
|
||||
'ColorIncreasing': color,
|
||||
'Contrast': contrast,
|
||||
'ContrastIncreasing': contrast,
|
||||
'Brightness': brightness,
|
||||
'BrightnessIncreasing': brightness,
|
||||
'Sharpness': sharpness,
|
||||
'SharpnessIncreasing': sharpness,
|
||||
'ShearX': shear_x,
|
||||
'ShearY': shear_y,
|
||||
'TranslateX': translate_x_abs,
|
||||
@ -274,7 +314,7 @@ NAME_TO_OP = {
|
||||
}
|
||||
|
||||
|
||||
class AutoAugmentOp:
|
||||
class AugmentOp:
|
||||
|
||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
@ -295,12 +335,12 @@ class AutoAugmentOp:
|
||||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() > self.prob:
|
||||
if self.prob < 1.0 and random.random() > self.prob:
|
||||
return img
|
||||
magnitude = self.magnitude
|
||||
if self.magnitude_std and self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
||||
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
||||
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||
|
||||
@ -320,7 +360,7 @@ def auto_augment_policy_v0(hparams):
|
||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
@ -330,16 +370,17 @@ def auto_augment_policy_v0(hparams):
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
||||
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_v0r(hparams):
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
|
||||
# in Google research implementation (number of bits discarded increases with magnitude)
|
||||
policy = [
|
||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||
@ -353,7 +394,7 @@ def auto_augment_policy_v0r(hparams):
|
||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
@ -363,11 +404,11 @@ def auto_augment_policy_v0r(hparams):
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
@ -400,23 +441,23 @@ def auto_augment_policy_original(hparams):
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_originalr(hparams):
|
||||
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
|
||||
policy = [
|
||||
[('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
[('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)],
|
||||
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||
[('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
|
||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||
@ -433,7 +474,7 @@ def auto_augment_policy_originalr(hparams):
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
@ -499,7 +540,7 @@ _RAND_TRANSFORMS = [
|
||||
'Equalize',
|
||||
'Invert',
|
||||
'Rotate',
|
||||
'PosterizeTpu',
|
||||
'Posterize',
|
||||
'Solarize',
|
||||
'SolarizeAdd',
|
||||
'Color',
|
||||
@ -510,10 +551,31 @@ _RAND_TRANSFORMS = [
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # FIXME I implement this as random erasing separately
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
|
||||
_RAND_INCREASING_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'Equalize',
|
||||
'Invert',
|
||||
'Rotate',
|
||||
'PosterizeIncreasing',
|
||||
'SolarizeIncreasing',
|
||||
'SolarizeAdd',
|
||||
'ColorIncreasing',
|
||||
'ContrastIncreasing',
|
||||
'BrightnessIncreasing',
|
||||
'SharpnessIncreasing',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
|
||||
|
||||
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
||||
# They may not result in increased performance, but could likely be tuned to so.
|
||||
_RAND_CHOICE_WEIGHTS_0 = {
|
||||
@ -530,7 +592,7 @@ _RAND_CHOICE_WEIGHTS_0 = {
|
||||
'Contrast': .005,
|
||||
'Brightness': .005,
|
||||
'Equalize': .005,
|
||||
'PosterizeTpu': 0,
|
||||
'Posterize': 0,
|
||||
'Invert': 0,
|
||||
}
|
||||
|
||||
@ -547,7 +609,7 @@ def _select_rand_weights(weight_idx=0, transforms=None):
|
||||
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
return [AutoAugmentOp(
|
||||
return [AugmentOp(
|
||||
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
@ -577,6 +639,7 @@ def rand_augment_transform(config_str, hparams):
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
||||
|
||||
@ -587,6 +650,7 @@ def rand_augment_transform(config_str, hparams):
|
||||
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
||||
num_layers = 2 # default to 2 ops per image
|
||||
weight_idx = None # default to no probability weights for op choice
|
||||
transforms = _RAND_TRANSFORMS
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'rand'
|
||||
config = config[1:]
|
||||
@ -598,6 +662,9 @@ def rand_augment_transform(config_str, hparams):
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
elif key == 'inc':
|
||||
if bool(val):
|
||||
transforms = _RAND_INCREASING_TRANSFORMS
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
@ -606,6 +673,145 @@ def rand_augment_transform(config_str, hparams):
|
||||
weight_idx = int(val)
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
|
||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||
|
||||
|
||||
_AUGMIX_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'ColorIncreasing', # not in paper
|
||||
'ContrastIncreasing', # not in paper
|
||||
'BrightnessIncreasing', # not in paper
|
||||
'SharpnessIncreasing', # not in paper
|
||||
'Equalize',
|
||||
'Rotate',
|
||||
'PosterizeIncreasing',
|
||||
'SolarizeIncreasing',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
]
|
||||
|
||||
|
||||
def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _AUGMIX_TRANSFORMS
|
||||
return [AugmentOp(
|
||||
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
class AugMixAugment:
|
||||
""" AugMix Transform
|
||||
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
||||
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
||||
https://arxiv.org/abs/1912.02781
|
||||
"""
|
||||
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
||||
self.ops = ops
|
||||
self.alpha = alpha
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.blended = blended # blended mode is faster but not well tested
|
||||
|
||||
def _calc_blended_weights(self, ws, m):
|
||||
ws = ws * m
|
||||
cump = 1.
|
||||
rws = []
|
||||
for w in ws[::-1]:
|
||||
alpha = w / cump
|
||||
cump *= (1 - alpha)
|
||||
rws.append(alpha)
|
||||
return np.array(rws[::-1], dtype=np.float32)
|
||||
|
||||
def _apply_blended(self, img, mixing_weights, m):
|
||||
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
|
||||
# of accumulating the mix for each chain in a Numpy array and then blending with original,
|
||||
# it recomputes the blending coefficients and applies one PIL image blend per chain.
|
||||
# TODO the results appear in the right ballpark but they differ by more than rounding.
|
||||
img_orig = img.copy()
|
||||
ws = self._calc_blended_weights(mixing_weights, m)
|
||||
for w in ws:
|
||||
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||
ops = np.random.choice(self.ops, depth, replace=True)
|
||||
img_aug = img_orig # no ops are in-place, deep copy not necessary
|
||||
for op in ops:
|
||||
img_aug = op(img_aug)
|
||||
img = Image.blend(img, img_aug, w)
|
||||
return img
|
||||
|
||||
def _apply_basic(self, img, mixing_weights, m):
|
||||
# This is a literal adaptation of the paper/official implementation without normalizations and
|
||||
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
|
||||
# typical augmentation transforms, could use a GPU / Kornia implementation.
|
||||
img_shape = img.size[0], img.size[1], len(img.getbands())
|
||||
mixed = np.zeros(img_shape, dtype=np.float32)
|
||||
for mw in mixing_weights:
|
||||
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||
ops = np.random.choice(self.ops, depth, replace=True)
|
||||
img_aug = img # no ops are in-place, deep copy not necessary
|
||||
for op in ops:
|
||||
img_aug = op(img_aug)
|
||||
mixed += mw * np.asarray(img_aug, dtype=np.float32)
|
||||
np.clip(mixed, 0, 255., out=mixed)
|
||||
mixed = Image.fromarray(mixed.astype(np.uint8))
|
||||
return Image.blend(img, mixed, m)
|
||||
|
||||
def __call__(self, img):
|
||||
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
|
||||
m = np.float32(np.random.beta(self.alpha, self.alpha))
|
||||
if self.blended:
|
||||
mixed = self._apply_blended(img, mixing_weights, m)
|
||||
else:
|
||||
mixed = self._apply_basic(img, mixing_weights, m)
|
||||
return mixed
|
||||
|
||||
|
||||
def augment_and_mix_transform(config_str, hparams):
|
||||
""" Create AugMix PyTorch transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
||||
'w' - integer width of augmentation chain (default: 3)
|
||||
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
||||
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
||||
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
||||
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the Augmentation transforms
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
"""
|
||||
magnitude = 3
|
||||
width = 3
|
||||
depth = -1
|
||||
alpha = 1.
|
||||
blended = False
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'augmix'
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'w':
|
||||
width = int(val)
|
||||
elif key == 'd':
|
||||
depth = int(val)
|
||||
elif key == 'a':
|
||||
alpha = float(val)
|
||||
elif key == 'b':
|
||||
blended = bool(val)
|
||||
else:
|
||||
assert False, 'Unknown AugMix config section'
|
||||
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
|
||||
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
|
||||
|
@ -140,3 +140,42 @@ class DatasetTar(data.Dataset):
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
||||
class AugMixDataset(torch.utils.data.Dataset):
|
||||
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
||||
|
||||
def __init__(self, dataset, num_splits=2):
|
||||
self.augmentation = None
|
||||
self.normalize = None
|
||||
self.dataset = dataset
|
||||
if self.dataset.transform is not None:
|
||||
self._set_transforms(self.dataset.transform)
|
||||
self.num_splits = num_splits
|
||||
|
||||
def _set_transforms(self, x):
|
||||
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
||||
self.dataset.transform = x[0]
|
||||
self.augmentation = x[1]
|
||||
self.normalize = x[2]
|
||||
|
||||
@property
|
||||
def transform(self):
|
||||
return self.dataset.transform
|
||||
|
||||
@transform.setter
|
||||
def transform(self, x):
|
||||
self._set_transforms(x)
|
||||
|
||||
def _normalize(self, x):
|
||||
return x if self.normalize is None else self.normalize(x)
|
||||
|
||||
def __getitem__(self, i):
|
||||
x, y = self.dataset[i] # all splits share the same dataset base transform
|
||||
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
|
||||
# run the full augmentation on the remaining splits
|
||||
for _ in range(self.num_splits - 1):
|
||||
x_list.append(self._normalize(self.augmentation(x)))
|
||||
return tuple(x_list), y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
@ -1,29 +1,59 @@
|
||||
import torch.utils.data
|
||||
from .transforms import *
|
||||
import numpy as np
|
||||
|
||||
from .transforms_factory import create_transform
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .distributed_sampler import OrderedDistributedSampler
|
||||
from .random_erasing import RandomErasing
|
||||
from .mixup import FastCollateMixup
|
||||
|
||||
|
||||
def fast_collate(batch):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
batch_size = len(targets)
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i] += torch.from_numpy(batch[i][0])
|
||||
|
||||
return tensor, targets
|
||||
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
||||
assert isinstance(batch[0], tuple)
|
||||
batch_size = len(batch)
|
||||
if isinstance(batch[0][0], tuple):
|
||||
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
||||
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
||||
inner_tuple_size = len(batch[0][0])
|
||||
flattened_batch_size = batch_size * inner_tuple_size
|
||||
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
||||
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
||||
for j in range(inner_tuple_size):
|
||||
targets[i + j * batch_size] = batch[i][1]
|
||||
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
||||
return tensor, targets
|
||||
elif isinstance(batch[0][0], np.ndarray):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
assert len(targets) == batch_size
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i] += torch.from_numpy(batch[i][0])
|
||||
return tensor, targets
|
||||
elif isinstance(batch[0][0], torch.Tensor):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
assert len(targets) == batch_size
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i].copy_(batch[i][0])
|
||||
return tensor, targets
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
class PrefetchLoader:
|
||||
|
||||
def __init__(self,
|
||||
loader,
|
||||
rand_erase_prob=0.,
|
||||
rand_erase_mode='const',
|
||||
rand_erase_count=1,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
fp16=False):
|
||||
loader,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
fp16=False,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0):
|
||||
self.loader = loader
|
||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||
@ -31,9 +61,9 @@ class PrefetchLoader:
|
||||
if fp16:
|
||||
self.mean = self.mean.half()
|
||||
self.std = self.std.half()
|
||||
if rand_erase_prob > 0.:
|
||||
if re_prob > 0.:
|
||||
self.random_erasing = RandomErasing(
|
||||
probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
|
||||
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
|
||||
@ -87,60 +117,19 @@ class PrefetchLoader:
|
||||
self.loader.collate_fn.mixup_enabled = x
|
||||
|
||||
|
||||
def create_transform(
|
||||
input_size,
|
||||
is_training=False,
|
||||
use_prefetcher=False,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
crop_pct=None,
|
||||
tf_preprocessing=False):
|
||||
|
||||
if isinstance(input_size, tuple):
|
||||
img_size = input_size[-2:]
|
||||
else:
|
||||
img_size = input_size
|
||||
|
||||
if tf_preprocessing and use_prefetcher:
|
||||
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||||
transform = TfPreprocessTransform(
|
||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
||||
else:
|
||||
if is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
img_size,
|
||||
color_jitter=color_jitter,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std)
|
||||
else:
|
||||
transform = transforms_imagenet_eval(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
crop_pct=crop_pct)
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
def create_loader(
|
||||
dataset,
|
||||
input_size,
|
||||
batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
rand_erase_prob=0.,
|
||||
rand_erase_mode='const',
|
||||
rand_erase_count=1,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_split=False,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
num_aug_splits=0,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
@ -152,6 +141,10 @@ def create_loader(
|
||||
fp16=False,
|
||||
tf_preprocessing=False,
|
||||
):
|
||||
re_num_splits = 0
|
||||
if re_split:
|
||||
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
||||
re_num_splits = num_aug_splits or 2
|
||||
dataset.transform = create_transform(
|
||||
input_size,
|
||||
is_training=is_training,
|
||||
@ -163,6 +156,11 @@ def create_loader(
|
||||
std=std,
|
||||
crop_pct=crop_pct,
|
||||
tf_preprocessing=tf_preprocessing,
|
||||
re_prob=re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits,
|
||||
separate=num_aug_splits > 0,
|
||||
)
|
||||
|
||||
sampler = None
|
||||
@ -190,11 +188,13 @@ def create_loader(
|
||||
if use_prefetcher:
|
||||
loader = PrefetchLoader(
|
||||
loader,
|
||||
rand_erase_prob=rand_erase_prob if is_training else 0.,
|
||||
rand_erase_mode=rand_erase_mode,
|
||||
rand_erase_count=rand_erase_count,
|
||||
mean=mean,
|
||||
std=std,
|
||||
fp16=fp16)
|
||||
fp16=fp16,
|
||||
re_prob=re_prob if is_training else 0.,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits
|
||||
)
|
||||
|
||||
return loader
|
||||
|
@ -15,6 +15,15 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
||||
return lam*y1 + (1. - lam)*y2
|
||||
|
||||
|
||||
def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
|
||||
lam = 1.
|
||||
if not disable:
|
||||
lam = np.random.beta(alpha, alpha)
|
||||
input = input.mul(lam).add_(1 - lam, input.flip(0))
|
||||
target = mixup_target(target, num_classes, lam, smoothing)
|
||||
return input, target
|
||||
|
||||
|
||||
class FastCollateMixup:
|
||||
|
||||
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
|
||||
|
@ -23,13 +23,13 @@ class RandomErasing:
|
||||
This variant of RandomErasing is intended to be applied to either a batch
|
||||
or single image tensor after it has been normalized by dataset mean and std.
|
||||
Args:
|
||||
probability: The probability that the Random Erasing operation will be performed.
|
||||
sl: Minimum proportion of erased area against input image.
|
||||
sh: Maximum proportion of erased area against input image.
|
||||
probability: Probability that the Random Erasing operation will be performed.
|
||||
min_area: Minimum percentage of erased area wrt input image area.
|
||||
max_area: Maximum percentage of erased area wrt input image area.
|
||||
min_aspect: Minimum aspect ratio of erased area.
|
||||
mode: pixel color mode, one of 'const', 'rand', or 'pixel'
|
||||
'const' - erase block is constant color of 0 for all channels
|
||||
'rand' - erase block is same per-cannel random (normal) color
|
||||
'rand' - erase block is same per-channel random (normal) color
|
||||
'pixel' - erase block is per-pixel random (normal) color
|
||||
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
|
||||
per-image count is randomly chosen between 1 and this value.
|
||||
@ -37,14 +37,16 @@ class RandomErasing:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||
mode='const', max_count=1, device='cuda'):
|
||||
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
|
||||
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
|
||||
self.probability = probability
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.min_aspect = min_aspect
|
||||
self.min_count = 1
|
||||
self.max_count = max_count
|
||||
self.min_area = min_area
|
||||
self.max_area = max_area
|
||||
max_aspect = max_aspect or 1 / min_aspect
|
||||
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||
self.min_count = min_count
|
||||
self.max_count = max_count or min_count
|
||||
self.num_splits = num_splits
|
||||
mode = mode.lower()
|
||||
self.rand_color = False
|
||||
self.per_pixel = False
|
||||
@ -64,9 +66,8 @@ class RandomErasing:
|
||||
random.randint(self.min_count, self.max_count)
|
||||
for _ in range(count):
|
||||
for attempt in range(10):
|
||||
target_area = random.uniform(self.sl, self.sh) * area / count
|
||||
log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
target_area = random.uniform(self.min_area, self.max_area) * area / count
|
||||
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if w < img_w and h < img_h:
|
||||
@ -82,6 +83,8 @@ class RandomErasing:
|
||||
self._erase(input, *input.size(), input.dtype)
|
||||
else:
|
||||
batch_size, chan, img_h, img_w = input.size()
|
||||
for i in range(batch_size):
|
||||
# skip first slice of batch if num_splits is set (for clean portion of samples)
|
||||
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
|
||||
for i in range(batch_start, batch_size):
|
||||
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
||||
return input
|
||||
|
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image
|
||||
import warnings
|
||||
@ -7,10 +6,6 @@ import math
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .random_erasing import RandomErasing
|
||||
from .auto_augment import auto_augment_transform, rand_augment_transform
|
||||
|
||||
|
||||
class ToNumpy:
|
||||
|
||||
@ -161,97 +156,3 @@ class RandomResizedCropAndInterpolation:
|
||||
return format_string
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.08, 1.0),
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='random',
|
||||
random_erasing=0.4,
|
||||
random_erasing_mode='const',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD
|
||||
):
|
||||
tfl = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size, scale=scale, interpolation=interpolation),
|
||||
transforms.RandomHorizontalFlip()
|
||||
]
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
if isinstance(img_size, tuple):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
img_size_min = img_size
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
||||
)
|
||||
if interpolation and interpolation != 'random':
|
||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
||||
if auto_augment.startswith('rand'):
|
||||
tfl += [rand_augment_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
tfl += [auto_augment_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
# color jitter is enabled when not using AA
|
||||
if isinstance(color_jitter, (list, tuple)):
|
||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||
# or 4 if also augmenting hue
|
||||
assert len(color_jitter) in (3, 4)
|
||||
else:
|
||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||
color_jitter = (float(color_jitter),) * 3
|
||||
tfl += [transforms.ColorJitter(*color_jitter)]
|
||||
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
if random_erasing > 0.:
|
||||
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
def transforms_imagenet_eval(
|
||||
img_size=224,
|
||||
crop_pct=None,
|
||||
interpolation='bilinear',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD):
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
|
||||
if isinstance(img_size, tuple):
|
||||
assert len(img_size) == 2
|
||||
if img_size[-1] == img_size[-2]:
|
||||
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
||||
scale_size = int(math.floor(img_size[0] / crop_pct))
|
||||
else:
|
||||
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
||||
else:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
|
||||
tfl = [
|
||||
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
||||
transforms.CenterCrop(img_size),
|
||||
]
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
|
184
timm/data/transforms_factory.py
Normal file
184
timm/data/transforms_factory.py
Normal file
@ -0,0 +1,184 @@
|
||||
""" Transforms Factory
|
||||
Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
|
||||
from timm.data.random_erasing import RandomErasing
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.08, 1.0),
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='random',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0,
|
||||
separate=False,
|
||||
):
|
||||
"""
|
||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||
for use in a mixing dataset that passes
|
||||
* all data through the first (primary) transform, called the 'clean' data
|
||||
* a portion of the data through the secondary transform
|
||||
* normalizes and converts the branches above with the third, final transform
|
||||
"""
|
||||
primary_tfl = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size, scale=scale, interpolation=interpolation),
|
||||
transforms.RandomHorizontalFlip()
|
||||
]
|
||||
|
||||
secondary_tfl = []
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
if isinstance(img_size, tuple):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
img_size_min = img_size
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
||||
)
|
||||
if interpolation and interpolation != 'random':
|
||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
||||
if auto_augment.startswith('rand'):
|
||||
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
|
||||
elif auto_augment.startswith('augmix'):
|
||||
aa_params['translate_pct'] = 0.3
|
||||
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
|
||||
elif color_jitter is not None:
|
||||
# color jitter is enabled when not using AA
|
||||
if isinstance(color_jitter, (list, tuple)):
|
||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||
# or 4 if also augmenting hue
|
||||
assert len(color_jitter) in (3, 4)
|
||||
else:
|
||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||
color_jitter = (float(color_jitter),) * 3
|
||||
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
||||
|
||||
final_tfl = []
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
final_tfl += [ToNumpy()]
|
||||
else:
|
||||
final_tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
if re_prob > 0.:
|
||||
final_tfl.append(
|
||||
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
|
||||
|
||||
if separate:
|
||||
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||
else:
|
||||
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
|
||||
|
||||
|
||||
def transforms_imagenet_eval(
|
||||
img_size=224,
|
||||
crop_pct=None,
|
||||
interpolation='bilinear',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD):
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
|
||||
if isinstance(img_size, tuple):
|
||||
assert len(img_size) == 2
|
||||
if img_size[-1] == img_size[-2]:
|
||||
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
||||
scale_size = int(math.floor(img_size[0] / crop_pct))
|
||||
else:
|
||||
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
||||
else:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
|
||||
tfl = [
|
||||
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
||||
transforms.CenterCrop(img_size),
|
||||
]
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
def create_transform(
|
||||
input_size,
|
||||
is_training=False,
|
||||
use_prefetcher=False,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0,
|
||||
crop_pct=None,
|
||||
tf_preprocessing=False,
|
||||
separate=False):
|
||||
|
||||
if isinstance(input_size, tuple):
|
||||
img_size = input_size[-2:]
|
||||
else:
|
||||
img_size = input_size
|
||||
|
||||
if tf_preprocessing and use_prefetcher:
|
||||
assert not separate, "Separate transforms not supported for TF preprocessing"
|
||||
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||||
transform = TfPreprocessTransform(
|
||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
||||
else:
|
||||
if is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
img_size,
|
||||
color_jitter=color_jitter,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
re_prob=re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits,
|
||||
separate=separate)
|
||||
else:
|
||||
assert not separate, "Separate transforms not supported for validation preprocessing"
|
||||
transform = transforms_imagenet_eval(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
crop_pct=crop_pct)
|
||||
|
||||
return transform
|
@ -1 +1,2 @@
|
||||
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from .jsd import JsdCrossEntropy
|
39
timm/loss/jsd.py
Normal file
39
timm/loss/jsd.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .cross_entropy import LabelSmoothingCrossEntropy
|
||||
|
||||
|
||||
class JsdCrossEntropy(nn.Module):
|
||||
""" Jensen-Shannon Divergence + Cross-Entropy Loss
|
||||
|
||||
Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
||||
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
||||
https://arxiv.org/abs/1912.02781
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
||||
super().__init__()
|
||||
self.num_splits = num_splits
|
||||
self.alpha = alpha
|
||||
if smoothing is not None and smoothing > 0:
|
||||
self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
|
||||
else:
|
||||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def __call__(self, output, target):
|
||||
split_size = output.shape[0] // self.num_splits
|
||||
assert split_size * self.num_splits == output.shape[0]
|
||||
logits_split = torch.split(output, split_size)
|
||||
|
||||
# Cross-entropy is only computed on clean images
|
||||
loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
|
||||
probs = [F.softmax(logits, dim=1) for logits in logits_split]
|
||||
|
||||
# Clamp mixture distribution to avoid exploding KL divergence
|
||||
logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
|
||||
loss += self.alpha * sum([F.kl_div(
|
||||
logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
|
||||
return loss
|
@ -21,3 +21,4 @@ from .registry import *
|
||||
from .factory import create_model
|
||||
from .helpers import load_checkpoint, resume_checkpoint
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .split_batchnorm import convert_splitbn_model
|
||||
|
75
timm/models/split_batchnorm.py
Normal file
75
timm/models/split_batchnorm.py
Normal file
@ -0,0 +1,75 @@
|
||||
""" Split BatchNorm
|
||||
|
||||
A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
|
||||
a separate BN layer. The first split is passed through the parent BN layers with weight/bias
|
||||
keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
|
||||
namespace.
|
||||
|
||||
This allows easily removing the auxiliary BN layers after training to efficiently
|
||||
achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
|
||||
'Disentangled Learning via An Auxiliary BN'
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
|
||||
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
||||
track_running_stats=True, num_splits=2):
|
||||
super().__init__(num_features, eps, momentum, affine, track_running_stats)
|
||||
assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
|
||||
self.num_splits = num_splits
|
||||
self.aux_bn = nn.ModuleList([
|
||||
nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
if self.training: # aux BN only relevant while training
|
||||
split_size = input.shape[0] // self.num_splits
|
||||
assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
|
||||
split_input = input.split(split_size)
|
||||
x = [super().forward(split_input[0])]
|
||||
for i, a in enumerate(self.aux_bn):
|
||||
x.append(a(split_input[i + 1]))
|
||||
return torch.cat(x, dim=0)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
def convert_splitbn_model(module, num_splits=2):
|
||||
"""
|
||||
Recursively traverse module and its children to replace all instances of
|
||||
``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
|
||||
Args:
|
||||
module (torch.nn.Module): input module
|
||||
num_splits: number of separate batchnorm layers to split input across
|
||||
Example::
|
||||
>>> # model is an instance of torch.nn.Module
|
||||
>>> model = timm.models.convert_splitbn_model(model, num_splits=2)
|
||||
"""
|
||||
mod = module
|
||||
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
|
||||
return module
|
||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
||||
mod = SplitBatchNorm2d(
|
||||
module.num_features, module.eps, module.momentum, module.affine,
|
||||
module.track_running_stats, num_splits=num_splits)
|
||||
mod.running_mean = module.running_mean
|
||||
mod.running_var = module.running_var
|
||||
mod.num_batches_tracked = module.num_batches_tracked
|
||||
if module.affine:
|
||||
mod.weight.data = module.weight.data.clone().detach()
|
||||
mod.bias.data = module.bias.data.clone().detach()
|
||||
for aux in mod.aux_bn:
|
||||
aux.running_mean = module.running_mean.clone()
|
||||
aux.running_var = module.running_var.clone()
|
||||
aux.num_batches_tracked = module.num_batches_tracked.clone()
|
||||
if module.affine:
|
||||
aux.weight.data = module.weight.data.clone().detach()
|
||||
aux.bias.data = module.bias.data.clone().detach()
|
||||
for name, child in module.named_children():
|
||||
mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
|
||||
del module
|
||||
return mod
|
56
train.py
56
train.py
@ -16,7 +16,6 @@ Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
|
||||
@ -29,10 +28,10 @@ except ImportError:
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
has_apex = False
|
||||
|
||||
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
||||
from timm.models import create_model, resume_checkpoint
|
||||
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset
|
||||
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
||||
from timm.utils import *
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||
from timm.optim import create_optimizer
|
||||
from timm.scheduler import create_scheduler
|
||||
|
||||
@ -84,6 +83,8 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
|
||||
help='Drop connect rate (default: 0.)')
|
||||
parser.add_argument('--jsd', action='store_true', default=False,
|
||||
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
@ -119,18 +120,24 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
||||
help='Color jitter factor (default: 0.4)')
|
||||
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
||||
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
||||
parser.add_argument('--aug-splits', type=int, default=0,
|
||||
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
|
||||
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||
help='Random erase prob (default: 0.)')
|
||||
parser.add_argument('--remode', type=str, default='const',
|
||||
help='Random erase mode (default: "const")')
|
||||
parser.add_argument('--recount', type=int, default=1,
|
||||
help='Random erase count (default: 1)')
|
||||
parser.add_argument('--resplit', action='store_true', default=False,
|
||||
help='Do not random erase first (clean) augmentation split')
|
||||
parser.add_argument('--mixup', type=float, default=0.0,
|
||||
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
|
||||
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
||||
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='label smoothing (default: 0.1)')
|
||||
parser.add_argument('--train-interpolation', type=str, default='random',
|
||||
help='Training interpolation (random, bilinear, bicubic default: "random")')
|
||||
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
||||
parser.add_argument('--bn-tf', action='store_true', default=False,
|
||||
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
||||
@ -142,6 +149,8 @@ parser.add_argument('--sync-bn', action='store_true',
|
||||
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
|
||||
parser.add_argument('--dist-bn', type=str, default='',
|
||||
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
|
||||
parser.add_argument('--split-bn', action='store_true',
|
||||
help='Enable separate BN layers per augmentation split.')
|
||||
# Model Exponential Moving Average
|
||||
parser.add_argument('--model-ema', action='store_true', default=False,
|
||||
help='Enable tracking moving average of model weights')
|
||||
@ -244,6 +253,15 @@ def main():
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||
|
||||
num_aug_splits = 0
|
||||
if args.aug_splits > 0:
|
||||
assert args.aug_splits > 1, 'A split of 1 makes no sense'
|
||||
num_aug_splits = args.aug_splits
|
||||
|
||||
if args.split_bn:
|
||||
assert num_aug_splits > 1 or args.resplit
|
||||
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||
|
||||
if args.num_gpu > 1:
|
||||
if args.amp:
|
||||
logging.warning(
|
||||
@ -290,6 +308,7 @@ def main():
|
||||
|
||||
if args.distributed:
|
||||
if args.sync_bn:
|
||||
assert not args.split_bn
|
||||
try:
|
||||
if has_apex:
|
||||
model = convert_syncbn_model(model)
|
||||
@ -328,20 +347,26 @@ def main():
|
||||
|
||||
collate_fn = None
|
||||
if args.prefetcher and args.mixup > 0:
|
||||
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
|
||||
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
|
||||
|
||||
if num_aug_splits > 1:
|
||||
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
||||
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
is_training=True,
|
||||
use_prefetcher=args.prefetcher,
|
||||
rand_erase_prob=args.reprob,
|
||||
rand_erase_mode=args.remode,
|
||||
rand_erase_count=args.recount,
|
||||
re_prob=args.reprob,
|
||||
re_mode=args.remode,
|
||||
re_count=args.recount,
|
||||
re_split=args.resplit,
|
||||
color_jitter=args.color_jitter,
|
||||
auto_augment=args.aa,
|
||||
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
||||
num_aug_splits=num_aug_splits,
|
||||
interpolation=args.train_interpolation,
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
@ -373,7 +398,11 @@ def main():
|
||||
pin_memory=args.pin_mem,
|
||||
)
|
||||
|
||||
if args.mixup > 0.:
|
||||
if args.jsd:
|
||||
assert num_aug_splits > 1 # JSD only valid with aug splits set
|
||||
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
elif args.mixup > 0.:
|
||||
# smoothing is handled with mixup label transform
|
||||
train_loss_fn = SoftTargetCrossEntropy().cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
@ -471,11 +500,10 @@ def train_epoch(
|
||||
if not args.prefetcher:
|
||||
input, target = input.cuda(), target.cuda()
|
||||
if args.mixup > 0.:
|
||||
lam = 1.
|
||||
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
|
||||
lam = np.random.beta(args.mixup, args.mixup)
|
||||
input = input.mul(lam).add_(1 - lam, input.flip(0))
|
||||
target = mixup_target(target, args.num_classes, lam, args.smoothing)
|
||||
input, target = mixup_batch(
|
||||
input, target,
|
||||
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
|
||||
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
|
||||
|
||||
output = model(input)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user