Working on an implementation of AugMix with JensenShannonDivergence loss that's compatible with my AutoAugment and RandAugment impl
parent
ff8688ca3d
commit
232ab7fb12
|
@ -2,7 +2,8 @@ from .constants import *
|
|||
from .config import resolve_data_config
|
||||
from .dataset import Dataset, DatasetTar
|
||||
from .transforms import *
|
||||
from .loader import create_loader, create_transform
|
||||
from .mixup import mixup_target, FastCollateMixup
|
||||
from .loader import create_loader
|
||||
from .transforms_factory import create_transform
|
||||
from .mixup import mixup_batch, FastCollateMixup
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
|
|
|
@ -8,12 +8,11 @@ Hacked together by Ross Wightman
|
|||
import random
|
||||
import math
|
||||
import re
|
||||
from PIL import Image, ImageOps, ImageEnhance
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||
|
||||
_FILL = (128, 128, 128)
|
||||
|
@ -192,36 +191,47 @@ def _translate_abs_level_to_arg(level, hparams):
|
|||
return level,
|
||||
|
||||
|
||||
def _translate_rel_level_to_arg(level, _hparams):
|
||||
# range [-0.45, 0.45]
|
||||
level = (level / _MAX_LEVEL) * 0.45
|
||||
def _translate_rel_level_to_arg(level, hparams):
|
||||
# default range [-0.45, 0.45]
|
||||
translate_pct = hparams.get('translate_pct', 0.45)
|
||||
level = (level / _MAX_LEVEL) * translate_pct
|
||||
level = _randomly_negate(level)
|
||||
return level,
|
||||
|
||||
|
||||
def _posterize_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow TPU EfficientNet impl
|
||||
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _MAX_LEVEL) * 4),
|
||||
|
||||
|
||||
def _posterize_increasing_level_to_arg(level, hparams):
|
||||
# As per Tensorflow models research and UDA impl
|
||||
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
||||
# intensity/severity of augmentation increases with level
|
||||
return 4 - _posterize_level_to_arg(level, hparams)[0],
|
||||
|
||||
|
||||
def _posterize_original_level_to_arg(level, _hparams):
|
||||
# As per original AutoAugment paper description
|
||||
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _MAX_LEVEL) * 4) + 4,
|
||||
|
||||
|
||||
def _posterize_research_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow models research and UDA impl
|
||||
# range [4, 0], 'keep 4 down to 0 MSB of original image'
|
||||
return 4 - int((level / _MAX_LEVEL) * 4),
|
||||
|
||||
|
||||
def _posterize_tpu_level_to_arg(level, _hparams):
|
||||
# As per Tensorflow TPU EfficientNet impl
|
||||
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
||||
return int((level / _MAX_LEVEL) * 4),
|
||||
|
||||
|
||||
def _solarize_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _MAX_LEVEL) * 256),
|
||||
|
||||
|
||||
def _solarize_increasing_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation increases with level
|
||||
return 256 - _solarize_level_to_arg(level, _hparams)[0],
|
||||
|
||||
|
||||
def _solarize_add_level_to_arg(level, _hparams):
|
||||
# range [0, 110]
|
||||
return int((level / _MAX_LEVEL) * 110),
|
||||
|
@ -233,10 +243,11 @@ LEVEL_TO_ARG = {
|
|||
'Invert': None,
|
||||
'Rotate': _rotate_level_to_arg,
|
||||
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
||||
'Posterize': _posterize_level_to_arg,
|
||||
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
|
||||
'PosterizeOriginal': _posterize_original_level_to_arg,
|
||||
'PosterizeResearch': _posterize_research_level_to_arg,
|
||||
'PosterizeTpu': _posterize_tpu_level_to_arg,
|
||||
'Solarize': _solarize_level_to_arg,
|
||||
'SolarizeIncreasing': _solarize_level_to_arg,
|
||||
'SolarizeAdd': _solarize_add_level_to_arg,
|
||||
'Color': _enhance_level_to_arg,
|
||||
'Contrast': _enhance_level_to_arg,
|
||||
|
@ -256,10 +267,11 @@ NAME_TO_OP = {
|
|||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'Rotate': rotate,
|
||||
'Posterize': posterize,
|
||||
'PosterizeIncreasing': posterize,
|
||||
'PosterizeOriginal': posterize,
|
||||
'PosterizeResearch': posterize,
|
||||
'PosterizeTpu': posterize,
|
||||
'Solarize': solarize,
|
||||
'SolarizeIncreasing': solarize,
|
||||
'SolarizeAdd': solarize_add,
|
||||
'Color': color,
|
||||
'Contrast': contrast,
|
||||
|
@ -274,7 +286,7 @@ NAME_TO_OP = {
|
|||
}
|
||||
|
||||
|
||||
class AutoAugmentOp:
|
||||
class AugmentOp:
|
||||
|
||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
|
@ -295,12 +307,12 @@ class AutoAugmentOp:
|
|||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() > self.prob:
|
||||
if not self.prob >= 1.0 or random.random() > self.prob:
|
||||
return img
|
||||
magnitude = self.magnitude
|
||||
if self.magnitude_std and self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
||||
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
||||
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
||||
return self.aug_fn(img, *level_args, **self.kwargs)
|
||||
|
||||
|
@ -320,7 +332,7 @@ def auto_augment_policy_v0(hparams):
|
|||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
|
@ -330,16 +342,17 @@ def auto_augment_policy_v0(hparams):
|
|||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
||||
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_v0r(hparams):
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize
|
||||
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
|
||||
# in Google research implementation (number of bits discarded increases with magnitude)
|
||||
policy = [
|
||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||
|
@ -353,7 +366,7 @@ def auto_augment_policy_v0r(hparams):
|
|||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
|
@ -363,11 +376,11 @@ def auto_augment_policy_v0r(hparams):
|
|||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
|
@ -400,23 +413,23 @@ def auto_augment_policy_original(hparams):
|
|||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_originalr(hparams):
|
||||
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
|
||||
policy = [
|
||||
[('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
[('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)],
|
||||
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||
[('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)],
|
||||
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
|
||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||
|
@ -433,7 +446,7 @@ def auto_augment_policy_originalr(hparams):
|
|||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
|
@ -499,7 +512,7 @@ _RAND_TRANSFORMS = [
|
|||
'Equalize',
|
||||
'Invert',
|
||||
'Rotate',
|
||||
'PosterizeTpu',
|
||||
'Posterize',
|
||||
'Solarize',
|
||||
'SolarizeAdd',
|
||||
'Color',
|
||||
|
@ -530,7 +543,7 @@ _RAND_CHOICE_WEIGHTS_0 = {
|
|||
'Contrast': .005,
|
||||
'Brightness': .005,
|
||||
'Equalize': .005,
|
||||
'PosterizeTpu': 0,
|
||||
'Posterize': 0,
|
||||
'Invert': 0,
|
||||
}
|
||||
|
||||
|
@ -547,7 +560,7 @@ def _select_rand_weights(weight_idx=0, transforms=None):
|
|||
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
return [AutoAugmentOp(
|
||||
return [AugmentOp(
|
||||
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
|
@ -609,3 +622,94 @@ def rand_augment_transform(config_str, hparams):
|
|||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
|
||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||
|
||||
|
||||
_AUGMIX_TRANSFORMS = [
|
||||
'AutoContrast',
|
||||
'Contrast', # not in paper
|
||||
'Brightness', # not in paper
|
||||
'Sharpness', # not in paper
|
||||
'Equalize',
|
||||
'Rotate',
|
||||
'PosterizeIncreasing',
|
||||
'SolarizeIncreasing',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
]
|
||||
|
||||
|
||||
def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _AUGMIX_TRANSFORMS
|
||||
return [AugmentOp(
|
||||
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
class AugMixAugment:
|
||||
def __init__(self, ops, alpha=1., width=3, depth=-1):
|
||||
self.ops = ops
|
||||
self.alpha = alpha
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.recursive = True
|
||||
|
||||
def _apply_recursive(self, img, ws, prod=1.):
|
||||
alpha = ws[-1] / prod
|
||||
if len(ws) > 1:
|
||||
img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha))
|
||||
|
||||
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||
ops = np.random.choice(self.ops, depth, replace=True)
|
||||
img_aug = img # no ops are in-place, deep copy not necessary
|
||||
for op in ops:
|
||||
img_aug = op(img_aug)
|
||||
return Image.blend(img, img_aug, alpha)
|
||||
|
||||
def _apply_basic(self, img, ws, m):
|
||||
w, h = img.size
|
||||
c = len(img.getbands())
|
||||
mixed = np.zeros((w, h, c), dtype=np.float32)
|
||||
for w in ws:
|
||||
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
||||
ops = np.random.choice(self.ops, depth, replace=True)
|
||||
img_aug = img # no ops are in-place, deep copy not necessary
|
||||
for op in ops:
|
||||
img_aug = op(img_aug)
|
||||
img_aug = np.asarray(img_aug, dtype=np.float32)
|
||||
mixed += w * img_aug
|
||||
np.clip(mixed, 0, 255., out=mixed)
|
||||
mixed = Image.fromarray(mixed.astype(np.uint8))
|
||||
return Image.blend(img, mixed, m)
|
||||
|
||||
def __call__(self, img):
|
||||
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
|
||||
m = np.float32(np.random.beta(self.alpha, self.alpha))
|
||||
if self.recursive:
|
||||
mixing_weights *= m
|
||||
mixed = self._apply_recursive(img, mixing_weights)
|
||||
else:
|
||||
mixed = self._apply_basic(img, mixing_weights, m)
|
||||
return mixed
|
||||
|
||||
|
||||
def augment_and_mix_transform(config_str, hparams):
|
||||
"""Perform AugMix augmentations and compute mixture.
|
||||
Args:
|
||||
image: Raw input image as float32 np.ndarray of shape (h, w, c)
|
||||
severity: Severity of underlying augmentation operators (between 1 to 10).
|
||||
width: Width of augmentation chain
|
||||
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
|
||||
from [1, 3]
|
||||
alpha: Probability coefficient for Beta and Dirichlet distributions.
|
||||
Returns:
|
||||
mixed: Augmented and mixed image.
|
||||
"""
|
||||
# FIXME parse args from config str
|
||||
severity = 3
|
||||
width = 3
|
||||
depth = -1
|
||||
alpha = 1.
|
||||
ops = augmix_ops(magnitude=severity, hparams=hparams)
|
||||
return AugMixAugment(ops, alpha, width, depth)
|
||||
|
|
|
@ -140,3 +140,41 @@ class DatasetTar(data.Dataset):
|
|||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
||||
class AugMixDataset(torch.utils.data.Dataset):
|
||||
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
||||
|
||||
def __init__(self, dataset, num_aug=2):
|
||||
self.augmentation = None
|
||||
self.normalize = None
|
||||
self.dataset = dataset
|
||||
if self.dataset.transform is not None:
|
||||
self._set_transforms(self.dataset.transform)
|
||||
self.num_aug = num_aug
|
||||
|
||||
def _set_transforms(self, x):
|
||||
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
||||
self.dataset.transform = x[0]
|
||||
self.augmentation = x[1]
|
||||
self.normalize = x[2]
|
||||
|
||||
@property
|
||||
def transform(self):
|
||||
return self.dataset.transform
|
||||
|
||||
@transform.setter
|
||||
def transform(self, x):
|
||||
self._set_transforms(x)
|
||||
|
||||
def _normalize(self, x):
|
||||
return x if self.normalize is None else self.normalize(x)
|
||||
|
||||
def __getitem__(self, i):
|
||||
x, y = self.dataset[i]
|
||||
x_list = [self._normalize(x)]
|
||||
for n in range(self.num_aug):
|
||||
x_list.append(self._normalize(self.augmentation(x)))
|
||||
return tuple(x_list), y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
|
|
@ -1,17 +1,46 @@
|
|||
import torch.utils.data
|
||||
from .transforms import *
|
||||
import numpy as np
|
||||
|
||||
from .transforms_factory import create_transform
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .distributed_sampler import OrderedDistributedSampler
|
||||
from .random_erasing import RandomErasing
|
||||
from .mixup import FastCollateMixup
|
||||
|
||||
|
||||
def fast_collate(batch):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
batch_size = len(targets)
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i] += torch.from_numpy(batch[i][0])
|
||||
|
||||
return tensor, targets
|
||||
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
||||
assert isinstance(batch[0], tuple)
|
||||
batch_size = len(batch)
|
||||
if isinstance(batch[0][0], tuple):
|
||||
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
||||
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
||||
inner_tuple_size = len(batch[0][0][0])
|
||||
flattened_batch_size = batch_size * inner_tuple_size
|
||||
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
||||
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
||||
for j in range(inner_tuple_size):
|
||||
targets[i + j * batch_size] = batch[i][1]
|
||||
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
||||
return tensor, targets
|
||||
elif isinstance(batch[0][0], np.ndarray):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
assert len(targets) == batch_size
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i] += torch.from_numpy(batch[i][0])
|
||||
return tensor, targets
|
||||
elif isinstance(batch[0][0], torch.Tensor):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
assert len(targets) == batch_size
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i].copy_(batch[i][0])
|
||||
return tensor, targets
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
class PrefetchLoader:
|
||||
|
@ -87,49 +116,6 @@ class PrefetchLoader:
|
|||
self.loader.collate_fn.mixup_enabled = x
|
||||
|
||||
|
||||
def create_transform(
|
||||
input_size,
|
||||
is_training=False,
|
||||
use_prefetcher=False,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
crop_pct=None,
|
||||
tf_preprocessing=False):
|
||||
|
||||
if isinstance(input_size, tuple):
|
||||
img_size = input_size[-2:]
|
||||
else:
|
||||
img_size = input_size
|
||||
|
||||
if tf_preprocessing and use_prefetcher:
|
||||
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||||
transform = TfPreprocessTransform(
|
||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
||||
else:
|
||||
if is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
img_size,
|
||||
color_jitter=color_jitter,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std)
|
||||
else:
|
||||
transform = transforms_imagenet_eval(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
crop_pct=crop_pct)
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
def create_loader(
|
||||
dataset,
|
||||
input_size,
|
||||
|
@ -150,6 +136,7 @@ def create_loader(
|
|||
collate_fn=None,
|
||||
fp16=False,
|
||||
tf_preprocessing=False,
|
||||
separate_transforms=False,
|
||||
):
|
||||
dataset.transform = create_transform(
|
||||
input_size,
|
||||
|
@ -162,6 +149,7 @@ def create_loader(
|
|||
std=std,
|
||||
crop_pct=crop_pct,
|
||||
tf_preprocessing=tf_preprocessing,
|
||||
separate=separate_transforms,
|
||||
)
|
||||
|
||||
sampler = None
|
||||
|
|
|
@ -15,6 +15,15 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
|||
return lam*y1 + (1. - lam)*y2
|
||||
|
||||
|
||||
def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
|
||||
lam = 1.
|
||||
if not disable:
|
||||
lam = np.random.beta(alpha, alpha)
|
||||
input = input.mul(lam).add_(1 - lam, input.flip(0))
|
||||
target = mixup_target(target, num_classes, lam, smoothing)
|
||||
return input, target
|
||||
|
||||
|
||||
class FastCollateMixup:
|
||||
|
||||
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import torch
|
||||
from torchvision import transforms
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image
|
||||
import warnings
|
||||
|
@ -7,10 +6,6 @@ import math
|
|||
import random
|
||||
import numpy as np
|
||||
|
||||
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .random_erasing import RandomErasing
|
||||
from .auto_augment import auto_augment_transform, rand_augment_transform
|
||||
|
||||
|
||||
class ToNumpy:
|
||||
|
||||
|
@ -161,97 +156,3 @@ class RandomResizedCropAndInterpolation:
|
|||
return format_string
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.08, 1.0),
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='random',
|
||||
random_erasing=0.4,
|
||||
random_erasing_mode='const',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD
|
||||
):
|
||||
tfl = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size, scale=scale, interpolation=interpolation),
|
||||
transforms.RandomHorizontalFlip()
|
||||
]
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
if isinstance(img_size, tuple):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
img_size_min = img_size
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
||||
)
|
||||
if interpolation and interpolation != 'random':
|
||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
||||
if auto_augment.startswith('rand'):
|
||||
tfl += [rand_augment_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
tfl += [auto_augment_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
# color jitter is enabled when not using AA
|
||||
if isinstance(color_jitter, (list, tuple)):
|
||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||
# or 4 if also augmenting hue
|
||||
assert len(color_jitter) in (3, 4)
|
||||
else:
|
||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||
color_jitter = (float(color_jitter),) * 3
|
||||
tfl += [transforms.ColorJitter(*color_jitter)]
|
||||
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
if random_erasing > 0.:
|
||||
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
def transforms_imagenet_eval(
|
||||
img_size=224,
|
||||
crop_pct=None,
|
||||
interpolation='bilinear',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD):
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
|
||||
if isinstance(img_size, tuple):
|
||||
assert len(img_size) == 2
|
||||
if img_size[-1] == img_size[-2]:
|
||||
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
||||
scale_size = int(math.floor(img_size[0] / crop_pct))
|
||||
else:
|
||||
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
||||
else:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
|
||||
tfl = [
|
||||
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
||||
transforms.CenterCrop(img_size),
|
||||
]
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
|
||||
from timm.data.random_erasing import RandomErasing
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.08, 1.0),
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='random',
|
||||
random_erasing=0.4,
|
||||
random_erasing_mode='const',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
separate=False,
|
||||
):
|
||||
|
||||
primary_tfl = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size, scale=scale, interpolation=interpolation),
|
||||
transforms.RandomHorizontalFlip()
|
||||
]
|
||||
|
||||
secondary_tfl = []
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
if isinstance(img_size, tuple):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
img_size_min = img_size
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
||||
)
|
||||
if interpolation and interpolation != 'random':
|
||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
||||
if auto_augment.startswith('rand'):
|
||||
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
|
||||
elif auto_augment.startswith('augmix'):
|
||||
aa_params['translate_pct'] = 0.3
|
||||
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
|
||||
elif color_jitter is not None:
|
||||
# color jitter is enabled when not using AA
|
||||
if isinstance(color_jitter, (list, tuple)):
|
||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||
# or 4 if also augmenting hue
|
||||
assert len(color_jitter) in (3, 4)
|
||||
else:
|
||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||
color_jitter = (float(color_jitter),) * 3
|
||||
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
||||
|
||||
final_tfl = []
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
final_tfl += [ToNumpy()]
|
||||
else:
|
||||
final_tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
if random_erasing > 0.:
|
||||
final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
||||
|
||||
if separate:
|
||||
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||
else:
|
||||
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
|
||||
|
||||
|
||||
def transforms_imagenet_eval(
|
||||
img_size=224,
|
||||
crop_pct=None,
|
||||
interpolation='bilinear',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD):
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
|
||||
if isinstance(img_size, tuple):
|
||||
assert len(img_size) == 2
|
||||
if img_size[-1] == img_size[-2]:
|
||||
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
||||
scale_size = int(math.floor(img_size[0] / crop_pct))
|
||||
else:
|
||||
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
||||
else:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
|
||||
tfl = [
|
||||
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
||||
transforms.CenterCrop(img_size),
|
||||
]
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
def create_transform(
|
||||
input_size,
|
||||
is_training=False,
|
||||
use_prefetcher=False,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
crop_pct=None,
|
||||
tf_preprocessing=False,
|
||||
separate=False):
|
||||
|
||||
if isinstance(input_size, tuple):
|
||||
img_size = input_size[-2:]
|
||||
else:
|
||||
img_size = input_size
|
||||
|
||||
if tf_preprocessing and use_prefetcher:
|
||||
assert not separate, "Separate transforms not supported for TF preprocessing"
|
||||
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||||
transform = TfPreprocessTransform(
|
||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
||||
else:
|
||||
if is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
img_size,
|
||||
color_jitter=color_jitter,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
separate=separate)
|
||||
else:
|
||||
assert not separate, "Separate transforms not supported for validation preprocessing"
|
||||
transform = transforms_imagenet_eval(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
crop_pct=crop_pct)
|
||||
|
||||
return transform
|
|
@ -1 +1,2 @@
|
|||
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from .jsd import JsdCrossEntropy
|
|
@ -0,0 +1,34 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .cross_entropy import LabelSmoothingCrossEntropy
|
||||
|
||||
|
||||
class JsdCrossEntropy(nn.Module):
|
||||
""" Jenson-Shannon Divergence + Cross-Entropy Loss
|
||||
|
||||
"""
|
||||
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
||||
super().__init__()
|
||||
self.num_splits = num_splits
|
||||
self.alpha = alpha
|
||||
if smoothing is not None and smoothing > 0:
|
||||
self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
|
||||
else:
|
||||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def __call__(self, output, target):
|
||||
split_size = output.shape[0] // self.num_splits
|
||||
assert split_size * self.num_splits == output.shape[0]
|
||||
logits_split = torch.split(output, split_size)
|
||||
|
||||
# Cross-entropy is only computed on clean images
|
||||
loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
|
||||
probs = [F.softmax(logits, dim=1) for logits in logits_split]
|
||||
|
||||
# Clamp mixture distribution to avoid exploding KL divergence
|
||||
logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
|
||||
loss += self.alpha * sum([F.kl_div(
|
||||
logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
|
||||
return loss
|
33
train.py
33
train.py
|
@ -1,7 +1,6 @@
|
|||
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
|
||||
|
@ -14,13 +13,16 @@ except ImportError:
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
has_apex = False
|
||||
|
||||
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
||||
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch
|
||||
from timm.models import create_model, resume_checkpoint
|
||||
from timm.utils import *
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||
from timm.optim import create_optimizer
|
||||
from timm.scheduler import create_scheduler
|
||||
|
||||
#FIXME
|
||||
from timm.data.dataset import AugMixDataset
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.utils
|
||||
|
@ -160,6 +162,10 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|||
parser.add_argument("--local_rank", default=0, type=int)
|
||||
|
||||
|
||||
parser.add_argument('--jsd', action='store_true', default=False,
|
||||
help='')
|
||||
|
||||
|
||||
def _parse_args():
|
||||
# Do we have a config file to parse?
|
||||
args_config, remaining = config_parser.parse_known_args()
|
||||
|
@ -311,8 +317,14 @@ def main():
|
|||
|
||||
collate_fn = None
|
||||
if args.prefetcher and args.mixup > 0:
|
||||
assert not args.jsd
|
||||
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
|
||||
|
||||
separate_transforms = False
|
||||
if args.jsd:
|
||||
dataset_train = AugMixDataset(dataset_train)
|
||||
separate_transforms = True
|
||||
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
input_size=data_config['input_size'],
|
||||
|
@ -330,6 +342,7 @@ def main():
|
|||
num_workers=args.workers,
|
||||
distributed=args.distributed,
|
||||
collate_fn=collate_fn,
|
||||
separate_transforms=separate_transforms,
|
||||
)
|
||||
|
||||
eval_dir = os.path.join(args.data, 'val')
|
||||
|
@ -354,7 +367,10 @@ def main():
|
|||
crop_pct=data_config['crop_pct'],
|
||||
)
|
||||
|
||||
if args.mixup > 0.:
|
||||
if args.jsd:
|
||||
train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
elif args.mixup > 0.:
|
||||
# smoothing is handled with mixup label transform
|
||||
train_loss_fn = SoftTargetCrossEntropy().cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
|
@ -452,11 +468,10 @@ def train_epoch(
|
|||
if not args.prefetcher:
|
||||
input, target = input.cuda(), target.cuda()
|
||||
if args.mixup > 0.:
|
||||
lam = 1.
|
||||
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
|
||||
lam = np.random.beta(args.mixup, args.mixup)
|
||||
input = input.mul(lam).add_(1 - lam, input.flip(0))
|
||||
target = mixup_target(target, args.num_classes, lam, args.smoothing)
|
||||
input, target = mixup_batch(
|
||||
input, target,
|
||||
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
|
||||
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
|
||||
|
||||
output = model(input)
|
||||
|
||||
|
|
Loading…
Reference in New Issue