mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add SplitBatchNorm. AugMix, Rand/AutoAugment, Split (Aux) BatchNorm, Jensen-Shannon Divergence, RandomErasing all working together
This commit is contained in:
parent
2e955cfd0c
commit
7547119891
@ -1,6 +1,6 @@
|
||||
from .constants import *
|
||||
from .config import resolve_data_config
|
||||
from .dataset import Dataset, DatasetTar
|
||||
from .dataset import Dataset, DatasetTar, AugMixDataset
|
||||
from .transforms import *
|
||||
from .loader import create_loader
|
||||
from .transforms_factory import create_transform
|
||||
|
@ -323,7 +323,7 @@ class AugmentOp:
|
||||
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
||||
|
||||
def __call__(self, img):
|
||||
if not self.prob >= 1.0 or random.random() > self.prob:
|
||||
if self.prob < 1.0 and random.random() > self.prob:
|
||||
return img
|
||||
magnitude = self.magnitude
|
||||
if self.magnitude_std and self.magnitude_std > 0:
|
||||
@ -539,7 +539,7 @@ _RAND_TRANSFORMS = [
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # FIXME I implement this as random erasing separately
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
|
||||
@ -559,7 +559,7 @@ _RAND_INCREASING_TRANSFORMS = [
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # FIXME I implement this as random erasing separately
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
|
||||
@ -627,6 +627,7 @@ def rand_augment_transform(config_str, hparams):
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
||||
|
||||
@ -637,6 +638,7 @@ def rand_augment_transform(config_str, hparams):
|
||||
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
||||
num_layers = 2 # default to 2 ops per image
|
||||
weight_idx = None # default to no probability weights for op choice
|
||||
transforms = _RAND_TRANSFORMS
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'rand'
|
||||
config = config[1:]
|
||||
@ -648,6 +650,9 @@ def rand_augment_transform(config_str, hparams):
|
||||
if key == 'mstd':
|
||||
# noise param injected via hparams for now
|
||||
hparams.setdefault('magnitude_std', float(val))
|
||||
elif key == 'inc':
|
||||
if bool(val):
|
||||
transforms = _RAND_INCREASING_TRANSFORMS
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
@ -656,7 +661,7 @@ def rand_augment_transform(config_str, hparams):
|
||||
weight_idx = int(val)
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
|
||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||
|
||||
@ -686,12 +691,12 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
||||
|
||||
|
||||
class AugMixAugment:
|
||||
def __init__(self, ops, alpha=1., width=3, depth=-1):
|
||||
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
||||
self.ops = ops
|
||||
self.alpha = alpha
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.blended = False
|
||||
self.blended = blended
|
||||
|
||||
def _calc_blended_weights(self, ws, m):
|
||||
ws = ws * m
|
||||
@ -707,7 +712,7 @@ class AugMixAugment:
|
||||
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
|
||||
# of accumulating the mix for each chain in a Numpy array and then blending with original,
|
||||
# it recomputes the blending coefficients and applies one PIL image blend per chain.
|
||||
# TODO I've verified the results are in the right ballpark but they differ by more than rounding.
|
||||
# TODO the results appear in the right ballpark but they differ by more than rounding.
|
||||
img_orig = img.copy()
|
||||
ws = self._calc_blended_weights(mixing_weights, m)
|
||||
for w in ws:
|
||||
@ -755,6 +760,7 @@ def augment_and_mix_transform(config_str, hparams):
|
||||
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
||||
'w' - integer width of augmentation chain (default: 3)
|
||||
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
||||
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
||||
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
||||
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
||||
|
||||
@ -766,6 +772,7 @@ def augment_and_mix_transform(config_str, hparams):
|
||||
width = 3
|
||||
depth = -1
|
||||
alpha = 1.
|
||||
blended = False
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'augmix'
|
||||
config = config[1:]
|
||||
@ -785,7 +792,9 @@ def augment_and_mix_transform(config_str, hparams):
|
||||
depth = int(val)
|
||||
elif key == 'a':
|
||||
alpha = float(val)
|
||||
elif key == 'b':
|
||||
blended = bool(val)
|
||||
else:
|
||||
assert False, 'Unknown AugMix config section'
|
||||
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
|
||||
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth)
|
||||
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
|
||||
|
@ -144,13 +144,13 @@ class DatasetTar(data.Dataset):
|
||||
class AugMixDataset(torch.utils.data.Dataset):
|
||||
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
||||
|
||||
def __init__(self, dataset, num_aug=2):
|
||||
def __init__(self, dataset, num_splits=2):
|
||||
self.augmentation = None
|
||||
self.normalize = None
|
||||
self.dataset = dataset
|
||||
if self.dataset.transform is not None:
|
||||
self._set_transforms(self.dataset.transform)
|
||||
self.num_aug = num_aug
|
||||
self.num_splits = num_splits
|
||||
|
||||
def _set_transforms(self, x):
|
||||
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
||||
@ -170,9 +170,10 @@ class AugMixDataset(torch.utils.data.Dataset):
|
||||
return x if self.normalize is None else self.normalize(x)
|
||||
|
||||
def __getitem__(self, i):
|
||||
x, y = self.dataset[i]
|
||||
x_list = [self._normalize(x)]
|
||||
for n in range(self.num_aug):
|
||||
x, y = self.dataset[i] # all splits share the same dataset base transform
|
||||
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
|
||||
# run the full augmentation on the remaining splits
|
||||
for _ in range(self.num_splits - 1):
|
||||
x_list.append(self._normalize(self.augmentation(x)))
|
||||
return tuple(x_list), y
|
||||
|
||||
|
@ -15,7 +15,7 @@ def fast_collate(batch):
|
||||
if isinstance(batch[0][0], tuple):
|
||||
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
||||
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
||||
inner_tuple_size = len(batch[0][0][0])
|
||||
inner_tuple_size = len(batch[0][0])
|
||||
flattened_batch_size = batch_size * inner_tuple_size
|
||||
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
||||
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
||||
@ -46,13 +46,14 @@ def fast_collate(batch):
|
||||
class PrefetchLoader:
|
||||
|
||||
def __init__(self,
|
||||
loader,
|
||||
rand_erase_prob=0.,
|
||||
rand_erase_mode='const',
|
||||
rand_erase_count=1,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
fp16=False):
|
||||
loader,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
fp16=False,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0):
|
||||
self.loader = loader
|
||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||
@ -60,9 +61,9 @@ class PrefetchLoader:
|
||||
if fp16:
|
||||
self.mean = self.mean.half()
|
||||
self.std = self.std.half()
|
||||
if rand_erase_prob > 0.:
|
||||
if re_prob > 0.:
|
||||
self.random_erasing = RandomErasing(
|
||||
probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
|
||||
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
|
||||
@ -122,11 +123,13 @@ def create_loader(
|
||||
batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
rand_erase_prob=0.,
|
||||
rand_erase_mode='const',
|
||||
rand_erase_count=1,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_split=False,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
num_aug_splits=0,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
@ -136,8 +139,11 @@ def create_loader(
|
||||
collate_fn=None,
|
||||
fp16=False,
|
||||
tf_preprocessing=False,
|
||||
separate_transforms=False,
|
||||
):
|
||||
re_num_splits = 0
|
||||
if re_split:
|
||||
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
||||
re_num_splits = num_aug_splits or 2
|
||||
dataset.transform = create_transform(
|
||||
input_size,
|
||||
is_training=is_training,
|
||||
@ -149,7 +155,11 @@ def create_loader(
|
||||
std=std,
|
||||
crop_pct=crop_pct,
|
||||
tf_preprocessing=tf_preprocessing,
|
||||
separate=separate_transforms,
|
||||
re_prob=re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits,
|
||||
separate=num_aug_splits > 0,
|
||||
)
|
||||
|
||||
sampler = None
|
||||
@ -176,11 +186,13 @@ def create_loader(
|
||||
if use_prefetcher:
|
||||
loader = PrefetchLoader(
|
||||
loader,
|
||||
rand_erase_prob=rand_erase_prob if is_training else 0.,
|
||||
rand_erase_mode=rand_erase_mode,
|
||||
rand_erase_count=rand_erase_count,
|
||||
mean=mean,
|
||||
std=std,
|
||||
fp16=fp16)
|
||||
fp16=fp16,
|
||||
re_prob=re_prob if is_training else 0.,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits
|
||||
)
|
||||
|
||||
return loader
|
||||
|
@ -38,7 +38,7 @@ class RandomErasing:
|
||||
def __init__(
|
||||
self,
|
||||
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
|
||||
mode='const', min_count=1, max_count=None, device='cuda'):
|
||||
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
|
||||
self.probability = probability
|
||||
self.min_area = min_area
|
||||
self.max_area = max_area
|
||||
@ -46,6 +46,7 @@ class RandomErasing:
|
||||
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||
self.min_count = min_count
|
||||
self.max_count = max_count or min_count
|
||||
self.num_splits = num_splits
|
||||
mode = mode.lower()
|
||||
self.rand_color = False
|
||||
self.per_pixel = False
|
||||
@ -82,6 +83,8 @@ class RandomErasing:
|
||||
self._erase(input, *input.size(), input.dtype)
|
||||
else:
|
||||
batch_size, chan, img_h, img_w = input.size()
|
||||
for i in range(batch_size):
|
||||
# skip first slice of batch if num_splits is set (for clean portion of samples)
|
||||
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
|
||||
for i in range(batch_start, batch_size):
|
||||
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
||||
return input
|
||||
|
@ -15,11 +15,13 @@ def transforms_imagenet_train(
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='random',
|
||||
random_erasing=0.4,
|
||||
random_erasing_mode='const',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0,
|
||||
separate=False,
|
||||
):
|
||||
|
||||
@ -71,8 +73,9 @@ def transforms_imagenet_train(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
if random_erasing > 0.:
|
||||
final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
||||
if re_prob > 0.:
|
||||
final_tfl.append(
|
||||
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
|
||||
|
||||
if separate:
|
||||
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||
@ -126,6 +129,10 @@ def create_transform(
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0,
|
||||
crop_pct=None,
|
||||
tf_preprocessing=False,
|
||||
separate=False):
|
||||
@ -150,6 +157,10 @@ def create_transform(
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
re_prob=re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits,
|
||||
separate=separate)
|
||||
else:
|
||||
assert not separate, "Separate transforms not supported for validation preprocessing"
|
||||
|
@ -6,7 +6,7 @@ from .cross_entropy import LabelSmoothingCrossEntropy
|
||||
|
||||
|
||||
class JsdCrossEntropy(nn.Module):
|
||||
""" Jenson-Shannon Divergence + Cross-Entropy Loss
|
||||
""" Jensen-Shannon Divergence + Cross-Entropy Loss
|
||||
|
||||
"""
|
||||
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
||||
|
@ -20,3 +20,4 @@ from .registry import *
|
||||
from .factory import create_model
|
||||
from .helpers import load_checkpoint, resume_checkpoint
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .split_batchnorm import convert_splitbn_model
|
||||
|
64
timm/models/split_batchnorm.py
Normal file
64
timm/models/split_batchnorm.py
Normal file
@ -0,0 +1,64 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
|
||||
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
||||
track_running_stats=True, num_splits=1):
|
||||
super().__init__(num_features, eps, momentum, affine, track_running_stats)
|
||||
assert num_splits >= 2, 'Should have at least one aux BN layer (num_splits at least 2)'
|
||||
self.num_splits = num_splits
|
||||
self.aux_bn = nn.ModuleList([
|
||||
nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
if self.training: # aux BN only relevant while training
|
||||
split_size = input.shape[0] // self.num_splits
|
||||
assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
|
||||
split_input = input.split(split_size)
|
||||
x = [super().forward(split_input[0])]
|
||||
for i, a in enumerate(self.aux_bn):
|
||||
x.append(a(split_input[i + 1]))
|
||||
return torch.cat(x, dim=0)
|
||||
else:
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
def convert_splitbn_model(module, num_splits=2):
|
||||
"""
|
||||
Recursively traverse module and its children to replace all instances of
|
||||
``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
|
||||
Args:
|
||||
module (torch.nn.Module): input module
|
||||
num_splits: number of separate batchnorm layers to split input across
|
||||
Example::
|
||||
>>> # model is an instance of torch.nn.Module
|
||||
>>> import apex
|
||||
>>> sync_bn_model = timm.models.convert_splitbn_model(model, num_splits=2)
|
||||
"""
|
||||
mod = module
|
||||
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
|
||||
return module
|
||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
||||
mod = SplitBatchNorm2d(
|
||||
module.num_features, module.eps, module.momentum, module.affine,
|
||||
module.track_running_stats, num_splits=num_splits)
|
||||
mod.running_mean = module.running_mean
|
||||
mod.running_var = module.running_var
|
||||
mod.num_batches_tracked = module.num_batches_tracked
|
||||
if module.affine:
|
||||
mod.weight.data = module.weight.data.clone().detach()
|
||||
mod.bias.data = module.bias.data.clone().detach()
|
||||
for aux in mod.aux_bn:
|
||||
aux.running_mean = module.running_mean.clone()
|
||||
aux.running_var = module.running_var.clone()
|
||||
aux.num_batches_tracked = module.num_batches_tracked.clone()
|
||||
if module.affine:
|
||||
aux.weight.data = module.weight.data.clone().detach()
|
||||
aux.bias.data = module.bias.data.clone().detach()
|
||||
for name, child in module.named_children():
|
||||
mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
|
||||
del module
|
||||
return mod
|
52
train.py
52
train.py
@ -13,16 +13,13 @@ except ImportError:
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
has_apex = False
|
||||
|
||||
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch
|
||||
from timm.models import create_model, resume_checkpoint
|
||||
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset
|
||||
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
||||
from timm.utils import *
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||
from timm.optim import create_optimizer
|
||||
from timm.scheduler import create_scheduler
|
||||
|
||||
#FIXME
|
||||
from timm.data.dataset import AugMixDataset
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.utils
|
||||
@ -71,6 +68,8 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
|
||||
help='Drop connect rate (default: 0.)')
|
||||
parser.add_argument('--jsd', action='store_true', default=False,
|
||||
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
@ -106,18 +105,24 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
||||
help='Color jitter factor (default: 0.4)')
|
||||
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
||||
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
||||
parser.add_argument('--aug-splits', type=int, default=0,
|
||||
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
|
||||
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||
help='Random erase prob (default: 0.)')
|
||||
parser.add_argument('--remode', type=str, default='const',
|
||||
help='Random erase mode (default: "const")')
|
||||
parser.add_argument('--recount', type=int, default=1,
|
||||
help='Random erase count (default: 1)')
|
||||
parser.add_argument('--resplit', action='store_true', default=False,
|
||||
help='Do not random erase first (clean) augmentation split')
|
||||
parser.add_argument('--mixup', type=float, default=0.0,
|
||||
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
|
||||
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
||||
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='label smoothing (default: 0.1)')
|
||||
parser.add_argument('--train-interpolation', type=str, default='random',
|
||||
help='Training interpolation (random, bilinear, bicubic default: "random")')
|
||||
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
||||
parser.add_argument('--bn-tf', action='store_true', default=False,
|
||||
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
||||
@ -129,6 +134,8 @@ parser.add_argument('--sync-bn', action='store_true',
|
||||
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
|
||||
parser.add_argument('--dist-bn', type=str, default='',
|
||||
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
|
||||
parser.add_argument('--split-bn', action='store_true',
|
||||
help='Enable separate BN layers per augmentation split.')
|
||||
# Model Exponential Moving Average
|
||||
parser.add_argument('--model-ema', action='store_true', default=False,
|
||||
help='Enable tracking moving average of model weights')
|
||||
@ -162,10 +169,6 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||
parser.add_argument("--local_rank", default=0, type=int)
|
||||
|
||||
|
||||
parser.add_argument('--jsd', action='store_true', default=False,
|
||||
help='')
|
||||
|
||||
|
||||
def _parse_args():
|
||||
# Do we have a config file to parse?
|
||||
args_config, remaining = config_parser.parse_known_args()
|
||||
@ -233,6 +236,14 @@ def main():
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||
|
||||
num_aug_splits = 0
|
||||
if args.aug_splits:
|
||||
num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense
|
||||
|
||||
if args.split_bn:
|
||||
assert num_aug_splits > 1 or args.resplit
|
||||
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||
|
||||
if args.num_gpu > 1:
|
||||
if args.amp:
|
||||
logging.warning(
|
||||
@ -279,6 +290,7 @@ def main():
|
||||
|
||||
if args.distributed:
|
||||
if args.sync_bn:
|
||||
assert not args.split_bn
|
||||
try:
|
||||
if has_apex:
|
||||
model = convert_syncbn_model(model)
|
||||
@ -317,13 +329,11 @@ def main():
|
||||
|
||||
collate_fn = None
|
||||
if args.prefetcher and args.mixup > 0:
|
||||
assert not args.jsd
|
||||
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
|
||||
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
|
||||
|
||||
separate_transforms = False
|
||||
if args.jsd:
|
||||
dataset_train = AugMixDataset(dataset_train)
|
||||
separate_transforms = True
|
||||
if num_aug_splits > 1:
|
||||
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
||||
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
@ -331,18 +341,19 @@ def main():
|
||||
batch_size=args.batch_size,
|
||||
is_training=True,
|
||||
use_prefetcher=args.prefetcher,
|
||||
rand_erase_prob=args.reprob,
|
||||
rand_erase_mode=args.remode,
|
||||
rand_erase_count=args.recount,
|
||||
re_prob=args.reprob,
|
||||
re_mode=args.remode,
|
||||
re_count=args.recount,
|
||||
re_split=args.resplit,
|
||||
color_jitter=args.color_jitter,
|
||||
auto_augment=args.aa,
|
||||
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
||||
num_aug_splits=num_aug_splits,
|
||||
interpolation=args.train_interpolation,
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
distributed=args.distributed,
|
||||
collate_fn=collate_fn,
|
||||
separate_transforms=separate_transforms,
|
||||
)
|
||||
|
||||
eval_dir = os.path.join(args.data, 'val')
|
||||
@ -368,7 +379,8 @@ def main():
|
||||
)
|
||||
|
||||
if args.jsd:
|
||||
train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda()
|
||||
assert num_aug_splits > 1 # JSD only valid with aug splits set
|
||||
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
elif args.mixup > 0.:
|
||||
# smoothing is handled with mixup label transform
|
||||
|
Loading…
x
Reference in New Issue
Block a user