Working on an implementation of AugMix with JensenShannonDivergence loss that's compatible with my AutoAugment and RandAugment impl
parent
ff8688ca3d
commit
232ab7fb12
|
@ -2,7 +2,8 @@ from .constants import *
|
||||||
from .config import resolve_data_config
|
from .config import resolve_data_config
|
||||||
from .dataset import Dataset, DatasetTar
|
from .dataset import Dataset, DatasetTar
|
||||||
from .transforms import *
|
from .transforms import *
|
||||||
from .loader import create_loader, create_transform
|
from .loader import create_loader
|
||||||
from .mixup import mixup_target, FastCollateMixup
|
from .transforms_factory import create_transform
|
||||||
|
from .mixup import mixup_batch, FastCollateMixup
|
||||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||||
rand_augment_transform, auto_augment_transform
|
rand_augment_transform, auto_augment_transform
|
||||||
|
|
|
@ -8,12 +8,11 @@ Hacked together by Ross Wightman
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from PIL import Image, ImageOps, ImageEnhance
|
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
||||||
import PIL
|
import PIL
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||||
|
|
||||||
_FILL = (128, 128, 128)
|
_FILL = (128, 128, 128)
|
||||||
|
@ -192,36 +191,47 @@ def _translate_abs_level_to_arg(level, hparams):
|
||||||
return level,
|
return level,
|
||||||
|
|
||||||
|
|
||||||
def _translate_rel_level_to_arg(level, _hparams):
|
def _translate_rel_level_to_arg(level, hparams):
|
||||||
# range [-0.45, 0.45]
|
# default range [-0.45, 0.45]
|
||||||
level = (level / _MAX_LEVEL) * 0.45
|
translate_pct = hparams.get('translate_pct', 0.45)
|
||||||
|
level = (level / _MAX_LEVEL) * translate_pct
|
||||||
level = _randomly_negate(level)
|
level = _randomly_negate(level)
|
||||||
return level,
|
return level,
|
||||||
|
|
||||||
|
|
||||||
|
def _posterize_level_to_arg(level, _hparams):
|
||||||
|
# As per Tensorflow TPU EfficientNet impl
|
||||||
|
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
||||||
|
# intensity/severity of augmentation decreases with level
|
||||||
|
return int((level / _MAX_LEVEL) * 4),
|
||||||
|
|
||||||
|
|
||||||
|
def _posterize_increasing_level_to_arg(level, hparams):
|
||||||
|
# As per Tensorflow models research and UDA impl
|
||||||
|
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
||||||
|
# intensity/severity of augmentation increases with level
|
||||||
|
return 4 - _posterize_level_to_arg(level, hparams)[0],
|
||||||
|
|
||||||
|
|
||||||
def _posterize_original_level_to_arg(level, _hparams):
|
def _posterize_original_level_to_arg(level, _hparams):
|
||||||
# As per original AutoAugment paper description
|
# As per original AutoAugment paper description
|
||||||
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
||||||
|
# intensity/severity of augmentation decreases with level
|
||||||
return int((level / _MAX_LEVEL) * 4) + 4,
|
return int((level / _MAX_LEVEL) * 4) + 4,
|
||||||
|
|
||||||
|
|
||||||
def _posterize_research_level_to_arg(level, _hparams):
|
|
||||||
# As per Tensorflow models research and UDA impl
|
|
||||||
# range [4, 0], 'keep 4 down to 0 MSB of original image'
|
|
||||||
return 4 - int((level / _MAX_LEVEL) * 4),
|
|
||||||
|
|
||||||
|
|
||||||
def _posterize_tpu_level_to_arg(level, _hparams):
|
|
||||||
# As per Tensorflow TPU EfficientNet impl
|
|
||||||
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
|
||||||
return int((level / _MAX_LEVEL) * 4),
|
|
||||||
|
|
||||||
|
|
||||||
def _solarize_level_to_arg(level, _hparams):
|
def _solarize_level_to_arg(level, _hparams):
|
||||||
# range [0, 256]
|
# range [0, 256]
|
||||||
|
# intensity/severity of augmentation decreases with level
|
||||||
return int((level / _MAX_LEVEL) * 256),
|
return int((level / _MAX_LEVEL) * 256),
|
||||||
|
|
||||||
|
|
||||||
|
def _solarize_increasing_level_to_arg(level, _hparams):
|
||||||
|
# range [0, 256]
|
||||||
|
# intensity/severity of augmentation increases with level
|
||||||
|
return 256 - _solarize_level_to_arg(level, _hparams)[0],
|
||||||
|
|
||||||
|
|
||||||
def _solarize_add_level_to_arg(level, _hparams):
|
def _solarize_add_level_to_arg(level, _hparams):
|
||||||
# range [0, 110]
|
# range [0, 110]
|
||||||
return int((level / _MAX_LEVEL) * 110),
|
return int((level / _MAX_LEVEL) * 110),
|
||||||
|
@ -233,10 +243,11 @@ LEVEL_TO_ARG = {
|
||||||
'Invert': None,
|
'Invert': None,
|
||||||
'Rotate': _rotate_level_to_arg,
|
'Rotate': _rotate_level_to_arg,
|
||||||
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
||||||
|
'Posterize': _posterize_level_to_arg,
|
||||||
|
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
|
||||||
'PosterizeOriginal': _posterize_original_level_to_arg,
|
'PosterizeOriginal': _posterize_original_level_to_arg,
|
||||||
'PosterizeResearch': _posterize_research_level_to_arg,
|
|
||||||
'PosterizeTpu': _posterize_tpu_level_to_arg,
|
|
||||||
'Solarize': _solarize_level_to_arg,
|
'Solarize': _solarize_level_to_arg,
|
||||||
|
'SolarizeIncreasing': _solarize_level_to_arg,
|
||||||
'SolarizeAdd': _solarize_add_level_to_arg,
|
'SolarizeAdd': _solarize_add_level_to_arg,
|
||||||
'Color': _enhance_level_to_arg,
|
'Color': _enhance_level_to_arg,
|
||||||
'Contrast': _enhance_level_to_arg,
|
'Contrast': _enhance_level_to_arg,
|
||||||
|
@ -256,10 +267,11 @@ NAME_TO_OP = {
|
||||||
'Equalize': equalize,
|
'Equalize': equalize,
|
||||||
'Invert': invert,
|
'Invert': invert,
|
||||||
'Rotate': rotate,
|
'Rotate': rotate,
|
||||||
|
'Posterize': posterize,
|
||||||
|
'PosterizeIncreasing': posterize,
|
||||||
'PosterizeOriginal': posterize,
|
'PosterizeOriginal': posterize,
|
||||||
'PosterizeResearch': posterize,
|
|
||||||
'PosterizeTpu': posterize,
|
|
||||||
'Solarize': solarize,
|
'Solarize': solarize,
|
||||||
|
'SolarizeIncreasing': solarize,
|
||||||
'SolarizeAdd': solarize_add,
|
'SolarizeAdd': solarize_add,
|
||||||
'Color': color,
|
'Color': color,
|
||||||
'Contrast': contrast,
|
'Contrast': contrast,
|
||||||
|
@ -274,7 +286,7 @@ NAME_TO_OP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class AutoAugmentOp:
|
class AugmentOp:
|
||||||
|
|
||||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||||
hparams = hparams or _HPARAMS_DEFAULT
|
hparams = hparams or _HPARAMS_DEFAULT
|
||||||
|
@ -295,12 +307,12 @@ class AutoAugmentOp:
|
||||||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
if random.random() > self.prob:
|
if not self.prob >= 1.0 or random.random() > self.prob:
|
||||||
return img
|
return img
|
||||||
magnitude = self.magnitude
|
magnitude = self.magnitude
|
||||||
if self.magnitude_std and self.magnitude_std > 0:
|
if self.magnitude_std and self.magnitude_std > 0:
|
||||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||||
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
||||||
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||||
|
|
||||||
|
@ -320,7 +332,7 @@ def auto_augment_policy_v0(hparams):
|
||||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||||
[('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||||
|
@ -330,16 +342,17 @@ def auto_augment_policy_v0(hparams):
|
||||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||||
[('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
||||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||||
]
|
]
|
||||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||||
return pc
|
return pc
|
||||||
|
|
||||||
|
|
||||||
def auto_augment_policy_v0r(hparams):
|
def auto_augment_policy_v0r(hparams):
|
||||||
# ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize
|
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
|
||||||
|
# in Google research implementation (number of bits discarded increases with magnitude)
|
||||||
policy = [
|
policy = [
|
||||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||||
|
@ -353,7 +366,7 @@ def auto_augment_policy_v0r(hparams):
|
||||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||||
[('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||||
|
@ -363,11 +376,11 @@ def auto_augment_policy_v0r(hparams):
|
||||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||||
[('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)],
|
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||||
]
|
]
|
||||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||||
return pc
|
return pc
|
||||||
|
|
||||||
|
|
||||||
|
@ -400,23 +413,23 @@ def auto_augment_policy_original(hparams):
|
||||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||||
]
|
]
|
||||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||||
return pc
|
return pc
|
||||||
|
|
||||||
|
|
||||||
def auto_augment_policy_originalr(hparams):
|
def auto_augment_policy_originalr(hparams):
|
||||||
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
|
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
|
||||||
policy = [
|
policy = [
|
||||||
[('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)],
|
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||||
[('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)],
|
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
|
||||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||||
[('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)],
|
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||||
[('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)],
|
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
|
||||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||||
|
@ -433,7 +446,7 @@ def auto_augment_policy_originalr(hparams):
|
||||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||||
]
|
]
|
||||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||||
return pc
|
return pc
|
||||||
|
|
||||||
|
|
||||||
|
@ -499,7 +512,7 @@ _RAND_TRANSFORMS = [
|
||||||
'Equalize',
|
'Equalize',
|
||||||
'Invert',
|
'Invert',
|
||||||
'Rotate',
|
'Rotate',
|
||||||
'PosterizeTpu',
|
'Posterize',
|
||||||
'Solarize',
|
'Solarize',
|
||||||
'SolarizeAdd',
|
'SolarizeAdd',
|
||||||
'Color',
|
'Color',
|
||||||
|
@ -530,7 +543,7 @@ _RAND_CHOICE_WEIGHTS_0 = {
|
||||||
'Contrast': .005,
|
'Contrast': .005,
|
||||||
'Brightness': .005,
|
'Brightness': .005,
|
||||||
'Equalize': .005,
|
'Equalize': .005,
|
||||||
'PosterizeTpu': 0,
|
'Posterize': 0,
|
||||||
'Invert': 0,
|
'Invert': 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -547,7 +560,7 @@ def _select_rand_weights(weight_idx=0, transforms=None):
|
||||||
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
||||||
hparams = hparams or _HPARAMS_DEFAULT
|
hparams = hparams or _HPARAMS_DEFAULT
|
||||||
transforms = transforms or _RAND_TRANSFORMS
|
transforms = transforms or _RAND_TRANSFORMS
|
||||||
return [AutoAugmentOp(
|
return [AugmentOp(
|
||||||
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||||
|
|
||||||
|
|
||||||
|
@ -609,3 +622,94 @@ def rand_augment_transform(config_str, hparams):
|
||||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
|
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
|
||||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||||
|
|
||||||
|
|
||||||
|
_AUGMIX_TRANSFORMS = [
|
||||||
|
'AutoContrast',
|
||||||
|
'Contrast', # not in paper
|
||||||
|
'Brightness', # not in paper
|
||||||
|
'Sharpness', # not in paper
|
||||||
|
'Equalize',
|
||||||
|
'Rotate',
|
||||||
|
'PosterizeIncreasing',
|
||||||
|
'SolarizeIncreasing',
|
||||||
|
'ShearX',
|
||||||
|
'ShearY',
|
||||||
|
'TranslateXRel',
|
||||||
|
'TranslateYRel',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
||||||
|
hparams = hparams or _HPARAMS_DEFAULT
|
||||||
|
transforms = transforms or _AUGMIX_TRANSFORMS
|
||||||
|
return [AugmentOp(
|
||||||
|
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||||
|
|
||||||
|
|
||||||
|
class AugMixAugment:
|
||||||
|
def __init__(self, ops, alpha=1., width=3, depth=-1):
|
||||||
|
self.ops = ops
|
||||||
|
self.alpha = alpha
|
||||||
|
self.width = width
|
||||||
|
self.depth = depth
|
||||||
|
self.recursive = True
|
||||||
|
|
||||||
|
def _apply_recursive(self, img, ws, prod=1.):
|
||||||
|
alpha = ws[-1] / prod
|
||||||
|
if len(ws) > 1:
|
||||||
|
img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha))
|
||||||
|
|
||||||
|
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||||
|
ops = np.random.choice(self.ops, depth, replace=True)
|
||||||
|
img_aug = img # no ops are in-place, deep copy not necessary
|
||||||
|
for op in ops:
|
||||||
|
img_aug = op(img_aug)
|
||||||
|
return Image.blend(img, img_aug, alpha)
|
||||||
|
|
||||||
|
def _apply_basic(self, img, ws, m):
|
||||||
|
w, h = img.size
|
||||||
|
c = len(img.getbands())
|
||||||
|
mixed = np.zeros((w, h, c), dtype=np.float32)
|
||||||
|
for w in ws:
|
||||||
|
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||||
|
ops = np.random.choice(self.ops, depth, replace=True)
|
||||||
|
img_aug = img # no ops are in-place, deep copy not necessary
|
||||||
|
for op in ops:
|
||||||
|
img_aug = op(img_aug)
|
||||||
|
img_aug = np.asarray(img_aug, dtype=np.float32)
|
||||||
|
mixed += w * img_aug
|
||||||
|
np.clip(mixed, 0, 255., out=mixed)
|
||||||
|
mixed = Image.fromarray(mixed.astype(np.uint8))
|
||||||
|
return Image.blend(img, mixed, m)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
|
||||||
|
m = np.float32(np.random.beta(self.alpha, self.alpha))
|
||||||
|
if self.recursive:
|
||||||
|
mixing_weights *= m
|
||||||
|
mixed = self._apply_recursive(img, mixing_weights)
|
||||||
|
else:
|
||||||
|
mixed = self._apply_basic(img, mixing_weights, m)
|
||||||
|
return mixed
|
||||||
|
|
||||||
|
|
||||||
|
def augment_and_mix_transform(config_str, hparams):
|
||||||
|
"""Perform AugMix augmentations and compute mixture.
|
||||||
|
Args:
|
||||||
|
image: Raw input image as float32 np.ndarray of shape (h, w, c)
|
||||||
|
severity: Severity of underlying augmentation operators (between 1 to 10).
|
||||||
|
width: Width of augmentation chain
|
||||||
|
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
|
||||||
|
from [1, 3]
|
||||||
|
alpha: Probability coefficient for Beta and Dirichlet distributions.
|
||||||
|
Returns:
|
||||||
|
mixed: Augmented and mixed image.
|
||||||
|
"""
|
||||||
|
# FIXME parse args from config str
|
||||||
|
severity = 3
|
||||||
|
width = 3
|
||||||
|
depth = -1
|
||||||
|
alpha = 1.
|
||||||
|
ops = augmix_ops(magnitude=severity, hparams=hparams)
|
||||||
|
return AugMixAugment(ops, alpha, width, depth)
|
||||||
|
|
|
@ -140,3 +140,41 @@ class DatasetTar(data.Dataset):
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imgs)
|
return len(self.imgs)
|
||||||
|
|
||||||
|
|
||||||
|
class AugMixDataset(torch.utils.data.Dataset):
|
||||||
|
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, num_aug=2):
|
||||||
|
self.augmentation = None
|
||||||
|
self.normalize = None
|
||||||
|
self.dataset = dataset
|
||||||
|
if self.dataset.transform is not None:
|
||||||
|
self._set_transforms(self.dataset.transform)
|
||||||
|
self.num_aug = num_aug
|
||||||
|
|
||||||
|
def _set_transforms(self, x):
|
||||||
|
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
||||||
|
self.dataset.transform = x[0]
|
||||||
|
self.augmentation = x[1]
|
||||||
|
self.normalize = x[2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transform(self):
|
||||||
|
return self.dataset.transform
|
||||||
|
|
||||||
|
@transform.setter
|
||||||
|
def transform(self, x):
|
||||||
|
self._set_transforms(x)
|
||||||
|
|
||||||
|
def _normalize(self, x):
|
||||||
|
return x if self.normalize is None else self.normalize(x)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
x, y = self.dataset[i]
|
||||||
|
x_list = [self._normalize(x)]
|
||||||
|
for n in range(self.num_aug):
|
||||||
|
x_list.append(self._normalize(self.augmentation(x)))
|
||||||
|
return tuple(x_list), y
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataset)
|
||||||
|
|
|
@ -1,17 +1,46 @@
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from .transforms import *
|
import numpy as np
|
||||||
|
|
||||||
|
from .transforms_factory import create_transform
|
||||||
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .distributed_sampler import OrderedDistributedSampler
|
from .distributed_sampler import OrderedDistributedSampler
|
||||||
|
from .random_erasing import RandomErasing
|
||||||
from .mixup import FastCollateMixup
|
from .mixup import FastCollateMixup
|
||||||
|
|
||||||
|
|
||||||
def fast_collate(batch):
|
def fast_collate(batch):
|
||||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
||||||
batch_size = len(targets)
|
assert isinstance(batch[0], tuple)
|
||||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
batch_size = len(batch)
|
||||||
for i in range(batch_size):
|
if isinstance(batch[0][0], tuple):
|
||||||
tensor[i] += torch.from_numpy(batch[i][0])
|
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
||||||
|
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
||||||
return tensor, targets
|
inner_tuple_size = len(batch[0][0][0])
|
||||||
|
flattened_batch_size = batch_size * inner_tuple_size
|
||||||
|
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
||||||
|
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
||||||
|
for i in range(batch_size):
|
||||||
|
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
||||||
|
for j in range(inner_tuple_size):
|
||||||
|
targets[i + j * batch_size] = batch[i][1]
|
||||||
|
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
||||||
|
return tensor, targets
|
||||||
|
elif isinstance(batch[0][0], np.ndarray):
|
||||||
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||||
|
assert len(targets) == batch_size
|
||||||
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||||
|
for i in range(batch_size):
|
||||||
|
tensor[i] += torch.from_numpy(batch[i][0])
|
||||||
|
return tensor, targets
|
||||||
|
elif isinstance(batch[0][0], torch.Tensor):
|
||||||
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||||
|
assert len(targets) == batch_size
|
||||||
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||||
|
for i in range(batch_size):
|
||||||
|
tensor[i].copy_(batch[i][0])
|
||||||
|
return tensor, targets
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
class PrefetchLoader:
|
class PrefetchLoader:
|
||||||
|
@ -87,49 +116,6 @@ class PrefetchLoader:
|
||||||
self.loader.collate_fn.mixup_enabled = x
|
self.loader.collate_fn.mixup_enabled = x
|
||||||
|
|
||||||
|
|
||||||
def create_transform(
|
|
||||||
input_size,
|
|
||||||
is_training=False,
|
|
||||||
use_prefetcher=False,
|
|
||||||
color_jitter=0.4,
|
|
||||||
auto_augment=None,
|
|
||||||
interpolation='bilinear',
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
|
||||||
std=IMAGENET_DEFAULT_STD,
|
|
||||||
crop_pct=None,
|
|
||||||
tf_preprocessing=False):
|
|
||||||
|
|
||||||
if isinstance(input_size, tuple):
|
|
||||||
img_size = input_size[-2:]
|
|
||||||
else:
|
|
||||||
img_size = input_size
|
|
||||||
|
|
||||||
if tf_preprocessing and use_prefetcher:
|
|
||||||
from timm.data.tf_preprocessing import TfPreprocessTransform
|
|
||||||
transform = TfPreprocessTransform(
|
|
||||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
|
||||||
else:
|
|
||||||
if is_training:
|
|
||||||
transform = transforms_imagenet_train(
|
|
||||||
img_size,
|
|
||||||
color_jitter=color_jitter,
|
|
||||||
auto_augment=auto_augment,
|
|
||||||
interpolation=interpolation,
|
|
||||||
use_prefetcher=use_prefetcher,
|
|
||||||
mean=mean,
|
|
||||||
std=std)
|
|
||||||
else:
|
|
||||||
transform = transforms_imagenet_eval(
|
|
||||||
img_size,
|
|
||||||
interpolation=interpolation,
|
|
||||||
use_prefetcher=use_prefetcher,
|
|
||||||
mean=mean,
|
|
||||||
std=std,
|
|
||||||
crop_pct=crop_pct)
|
|
||||||
|
|
||||||
return transform
|
|
||||||
|
|
||||||
|
|
||||||
def create_loader(
|
def create_loader(
|
||||||
dataset,
|
dataset,
|
||||||
input_size,
|
input_size,
|
||||||
|
@ -150,6 +136,7 @@ def create_loader(
|
||||||
collate_fn=None,
|
collate_fn=None,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
tf_preprocessing=False,
|
tf_preprocessing=False,
|
||||||
|
separate_transforms=False,
|
||||||
):
|
):
|
||||||
dataset.transform = create_transform(
|
dataset.transform = create_transform(
|
||||||
input_size,
|
input_size,
|
||||||
|
@ -162,6 +149,7 @@ def create_loader(
|
||||||
std=std,
|
std=std,
|
||||||
crop_pct=crop_pct,
|
crop_pct=crop_pct,
|
||||||
tf_preprocessing=tf_preprocessing,
|
tf_preprocessing=tf_preprocessing,
|
||||||
|
separate=separate_transforms,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
|
@ -15,6 +15,15 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
||||||
return lam*y1 + (1. - lam)*y2
|
return lam*y1 + (1. - lam)*y2
|
||||||
|
|
||||||
|
|
||||||
|
def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
|
||||||
|
lam = 1.
|
||||||
|
if not disable:
|
||||||
|
lam = np.random.beta(alpha, alpha)
|
||||||
|
input = input.mul(lam).add_(1 - lam, input.flip(0))
|
||||||
|
target = mixup_target(target, num_classes, lam, smoothing)
|
||||||
|
return input, target
|
||||||
|
|
||||||
|
|
||||||
class FastCollateMixup:
|
class FastCollateMixup:
|
||||||
|
|
||||||
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
|
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -7,10 +6,6 @@ import math
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
||||||
from .random_erasing import RandomErasing
|
|
||||||
from .auto_augment import auto_augment_transform, rand_augment_transform
|
|
||||||
|
|
||||||
|
|
||||||
class ToNumpy:
|
class ToNumpy:
|
||||||
|
|
||||||
|
@ -161,97 +156,3 @@ class RandomResizedCropAndInterpolation:
|
||||||
return format_string
|
return format_string
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_train(
|
|
||||||
img_size=224,
|
|
||||||
scale=(0.08, 1.0),
|
|
||||||
color_jitter=0.4,
|
|
||||||
auto_augment=None,
|
|
||||||
interpolation='random',
|
|
||||||
random_erasing=0.4,
|
|
||||||
random_erasing_mode='const',
|
|
||||||
use_prefetcher=False,
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
|
||||||
std=IMAGENET_DEFAULT_STD
|
|
||||||
):
|
|
||||||
tfl = [
|
|
||||||
RandomResizedCropAndInterpolation(
|
|
||||||
img_size, scale=scale, interpolation=interpolation),
|
|
||||||
transforms.RandomHorizontalFlip()
|
|
||||||
]
|
|
||||||
if auto_augment:
|
|
||||||
assert isinstance(auto_augment, str)
|
|
||||||
if isinstance(img_size, tuple):
|
|
||||||
img_size_min = min(img_size)
|
|
||||||
else:
|
|
||||||
img_size_min = img_size
|
|
||||||
aa_params = dict(
|
|
||||||
translate_const=int(img_size_min * 0.45),
|
|
||||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
|
||||||
)
|
|
||||||
if interpolation and interpolation != 'random':
|
|
||||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
|
||||||
if auto_augment.startswith('rand'):
|
|
||||||
tfl += [rand_augment_transform(auto_augment, aa_params)]
|
|
||||||
else:
|
|
||||||
tfl += [auto_augment_transform(auto_augment, aa_params)]
|
|
||||||
else:
|
|
||||||
# color jitter is enabled when not using AA
|
|
||||||
if isinstance(color_jitter, (list, tuple)):
|
|
||||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
|
||||||
# or 4 if also augmenting hue
|
|
||||||
assert len(color_jitter) in (3, 4)
|
|
||||||
else:
|
|
||||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
|
||||||
color_jitter = (float(color_jitter),) * 3
|
|
||||||
tfl += [transforms.ColorJitter(*color_jitter)]
|
|
||||||
|
|
||||||
if use_prefetcher:
|
|
||||||
# prefetcher and collate will handle tensor conversion and norm
|
|
||||||
tfl += [ToNumpy()]
|
|
||||||
else:
|
|
||||||
tfl += [
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(
|
|
||||||
mean=torch.tensor(mean),
|
|
||||||
std=torch.tensor(std))
|
|
||||||
]
|
|
||||||
if random_erasing > 0.:
|
|
||||||
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
|
||||||
return transforms.Compose(tfl)
|
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_eval(
|
|
||||||
img_size=224,
|
|
||||||
crop_pct=None,
|
|
||||||
interpolation='bilinear',
|
|
||||||
use_prefetcher=False,
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
|
||||||
std=IMAGENET_DEFAULT_STD):
|
|
||||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
|
||||||
|
|
||||||
if isinstance(img_size, tuple):
|
|
||||||
assert len(img_size) == 2
|
|
||||||
if img_size[-1] == img_size[-2]:
|
|
||||||
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
|
||||||
scale_size = int(math.floor(img_size[0] / crop_pct))
|
|
||||||
else:
|
|
||||||
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
|
||||||
else:
|
|
||||||
scale_size = int(math.floor(img_size / crop_pct))
|
|
||||||
|
|
||||||
tfl = [
|
|
||||||
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
|
||||||
transforms.CenterCrop(img_size),
|
|
||||||
]
|
|
||||||
if use_prefetcher:
|
|
||||||
# prefetcher and collate will handle tensor conversion and norm
|
|
||||||
tfl += [ToNumpy()]
|
|
||||||
else:
|
|
||||||
tfl += [
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(
|
|
||||||
mean=torch.tensor(mean),
|
|
||||||
std=torch.tensor(std))
|
|
||||||
]
|
|
||||||
|
|
||||||
return transforms.Compose(tfl)
|
|
||||||
|
|
|
@ -0,0 +1,164 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||||
|
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||||
|
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
|
||||||
|
from timm.data.random_erasing import RandomErasing
|
||||||
|
|
||||||
|
|
||||||
|
def transforms_imagenet_train(
|
||||||
|
img_size=224,
|
||||||
|
scale=(0.08, 1.0),
|
||||||
|
color_jitter=0.4,
|
||||||
|
auto_augment=None,
|
||||||
|
interpolation='random',
|
||||||
|
random_erasing=0.4,
|
||||||
|
random_erasing_mode='const',
|
||||||
|
use_prefetcher=False,
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
|
std=IMAGENET_DEFAULT_STD,
|
||||||
|
separate=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
primary_tfl = [
|
||||||
|
RandomResizedCropAndInterpolation(
|
||||||
|
img_size, scale=scale, interpolation=interpolation),
|
||||||
|
transforms.RandomHorizontalFlip()
|
||||||
|
]
|
||||||
|
|
||||||
|
secondary_tfl = []
|
||||||
|
if auto_augment:
|
||||||
|
assert isinstance(auto_augment, str)
|
||||||
|
if isinstance(img_size, tuple):
|
||||||
|
img_size_min = min(img_size)
|
||||||
|
else:
|
||||||
|
img_size_min = img_size
|
||||||
|
aa_params = dict(
|
||||||
|
translate_const=int(img_size_min * 0.45),
|
||||||
|
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
||||||
|
)
|
||||||
|
if interpolation and interpolation != 'random':
|
||||||
|
aa_params['interpolation'] = _pil_interp(interpolation)
|
||||||
|
if auto_augment.startswith('rand'):
|
||||||
|
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
|
||||||
|
elif auto_augment.startswith('augmix'):
|
||||||
|
aa_params['translate_pct'] = 0.3
|
||||||
|
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
|
||||||
|
else:
|
||||||
|
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
|
||||||
|
elif color_jitter is not None:
|
||||||
|
# color jitter is enabled when not using AA
|
||||||
|
if isinstance(color_jitter, (list, tuple)):
|
||||||
|
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||||
|
# or 4 if also augmenting hue
|
||||||
|
assert len(color_jitter) in (3, 4)
|
||||||
|
else:
|
||||||
|
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||||
|
color_jitter = (float(color_jitter),) * 3
|
||||||
|
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
||||||
|
|
||||||
|
final_tfl = []
|
||||||
|
if use_prefetcher:
|
||||||
|
# prefetcher and collate will handle tensor conversion and norm
|
||||||
|
final_tfl += [ToNumpy()]
|
||||||
|
else:
|
||||||
|
final_tfl += [
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=torch.tensor(mean),
|
||||||
|
std=torch.tensor(std))
|
||||||
|
]
|
||||||
|
if random_erasing > 0.:
|
||||||
|
final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
||||||
|
|
||||||
|
if separate:
|
||||||
|
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||||
|
else:
|
||||||
|
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
|
||||||
|
|
||||||
|
|
||||||
|
def transforms_imagenet_eval(
|
||||||
|
img_size=224,
|
||||||
|
crop_pct=None,
|
||||||
|
interpolation='bilinear',
|
||||||
|
use_prefetcher=False,
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
|
std=IMAGENET_DEFAULT_STD):
|
||||||
|
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||||
|
|
||||||
|
if isinstance(img_size, tuple):
|
||||||
|
assert len(img_size) == 2
|
||||||
|
if img_size[-1] == img_size[-2]:
|
||||||
|
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
||||||
|
scale_size = int(math.floor(img_size[0] / crop_pct))
|
||||||
|
else:
|
||||||
|
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
||||||
|
else:
|
||||||
|
scale_size = int(math.floor(img_size / crop_pct))
|
||||||
|
|
||||||
|
tfl = [
|
||||||
|
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
]
|
||||||
|
if use_prefetcher:
|
||||||
|
# prefetcher and collate will handle tensor conversion and norm
|
||||||
|
tfl += [ToNumpy()]
|
||||||
|
else:
|
||||||
|
tfl += [
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=torch.tensor(mean),
|
||||||
|
std=torch.tensor(std))
|
||||||
|
]
|
||||||
|
|
||||||
|
return transforms.Compose(tfl)
|
||||||
|
|
||||||
|
|
||||||
|
def create_transform(
|
||||||
|
input_size,
|
||||||
|
is_training=False,
|
||||||
|
use_prefetcher=False,
|
||||||
|
color_jitter=0.4,
|
||||||
|
auto_augment=None,
|
||||||
|
interpolation='bilinear',
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
|
std=IMAGENET_DEFAULT_STD,
|
||||||
|
crop_pct=None,
|
||||||
|
tf_preprocessing=False,
|
||||||
|
separate=False):
|
||||||
|
|
||||||
|
if isinstance(input_size, tuple):
|
||||||
|
img_size = input_size[-2:]
|
||||||
|
else:
|
||||||
|
img_size = input_size
|
||||||
|
|
||||||
|
if tf_preprocessing and use_prefetcher:
|
||||||
|
assert not separate, "Separate transforms not supported for TF preprocessing"
|
||||||
|
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||||||
|
transform = TfPreprocessTransform(
|
||||||
|
is_training=is_training, size=img_size, interpolation=interpolation)
|
||||||
|
else:
|
||||||
|
if is_training:
|
||||||
|
transform = transforms_imagenet_train(
|
||||||
|
img_size,
|
||||||
|
color_jitter=color_jitter,
|
||||||
|
auto_augment=auto_augment,
|
||||||
|
interpolation=interpolation,
|
||||||
|
use_prefetcher=use_prefetcher,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
separate=separate)
|
||||||
|
else:
|
||||||
|
assert not separate, "Separate transforms not supported for validation preprocessing"
|
||||||
|
transform = transforms_imagenet_eval(
|
||||||
|
img_size,
|
||||||
|
interpolation=interpolation,
|
||||||
|
use_prefetcher=use_prefetcher,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
crop_pct=crop_pct)
|
||||||
|
|
||||||
|
return transform
|
|
@ -1 +1,2 @@
|
||||||
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||||
|
from .jsd import JsdCrossEntropy
|
|
@ -0,0 +1,34 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .cross_entropy import LabelSmoothingCrossEntropy
|
||||||
|
|
||||||
|
|
||||||
|
class JsdCrossEntropy(nn.Module):
|
||||||
|
""" Jenson-Shannon Divergence + Cross-Entropy Loss
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_splits = num_splits
|
||||||
|
self.alpha = alpha
|
||||||
|
if smoothing is not None and smoothing > 0:
|
||||||
|
self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
|
||||||
|
else:
|
||||||
|
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def __call__(self, output, target):
|
||||||
|
split_size = output.shape[0] // self.num_splits
|
||||||
|
assert split_size * self.num_splits == output.shape[0]
|
||||||
|
logits_split = torch.split(output, split_size)
|
||||||
|
|
||||||
|
# Cross-entropy is only computed on clean images
|
||||||
|
loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
|
||||||
|
probs = [F.softmax(logits, dim=1) for logits in logits_split]
|
||||||
|
|
||||||
|
# Clamp mixture distribution to avoid exploding KL divergence
|
||||||
|
logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
|
||||||
|
loss += self.alpha * sum([F.kl_div(
|
||||||
|
logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
|
||||||
|
return loss
|
33
train.py
33
train.py
|
@ -1,7 +1,6 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
import logging
|
|
||||||
import yaml
|
import yaml
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
@ -14,13 +13,16 @@ except ImportError:
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch
|
||||||
from timm.models import create_model, resume_checkpoint
|
from timm.models import create_model, resume_checkpoint
|
||||||
from timm.utils import *
|
from timm.utils import *
|
||||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||||
from timm.optim import create_optimizer
|
from timm.optim import create_optimizer
|
||||||
from timm.scheduler import create_scheduler
|
from timm.scheduler import create_scheduler
|
||||||
|
|
||||||
|
#FIXME
|
||||||
|
from timm.data.dataset import AugMixDataset
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
|
@ -160,6 +162,10 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||||
parser.add_argument("--local_rank", default=0, type=int)
|
parser.add_argument("--local_rank", default=0, type=int)
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument('--jsd', action='store_true', default=False,
|
||||||
|
help='')
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
# Do we have a config file to parse?
|
# Do we have a config file to parse?
|
||||||
args_config, remaining = config_parser.parse_known_args()
|
args_config, remaining = config_parser.parse_known_args()
|
||||||
|
@ -311,8 +317,14 @@ def main():
|
||||||
|
|
||||||
collate_fn = None
|
collate_fn = None
|
||||||
if args.prefetcher and args.mixup > 0:
|
if args.prefetcher and args.mixup > 0:
|
||||||
|
assert not args.jsd
|
||||||
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
|
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
|
||||||
|
|
||||||
|
separate_transforms = False
|
||||||
|
if args.jsd:
|
||||||
|
dataset_train = AugMixDataset(dataset_train)
|
||||||
|
separate_transforms = True
|
||||||
|
|
||||||
loader_train = create_loader(
|
loader_train = create_loader(
|
||||||
dataset_train,
|
dataset_train,
|
||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
|
@ -330,6 +342,7 @@ def main():
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
|
separate_transforms=separate_transforms,
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_dir = os.path.join(args.data, 'val')
|
eval_dir = os.path.join(args.data, 'val')
|
||||||
|
@ -354,7 +367,10 @@ def main():
|
||||||
crop_pct=data_config['crop_pct'],
|
crop_pct=data_config['crop_pct'],
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.mixup > 0.:
|
if args.jsd:
|
||||||
|
train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda()
|
||||||
|
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||||
|
elif args.mixup > 0.:
|
||||||
# smoothing is handled with mixup label transform
|
# smoothing is handled with mixup label transform
|
||||||
train_loss_fn = SoftTargetCrossEntropy().cuda()
|
train_loss_fn = SoftTargetCrossEntropy().cuda()
|
||||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||||
|
@ -452,11 +468,10 @@ def train_epoch(
|
||||||
if not args.prefetcher:
|
if not args.prefetcher:
|
||||||
input, target = input.cuda(), target.cuda()
|
input, target = input.cuda(), target.cuda()
|
||||||
if args.mixup > 0.:
|
if args.mixup > 0.:
|
||||||
lam = 1.
|
input, target = mixup_batch(
|
||||||
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
|
input, target,
|
||||||
lam = np.random.beta(args.mixup, args.mixup)
|
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
|
||||||
input = input.mul(lam).add_(1 - lam, input.flip(0))
|
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
|
||||||
target = mixup_target(target, args.num_classes, lam, args.smoothing)
|
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue