diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 66c16257..ee2240b4 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,6 +1,6 @@ from .constants import * from .config import resolve_data_config -from .dataset import Dataset, DatasetTar +from .dataset import Dataset, DatasetTar, AugMixDataset from .transforms import * from .loader import create_loader from .transforms_factory import create_transform diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index ec2602b3..8d7b36f9 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -323,7 +323,7 @@ class AugmentOp: self.magnitude_std = self.hparams.get('magnitude_std', 0) def __call__(self, img): - if not self.prob >= 1.0 or random.random() > self.prob: + if self.prob < 1.0 and random.random() > self.prob: return img magnitude = self.magnitude if self.magnitude_std and self.magnitude_std > 0: @@ -539,7 +539,7 @@ _RAND_TRANSFORMS = [ 'ShearY', 'TranslateXRel', 'TranslateYRel', - #'Cutout' # FIXME I implement this as random erasing separately + #'Cutout' # NOTE I've implement this as random erasing separately ] @@ -559,7 +559,7 @@ _RAND_INCREASING_TRANSFORMS = [ 'ShearY', 'TranslateXRel', 'TranslateYRel', - #'Cutout' # FIXME I implement this as random erasing separately + #'Cutout' # NOTE I've implement this as random erasing separately ] @@ -627,6 +627,7 @@ def rand_augment_transform(config_str, hparams): 'n' - integer num layers (number of transform ops selected per image) 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 @@ -637,6 +638,7 @@ def rand_augment_transform(config_str, hparams): magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) num_layers = 2 # default to 2 ops per image weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS config = config_str.split('-') assert config[0] == 'rand' config = config[1:] @@ -648,6 +650,9 @@ def rand_augment_transform(config_str, hparams): if key == 'mstd': # noise param injected via hparams for now hparams.setdefault('magnitude_std', float(val)) + elif key == 'inc': + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS elif key == 'm': magnitude = int(val) elif key == 'n': @@ -656,7 +661,7 @@ def rand_augment_transform(config_str, hparams): weight_idx = int(val) else: assert False, 'Unknown RandAugment config section' - ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) @@ -686,12 +691,12 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None): class AugMixAugment: - def __init__(self, ops, alpha=1., width=3, depth=-1): + def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): self.ops = ops self.alpha = alpha self.width = width self.depth = depth - self.blended = False + self.blended = blended def _calc_blended_weights(self, ws, m): ws = ws * m @@ -707,7 +712,7 @@ class AugMixAugment: # This is my first crack and implementing a slightly faster mixed augmentation. Instead # of accumulating the mix for each chain in a Numpy array and then blending with original, # it recomputes the blending coefficients and applies one PIL image blend per chain. - # TODO I've verified the results are in the right ballpark but they differ by more than rounding. + # TODO the results appear in the right ballpark but they differ by more than rounding. img_orig = img.copy() ws = self._calc_blended_weights(mixing_weights, m) for w in ws: @@ -755,6 +760,7 @@ def augment_and_mix_transform(config_str, hparams): 'm' - integer magnitude (severity) of augmentation mix (default: 3) 'w' - integer width of augmentation chain (default: 3) 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) 'mstd' - float std deviation of magnitude noise applied (default: 0) Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 @@ -766,6 +772,7 @@ def augment_and_mix_transform(config_str, hparams): width = 3 depth = -1 alpha = 1. + blended = False config = config_str.split('-') assert config[0] == 'augmix' config = config[1:] @@ -785,7 +792,9 @@ def augment_and_mix_transform(config_str, hparams): depth = int(val) elif key == 'a': alpha = float(val) + elif key == 'b': + blended = bool(val) else: assert False, 'Unknown AugMix config section' ops = augmix_ops(magnitude=magnitude, hparams=hparams) - return AugMixAugment(ops, alpha=alpha, width=width, depth=depth) + return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 3220883a..fc252d9e 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -144,13 +144,13 @@ class DatasetTar(data.Dataset): class AugMixDataset(torch.utils.data.Dataset): """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" - def __init__(self, dataset, num_aug=2): + def __init__(self, dataset, num_splits=2): self.augmentation = None self.normalize = None self.dataset = dataset if self.dataset.transform is not None: self._set_transforms(self.dataset.transform) - self.num_aug = num_aug + self.num_splits = num_splits def _set_transforms(self, x): assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' @@ -170,9 +170,10 @@ class AugMixDataset(torch.utils.data.Dataset): return x if self.normalize is None else self.normalize(x) def __getitem__(self, i): - x, y = self.dataset[i] - x_list = [self._normalize(x)] - for n in range(self.num_aug): + x, y = self.dataset[i] # all splits share the same dataset base transform + x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) + # run the full augmentation on the remaining splits + for _ in range(self.num_splits - 1): x_list.append(self._normalize(self.augmentation(x))) return tuple(x_list), y diff --git a/timm/data/loader.py b/timm/data/loader.py index 06d431ee..e2ec8797 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -15,7 +15,7 @@ def fast_collate(batch): if isinstance(batch[0][0], tuple): # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position - inner_tuple_size = len(batch[0][0][0]) + inner_tuple_size = len(batch[0][0]) flattened_batch_size = batch_size * inner_tuple_size targets = torch.zeros(flattened_batch_size, dtype=torch.int64) tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) @@ -46,13 +46,14 @@ def fast_collate(batch): class PrefetchLoader: def __init__(self, - loader, - rand_erase_prob=0., - rand_erase_mode='const', - rand_erase_count=1, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - fp16=False): + loader, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + fp16=False, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0): self.loader = loader self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) @@ -60,9 +61,9 @@ class PrefetchLoader: if fp16: self.mean = self.mean.half() self.std = self.std.half() - if rand_erase_prob > 0.: + if re_prob > 0.: self.random_erasing = RandomErasing( - probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count) + probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) else: self.random_erasing = None @@ -122,11 +123,13 @@ def create_loader( batch_size, is_training=False, use_prefetcher=True, - rand_erase_prob=0., - rand_erase_mode='const', - rand_erase_count=1, + re_prob=0., + re_mode='const', + re_count=1, + re_split=False, color_jitter=0.4, auto_augment=None, + num_aug_splits=0, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, @@ -136,8 +139,11 @@ def create_loader( collate_fn=None, fp16=False, tf_preprocessing=False, - separate_transforms=False, ): + re_num_splits = 0 + if re_split: + # apply RE to second half of batch if no aug split otherwise line up with aug split + re_num_splits = num_aug_splits or 2 dataset.transform = create_transform( input_size, is_training=is_training, @@ -149,7 +155,11 @@ def create_loader( std=std, crop_pct=crop_pct, tf_preprocessing=tf_preprocessing, - separate=separate_transforms, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, + separate=num_aug_splits > 0, ) sampler = None @@ -176,11 +186,13 @@ def create_loader( if use_prefetcher: loader = PrefetchLoader( loader, - rand_erase_prob=rand_erase_prob if is_training else 0., - rand_erase_mode=rand_erase_mode, - rand_erase_count=rand_erase_count, mean=mean, std=std, - fp16=fp16) + fp16=fp16, + re_prob=re_prob if is_training else 0., + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits + ) return loader diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 2b6b61a5..589b2f0b 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -38,7 +38,7 @@ class RandomErasing: def __init__( self, probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, - mode='const', min_count=1, max_count=None, device='cuda'): + mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): self.probability = probability self.min_area = min_area self.max_area = max_area @@ -46,6 +46,7 @@ class RandomErasing: self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) self.min_count = min_count self.max_count = max_count or min_count + self.num_splits = num_splits mode = mode.lower() self.rand_color = False self.per_pixel = False @@ -82,6 +83,8 @@ class RandomErasing: self._erase(input, *input.size(), input.dtype) else: batch_size, chan, img_h, img_w = input.size() - for i in range(batch_size): + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + for i in range(batch_start, batch_size): self._erase(input[i], chan, img_h, img_w, input.dtype) return input diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index b70ae76f..faf55b70 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -15,11 +15,13 @@ def transforms_imagenet_train( color_jitter=0.4, auto_augment=None, interpolation='random', - random_erasing=0.4, - random_erasing_mode='const', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, separate=False, ): @@ -71,8 +73,9 @@ def transforms_imagenet_train( mean=torch.tensor(mean), std=torch.tensor(std)) ] - if random_erasing > 0.: - final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) + if re_prob > 0.: + final_tfl.append( + RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) if separate: return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) @@ -126,6 +129,10 @@ def create_transform( interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, crop_pct=None, tf_preprocessing=False, separate=False): @@ -150,6 +157,10 @@ def create_transform( use_prefetcher=use_prefetcher, mean=mean, std=std, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, separate=separate) else: assert not separate, "Separate transforms not supported for validation preprocessing" diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py index 0f99c699..ad6ca1e5 100644 --- a/timm/loss/jsd.py +++ b/timm/loss/jsd.py @@ -6,7 +6,7 @@ from .cross_entropy import LabelSmoothingCrossEntropy class JsdCrossEntropy(nn.Module): - """ Jenson-Shannon Divergence + Cross-Entropy Loss + """ Jensen-Shannon Divergence + Cross-Entropy Loss """ def __init__(self, num_splits=3, alpha=12, smoothing=0.1): diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 7119c4f5..a0be7bd0 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -20,3 +20,4 @@ from .registry import * from .factory import create_model from .helpers import load_checkpoint, resume_checkpoint from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .split_batchnorm import convert_splitbn_model diff --git a/timm/models/split_batchnorm.py b/timm/models/split_batchnorm.py new file mode 100644 index 00000000..0ed30d77 --- /dev/null +++ b/timm/models/split_batchnorm.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SplitBatchNorm2d(torch.nn.BatchNorm2d): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, num_splits=1): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + assert num_splits >= 2, 'Should have at least one aux BN layer (num_splits at least 2)' + self.num_splits = num_splits + self.aux_bn = nn.ModuleList([ + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + + def forward(self, input: torch.Tensor): + if self.training: # aux BN only relevant while training + split_size = input.shape[0] // self.num_splits + assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" + split_input = input.split(split_size) + x = [super().forward(split_input[0])] + for i, a in enumerate(self.aux_bn): + x.append(a(split_input[i + 1])) + return torch.cat(x, dim=0) + else: + return super().forward(input) + + +def convert_splitbn_model(module, num_splits=2): + """ + Recursively traverse module and its children to replace all instances of + ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. + Args: + module (torch.nn.Module): input module + num_splits: number of separate batchnorm layers to split input across + Example:: + >>> # model is an instance of torch.nn.Module + >>> import apex + >>> sync_bn_model = timm.models.convert_splitbn_model(model, num_splits=2) + """ + mod = module + if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): + return module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + mod = SplitBatchNorm2d( + module.num_features, module.eps, module.momentum, module.affine, + module.track_running_stats, num_splits=num_splits) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + mod.num_batches_tracked = module.num_batches_tracked + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + for aux in mod.aux_bn: + aux.running_mean = module.running_mean.clone() + aux.running_var = module.running_var.clone() + aux.num_batches_tracked = module.num_batches_tracked.clone() + if module.affine: + aux.weight.data = module.weight.data.clone().detach() + aux.bias.data = module.bias.data.clone().detach() + for name, child in module.named_children(): + mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) + del module + return mod diff --git a/train.py b/train.py index 9910a059..bb6db08d 100644 --- a/train.py +++ b/train.py @@ -13,16 +13,13 @@ except ImportError: from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False -from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch -from timm.models import create_model, resume_checkpoint +from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset +from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer from timm.scheduler import create_scheduler -#FIXME -from timm.data.dataset import AugMixDataset - import torch import torch.nn as nn import torchvision.utils @@ -71,6 +68,8 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', help='Dropout rate (default: 0.)') parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP', help='Drop connect rate (default: 0.)') +parser.add_argument('--jsd', action='store_true', default=False, + help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') @@ -106,18 +105,24 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), +parser.add_argument('--aug-splits', type=int, default=0, + help='Number of augmentation splits (default: 0, valid: 0 or >=2)') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') +parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='label smoothing (default: 0.1)') +parser.add_argument('--train-interpolation', type=str, default='random', + help='Training interpolation (random, bilinear, bicubic default: "random")') # Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-tf', action='store_true', default=False, help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') @@ -129,6 +134,8 @@ parser.add_argument('--sync-bn', action='store_true', help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') parser.add_argument('--dist-bn', type=str, default='', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') +parser.add_argument('--split-bn', action='store_true', + help='Enable separate BN layers per augmentation split.') # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') @@ -162,10 +169,6 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument("--local_rank", default=0, type=int) -parser.add_argument('--jsd', action='store_true', default=False, - help='') - - def _parse_args(): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() @@ -233,6 +236,14 @@ def main(): data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + num_aug_splits = 0 + if args.aug_splits: + num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense + + if args.split_bn: + assert num_aug_splits > 1 or args.resplit + model = convert_splitbn_model(model, max(num_aug_splits, 2)) + if args.num_gpu > 1: if args.amp: logging.warning( @@ -279,6 +290,7 @@ def main(): if args.distributed: if args.sync_bn: + assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) @@ -317,13 +329,11 @@ def main(): collate_fn = None if args.prefetcher and args.mixup > 0: - assert not args.jsd + assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) - separate_transforms = False - if args.jsd: - dataset_train = AugMixDataset(dataset_train) - separate_transforms = True + if num_aug_splits > 1: + dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, @@ -331,18 +341,19 @@ def main(): batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, - rand_erase_prob=args.reprob, - rand_erase_mode=args.remode, - rand_erase_count=args.recount, + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, - interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], + num_aug_splits=num_aug_splits, + interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, - separate_transforms=separate_transforms, ) eval_dir = os.path.join(args.data, 'val') @@ -368,7 +379,8 @@ def main(): ) if args.jsd: - train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda() + assert num_aug_splits > 1 # JSD only valid with aug splits set + train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.mixup > 0.: # smoothing is handled with mixup label transform