diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 49c4bc60..ee2240b4 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,8 +1,9 @@ from .constants import * from .config import resolve_data_config -from .dataset import Dataset, DatasetTar +from .dataset import Dataset, DatasetTar, AugMixDataset from .transforms import * -from .loader import create_loader, create_transform -from .mixup import mixup_target, FastCollateMixup +from .loader import create_loader +from .transforms_factory import create_transform +from .mixup import mixup_batch, FastCollateMixup from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index d730c266..e355eef5 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -1,19 +1,30 @@ -""" AutoAugment and RandAugment -Implementation adapted from: +""" AutoAugment, RandAugment, and AugMix for PyTorch + +This code implements the searched ImageNet policies with various tweaks and improvements and +does not include any of the search code. + +AA and RA Implementation adapted from: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py -Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719 + +AugMix adapted from: + https://github.com/google-research/augmix + +Papers: + AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 + Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 Hacked together by Ross Wightman """ import random import math import re -from PIL import Image, ImageOps, ImageEnhance +from PIL import Image, ImageOps, ImageEnhance, ImageChops import PIL import numpy as np - _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) _FILL = (128, 128, 128) @@ -178,6 +189,14 @@ def _enhance_level_to_arg(level, _hparams): return (level / _MAX_LEVEL) * 1.8 + 0.1, +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * .9 + level = 1.0 + _randomly_negate(level) + return level, + + def _shear_level_to_arg(level, _hparams): # range [-0.3, 0.3] level = (level / _MAX_LEVEL) * 0.3 @@ -192,36 +211,47 @@ def _translate_abs_level_to_arg(level, hparams): return level, -def _translate_rel_level_to_arg(level, _hparams): - # range [-0.45, 0.45] - level = (level / _MAX_LEVEL) * 0.45 +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get('translate_pct', 0.45) + level = (level / _MAX_LEVEL) * translate_pct level = _randomly_negate(level) return level, +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 4), + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return 4 - _posterize_level_to_arg(level, hparams)[0], + + def _posterize_original_level_to_arg(level, _hparams): # As per original AutoAugment paper description # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level return int((level / _MAX_LEVEL) * 4) + 4, -def _posterize_research_level_to_arg(level, _hparams): - # As per Tensorflow models research and UDA impl - # range [4, 0], 'keep 4 down to 0 MSB of original image' - return 4 - int((level / _MAX_LEVEL) * 4), - - -def _posterize_tpu_level_to_arg(level, _hparams): - # As per Tensorflow TPU EfficientNet impl - # range [0, 4], 'keep 0 up to 4 MSB of original image' - return int((level / _MAX_LEVEL) * 4), - - def _solarize_level_to_arg(level, _hparams): # range [0, 256] + # intensity/severity of augmentation decreases with level return int((level / _MAX_LEVEL) * 256), +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return 256 - _solarize_level_to_arg(level, _hparams)[0], + + def _solarize_add_level_to_arg(level, _hparams): # range [0, 110] return int((level / _MAX_LEVEL) * 110), @@ -233,15 +263,20 @@ LEVEL_TO_ARG = { 'Invert': None, 'Rotate': _rotate_level_to_arg, # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + 'Posterize': _posterize_level_to_arg, + 'PosterizeIncreasing': _posterize_increasing_level_to_arg, 'PosterizeOriginal': _posterize_original_level_to_arg, - 'PosterizeResearch': _posterize_research_level_to_arg, - 'PosterizeTpu': _posterize_tpu_level_to_arg, 'Solarize': _solarize_level_to_arg, + 'SolarizeIncreasing': _solarize_increasing_level_to_arg, 'SolarizeAdd': _solarize_add_level_to_arg, 'Color': _enhance_level_to_arg, + 'ColorIncreasing': _enhance_increasing_level_to_arg, 'Contrast': _enhance_level_to_arg, + 'ContrastIncreasing': _enhance_increasing_level_to_arg, 'Brightness': _enhance_level_to_arg, + 'BrightnessIncreasing': _enhance_increasing_level_to_arg, 'Sharpness': _enhance_level_to_arg, + 'SharpnessIncreasing': _enhance_increasing_level_to_arg, 'ShearX': _shear_level_to_arg, 'ShearY': _shear_level_to_arg, 'TranslateX': _translate_abs_level_to_arg, @@ -256,15 +291,20 @@ NAME_TO_OP = { 'Equalize': equalize, 'Invert': invert, 'Rotate': rotate, + 'Posterize': posterize, + 'PosterizeIncreasing': posterize, 'PosterizeOriginal': posterize, - 'PosterizeResearch': posterize, - 'PosterizeTpu': posterize, 'Solarize': solarize, + 'SolarizeIncreasing': solarize, 'SolarizeAdd': solarize_add, 'Color': color, + 'ColorIncreasing': color, 'Contrast': contrast, + 'ContrastIncreasing': contrast, 'Brightness': brightness, + 'BrightnessIncreasing': brightness, 'Sharpness': sharpness, + 'SharpnessIncreasing': sharpness, 'ShearX': shear_x, 'ShearY': shear_y, 'TranslateX': translate_x_abs, @@ -274,7 +314,7 @@ NAME_TO_OP = { } -class AutoAugmentOp: +class AugmentOp: def __init__(self, name, prob=0.5, magnitude=10, hparams=None): hparams = hparams or _HPARAMS_DEFAULT @@ -295,12 +335,12 @@ class AutoAugmentOp: self.magnitude_std = self.hparams.get('magnitude_std', 0) def __call__(self, img): - if random.random() > self.prob: + if self.prob < 1.0 and random.random() > self.prob: return img magnitude = self.magnitude if self.magnitude_std and self.magnitude_std > 0: magnitude = random.gauss(magnitude, self.magnitude_std) - magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() return self.aug_fn(img, *level_args, **self.kwargs) @@ -320,7 +360,7 @@ def auto_augment_policy_v0(hparams): [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)], - [('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], @@ -330,16 +370,17 @@ def auto_augment_policy_v0(hparams): [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], - [('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc def auto_augment_policy_v0r(hparams): - # ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize + # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used + # in Google research implementation (number of bits discarded increases with magnitude) policy = [ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], [('Color', 0.4, 9), ('Equalize', 0.6, 3)], @@ -353,7 +394,7 @@ def auto_augment_policy_v0r(hparams): [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)], - [('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], @@ -363,11 +404,11 @@ def auto_augment_policy_v0r(hparams): [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], - [('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)], + [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)], [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc @@ -400,23 +441,23 @@ def auto_augment_policy_original(hparams): [('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc def auto_augment_policy_originalr(hparams): # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation policy = [ - [('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)], + [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)], [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], - [('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)], + [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)], [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], - [('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)], + [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)], [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], - [('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)], + [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)], [('Rotate', 0.8, 8), ('Color', 0.4, 0)], [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], @@ -433,7 +474,7 @@ def auto_augment_policy_originalr(hparams): [('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc @@ -499,7 +540,7 @@ _RAND_TRANSFORMS = [ 'Equalize', 'Invert', 'Rotate', - 'PosterizeTpu', + 'Posterize', 'Solarize', 'SolarizeAdd', 'Color', @@ -510,10 +551,31 @@ _RAND_TRANSFORMS = [ 'ShearY', 'TranslateXRel', 'TranslateYRel', - #'Cutout' # FIXME I implement this as random erasing separately + #'Cutout' # NOTE I've implement this as random erasing separately ] +_RAND_INCREASING_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'SolarizeAdd', + 'ColorIncreasing', + 'ContrastIncreasing', + 'BrightnessIncreasing', + 'SharpnessIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # NOTE I've implement this as random erasing separately +] + + + # These experimental weights are based loosely on the relative improvements mentioned in paper. # They may not result in increased performance, but could likely be tuned to so. _RAND_CHOICE_WEIGHTS_0 = { @@ -530,7 +592,7 @@ _RAND_CHOICE_WEIGHTS_0 = { 'Contrast': .005, 'Brightness': .005, 'Equalize': .005, - 'PosterizeTpu': 0, + 'Posterize': 0, 'Invert': 0, } @@ -547,7 +609,7 @@ def _select_rand_weights(weight_idx=0, transforms=None): def rand_augment_ops(magnitude=10, hparams=None, transforms=None): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS - return [AutoAugmentOp( + return [AugmentOp( name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] @@ -577,6 +639,7 @@ def rand_augment_transform(config_str, hparams): 'n' - integer num layers (number of transform ops selected per image) 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 @@ -587,6 +650,7 @@ def rand_augment_transform(config_str, hparams): magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) num_layers = 2 # default to 2 ops per image weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS config = config_str.split('-') assert config[0] == 'rand' config = config[1:] @@ -598,6 +662,9 @@ def rand_augment_transform(config_str, hparams): if key == 'mstd': # noise param injected via hparams for now hparams.setdefault('magnitude_std', float(val)) + elif key == 'inc': + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS elif key == 'm': magnitude = int(val) elif key == 'n': @@ -606,6 +673,145 @@ def rand_augment_transform(config_str, hparams): weight_idx = int(val) else: assert False, 'Unknown RandAugment config section' - ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) + + +_AUGMIX_TRANSFORMS = [ + 'AutoContrast', + 'ColorIncreasing', # not in paper + 'ContrastIncreasing', # not in paper + 'BrightnessIncreasing', # not in paper + 'SharpnessIncreasing', # not in paper + 'Equalize', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', +] + + +def augmix_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _AUGMIX_TRANSFORMS + return [AugmentOp( + name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class AugMixAugment: + """ AugMix Transform + Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + """ + def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): + self.ops = ops + self.alpha = alpha + self.width = width + self.depth = depth + self.blended = blended # blended mode is faster but not well tested + + def _calc_blended_weights(self, ws, m): + ws = ws * m + cump = 1. + rws = [] + for w in ws[::-1]: + alpha = w / cump + cump *= (1 - alpha) + rws.append(alpha) + return np.array(rws[::-1], dtype=np.float32) + + def _apply_blended(self, img, mixing_weights, m): + # This is my first crack and implementing a slightly faster mixed augmentation. Instead + # of accumulating the mix for each chain in a Numpy array and then blending with original, + # it recomputes the blending coefficients and applies one PIL image blend per chain. + # TODO the results appear in the right ballpark but they differ by more than rounding. + img_orig = img.copy() + ws = self._calc_blended_weights(mixing_weights, m) + for w in ws: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img_orig # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + img = Image.blend(img, img_aug, w) + return img + + def _apply_basic(self, img, mixing_weights, m): + # This is a literal adaptation of the paper/official implementation without normalizations and + # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the + # typical augmentation transforms, could use a GPU / Kornia implementation. + img_shape = img.size[0], img.size[1], len(img.getbands()) + mixed = np.zeros(img_shape, dtype=np.float32) + for mw in mixing_weights: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + mixed += mw * np.asarray(img_aug, dtype=np.float32) + np.clip(mixed, 0, 255., out=mixed) + mixed = Image.fromarray(mixed.astype(np.uint8)) + return Image.blend(img, mixed, m) + + def __call__(self, img): + mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) + m = np.float32(np.random.beta(self.alpha, self.alpha)) + if self.blended: + mixed = self._apply_blended(img, mixing_weights, m) + else: + mixed = self._apply_basic(img, mixing_weights, m) + return mixed + + +def augment_and_mix_transform(config_str, hparams): + """ Create AugMix PyTorch transform + + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude (severity) of augmentation mix (default: 3) + 'w' - integer width of augmentation chain (default: 3) + 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) + 'mstd' - float std deviation of magnitude noise applied (default: 0) + Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 + + :param hparams: Other hparams (kwargs) for the Augmentation transforms + + :return: A PyTorch compatible Transform + """ + magnitude = 3 + width = 3 + depth = -1 + alpha = 1. + blended = False + config = config_str.split('-') + assert config[0] == 'augmix' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'w': + width = int(val) + elif key == 'd': + depth = int(val) + elif key == 'a': + alpha = float(val) + elif key == 'b': + blended = bool(val) + else: + assert False, 'Unknown AugMix config section' + ops = augmix_ops(magnitude=magnitude, hparams=hparams) + return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 47437d5e..fc252d9e 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -140,3 +140,42 @@ class DatasetTar(data.Dataset): def __len__(self): return len(self.imgs) + +class AugMixDataset(torch.utils.data.Dataset): + """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" + + def __init__(self, dataset, num_splits=2): + self.augmentation = None + self.normalize = None + self.dataset = dataset + if self.dataset.transform is not None: + self._set_transforms(self.dataset.transform) + self.num_splits = num_splits + + def _set_transforms(self, x): + assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' + self.dataset.transform = x[0] + self.augmentation = x[1] + self.normalize = x[2] + + @property + def transform(self): + return self.dataset.transform + + @transform.setter + def transform(self, x): + self._set_transforms(x) + + def _normalize(self, x): + return x if self.normalize is None else self.normalize(x) + + def __getitem__(self, i): + x, y = self.dataset[i] # all splits share the same dataset base transform + x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) + # run the full augmentation on the remaining splits + for _ in range(self.num_splits - 1): + x_list.append(self._normalize(self.augmentation(x))) + return tuple(x_list), y + + def __len__(self): + return len(self.dataset) diff --git a/timm/data/loader.py b/timm/data/loader.py index bbb71eca..f3faf7b9 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -1,29 +1,59 @@ import torch.utils.data -from .transforms import * +import numpy as np + +from .transforms_factory import create_transform +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .distributed_sampler import OrderedDistributedSampler +from .random_erasing import RandomErasing from .mixup import FastCollateMixup def fast_collate(batch): - targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) - batch_size = len(targets) - tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) - for i in range(batch_size): - tensor[i] += torch.from_numpy(batch[i][0]) - - return tensor, targets + """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" + assert isinstance(batch[0], tuple) + batch_size = len(batch) + if isinstance(batch[0][0], tuple): + # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position + # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position + inner_tuple_size = len(batch[0][0]) + flattened_batch_size = batch_size * inner_tuple_size + targets = torch.zeros(flattened_batch_size, dtype=torch.int64) + tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length + for j in range(inner_tuple_size): + targets[i + j * batch_size] = batch[i][1] + tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) + return tensor, targets + elif isinstance(batch[0][0], np.ndarray): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i] += torch.from_numpy(batch[i][0]) + return tensor, targets + elif isinstance(batch[0][0], torch.Tensor): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i].copy_(batch[i][0]) + return tensor, targets + else: + assert False class PrefetchLoader: def __init__(self, - loader, - rand_erase_prob=0., - rand_erase_mode='const', - rand_erase_count=1, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - fp16=False): + loader, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + fp16=False, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0): self.loader = loader self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) @@ -31,9 +61,9 @@ class PrefetchLoader: if fp16: self.mean = self.mean.half() self.std = self.std.half() - if rand_erase_prob > 0.: + if re_prob > 0.: self.random_erasing = RandomErasing( - probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count) + probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) else: self.random_erasing = None @@ -87,60 +117,19 @@ class PrefetchLoader: self.loader.collate_fn.mixup_enabled = x -def create_transform( - input_size, - is_training=False, - use_prefetcher=False, - color_jitter=0.4, - auto_augment=None, - interpolation='bilinear', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - crop_pct=None, - tf_preprocessing=False): - - if isinstance(input_size, tuple): - img_size = input_size[-2:] - else: - img_size = input_size - - if tf_preprocessing and use_prefetcher: - from timm.data.tf_preprocessing import TfPreprocessTransform - transform = TfPreprocessTransform( - is_training=is_training, size=img_size, interpolation=interpolation) - else: - if is_training: - transform = transforms_imagenet_train( - img_size, - color_jitter=color_jitter, - auto_augment=auto_augment, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std) - else: - transform = transforms_imagenet_eval( - img_size, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std, - crop_pct=crop_pct) - - return transform - - def create_loader( dataset, input_size, batch_size, is_training=False, use_prefetcher=True, - rand_erase_prob=0., - rand_erase_mode='const', - rand_erase_count=1, + re_prob=0., + re_mode='const', + re_count=1, + re_split=False, color_jitter=0.4, auto_augment=None, + num_aug_splits=0, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, @@ -152,6 +141,10 @@ def create_loader( fp16=False, tf_preprocessing=False, ): + re_num_splits = 0 + if re_split: + # apply RE to second half of batch if no aug split otherwise line up with aug split + re_num_splits = num_aug_splits or 2 dataset.transform = create_transform( input_size, is_training=is_training, @@ -163,6 +156,11 @@ def create_loader( std=std, crop_pct=crop_pct, tf_preprocessing=tf_preprocessing, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, + separate=num_aug_splits > 0, ) sampler = None @@ -190,11 +188,13 @@ def create_loader( if use_prefetcher: loader = PrefetchLoader( loader, - rand_erase_prob=rand_erase_prob if is_training else 0., - rand_erase_mode=rand_erase_mode, - rand_erase_count=rand_erase_count, mean=mean, std=std, - fp16=fp16) + fp16=fp16, + re_prob=re_prob if is_training else 0., + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits + ) return loader diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 83d51ccb..4678472d 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -15,6 +15,15 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): return lam*y1 + (1. - lam)*y2 +def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): + lam = 1. + if not disable: + lam = np.random.beta(alpha, alpha) + input = input.mul(lam).add_(1 - lam, input.flip(0)) + target = mixup_target(target, num_classes, lam, smoothing) + return input, target + + class FastCollateMixup: def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 5eed1387..589b2f0b 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -23,13 +23,13 @@ class RandomErasing: This variant of RandomErasing is intended to be applied to either a batch or single image tensor after it has been normalized by dataset mean and std. Args: - probability: The probability that the Random Erasing operation will be performed. - sl: Minimum proportion of erased area against input image. - sh: Maximum proportion of erased area against input image. + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. min_aspect: Minimum aspect ratio of erased area. mode: pixel color mode, one of 'const', 'rand', or 'pixel' 'const' - erase block is constant color of 0 for all channels - 'rand' - erase block is same per-cannel random (normal) color + 'rand' - erase block is same per-channel random (normal) color 'pixel' - erase block is per-pixel random (normal) color max_count: maximum number of erasing blocks per image, area per box is scaled by count. per-image count is randomly chosen between 1 and this value. @@ -37,14 +37,16 @@ class RandomErasing: def __init__( self, - probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, - mode='const', max_count=1, device='cuda'): + probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, + mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): self.probability = probability - self.sl = sl - self.sh = sh - self.min_aspect = min_aspect - self.min_count = 1 - self.max_count = max_count + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits mode = mode.lower() self.rand_color = False self.per_pixel = False @@ -64,9 +66,8 @@ class RandomErasing: random.randint(self.min_count, self.max_count) for _ in range(count): for attempt in range(10): - target_area = random.uniform(self.sl, self.sh) * area / count - log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect)) - aspect_ratio = math.exp(random.uniform(*log_ratio)) + target_area = random.uniform(self.min_area, self.max_area) * area / count + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < img_w and h < img_h: @@ -82,6 +83,8 @@ class RandomErasing: self._erase(input, *input.size(), input.dtype) else: batch_size, chan, img_h, img_w = input.size() - for i in range(batch_size): + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + for i in range(batch_start, batch_size): self._erase(input[i], chan, img_h, img_w, input.dtype) return input diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 41f2a63e..b3b08e30 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -1,5 +1,4 @@ import torch -from torchvision import transforms import torchvision.transforms.functional as F from PIL import Image import warnings @@ -7,10 +6,6 @@ import math import random import numpy as np -from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .random_erasing import RandomErasing -from .auto_augment import auto_augment_transform, rand_augment_transform - class ToNumpy: @@ -161,97 +156,3 @@ class RandomResizedCropAndInterpolation: return format_string -def transforms_imagenet_train( - img_size=224, - scale=(0.08, 1.0), - color_jitter=0.4, - auto_augment=None, - interpolation='random', - random_erasing=0.4, - random_erasing_mode='const', - use_prefetcher=False, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD -): - tfl = [ - RandomResizedCropAndInterpolation( - img_size, scale=scale, interpolation=interpolation), - transforms.RandomHorizontalFlip() - ] - if auto_augment: - assert isinstance(auto_augment, str) - if isinstance(img_size, tuple): - img_size_min = min(img_size) - else: - img_size_min = img_size - aa_params = dict( - translate_const=int(img_size_min * 0.45), - img_mean=tuple([min(255, round(255 * x)) for x in mean]), - ) - if interpolation and interpolation != 'random': - aa_params['interpolation'] = _pil_interp(interpolation) - if auto_augment.startswith('rand'): - tfl += [rand_augment_transform(auto_augment, aa_params)] - else: - tfl += [auto_augment_transform(auto_augment, aa_params)] - else: - # color jitter is enabled when not using AA - if isinstance(color_jitter, (list, tuple)): - # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation - # or 4 if also augmenting hue - assert len(color_jitter) in (3, 4) - else: - # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue - color_jitter = (float(color_jitter),) * 3 - tfl += [transforms.ColorJitter(*color_jitter)] - - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] - else: - tfl += [ - transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) - ] - if random_erasing > 0.: - tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) - return transforms.Compose(tfl) - - -def transforms_imagenet_eval( - img_size=224, - crop_pct=None, - interpolation='bilinear', - use_prefetcher=False, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD): - crop_pct = crop_pct or DEFAULT_CROP_PCT - - if isinstance(img_size, tuple): - assert len(img_size) == 2 - if img_size[-1] == img_size[-2]: - # fall-back to older behaviour so Resize scales to shortest edge if target is square - scale_size = int(math.floor(img_size[0] / crop_pct)) - else: - scale_size = tuple([int(x / crop_pct) for x in img_size]) - else: - scale_size = int(math.floor(img_size / crop_pct)) - - tfl = [ - transforms.Resize(scale_size, _pil_interp(interpolation)), - transforms.CenterCrop(img_size), - ] - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] - else: - tfl += [ - transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) - ] - - return transforms.Compose(tfl) diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py new file mode 100644 index 00000000..767dd157 --- /dev/null +++ b/timm/data/transforms_factory.py @@ -0,0 +1,184 @@ +""" Transforms Factory +Factory methods for building image transforms for use with TIMM (PyTorch Image Models) +""" +import math + +import torch +from torchvision import transforms + +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform +from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor +from timm.data.random_erasing import RandomErasing + + +def transforms_imagenet_train( + img_size=224, + scale=(0.08, 1.0), + color_jitter=0.4, + auto_augment=None, + interpolation='random', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, + separate=False, +): + """ + If separate==True, the transforms are returned as a tuple of 3 separate transforms + for use in a mixing dataset that passes + * all data through the first (primary) transform, called the 'clean' data + * a portion of the data through the secondary transform + * normalizes and converts the branches above with the third, final transform + """ + primary_tfl = [ + RandomResizedCropAndInterpolation( + img_size, scale=scale, interpolation=interpolation), + transforms.RandomHorizontalFlip() + ] + + secondary_tfl = [] + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = dict( + translate_const=int(img_size_min * 0.45), + img_mean=tuple([min(255, round(255 * x)) for x in mean]), + ) + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] + elif auto_augment.startswith('augmix'): + aa_params['translate_pct'] = 0.3 + secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] + else: + secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] + elif color_jitter is not None: + # color jitter is enabled when not using AA + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter),) * 3 + secondary_tfl += [transforms.ColorJitter(*color_jitter)] + + final_tfl = [] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + final_tfl += [ToNumpy()] + else: + final_tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + if re_prob > 0.: + final_tfl.append( + RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) + + if separate: + return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) + else: + return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) + + +def transforms_imagenet_eval( + img_size=224, + crop_pct=None, + interpolation='bilinear', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD): + crop_pct = crop_pct or DEFAULT_CROP_PCT + + if isinstance(img_size, tuple): + assert len(img_size) == 2 + if img_size[-1] == img_size[-2]: + # fall-back to older behaviour so Resize scales to shortest edge if target is square + scale_size = int(math.floor(img_size[0] / crop_pct)) + else: + scale_size = tuple([int(x / crop_pct) for x in img_size]) + else: + scale_size = int(math.floor(img_size / crop_pct)) + + tfl = [ + transforms.Resize(scale_size, _pil_interp(interpolation)), + transforms.CenterCrop(img_size), + ] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + + return transforms.Compose(tfl) + + +def create_transform( + input_size, + is_training=False, + use_prefetcher=False, + color_jitter=0.4, + auto_augment=None, + interpolation='bilinear', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, + crop_pct=None, + tf_preprocessing=False, + separate=False): + + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if tf_preprocessing and use_prefetcher: + assert not separate, "Separate transforms not supported for TF preprocessing" + from timm.data.tf_preprocessing import TfPreprocessTransform + transform = TfPreprocessTransform( + is_training=is_training, size=img_size, interpolation=interpolation) + else: + if is_training: + transform = transforms_imagenet_train( + img_size, + color_jitter=color_jitter, + auto_augment=auto_augment, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, + separate=separate) + else: + assert not separate, "Separate transforms not supported for validation preprocessing" + transform = transforms_imagenet_eval( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + crop_pct=crop_pct) + + return transform diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py index f436ccc7..b781472f 100644 --- a/timm/loss/__init__.py +++ b/timm/loss/__init__.py @@ -1 +1,2 @@ from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from .jsd import JsdCrossEntropy \ No newline at end of file diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py new file mode 100644 index 00000000..0f8eb696 --- /dev/null +++ b/timm/loss/jsd.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .cross_entropy import LabelSmoothingCrossEntropy + + +class JsdCrossEntropy(nn.Module): + """ Jensen-Shannon Divergence + Cross-Entropy Loss + + Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + + Hacked together by Ross Wightman + """ + def __init__(self, num_splits=3, alpha=12, smoothing=0.1): + super().__init__() + self.num_splits = num_splits + self.alpha = alpha + if smoothing is not None and smoothing > 0: + self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) + else: + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def __call__(self, output, target): + split_size = output.shape[0] // self.num_splits + assert split_size * self.num_splits == output.shape[0] + logits_split = torch.split(output, split_size) + + # Cross-entropy is only computed on clean images + loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) + probs = [F.softmax(logits, dim=1) for logits in logits_split] + + # Clamp mixture distribution to avoid exploding KL divergence + logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() + loss += self.alpha * sum([F.kl_div( + logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) + return loss diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 3d85eb92..0fa4d210 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -21,3 +21,4 @@ from .registry import * from .factory import create_model from .helpers import load_checkpoint, resume_checkpoint from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .split_batchnorm import convert_splitbn_model diff --git a/timm/models/split_batchnorm.py b/timm/models/split_batchnorm.py new file mode 100644 index 00000000..ad01cfeb --- /dev/null +++ b/timm/models/split_batchnorm.py @@ -0,0 +1,75 @@ +""" Split BatchNorm + +A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through +a separate BN layer. The first split is passed through the parent BN layers with weight/bias +keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' +namespace. + +This allows easily removing the auxiliary BN layers after training to efficiently +achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, +'Disentangled Learning via An Auxiliary BN' + +Hacked together by Ross Wightman +""" +import torch +import torch.nn as nn + + +class SplitBatchNorm2d(torch.nn.BatchNorm2d): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, num_splits=2): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' + self.num_splits = num_splits + self.aux_bn = nn.ModuleList([ + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + + def forward(self, input: torch.Tensor): + if self.training: # aux BN only relevant while training + split_size = input.shape[0] // self.num_splits + assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" + split_input = input.split(split_size) + x = [super().forward(split_input[0])] + for i, a in enumerate(self.aux_bn): + x.append(a(split_input[i + 1])) + return torch.cat(x, dim=0) + else: + return super().forward(input) + + +def convert_splitbn_model(module, num_splits=2): + """ + Recursively traverse module and its children to replace all instances of + ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. + Args: + module (torch.nn.Module): input module + num_splits: number of separate batchnorm layers to split input across + Example:: + >>> # model is an instance of torch.nn.Module + >>> model = timm.models.convert_splitbn_model(model, num_splits=2) + """ + mod = module + if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): + return module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + mod = SplitBatchNorm2d( + module.num_features, module.eps, module.momentum, module.affine, + module.track_running_stats, num_splits=num_splits) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + mod.num_batches_tracked = module.num_batches_tracked + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + for aux in mod.aux_bn: + aux.running_mean = module.running_mean.clone() + aux.running_var = module.running_var.clone() + aux.num_batches_tracked = module.num_batches_tracked.clone() + if module.affine: + aux.weight.data = module.weight.data.clone().detach() + aux.bias.data = module.bias.data.clone().detach() + for name, child in module.named_children(): + mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) + del module + return mod diff --git a/train.py b/train.py index 558c29ac..e3eec357 100755 --- a/train.py +++ b/train.py @@ -16,7 +16,6 @@ Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse import time -import logging import yaml from datetime import datetime @@ -29,10 +28,10 @@ except ImportError: from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False -from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target -from timm.models import create_model, resume_checkpoint +from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset +from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.utils import * -from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer from timm.scheduler import create_scheduler @@ -84,6 +83,8 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', help='Dropout rate (default: 0.)') parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP', help='Drop connect rate (default: 0.)') +parser.add_argument('--jsd', action='store_true', default=False, + help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') @@ -119,18 +120,24 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), +parser.add_argument('--aug-splits', type=int, default=0, + help='Number of augmentation splits (default: 0, valid: 0 or >=2)') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') +parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='label smoothing (default: 0.1)') +parser.add_argument('--train-interpolation', type=str, default='random', + help='Training interpolation (random, bilinear, bicubic default: "random")') # Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-tf', action='store_true', default=False, help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') @@ -142,6 +149,8 @@ parser.add_argument('--sync-bn', action='store_true', help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') parser.add_argument('--dist-bn', type=str, default='', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') +parser.add_argument('--split-bn', action='store_true', + help='Enable separate BN layers per augmentation split.') # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') @@ -244,6 +253,15 @@ def main(): data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + num_aug_splits = 0 + if args.aug_splits > 0: + assert args.aug_splits > 1, 'A split of 1 makes no sense' + num_aug_splits = args.aug_splits + + if args.split_bn: + assert num_aug_splits > 1 or args.resplit + model = convert_splitbn_model(model, max(num_aug_splits, 2)) + if args.num_gpu > 1: if args.amp: logging.warning( @@ -290,6 +308,7 @@ def main(): if args.distributed: if args.sync_bn: + assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) @@ -328,20 +347,26 @@ def main(): collate_fn = None if args.prefetcher and args.mixup > 0: + assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) + if num_aug_splits > 1: + dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, - rand_erase_prob=args.reprob, - rand_erase_mode=args.remode, - rand_erase_count=args.recount, + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, - interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], + num_aug_splits=num_aug_splits, + interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, @@ -373,7 +398,11 @@ def main(): pin_memory=args.pin_mem, ) - if args.mixup > 0.: + if args.jsd: + assert num_aug_splits > 1 # JSD only valid with aug splits set + train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + elif args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() @@ -471,11 +500,10 @@ def train_epoch( if not args.prefetcher: input, target = input.cuda(), target.cuda() if args.mixup > 0.: - lam = 1. - if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: - lam = np.random.beta(args.mixup, args.mixup) - input = input.mul(lam).add_(1 - lam, input.flip(0)) - target = mixup_target(target, args.num_classes, lam, args.smoothing) + input, target = mixup_batch( + input, target, + alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing, + disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch) output = model(input)