Add more augmentation arguments, including a no_aug disable flag. Fix #209
parent
e3f58fc90c
commit
fa28067704
|
@ -131,10 +131,15 @@ def create_loader(
|
|||
batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
no_aug=False,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_split=False,
|
||||
scale=None,
|
||||
ratio=None,
|
||||
hflip=0.5,
|
||||
vflip=0.,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
num_aug_splits=0,
|
||||
|
@ -158,6 +163,11 @@ def create_loader(
|
|||
input_size,
|
||||
is_training=is_training,
|
||||
use_prefetcher=use_prefetcher,
|
||||
no_aug=no_aug,
|
||||
scale=scale,
|
||||
ratio=ratio,
|
||||
hflip=hflip,
|
||||
vflip=vflip,
|
||||
color_jitter=color_jitter,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
|
@ -200,12 +210,13 @@ def create_loader(
|
|||
drop_last=is_training,
|
||||
)
|
||||
if use_prefetcher:
|
||||
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
|
||||
loader = PrefetchLoader(
|
||||
loader,
|
||||
mean=mean,
|
||||
std=std,
|
||||
fp16=fp16,
|
||||
re_prob=re_prob if is_training else 0.,
|
||||
re_prob=prefetch_re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
re_num_splits=re_num_splits
|
||||
|
|
|
@ -14,9 +14,39 @@ from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation,
|
|||
from timm.data.random_erasing import RandomErasing
|
||||
|
||||
|
||||
def transforms_noaug_train(
|
||||
img_size=224,
|
||||
interpolation='bilinear',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
):
|
||||
if interpolation == 'random':
|
||||
# random interpolation no supported with no-aug
|
||||
interpolation = 'bilinear'
|
||||
tfl = [
|
||||
transforms.Resize(img_size, _pil_interp(interpolation)),
|
||||
transforms.CenterCrop(img_size)
|
||||
]
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.08, 1.0),
|
||||
scale=None,
|
||||
ratio=None,
|
||||
hflip=0.5,
|
||||
vflip=0.,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='random',
|
||||
|
@ -36,11 +66,14 @@ def transforms_imagenet_train(
|
|||
* a portion of the data through the secondary transform
|
||||
* normalizes and converts the branches above with the third, final transform
|
||||
"""
|
||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||
ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
|
||||
primary_tfl = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size, scale=scale, interpolation=interpolation),
|
||||
transforms.RandomHorizontalFlip()
|
||||
]
|
||||
RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
|
||||
if hflip > 0.:
|
||||
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
||||
if vflip > 0.:
|
||||
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
|
||||
|
||||
secondary_tfl = []
|
||||
if auto_augment:
|
||||
|
@ -135,6 +168,11 @@ def create_transform(
|
|||
input_size,
|
||||
is_training=False,
|
||||
use_prefetcher=False,
|
||||
no_aug=False,
|
||||
scale=None,
|
||||
ratio=None,
|
||||
hflip=0.5,
|
||||
vflip=0.,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='bilinear',
|
||||
|
@ -159,9 +197,21 @@ def create_transform(
|
|||
transform = TfPreprocessTransform(
|
||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
||||
else:
|
||||
if is_training:
|
||||
if is_training and no_aug:
|
||||
assert not separate, "Cannot perform split augmentation with no_aug"
|
||||
transform = transforms_noaug_train(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std)
|
||||
elif is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
img_size,
|
||||
scale=scale,
|
||||
ratio=ratio,
|
||||
hflip=hflip,
|
||||
vflip=vflip,
|
||||
color_jitter=color_jitter,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
|
|
55
train.py
55
train.py
|
@ -51,6 +51,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
|
|||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
|
||||
# Dataset / Model parameters
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
|
@ -82,16 +83,7 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
|||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
|
||||
help='ratio of validation batch size to training batch size (default: 1)')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
|
||||
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
|
||||
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
|
||||
help='Drop path rate (default: None)')
|
||||
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
||||
help='Drop block rate (default: None)')
|
||||
parser.add_argument('--jsd', action='store_true', default=False,
|
||||
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
||||
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
|
@ -101,6 +93,7 @@ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
|||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
||||
help='weight decay (default: 0.0001)')
|
||||
|
||||
# Learning rate schedule parameters
|
||||
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
||||
help='LR scheduler (default: "step"')
|
||||
|
@ -134,13 +127,26 @@ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
|||
help='patience epochs for Plateau LR scheduler (default: 10')
|
||||
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||
help='LR decay rate (default: 0.1)')
|
||||
# Augmentation parameters
|
||||
|
||||
# Augmentation & regularization parameters
|
||||
parser.add_argument('--no-aug', action='store_true', default=False,
|
||||
help='Disable all training augmentation, override other train aug args')
|
||||
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
|
||||
help='Random resize scale (default: 0.08 1.0)')
|
||||
parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
|
||||
help='Random resize aspect ratio (default: 0.75 1.33)')
|
||||
parser.add_argument('--hflip', type=float, default=0.5,
|
||||
help='Horizontal flip training aug probability')
|
||||
parser.add_argument('--vflip', type=float, default=0.,
|
||||
help='Vertical flip training aug probability')
|
||||
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
||||
help='Color jitter factor (default: 0.4)')
|
||||
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
||||
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
||||
parser.add_argument('--aug-splits', type=int, default=0,
|
||||
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
|
||||
parser.add_argument('--jsd', action='store_true', default=False,
|
||||
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
||||
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||
help='Random erase prob (default: 0.)')
|
||||
parser.add_argument('--remode', type=str, default='const',
|
||||
|
@ -150,13 +156,22 @@ parser.add_argument('--recount', type=int, default=1,
|
|||
parser.add_argument('--resplit', action='store_true', default=False,
|
||||
help='Do not random erase first (clean) augmentation split')
|
||||
parser.add_argument('--mixup', type=float, default=0.0,
|
||||
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
|
||||
help='Mixup alpha, mixup enabled if > 0. (default: 0.)')
|
||||
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
||||
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
|
||||
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='label smoothing (default: 0.1)')
|
||||
help='Label smoothing (default: 0.1)')
|
||||
parser.add_argument('--train-interpolation', type=str, default='random',
|
||||
help='Training interpolation (random, bilinear, bicubic default: "random")')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
|
||||
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
|
||||
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
|
||||
help='Drop path rate (default: None)')
|
||||
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
||||
help='Drop block rate (default: None)')
|
||||
|
||||
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
||||
parser.add_argument('--bn-tf', action='store_true', default=False,
|
||||
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
||||
|
@ -170,6 +185,7 @@ parser.add_argument('--dist-bn', type=str, default='',
|
|||
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
|
||||
parser.add_argument('--split-bn', action='store_true',
|
||||
help='Enable separate BN layers per augmentation split.')
|
||||
|
||||
# Model Exponential Moving Average
|
||||
parser.add_argument('--model-ema', action='store_true', default=False,
|
||||
help='Enable tracking moving average of model weights')
|
||||
|
@ -177,6 +193,7 @@ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
|||
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
||||
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
||||
help='decay factor for model weights moving average (default: 0.9998)')
|
||||
|
||||
# Misc
|
||||
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
||||
help='random seed (default: 42)')
|
||||
|
@ -378,20 +395,28 @@ def main():
|
|||
if num_aug_splits > 1:
|
||||
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
||||
|
||||
train_interpolation = args.train_interpolation
|
||||
if args.no_aug or not train_interpolation:
|
||||
train_interpolation = data_config['interpolation']
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
is_training=True,
|
||||
use_prefetcher=args.prefetcher,
|
||||
no_aug=args.no_aug,
|
||||
re_prob=args.reprob,
|
||||
re_mode=args.remode,
|
||||
re_count=args.recount,
|
||||
re_split=args.resplit,
|
||||
scale=args.scale,
|
||||
ratio=args.ratio,
|
||||
hflip=args.hflip,
|
||||
vflip=args.vflip,
|
||||
color_jitter=args.color_jitter,
|
||||
auto_augment=args.aa,
|
||||
num_aug_splits=num_aug_splits,
|
||||
interpolation=args.train_interpolation,
|
||||
interpolation=train_interpolation,
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
|
|
Loading…
Reference in New Issue