mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update AugMix, JSD, etc comments and references
This commit is contained in:
parent
833066b540
commit
3eb4a96eda
@ -1,7 +1,19 @@
|
||||
""" AutoAugment and RandAugment
|
||||
Implementation adapted from:
|
||||
""" AutoAugment, RandAugment, and AugMix for PyTorch
|
||||
|
||||
This code implements the searched ImageNet policies with various tweaks and improvements and
|
||||
does not include any of the search code.
|
||||
|
||||
AA and RA Implementation adapted from:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
||||
Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
|
||||
|
||||
AugMix adapted from:
|
||||
https://github.com/google-research/augmix
|
||||
|
||||
Papers:
|
||||
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
|
||||
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
|
||||
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
||||
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
@ -691,12 +703,17 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
||||
|
||||
|
||||
class AugMixAugment:
|
||||
""" AugMix Transform
|
||||
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
||||
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
||||
https://arxiv.org/abs/1912.02781
|
||||
"""
|
||||
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
||||
self.ops = ops
|
||||
self.alpha = alpha
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.blended = blended
|
||||
self.blended = blended # blended mode is faster but not well tested
|
||||
|
||||
def _calc_blended_weights(self, ws, m):
|
||||
ws = ws * m
|
||||
|
@ -1,3 +1,6 @@
|
||||
""" Transforms Factory
|
||||
Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
@ -24,7 +27,13 @@ def transforms_imagenet_train(
|
||||
re_num_splits=0,
|
||||
separate=False,
|
||||
):
|
||||
|
||||
"""
|
||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||
for use in a mixing dataset that passes
|
||||
* all data through the first (primary) transform, called the 'clean' data
|
||||
* a portion of the data through the secondary transform
|
||||
* normalizes and converts the branches above with the third, final transform
|
||||
"""
|
||||
primary_tfl = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size, scale=scale, interpolation=interpolation),
|
||||
|
@ -8,6 +8,11 @@ from .cross_entropy import LabelSmoothingCrossEntropy
|
||||
class JsdCrossEntropy(nn.Module):
|
||||
""" Jensen-Shannon Divergence + Cross-Entropy Loss
|
||||
|
||||
Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
||||
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
||||
https://arxiv.org/abs/1912.02781
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
||||
super().__init__()
|
||||
|
@ -1,6 +1,18 @@
|
||||
""" Split BatchNorm
|
||||
|
||||
A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
|
||||
a separate BN layer. The first split is passed through the parent BN layers with weight/bias
|
||||
keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
|
||||
namespace.
|
||||
|
||||
This allows easily removing the auxiliary BN layers after training to efficiently
|
||||
achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
|
||||
'Disentangled Learning via An Auxiliary BN'
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
|
||||
|
5
train.py
5
train.py
@ -237,8 +237,9 @@ def main():
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||
|
||||
num_aug_splits = 0
|
||||
if args.aug_splits:
|
||||
num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense
|
||||
if args.aug_splits > 0:
|
||||
assert args.aug_splits > 1, 'A split of 1 makes no sense'
|
||||
num_aug_splits = args.aug_splits
|
||||
|
||||
if args.split_bn:
|
||||
assert num_aug_splits > 1 or args.resplit
|
||||
|
Loading…
x
Reference in New Issue
Block a user