mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update AugMix, JSD, etc comments and references
This commit is contained in:
parent
833066b540
commit
3eb4a96eda
@ -1,7 +1,19 @@
|
|||||||
""" AutoAugment and RandAugment
|
""" AutoAugment, RandAugment, and AugMix for PyTorch
|
||||||
Implementation adapted from:
|
|
||||||
|
This code implements the searched ImageNet policies with various tweaks and improvements and
|
||||||
|
does not include any of the search code.
|
||||||
|
|
||||||
|
AA and RA Implementation adapted from:
|
||||||
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
||||||
Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
|
|
||||||
|
AugMix adapted from:
|
||||||
|
https://github.com/google-research/augmix
|
||||||
|
|
||||||
|
Papers:
|
||||||
|
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
|
||||||
|
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
|
||||||
|
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
||||||
|
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
|
||||||
|
|
||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
@ -691,12 +703,17 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
|||||||
|
|
||||||
|
|
||||||
class AugMixAugment:
|
class AugMixAugment:
|
||||||
|
""" AugMix Transform
|
||||||
|
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
||||||
|
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
||||||
|
https://arxiv.org/abs/1912.02781
|
||||||
|
"""
|
||||||
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
||||||
self.ops = ops
|
self.ops = ops
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.width = width
|
self.width = width
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
self.blended = blended
|
self.blended = blended # blended mode is faster but not well tested
|
||||||
|
|
||||||
def _calc_blended_weights(self, ws, m):
|
def _calc_blended_weights(self, ws, m):
|
||||||
ws = ws * m
|
ws = ws * m
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
""" Transforms Factory
|
||||||
|
Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
|
||||||
|
"""
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -24,7 +27,13 @@ def transforms_imagenet_train(
|
|||||||
re_num_splits=0,
|
re_num_splits=0,
|
||||||
separate=False,
|
separate=False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||||
|
for use in a mixing dataset that passes
|
||||||
|
* all data through the first (primary) transform, called the 'clean' data
|
||||||
|
* a portion of the data through the secondary transform
|
||||||
|
* normalizes and converts the branches above with the third, final transform
|
||||||
|
"""
|
||||||
primary_tfl = [
|
primary_tfl = [
|
||||||
RandomResizedCropAndInterpolation(
|
RandomResizedCropAndInterpolation(
|
||||||
img_size, scale=scale, interpolation=interpolation),
|
img_size, scale=scale, interpolation=interpolation),
|
||||||
|
@ -8,6 +8,11 @@ from .cross_entropy import LabelSmoothingCrossEntropy
|
|||||||
class JsdCrossEntropy(nn.Module):
|
class JsdCrossEntropy(nn.Module):
|
||||||
""" Jensen-Shannon Divergence + Cross-Entropy Loss
|
""" Jensen-Shannon Divergence + Cross-Entropy Loss
|
||||||
|
|
||||||
|
Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
||||||
|
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
||||||
|
https://arxiv.org/abs/1912.02781
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1,6 +1,18 @@
|
|||||||
|
""" Split BatchNorm
|
||||||
|
|
||||||
|
A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
|
||||||
|
a separate BN layer. The first split is passed through the parent BN layers with weight/bias
|
||||||
|
keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
|
||||||
|
namespace.
|
||||||
|
|
||||||
|
This allows easily removing the auxiliary BN layers after training to efficiently
|
||||||
|
achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
|
||||||
|
'Disentangled Learning via An Auxiliary BN'
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
|
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
|
||||||
|
5
train.py
5
train.py
@ -237,8 +237,9 @@ def main():
|
|||||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||||
|
|
||||||
num_aug_splits = 0
|
num_aug_splits = 0
|
||||||
if args.aug_splits:
|
if args.aug_splits > 0:
|
||||||
num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense
|
assert args.aug_splits > 1, 'A split of 1 makes no sense'
|
||||||
|
num_aug_splits = args.aug_splits
|
||||||
|
|
||||||
if args.split_bn:
|
if args.split_bn:
|
||||||
assert num_aug_splits > 1 or args.resplit
|
assert num_aug_splits > 1 or args.resplit
|
||||||
|
Loading…
x
Reference in New Issue
Block a user