mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Mixup and prefetcher improvements
* Do mixup in custom collate fn if prefetcher enabled, reduces performance impact * Move mixup code to own file * Add arg to disable prefetcher * Fix no cuda transfer when prefetcher off * Random erasing when prefetcher off wasn't changed to match new args, fixed * Default random erasing to off (prob = 0.) for train
This commit is contained in:
parent
780c0a96a4
commit
4d2056722a
@ -3,3 +3,4 @@ from data.config import resolve_data_config
|
||||
from data.dataset import Dataset
|
||||
from data.transforms import *
|
||||
from data.loader import create_loader
|
||||
from data.mixup import mixup_target, FastCollateMixup
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch.utils.data
|
||||
from data.transforms import *
|
||||
from data.distributed_sampler import OrderedDistributedSampler
|
||||
from data.mixup import FastCollateMixup
|
||||
|
||||
|
||||
def fast_collate(batch):
|
||||
@ -60,6 +61,18 @@ class PrefetchLoader:
|
||||
def sampler(self):
|
||||
return self.loader.sampler
|
||||
|
||||
@property
|
||||
def mixup_enabled(self):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
return self.loader.collate_fn.mixup_enabled
|
||||
else:
|
||||
return False
|
||||
|
||||
@mixup_enabled.setter
|
||||
def mixup_enabled(self, x):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
self.loader.collate_fn.mixup_enabled = x
|
||||
|
||||
|
||||
def create_loader(
|
||||
dataset,
|
||||
@ -75,6 +88,7 @@ def create_loader(
|
||||
num_workers=1,
|
||||
distributed=False,
|
||||
crop_pct=None,
|
||||
collate_fn=None,
|
||||
):
|
||||
if isinstance(input_size, tuple):
|
||||
img_size = input_size[-2:]
|
||||
@ -108,13 +122,16 @@ def create_loader(
|
||||
# of samples per-process, will slightly alter validation results
|
||||
sampler = OrderedDistributedSampler(dataset)
|
||||
|
||||
if collate_fn is None:
|
||||
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=sampler is None and is_training,
|
||||
num_workers=num_workers,
|
||||
sampler=sampler,
|
||||
collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=is_training,
|
||||
)
|
||||
if use_prefetcher:
|
||||
|
42
data/mixup.py
Normal file
42
data/mixup.py
Normal file
@ -0,0 +1,42 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
||||
x = x.long().view(-1, 1)
|
||||
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
||||
|
||||
|
||||
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
||||
off_value = smoothing / num_classes
|
||||
on_value = 1. - smoothing + off_value
|
||||
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
|
||||
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
|
||||
return lam*y1 + (1. - lam)*y2
|
||||
|
||||
|
||||
class FastCollateMixup:
|
||||
|
||||
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
|
||||
self.mixup_alpha = mixup_alpha
|
||||
self.label_smoothing = label_smoothing
|
||||
self.num_classes = num_classes
|
||||
self.mixup_enabled = True
|
||||
|
||||
def __call__(self, batch):
|
||||
batch_size = len(batch)
|
||||
lam = 1.
|
||||
if self.mixup_enabled:
|
||||
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||
|
||||
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
|
||||
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
mixed = batch[i][0].astype(np.float32) * lam + \
|
||||
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
|
||||
np.round(mixed, out=mixed)
|
||||
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
|
||||
|
||||
return tensor, target
|
@ -159,7 +159,7 @@ def transforms_imagenet_train(
|
||||
color_jitter=(0.4, 0.4, 0.4),
|
||||
interpolation='random',
|
||||
random_erasing=0.4,
|
||||
random_erasing_pp=True,
|
||||
random_erasing_mode='const',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD
|
||||
@ -183,7 +183,7 @@ def transforms_imagenet_train(
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
if random_erasing > 0.:
|
||||
tfl.append(RandomErasing(random_erasing, per_pixel=random_erasing_pp, device='cpu'))
|
||||
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
|
41
train.py
41
train.py
@ -10,7 +10,7 @@ try:
|
||||
except ImportError:
|
||||
has_apex = False
|
||||
|
||||
from data import Dataset, create_loader, resolve_data_config
|
||||
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
||||
from models import create_model, resume_checkpoint
|
||||
from utils import *
|
||||
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
|
||||
@ -66,9 +66,9 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RA
|
||||
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
||||
help='LR scheduler (default: "step"')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.1)')
|
||||
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
|
||||
help='Random erase prob (default: 0.4)')
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||
help='Random erase prob (default: 0.)')
|
||||
parser.add_argument('--remode', type=str, default='const',
|
||||
help='Random erase mode (default: "const")')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
@ -109,6 +109,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
|
||||
help='save images of input bathes every log interval for debugging')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='use NVIDIA amp for mixed precision training')
|
||||
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||
help='disable fast prefetcher')
|
||||
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||
help='path to output folder (default: none, current dir)')
|
||||
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
|
||||
@ -119,6 +121,7 @@ parser.add_argument("--local_rank", default=0, type=int)
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
args.distributed = False
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||
@ -130,6 +133,7 @@ def main():
|
||||
args.world_size = 1
|
||||
r = -1
|
||||
if args.distributed:
|
||||
args.num_gpu = 1
|
||||
args.device = 'cuda:%d' % args.local_rank
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl',
|
||||
@ -216,12 +220,16 @@ def main():
|
||||
exit(1)
|
||||
dataset_train = Dataset(train_dir)
|
||||
|
||||
collate_fn = None
|
||||
if args.prefetcher and args.mixup > 0:
|
||||
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
|
||||
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
is_training=True,
|
||||
use_prefetcher=True,
|
||||
use_prefetcher=args.prefetcher,
|
||||
rand_erase_prob=args.reprob,
|
||||
rand_erase_mode=args.remode,
|
||||
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
||||
@ -229,6 +237,7 @@ def main():
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
distributed=args.distributed,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
eval_dir = os.path.join(args.data, 'validation')
|
||||
@ -242,7 +251,7 @@ def main():
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=4 * args.batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
use_prefetcher=args.prefetcher,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
@ -309,6 +318,10 @@ def train_epoch(
|
||||
epoch, model, loader, optimizer, loss_fn, args,
|
||||
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
|
||||
|
||||
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
|
||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||
loader.mixup_enabled = False
|
||||
|
||||
batch_time_m = AverageMeter()
|
||||
data_time_m = AverageMeter()
|
||||
losses_m = AverageMeter()
|
||||
@ -321,13 +334,15 @@ def train_epoch(
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
last_batch = batch_idx == last_idx
|
||||
data_time_m.update(time.time() - end)
|
||||
|
||||
if args.mixup > 0.:
|
||||
lam = 1.
|
||||
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
|
||||
lam = np.random.beta(args.mixup, args.mixup)
|
||||
input.mul_(lam).add_(1 - lam, input.flip(0))
|
||||
target = mixup_target(target, args.num_classes, lam, args.smoothing)
|
||||
if not args.prefetcher:
|
||||
input = input.cuda()
|
||||
target = target.cuda()
|
||||
if args.mixup > 0.:
|
||||
lam = 1.
|
||||
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
|
||||
lam = np.random.beta(args.mixup, args.mixup)
|
||||
input.mul_(lam).add_(1 - lam, input.flip(0))
|
||||
target = mixup_target(target, args.num_classes, lam, args.smoothing)
|
||||
|
||||
output = model(input)
|
||||
|
||||
|
13
utils.py
13
utils.py
@ -140,19 +140,6 @@ def accuracy(output, target, topk=(1,)):
|
||||
return res
|
||||
|
||||
|
||||
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
||||
x = x.long().view(-1, 1)
|
||||
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
||||
|
||||
|
||||
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
|
||||
off_value = smoothing / num_classes
|
||||
on_value = 1. - smoothing + off_value
|
||||
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
|
||||
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
|
||||
return lam*y1 + (1. - lam)*y2
|
||||
|
||||
|
||||
def get_outdir(path, *paths, inc=False):
|
||||
outdir = os.path.join(path, *paths)
|
||||
if not os.path.exists(outdir):
|
||||
|
Loading…
x
Reference in New Issue
Block a user