mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add SplitBatchNorm. AugMix, Rand/AutoAugment, Split (Aux) BatchNorm, Jensen-Shannon Divergence, RandomErasing all working together
This commit is contained in:
parent
2e955cfd0c
commit
7547119891
@ -1,6 +1,6 @@
|
|||||||
from .constants import *
|
from .constants import *
|
||||||
from .config import resolve_data_config
|
from .config import resolve_data_config
|
||||||
from .dataset import Dataset, DatasetTar
|
from .dataset import Dataset, DatasetTar, AugMixDataset
|
||||||
from .transforms import *
|
from .transforms import *
|
||||||
from .loader import create_loader
|
from .loader import create_loader
|
||||||
from .transforms_factory import create_transform
|
from .transforms_factory import create_transform
|
||||||
|
@ -323,7 +323,7 @@ class AugmentOp:
|
|||||||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
if not self.prob >= 1.0 or random.random() > self.prob:
|
if self.prob < 1.0 and random.random() > self.prob:
|
||||||
return img
|
return img
|
||||||
magnitude = self.magnitude
|
magnitude = self.magnitude
|
||||||
if self.magnitude_std and self.magnitude_std > 0:
|
if self.magnitude_std and self.magnitude_std > 0:
|
||||||
@ -539,7 +539,7 @@ _RAND_TRANSFORMS = [
|
|||||||
'ShearY',
|
'ShearY',
|
||||||
'TranslateXRel',
|
'TranslateXRel',
|
||||||
'TranslateYRel',
|
'TranslateYRel',
|
||||||
#'Cutout' # FIXME I implement this as random erasing separately
|
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -559,7 +559,7 @@ _RAND_INCREASING_TRANSFORMS = [
|
|||||||
'ShearY',
|
'ShearY',
|
||||||
'TranslateXRel',
|
'TranslateXRel',
|
||||||
'TranslateYRel',
|
'TranslateYRel',
|
||||||
#'Cutout' # FIXME I implement this as random erasing separately
|
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -627,6 +627,7 @@ def rand_augment_transform(config_str, hparams):
|
|||||||
'n' - integer num layers (number of transform ops selected per image)
|
'n' - integer num layers (number of transform ops selected per image)
|
||||||
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
||||||
'mstd' - float std deviation of magnitude noise applied
|
'mstd' - float std deviation of magnitude noise applied
|
||||||
|
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||||
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
||||||
|
|
||||||
@ -637,6 +638,7 @@ def rand_augment_transform(config_str, hparams):
|
|||||||
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
||||||
num_layers = 2 # default to 2 ops per image
|
num_layers = 2 # default to 2 ops per image
|
||||||
weight_idx = None # default to no probability weights for op choice
|
weight_idx = None # default to no probability weights for op choice
|
||||||
|
transforms = _RAND_TRANSFORMS
|
||||||
config = config_str.split('-')
|
config = config_str.split('-')
|
||||||
assert config[0] == 'rand'
|
assert config[0] == 'rand'
|
||||||
config = config[1:]
|
config = config[1:]
|
||||||
@ -648,6 +650,9 @@ def rand_augment_transform(config_str, hparams):
|
|||||||
if key == 'mstd':
|
if key == 'mstd':
|
||||||
# noise param injected via hparams for now
|
# noise param injected via hparams for now
|
||||||
hparams.setdefault('magnitude_std', float(val))
|
hparams.setdefault('magnitude_std', float(val))
|
||||||
|
elif key == 'inc':
|
||||||
|
if bool(val):
|
||||||
|
transforms = _RAND_INCREASING_TRANSFORMS
|
||||||
elif key == 'm':
|
elif key == 'm':
|
||||||
magnitude = int(val)
|
magnitude = int(val)
|
||||||
elif key == 'n':
|
elif key == 'n':
|
||||||
@ -656,7 +661,7 @@ def rand_augment_transform(config_str, hparams):
|
|||||||
weight_idx = int(val)
|
weight_idx = int(val)
|
||||||
else:
|
else:
|
||||||
assert False, 'Unknown RandAugment config section'
|
assert False, 'Unknown RandAugment config section'
|
||||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
|
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
|
||||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||||
|
|
||||||
@ -686,12 +691,12 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
|||||||
|
|
||||||
|
|
||||||
class AugMixAugment:
|
class AugMixAugment:
|
||||||
def __init__(self, ops, alpha=1., width=3, depth=-1):
|
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
||||||
self.ops = ops
|
self.ops = ops
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.width = width
|
self.width = width
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
self.blended = False
|
self.blended = blended
|
||||||
|
|
||||||
def _calc_blended_weights(self, ws, m):
|
def _calc_blended_weights(self, ws, m):
|
||||||
ws = ws * m
|
ws = ws * m
|
||||||
@ -707,7 +712,7 @@ class AugMixAugment:
|
|||||||
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
|
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
|
||||||
# of accumulating the mix for each chain in a Numpy array and then blending with original,
|
# of accumulating the mix for each chain in a Numpy array and then blending with original,
|
||||||
# it recomputes the blending coefficients and applies one PIL image blend per chain.
|
# it recomputes the blending coefficients and applies one PIL image blend per chain.
|
||||||
# TODO I've verified the results are in the right ballpark but they differ by more than rounding.
|
# TODO the results appear in the right ballpark but they differ by more than rounding.
|
||||||
img_orig = img.copy()
|
img_orig = img.copy()
|
||||||
ws = self._calc_blended_weights(mixing_weights, m)
|
ws = self._calc_blended_weights(mixing_weights, m)
|
||||||
for w in ws:
|
for w in ws:
|
||||||
@ -755,6 +760,7 @@ def augment_and_mix_transform(config_str, hparams):
|
|||||||
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
||||||
'w' - integer width of augmentation chain (default: 3)
|
'w' - integer width of augmentation chain (default: 3)
|
||||||
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
||||||
|
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
||||||
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
||||||
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
||||||
|
|
||||||
@ -766,6 +772,7 @@ def augment_and_mix_transform(config_str, hparams):
|
|||||||
width = 3
|
width = 3
|
||||||
depth = -1
|
depth = -1
|
||||||
alpha = 1.
|
alpha = 1.
|
||||||
|
blended = False
|
||||||
config = config_str.split('-')
|
config = config_str.split('-')
|
||||||
assert config[0] == 'augmix'
|
assert config[0] == 'augmix'
|
||||||
config = config[1:]
|
config = config[1:]
|
||||||
@ -785,7 +792,9 @@ def augment_and_mix_transform(config_str, hparams):
|
|||||||
depth = int(val)
|
depth = int(val)
|
||||||
elif key == 'a':
|
elif key == 'a':
|
||||||
alpha = float(val)
|
alpha = float(val)
|
||||||
|
elif key == 'b':
|
||||||
|
blended = bool(val)
|
||||||
else:
|
else:
|
||||||
assert False, 'Unknown AugMix config section'
|
assert False, 'Unknown AugMix config section'
|
||||||
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
|
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
|
||||||
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth)
|
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
|
||||||
|
@ -144,13 +144,13 @@ class DatasetTar(data.Dataset):
|
|||||||
class AugMixDataset(torch.utils.data.Dataset):
|
class AugMixDataset(torch.utils.data.Dataset):
|
||||||
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
||||||
|
|
||||||
def __init__(self, dataset, num_aug=2):
|
def __init__(self, dataset, num_splits=2):
|
||||||
self.augmentation = None
|
self.augmentation = None
|
||||||
self.normalize = None
|
self.normalize = None
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
if self.dataset.transform is not None:
|
if self.dataset.transform is not None:
|
||||||
self._set_transforms(self.dataset.transform)
|
self._set_transforms(self.dataset.transform)
|
||||||
self.num_aug = num_aug
|
self.num_splits = num_splits
|
||||||
|
|
||||||
def _set_transforms(self, x):
|
def _set_transforms(self, x):
|
||||||
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
||||||
@ -170,9 +170,10 @@ class AugMixDataset(torch.utils.data.Dataset):
|
|||||||
return x if self.normalize is None else self.normalize(x)
|
return x if self.normalize is None else self.normalize(x)
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
x, y = self.dataset[i]
|
x, y = self.dataset[i] # all splits share the same dataset base transform
|
||||||
x_list = [self._normalize(x)]
|
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
|
||||||
for n in range(self.num_aug):
|
# run the full augmentation on the remaining splits
|
||||||
|
for _ in range(self.num_splits - 1):
|
||||||
x_list.append(self._normalize(self.augmentation(x)))
|
x_list.append(self._normalize(self.augmentation(x)))
|
||||||
return tuple(x_list), y
|
return tuple(x_list), y
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ def fast_collate(batch):
|
|||||||
if isinstance(batch[0][0], tuple):
|
if isinstance(batch[0][0], tuple):
|
||||||
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
||||||
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
||||||
inner_tuple_size = len(batch[0][0][0])
|
inner_tuple_size = len(batch[0][0])
|
||||||
flattened_batch_size = batch_size * inner_tuple_size
|
flattened_batch_size = batch_size * inner_tuple_size
|
||||||
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
||||||
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
||||||
@ -47,12 +47,13 @@ class PrefetchLoader:
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
loader,
|
loader,
|
||||||
rand_erase_prob=0.,
|
|
||||||
rand_erase_mode='const',
|
|
||||||
rand_erase_count=1,
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std=IMAGENET_DEFAULT_STD,
|
||||||
fp16=False):
|
fp16=False,
|
||||||
|
re_prob=0.,
|
||||||
|
re_mode='const',
|
||||||
|
re_count=1,
|
||||||
|
re_num_splits=0):
|
||||||
self.loader = loader
|
self.loader = loader
|
||||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||||
@ -60,9 +61,9 @@ class PrefetchLoader:
|
|||||||
if fp16:
|
if fp16:
|
||||||
self.mean = self.mean.half()
|
self.mean = self.mean.half()
|
||||||
self.std = self.std.half()
|
self.std = self.std.half()
|
||||||
if rand_erase_prob > 0.:
|
if re_prob > 0.:
|
||||||
self.random_erasing = RandomErasing(
|
self.random_erasing = RandomErasing(
|
||||||
probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
|
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
||||||
else:
|
else:
|
||||||
self.random_erasing = None
|
self.random_erasing = None
|
||||||
|
|
||||||
@ -122,11 +123,13 @@ def create_loader(
|
|||||||
batch_size,
|
batch_size,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
use_prefetcher=True,
|
use_prefetcher=True,
|
||||||
rand_erase_prob=0.,
|
re_prob=0.,
|
||||||
rand_erase_mode='const',
|
re_mode='const',
|
||||||
rand_erase_count=1,
|
re_count=1,
|
||||||
|
re_split=False,
|
||||||
color_jitter=0.4,
|
color_jitter=0.4,
|
||||||
auto_augment=None,
|
auto_augment=None,
|
||||||
|
num_aug_splits=0,
|
||||||
interpolation='bilinear',
|
interpolation='bilinear',
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std=IMAGENET_DEFAULT_STD,
|
||||||
@ -136,8 +139,11 @@ def create_loader(
|
|||||||
collate_fn=None,
|
collate_fn=None,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
tf_preprocessing=False,
|
tf_preprocessing=False,
|
||||||
separate_transforms=False,
|
|
||||||
):
|
):
|
||||||
|
re_num_splits = 0
|
||||||
|
if re_split:
|
||||||
|
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
||||||
|
re_num_splits = num_aug_splits or 2
|
||||||
dataset.transform = create_transform(
|
dataset.transform = create_transform(
|
||||||
input_size,
|
input_size,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
@ -149,7 +155,11 @@ def create_loader(
|
|||||||
std=std,
|
std=std,
|
||||||
crop_pct=crop_pct,
|
crop_pct=crop_pct,
|
||||||
tf_preprocessing=tf_preprocessing,
|
tf_preprocessing=tf_preprocessing,
|
||||||
separate=separate_transforms,
|
re_prob=re_prob,
|
||||||
|
re_mode=re_mode,
|
||||||
|
re_count=re_count,
|
||||||
|
re_num_splits=re_num_splits,
|
||||||
|
separate=num_aug_splits > 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler = None
|
sampler = None
|
||||||
@ -176,11 +186,13 @@ def create_loader(
|
|||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
loader = PrefetchLoader(
|
loader = PrefetchLoader(
|
||||||
loader,
|
loader,
|
||||||
rand_erase_prob=rand_erase_prob if is_training else 0.,
|
|
||||||
rand_erase_mode=rand_erase_mode,
|
|
||||||
rand_erase_count=rand_erase_count,
|
|
||||||
mean=mean,
|
mean=mean,
|
||||||
std=std,
|
std=std,
|
||||||
fp16=fp16)
|
fp16=fp16,
|
||||||
|
re_prob=re_prob if is_training else 0.,
|
||||||
|
re_mode=re_mode,
|
||||||
|
re_count=re_count,
|
||||||
|
re_num_splits=re_num_splits
|
||||||
|
)
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
@ -38,7 +38,7 @@ class RandomErasing:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
|
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
|
||||||
mode='const', min_count=1, max_count=None, device='cuda'):
|
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
|
||||||
self.probability = probability
|
self.probability = probability
|
||||||
self.min_area = min_area
|
self.min_area = min_area
|
||||||
self.max_area = max_area
|
self.max_area = max_area
|
||||||
@ -46,6 +46,7 @@ class RandomErasing:
|
|||||||
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||||
self.min_count = min_count
|
self.min_count = min_count
|
||||||
self.max_count = max_count or min_count
|
self.max_count = max_count or min_count
|
||||||
|
self.num_splits = num_splits
|
||||||
mode = mode.lower()
|
mode = mode.lower()
|
||||||
self.rand_color = False
|
self.rand_color = False
|
||||||
self.per_pixel = False
|
self.per_pixel = False
|
||||||
@ -82,6 +83,8 @@ class RandomErasing:
|
|||||||
self._erase(input, *input.size(), input.dtype)
|
self._erase(input, *input.size(), input.dtype)
|
||||||
else:
|
else:
|
||||||
batch_size, chan, img_h, img_w = input.size()
|
batch_size, chan, img_h, img_w = input.size()
|
||||||
for i in range(batch_size):
|
# skip first slice of batch if num_splits is set (for clean portion of samples)
|
||||||
|
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
|
||||||
|
for i in range(batch_start, batch_size):
|
||||||
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
||||||
return input
|
return input
|
||||||
|
@ -15,11 +15,13 @@ def transforms_imagenet_train(
|
|||||||
color_jitter=0.4,
|
color_jitter=0.4,
|
||||||
auto_augment=None,
|
auto_augment=None,
|
||||||
interpolation='random',
|
interpolation='random',
|
||||||
random_erasing=0.4,
|
|
||||||
random_erasing_mode='const',
|
|
||||||
use_prefetcher=False,
|
use_prefetcher=False,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std=IMAGENET_DEFAULT_STD,
|
||||||
|
re_prob=0.,
|
||||||
|
re_mode='const',
|
||||||
|
re_count=1,
|
||||||
|
re_num_splits=0,
|
||||||
separate=False,
|
separate=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -71,8 +73,9 @@ def transforms_imagenet_train(
|
|||||||
mean=torch.tensor(mean),
|
mean=torch.tensor(mean),
|
||||||
std=torch.tensor(std))
|
std=torch.tensor(std))
|
||||||
]
|
]
|
||||||
if random_erasing > 0.:
|
if re_prob > 0.:
|
||||||
final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
final_tfl.append(
|
||||||
|
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
|
||||||
|
|
||||||
if separate:
|
if separate:
|
||||||
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||||
@ -126,6 +129,10 @@ def create_transform(
|
|||||||
interpolation='bilinear',
|
interpolation='bilinear',
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std=IMAGENET_DEFAULT_STD,
|
||||||
|
re_prob=0.,
|
||||||
|
re_mode='const',
|
||||||
|
re_count=1,
|
||||||
|
re_num_splits=0,
|
||||||
crop_pct=None,
|
crop_pct=None,
|
||||||
tf_preprocessing=False,
|
tf_preprocessing=False,
|
||||||
separate=False):
|
separate=False):
|
||||||
@ -150,6 +157,10 @@ def create_transform(
|
|||||||
use_prefetcher=use_prefetcher,
|
use_prefetcher=use_prefetcher,
|
||||||
mean=mean,
|
mean=mean,
|
||||||
std=std,
|
std=std,
|
||||||
|
re_prob=re_prob,
|
||||||
|
re_mode=re_mode,
|
||||||
|
re_count=re_count,
|
||||||
|
re_num_splits=re_num_splits,
|
||||||
separate=separate)
|
separate=separate)
|
||||||
else:
|
else:
|
||||||
assert not separate, "Separate transforms not supported for validation preprocessing"
|
assert not separate, "Separate transforms not supported for validation preprocessing"
|
||||||
|
@ -6,7 +6,7 @@ from .cross_entropy import LabelSmoothingCrossEntropy
|
|||||||
|
|
||||||
|
|
||||||
class JsdCrossEntropy(nn.Module):
|
class JsdCrossEntropy(nn.Module):
|
||||||
""" Jenson-Shannon Divergence + Cross-Entropy Loss
|
""" Jensen-Shannon Divergence + Cross-Entropy Loss
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
||||||
|
@ -20,3 +20,4 @@ from .registry import *
|
|||||||
from .factory import create_model
|
from .factory import create_model
|
||||||
from .helpers import load_checkpoint, resume_checkpoint
|
from .helpers import load_checkpoint, resume_checkpoint
|
||||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
|
from .split_batchnorm import convert_splitbn_model
|
||||||
|
64
timm/models/split_batchnorm.py
Normal file
64
timm/models/split_batchnorm.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
|
||||||
|
|
||||||
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
||||||
|
track_running_stats=True, num_splits=1):
|
||||||
|
super().__init__(num_features, eps, momentum, affine, track_running_stats)
|
||||||
|
assert num_splits >= 2, 'Should have at least one aux BN layer (num_splits at least 2)'
|
||||||
|
self.num_splits = num_splits
|
||||||
|
self.aux_bn = nn.ModuleList([
|
||||||
|
nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
if self.training: # aux BN only relevant while training
|
||||||
|
split_size = input.shape[0] // self.num_splits
|
||||||
|
assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
|
||||||
|
split_input = input.split(split_size)
|
||||||
|
x = [super().forward(split_input[0])]
|
||||||
|
for i, a in enumerate(self.aux_bn):
|
||||||
|
x.append(a(split_input[i + 1]))
|
||||||
|
return torch.cat(x, dim=0)
|
||||||
|
else:
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_splitbn_model(module, num_splits=2):
|
||||||
|
"""
|
||||||
|
Recursively traverse module and its children to replace all instances of
|
||||||
|
``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): input module
|
||||||
|
num_splits: number of separate batchnorm layers to split input across
|
||||||
|
Example::
|
||||||
|
>>> # model is an instance of torch.nn.Module
|
||||||
|
>>> import apex
|
||||||
|
>>> sync_bn_model = timm.models.convert_splitbn_model(model, num_splits=2)
|
||||||
|
"""
|
||||||
|
mod = module
|
||||||
|
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
|
||||||
|
return module
|
||||||
|
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
||||||
|
mod = SplitBatchNorm2d(
|
||||||
|
module.num_features, module.eps, module.momentum, module.affine,
|
||||||
|
module.track_running_stats, num_splits=num_splits)
|
||||||
|
mod.running_mean = module.running_mean
|
||||||
|
mod.running_var = module.running_var
|
||||||
|
mod.num_batches_tracked = module.num_batches_tracked
|
||||||
|
if module.affine:
|
||||||
|
mod.weight.data = module.weight.data.clone().detach()
|
||||||
|
mod.bias.data = module.bias.data.clone().detach()
|
||||||
|
for aux in mod.aux_bn:
|
||||||
|
aux.running_mean = module.running_mean.clone()
|
||||||
|
aux.running_var = module.running_var.clone()
|
||||||
|
aux.num_batches_tracked = module.num_batches_tracked.clone()
|
||||||
|
if module.affine:
|
||||||
|
aux.weight.data = module.weight.data.clone().detach()
|
||||||
|
aux.bias.data = module.bias.data.clone().detach()
|
||||||
|
for name, child in module.named_children():
|
||||||
|
mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
|
||||||
|
del module
|
||||||
|
return mod
|
52
train.py
52
train.py
@ -13,16 +13,13 @@ except ImportError:
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch
|
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset
|
||||||
from timm.models import create_model, resume_checkpoint
|
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
||||||
from timm.utils import *
|
from timm.utils import *
|
||||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||||
from timm.optim import create_optimizer
|
from timm.optim import create_optimizer
|
||||||
from timm.scheduler import create_scheduler
|
from timm.scheduler import create_scheduler
|
||||||
|
|
||||||
#FIXME
|
|
||||||
from timm.data.dataset import AugMixDataset
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
@ -71,6 +68,8 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
|||||||
help='Dropout rate (default: 0.)')
|
help='Dropout rate (default: 0.)')
|
||||||
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
|
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
|
||||||
help='Drop connect rate (default: 0.)')
|
help='Drop connect rate (default: 0.)')
|
||||||
|
parser.add_argument('--jsd', action='store_true', default=False,
|
||||||
|
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
||||||
# Optimizer parameters
|
# Optimizer parameters
|
||||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||||
help='Optimizer (default: "sgd"')
|
help='Optimizer (default: "sgd"')
|
||||||
@ -106,18 +105,24 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
|||||||
help='Color jitter factor (default: 0.4)')
|
help='Color jitter factor (default: 0.4)')
|
||||||
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
||||||
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
||||||
|
parser.add_argument('--aug-splits', type=int, default=0,
|
||||||
|
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
|
||||||
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||||
help='Random erase prob (default: 0.)')
|
help='Random erase prob (default: 0.)')
|
||||||
parser.add_argument('--remode', type=str, default='const',
|
parser.add_argument('--remode', type=str, default='const',
|
||||||
help='Random erase mode (default: "const")')
|
help='Random erase mode (default: "const")')
|
||||||
parser.add_argument('--recount', type=int, default=1,
|
parser.add_argument('--recount', type=int, default=1,
|
||||||
help='Random erase count (default: 1)')
|
help='Random erase count (default: 1)')
|
||||||
|
parser.add_argument('--resplit', action='store_true', default=False,
|
||||||
|
help='Do not random erase first (clean) augmentation split')
|
||||||
parser.add_argument('--mixup', type=float, default=0.0,
|
parser.add_argument('--mixup', type=float, default=0.0,
|
||||||
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
|
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
|
||||||
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
||||||
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
|
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
|
||||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||||
help='label smoothing (default: 0.1)')
|
help='label smoothing (default: 0.1)')
|
||||||
|
parser.add_argument('--train-interpolation', type=str, default='random',
|
||||||
|
help='Training interpolation (random, bilinear, bicubic default: "random")')
|
||||||
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
||||||
parser.add_argument('--bn-tf', action='store_true', default=False,
|
parser.add_argument('--bn-tf', action='store_true', default=False,
|
||||||
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
||||||
@ -129,6 +134,8 @@ parser.add_argument('--sync-bn', action='store_true',
|
|||||||
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
|
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
|
||||||
parser.add_argument('--dist-bn', type=str, default='',
|
parser.add_argument('--dist-bn', type=str, default='',
|
||||||
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
|
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
|
||||||
|
parser.add_argument('--split-bn', action='store_true',
|
||||||
|
help='Enable separate BN layers per augmentation split.')
|
||||||
# Model Exponential Moving Average
|
# Model Exponential Moving Average
|
||||||
parser.add_argument('--model-ema', action='store_true', default=False,
|
parser.add_argument('--model-ema', action='store_true', default=False,
|
||||||
help='Enable tracking moving average of model weights')
|
help='Enable tracking moving average of model weights')
|
||||||
@ -162,10 +169,6 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|||||||
parser.add_argument("--local_rank", default=0, type=int)
|
parser.add_argument("--local_rank", default=0, type=int)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument('--jsd', action='store_true', default=False,
|
|
||||||
help='')
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
# Do we have a config file to parse?
|
# Do we have a config file to parse?
|
||||||
args_config, remaining = config_parser.parse_known_args()
|
args_config, remaining = config_parser.parse_known_args()
|
||||||
@ -233,6 +236,14 @@ def main():
|
|||||||
|
|
||||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||||
|
|
||||||
|
num_aug_splits = 0
|
||||||
|
if args.aug_splits:
|
||||||
|
num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense
|
||||||
|
|
||||||
|
if args.split_bn:
|
||||||
|
assert num_aug_splits > 1 or args.resplit
|
||||||
|
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||||
|
|
||||||
if args.num_gpu > 1:
|
if args.num_gpu > 1:
|
||||||
if args.amp:
|
if args.amp:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@ -279,6 +290,7 @@ def main():
|
|||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
if args.sync_bn:
|
if args.sync_bn:
|
||||||
|
assert not args.split_bn
|
||||||
try:
|
try:
|
||||||
if has_apex:
|
if has_apex:
|
||||||
model = convert_syncbn_model(model)
|
model = convert_syncbn_model(model)
|
||||||
@ -317,13 +329,11 @@ def main():
|
|||||||
|
|
||||||
collate_fn = None
|
collate_fn = None
|
||||||
if args.prefetcher and args.mixup > 0:
|
if args.prefetcher and args.mixup > 0:
|
||||||
assert not args.jsd
|
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
|
||||||
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
|
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
|
||||||
|
|
||||||
separate_transforms = False
|
if num_aug_splits > 1:
|
||||||
if args.jsd:
|
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
||||||
dataset_train = AugMixDataset(dataset_train)
|
|
||||||
separate_transforms = True
|
|
||||||
|
|
||||||
loader_train = create_loader(
|
loader_train = create_loader(
|
||||||
dataset_train,
|
dataset_train,
|
||||||
@ -331,18 +341,19 @@ def main():
|
|||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_prefetcher=args.prefetcher,
|
use_prefetcher=args.prefetcher,
|
||||||
rand_erase_prob=args.reprob,
|
re_prob=args.reprob,
|
||||||
rand_erase_mode=args.remode,
|
re_mode=args.remode,
|
||||||
rand_erase_count=args.recount,
|
re_count=args.recount,
|
||||||
|
re_split=args.resplit,
|
||||||
color_jitter=args.color_jitter,
|
color_jitter=args.color_jitter,
|
||||||
auto_augment=args.aa,
|
auto_augment=args.aa,
|
||||||
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
num_aug_splits=num_aug_splits,
|
||||||
|
interpolation=args.train_interpolation,
|
||||||
mean=data_config['mean'],
|
mean=data_config['mean'],
|
||||||
std=data_config['std'],
|
std=data_config['std'],
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
separate_transforms=separate_transforms,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_dir = os.path.join(args.data, 'val')
|
eval_dir = os.path.join(args.data, 'val')
|
||||||
@ -368,7 +379,8 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.jsd:
|
if args.jsd:
|
||||||
train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda()
|
assert num_aug_splits > 1 # JSD only valid with aug splits set
|
||||||
|
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
|
||||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||||
elif args.mixup > 0.:
|
elif args.mixup > 0.:
|
||||||
# smoothing is handled with mixup label transform
|
# smoothing is handled with mixup label transform
|
||||||
|
Loading…
x
Reference in New Issue
Block a user