mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add 3-Augment support to auto_augment.py, clean up weighted choice handling, and allow adjust per op prob via arg string
This commit is contained in:
parent
e98c93264c
commit
e3b2f5be0a
@ -1,4 +1,4 @@
|
||||
""" AutoAugment, RandAugment, and AugMix for PyTorch
|
||||
""" AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch
|
||||
|
||||
This code implements the searched ImageNet policies with various tweaks and improvements and
|
||||
does not include any of the search code.
|
||||
@ -9,18 +9,24 @@ AA and RA Implementation adapted from:
|
||||
AugMix adapted from:
|
||||
https://github.com/google-research/augmix
|
||||
|
||||
3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md
|
||||
|
||||
Papers:
|
||||
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
|
||||
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
|
||||
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
||||
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
|
||||
3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118
|
||||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import random
|
||||
import math
|
||||
import re
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
@ -175,6 +181,24 @@ def sharpness(img, factor, **__):
|
||||
return ImageEnhance.Sharpness(img).enhance(factor)
|
||||
|
||||
|
||||
def gaussian_blur(img, factor, **__):
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
|
||||
return img
|
||||
|
||||
|
||||
def gaussian_blur_rand(img, factor, **__):
|
||||
radius_min = 0.1
|
||||
radius_max = 2.0
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
|
||||
return img
|
||||
|
||||
|
||||
def desaturate(img, factor, **_):
|
||||
factor = min(1., max(0., 1. - factor))
|
||||
# enhance factor 0 = grayscale, 1.0 = no-change
|
||||
return ImageEnhance.Color(img).enhance(factor)
|
||||
|
||||
|
||||
def _randomly_negate(v):
|
||||
"""With 50% prob, negate the value"""
|
||||
return -v if random.random() > 0.5 else v
|
||||
@ -200,6 +224,14 @@ def _enhance_increasing_level_to_arg(level, _hparams):
|
||||
return level,
|
||||
|
||||
|
||||
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
|
||||
level = (level / _LEVEL_DENOM)
|
||||
min_val + (max_val - min_val) * level
|
||||
if clamp:
|
||||
level = min(min_val, max(max_val, level))
|
||||
return level,
|
||||
|
||||
|
||||
def _shear_level_to_arg(level, _hparams):
|
||||
# range [-0.3, 0.3]
|
||||
level = (level / _LEVEL_DENOM) * 0.3
|
||||
@ -246,7 +278,7 @@ def _posterize_original_level_to_arg(level, _hparams):
|
||||
def _solarize_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _LEVEL_DENOM) * 256),
|
||||
return min(256, int((level / _LEVEL_DENOM) * 256)),
|
||||
|
||||
|
||||
def _solarize_increasing_level_to_arg(level, _hparams):
|
||||
@ -257,7 +289,7 @@ def _solarize_increasing_level_to_arg(level, _hparams):
|
||||
|
||||
def _solarize_add_level_to_arg(level, _hparams):
|
||||
# range [0, 110]
|
||||
return int((level / _LEVEL_DENOM) * 110),
|
||||
return min(128, int((level / _LEVEL_DENOM) * 110)),
|
||||
|
||||
|
||||
LEVEL_TO_ARG = {
|
||||
@ -286,6 +318,9 @@ LEVEL_TO_ARG = {
|
||||
'TranslateY': _translate_abs_level_to_arg,
|
||||
'TranslateXRel': _translate_rel_level_to_arg,
|
||||
'TranslateYRel': _translate_rel_level_to_arg,
|
||||
'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0),
|
||||
'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0),
|
||||
'GaussianBlurRand': _minmax_level_to_arg,
|
||||
}
|
||||
|
||||
|
||||
@ -314,6 +349,9 @@ NAME_TO_OP = {
|
||||
'TranslateY': translate_y_abs,
|
||||
'TranslateXRel': translate_x_rel,
|
||||
'TranslateYRel': translate_y_rel,
|
||||
'Desaturate': desaturate,
|
||||
'GaussianBlur': gaussian_blur,
|
||||
'GaussianBlurRand': gaussian_blur_rand,
|
||||
}
|
||||
|
||||
|
||||
@ -347,6 +385,7 @@ class AugmentOp:
|
||||
if self.magnitude_std > 0:
|
||||
# magnitude randomization enabled
|
||||
if self.magnitude_std == float('inf'):
|
||||
# inf == uniform sampling
|
||||
magnitude = random.uniform(0, magnitude)
|
||||
elif self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
@ -499,6 +538,16 @@ def auto_augment_policy_originalr(hparams):
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_3a(hparams):
|
||||
policy = [
|
||||
[('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude
|
||||
[('Desaturate', 1.0, 10)], # grayscale at 10 magnitude
|
||||
[('GaussianBlurRand', 1.0, 10)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy(name='v0', hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
if name == 'original':
|
||||
@ -509,6 +558,8 @@ def auto_augment_policy(name='v0', hparams=None):
|
||||
return auto_augment_policy_v0(hparams)
|
||||
elif name == 'v0r':
|
||||
return auto_augment_policy_v0r(hparams)
|
||||
elif name == '3a':
|
||||
return auto_augment_policy_3a(hparams)
|
||||
else:
|
||||
assert False, 'Unknown AA policy (%s)' % name
|
||||
|
||||
@ -534,19 +585,23 @@ class AutoAugment:
|
||||
return fs
|
||||
|
||||
|
||||
def auto_augment_transform(config_str, hparams):
|
||||
def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
|
||||
"""
|
||||
Create a AutoAugment transform
|
||||
|
||||
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
||||
The remaining sections, not order sepecific determine
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
||||
Args:
|
||||
config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
||||
dashes ('-').
|
||||
The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
||||
The remaining sections:
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
||||
|
||||
Returns:
|
||||
A PyTorch compatible Transform
|
||||
"""
|
||||
config = config_str.split('-')
|
||||
policy_name = config[0]
|
||||
@ -605,42 +660,80 @@ _RAND_INCREASING_TRANSFORMS = [
|
||||
]
|
||||
|
||||
|
||||
_RAND_3A = [
|
||||
'SolarizeIncreasing',
|
||||
'Desaturate',
|
||||
'GaussianBlur',
|
||||
]
|
||||
|
||||
|
||||
_RAND_CHOICE_3A = {
|
||||
'SolarizeIncreasing': 6,
|
||||
'Desaturate': 6,
|
||||
'GaussianBlur': 6,
|
||||
'Rotate': 3,
|
||||
'ShearX': 2,
|
||||
'ShearY': 2,
|
||||
'PosterizeIncreasing': 1,
|
||||
'AutoContrast': 1,
|
||||
'ColorIncreasing': 1,
|
||||
'SharpnessIncreasing': 1,
|
||||
'ContrastIncreasing': 1,
|
||||
'BrightnessIncreasing': 1,
|
||||
'Equalize': 1,
|
||||
'Invert': 1,
|
||||
}
|
||||
|
||||
|
||||
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
||||
# They may not result in increased performance, but could likely be tuned to so.
|
||||
_RAND_CHOICE_WEIGHTS_0 = {
|
||||
'Rotate': 0.3,
|
||||
'ShearX': 0.2,
|
||||
'ShearY': 0.2,
|
||||
'TranslateXRel': 0.1,
|
||||
'TranslateYRel': 0.1,
|
||||
'Color': .025,
|
||||
'Sharpness': 0.025,
|
||||
'AutoContrast': 0.025,
|
||||
'Solarize': .005,
|
||||
'SolarizeAdd': .005,
|
||||
'Contrast': .005,
|
||||
'Brightness': .005,
|
||||
'Equalize': .005,
|
||||
'Posterize': 0,
|
||||
'Invert': 0,
|
||||
'Rotate': 3,
|
||||
'ShearX': 2,
|
||||
'ShearY': 2,
|
||||
'TranslateXRel': 1,
|
||||
'TranslateYRel': 1,
|
||||
'ColorIncreasing': .25,
|
||||
'SharpnessIncreasing': 0.25,
|
||||
'AutoContrast': 0.25,
|
||||
'SolarizeIncreasing': .05,
|
||||
'SolarizeAdd': .05,
|
||||
'ContrastIncreasing': .05,
|
||||
'BrightnessIncreasing': .05,
|
||||
'Equalize': .05,
|
||||
'PosterizeIncreasing': 0.05,
|
||||
'Invert': 0.05,
|
||||
}
|
||||
|
||||
|
||||
def _select_rand_weights(weight_idx=0, transforms=None):
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
assert weight_idx == 0 # only one set of weights currently
|
||||
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
||||
probs = [rand_weights[k] for k in transforms]
|
||||
probs /= np.sum(probs)
|
||||
return probs
|
||||
def _get_weighted_transforms(transforms: Dict):
|
||||
transforms, probs = list(zip(*transforms.items()))
|
||||
probs = np.array(probs)
|
||||
probs = probs / np.sum(probs)
|
||||
return transforms, probs
|
||||
|
||||
|
||||
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
||||
def rand_augment_choices(name: str, increasing=True):
|
||||
if name == 'weights':
|
||||
return _RAND_CHOICE_WEIGHTS_0
|
||||
elif name == '3aw':
|
||||
return _RAND_CHOICE_3A
|
||||
elif name == '3a':
|
||||
return _RAND_3A
|
||||
else:
|
||||
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
|
||||
|
||||
|
||||
def rand_augment_ops(
|
||||
magnitude: Union[int, float] = 10,
|
||||
prob: float = 0.5,
|
||||
hparams: Optional[Dict] = None,
|
||||
transforms: Optional[Union[Dict, List]] = None,
|
||||
):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
return [AugmentOp(
|
||||
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
class RandAugment:
|
||||
@ -648,11 +741,16 @@ class RandAugment:
|
||||
self.ops = ops
|
||||
self.num_layers = num_layers
|
||||
self.choice_weights = choice_weights
|
||||
print(self.ops, self.choice_weights)
|
||||
|
||||
def __call__(self, img):
|
||||
# no replacement when using weighted choice
|
||||
ops = np.random.choice(
|
||||
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
|
||||
self.ops,
|
||||
self.num_layers,
|
||||
replace=self.choice_weights is None,
|
||||
p=self.choice_weights,
|
||||
)
|
||||
for op in ops:
|
||||
img = op(img)
|
||||
return img
|
||||
@ -665,61 +763,84 @@ class RandAugment:
|
||||
return fs
|
||||
|
||||
|
||||
def rand_augment_transform(config_str, hparams):
|
||||
def rand_augment_transform(
|
||||
config_str: str,
|
||||
hparams: Optional[Dict] = None,
|
||||
transforms: Optional[Union[str, Dict, List]] = None,
|
||||
):
|
||||
"""
|
||||
Create a RandAugment transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude of rand augment
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
||||
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
|
||||
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
|
||||
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
||||
Args:
|
||||
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
|
||||
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
|
||||
The remaining sections, not order sepecific determine
|
||||
'm' - integer magnitude of rand augment
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'p' - float probability of applying each layer (default 0.5)
|
||||
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
|
||||
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
|
||||
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
't' - str name of transform set to use
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
||||
hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
Returns:
|
||||
A PyTorch compatible Transform
|
||||
"""
|
||||
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
|
||||
num_layers = 2 # default to 2 ops per image
|
||||
weight_idx = None # default to no probability weights for op choice
|
||||
transforms = _RAND_TRANSFORMS
|
||||
increasing = False
|
||||
prob = 0.5
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'rand'
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param / randomization of magnitude values
|
||||
mstd = float(val)
|
||||
if mstd > 100:
|
||||
# use uniform sampling in 0 to magnitude if mstd is > 100
|
||||
mstd = float('inf')
|
||||
hparams.setdefault('magnitude_std', mstd)
|
||||
elif key == 'mmax':
|
||||
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
|
||||
hparams.setdefault('magnitude_max', int(val))
|
||||
elif key == 'inc':
|
||||
if bool(val):
|
||||
transforms = _RAND_INCREASING_TRANSFORMS
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
num_layers = int(val)
|
||||
elif key == 'w':
|
||||
weight_idx = int(val)
|
||||
if c.startswith('t'):
|
||||
# NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights'
|
||||
val = str(c[1:])
|
||||
if transforms is None:
|
||||
transforms = val
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
|
||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||
# numeric options
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param / randomization of magnitude values
|
||||
mstd = float(val)
|
||||
if mstd > 100:
|
||||
# use uniform sampling in 0 to magnitude if mstd is > 100
|
||||
mstd = float('inf')
|
||||
hparams.setdefault('magnitude_std', mstd)
|
||||
elif key == 'mmax':
|
||||
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
|
||||
hparams.setdefault('magnitude_max', int(val))
|
||||
elif key == 'inc':
|
||||
if bool(val):
|
||||
increasing = True
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
num_layers = int(val)
|
||||
elif key == 'p':
|
||||
prob = float(val)
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
|
||||
if isinstance(transforms, str):
|
||||
transforms = rand_augment_choices(transforms, increasing=increasing)
|
||||
elif transforms is None:
|
||||
transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
|
||||
|
||||
choice_weights = None
|
||||
if isinstance(transforms, Dict):
|
||||
transforms, choice_weights = _get_weighted_transforms(transforms)
|
||||
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms)
|
||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||
|
||||
|
||||
@ -740,11 +861,19 @@ _AUGMIX_TRANSFORMS = [
|
||||
]
|
||||
|
||||
|
||||
def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
||||
def augmix_ops(
|
||||
magnitude: Union[int, float] = 10,
|
||||
hparams: Optional[Dict] = None,
|
||||
transforms: Optional[Union[str, Dict, List]] = None,
|
||||
):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _AUGMIX_TRANSFORMS
|
||||
return [AugmentOp(
|
||||
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
name,
|
||||
prob=1.0,
|
||||
magnitude=magnitude,
|
||||
hparams=hparams
|
||||
) for name in transforms]
|
||||
|
||||
|
||||
class AugMixAugment:
|
||||
@ -820,22 +949,24 @@ class AugMixAugment:
|
||||
return fs
|
||||
|
||||
|
||||
def augment_and_mix_transform(config_str, hparams):
|
||||
def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None):
|
||||
""" Create AugMix PyTorch transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
||||
'w' - integer width of augmentation chain (default: 3)
|
||||
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
||||
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
||||
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
||||
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
||||
Args:
|
||||
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
|
||||
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
|
||||
The remaining sections, not order sepecific determine
|
||||
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
||||
'w' - integer width of augmentation chain (default: 3)
|
||||
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
||||
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
||||
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
||||
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the Augmentation transforms
|
||||
hparams: Other hparams (kwargs) for the Augmentation transforms
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
Returns:
|
||||
A PyTorch compatible Transform
|
||||
"""
|
||||
magnitude = 3
|
||||
width = 3
|
||||
|
@ -59,6 +59,7 @@ def transforms_imagenet_train(
|
||||
re_count=1,
|
||||
re_num_splits=0,
|
||||
separate=False,
|
||||
force_color_jitter=False,
|
||||
):
|
||||
"""
|
||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||
@ -77,8 +78,12 @@ def transforms_imagenet_train(
|
||||
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
|
||||
|
||||
secondary_tfl = []
|
||||
disable_color_jitter = False
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
# color jitter is typically disabled if AA/RA on,
|
||||
# this allows override without breaking old hparm cfgs
|
||||
disable_color_jitter = not (force_color_jitter or '3a' in auto_augment)
|
||||
if isinstance(img_size, (tuple, list)):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
@ -96,8 +101,9 @@ def transforms_imagenet_train(
|
||||
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
|
||||
elif color_jitter is not None:
|
||||
# color jitter is enabled when not using AA
|
||||
|
||||
if color_jitter is not None and not disable_color_jitter:
|
||||
# color jitter is enabled when not using AA or when forced
|
||||
if isinstance(color_jitter, (list, tuple)):
|
||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||
# or 4 if also augmenting hue
|
||||
|
Loading…
x
Reference in New Issue
Block a user