mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Mixup and prefetcher improvements
* Do mixup in custom collate fn if prefetcher enabled, reduces performance impact * Move mixup code to own file * Add arg to disable prefetcher * Fix no cuda transfer when prefetcher off * Random erasing when prefetcher off wasn't changed to match new args, fixed * Default random erasing to off (prob = 0.) for train
This commit is contained in:
parent
780c0a96a4
commit
4d2056722a
@ -3,3 +3,4 @@ from data.config import resolve_data_config
|
|||||||
from data.dataset import Dataset
|
from data.dataset import Dataset
|
||||||
from data.transforms import *
|
from data.transforms import *
|
||||||
from data.loader import create_loader
|
from data.loader import create_loader
|
||||||
|
from data.mixup import mixup_target, FastCollateMixup
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from data.transforms import *
|
from data.transforms import *
|
||||||
from data.distributed_sampler import OrderedDistributedSampler
|
from data.distributed_sampler import OrderedDistributedSampler
|
||||||
|
from data.mixup import FastCollateMixup
|
||||||
|
|
||||||
|
|
||||||
def fast_collate(batch):
|
def fast_collate(batch):
|
||||||
@ -60,6 +61,18 @@ class PrefetchLoader:
|
|||||||
def sampler(self):
|
def sampler(self):
|
||||||
return self.loader.sampler
|
return self.loader.sampler
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mixup_enabled(self):
|
||||||
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||||
|
return self.loader.collate_fn.mixup_enabled
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@mixup_enabled.setter
|
||||||
|
def mixup_enabled(self, x):
|
||||||
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||||
|
self.loader.collate_fn.mixup_enabled = x
|
||||||
|
|
||||||
|
|
||||||
def create_loader(
|
def create_loader(
|
||||||
dataset,
|
dataset,
|
||||||
@ -75,6 +88,7 @@ def create_loader(
|
|||||||
num_workers=1,
|
num_workers=1,
|
||||||
distributed=False,
|
distributed=False,
|
||||||
crop_pct=None,
|
crop_pct=None,
|
||||||
|
collate_fn=None,
|
||||||
):
|
):
|
||||||
if isinstance(input_size, tuple):
|
if isinstance(input_size, tuple):
|
||||||
img_size = input_size[-2:]
|
img_size = input_size[-2:]
|
||||||
@ -108,13 +122,16 @@ def create_loader(
|
|||||||
# of samples per-process, will slightly alter validation results
|
# of samples per-process, will slightly alter validation results
|
||||||
sampler = OrderedDistributedSampler(dataset)
|
sampler = OrderedDistributedSampler(dataset)
|
||||||
|
|
||||||
|
if collate_fn is None:
|
||||||
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
||||||
|
|
||||||
loader = torch.utils.data.DataLoader(
|
loader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=sampler is None and is_training,
|
shuffle=sampler is None and is_training,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
|
collate_fn=collate_fn,
|
||||||
drop_last=is_training,
|
drop_last=is_training,
|
||||||
)
|
)
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
|
42
data/mixup.py
Normal file
42
data/mixup.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
||||||
|
x = x.long().view(-1, 1)
|
||||||
|
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
||||||
|
|
||||||
|
|
||||||
|
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
||||||
|
off_value = smoothing / num_classes
|
||||||
|
on_value = 1. - smoothing + off_value
|
||||||
|
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
|
||||||
|
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
|
||||||
|
return lam*y1 + (1. - lam)*y2
|
||||||
|
|
||||||
|
|
||||||
|
class FastCollateMixup:
|
||||||
|
|
||||||
|
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
|
||||||
|
self.mixup_alpha = mixup_alpha
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.mixup_enabled = True
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
batch_size = len(batch)
|
||||||
|
lam = 1.
|
||||||
|
if self.mixup_enabled:
|
||||||
|
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||||
|
|
||||||
|
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||||
|
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
|
||||||
|
|
||||||
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||||
|
for i in range(batch_size):
|
||||||
|
mixed = batch[i][0].astype(np.float32) * lam + \
|
||||||
|
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
|
||||||
|
np.round(mixed, out=mixed)
|
||||||
|
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
|
||||||
|
|
||||||
|
return tensor, target
|
@ -159,7 +159,7 @@ def transforms_imagenet_train(
|
|||||||
color_jitter=(0.4, 0.4, 0.4),
|
color_jitter=(0.4, 0.4, 0.4),
|
||||||
interpolation='random',
|
interpolation='random',
|
||||||
random_erasing=0.4,
|
random_erasing=0.4,
|
||||||
random_erasing_pp=True,
|
random_erasing_mode='const',
|
||||||
use_prefetcher=False,
|
use_prefetcher=False,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD
|
std=IMAGENET_DEFAULT_STD
|
||||||
@ -183,7 +183,7 @@ def transforms_imagenet_train(
|
|||||||
std=torch.tensor(std))
|
std=torch.tensor(std))
|
||||||
]
|
]
|
||||||
if random_erasing > 0.:
|
if random_erasing > 0.:
|
||||||
tfl.append(RandomErasing(random_erasing, per_pixel=random_erasing_pp, device='cpu'))
|
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
||||||
return transforms.Compose(tfl)
|
return transforms.Compose(tfl)
|
||||||
|
|
||||||
|
|
||||||
|
41
train.py
41
train.py
@ -10,7 +10,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
from data import Dataset, create_loader, resolve_data_config
|
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
||||||
from models import create_model, resume_checkpoint
|
from models import create_model, resume_checkpoint
|
||||||
from utils import *
|
from utils import *
|
||||||
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
|
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
|
||||||
@ -66,9 +66,9 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RA
|
|||||||
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
||||||
help='LR scheduler (default: "step"')
|
help='LR scheduler (default: "step"')
|
||||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||||
help='Dropout rate (default: 0.1)')
|
help='Dropout rate (default: 0.)')
|
||||||
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
|
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||||
help='Random erase prob (default: 0.4)')
|
help='Random erase prob (default: 0.)')
|
||||||
parser.add_argument('--remode', type=str, default='const',
|
parser.add_argument('--remode', type=str, default='const',
|
||||||
help='Random erase mode (default: "const")')
|
help='Random erase mode (default: "const")')
|
||||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||||
@ -109,6 +109,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
|
|||||||
help='save images of input bathes every log interval for debugging')
|
help='save images of input bathes every log interval for debugging')
|
||||||
parser.add_argument('--amp', action='store_true', default=False,
|
parser.add_argument('--amp', action='store_true', default=False,
|
||||||
help='use NVIDIA amp for mixed precision training')
|
help='use NVIDIA amp for mixed precision training')
|
||||||
|
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||||
|
help='disable fast prefetcher')
|
||||||
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||||
help='path to output folder (default: none, current dir)')
|
help='path to output folder (default: none, current dir)')
|
||||||
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
|
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
|
||||||
@ -119,6 +121,7 @@ parser.add_argument("--local_rank", default=0, type=int)
|
|||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.prefetcher = not args.no_prefetcher
|
||||||
args.distributed = False
|
args.distributed = False
|
||||||
if 'WORLD_SIZE' in os.environ:
|
if 'WORLD_SIZE' in os.environ:
|
||||||
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||||
@ -130,6 +133,7 @@ def main():
|
|||||||
args.world_size = 1
|
args.world_size = 1
|
||||||
r = -1
|
r = -1
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
|
args.num_gpu = 1
|
||||||
args.device = 'cuda:%d' % args.local_rank
|
args.device = 'cuda:%d' % args.local_rank
|
||||||
torch.cuda.set_device(args.local_rank)
|
torch.cuda.set_device(args.local_rank)
|
||||||
torch.distributed.init_process_group(backend='nccl',
|
torch.distributed.init_process_group(backend='nccl',
|
||||||
@ -216,12 +220,16 @@ def main():
|
|||||||
exit(1)
|
exit(1)
|
||||||
dataset_train = Dataset(train_dir)
|
dataset_train = Dataset(train_dir)
|
||||||
|
|
||||||
|
collate_fn = None
|
||||||
|
if args.prefetcher and args.mixup > 0:
|
||||||
|
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
|
||||||
|
|
||||||
loader_train = create_loader(
|
loader_train = create_loader(
|
||||||
dataset_train,
|
dataset_train,
|
||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_prefetcher=True,
|
use_prefetcher=args.prefetcher,
|
||||||
rand_erase_prob=args.reprob,
|
rand_erase_prob=args.reprob,
|
||||||
rand_erase_mode=args.remode,
|
rand_erase_mode=args.remode,
|
||||||
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
||||||
@ -229,6 +237,7 @@ def main():
|
|||||||
std=data_config['std'],
|
std=data_config['std'],
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_dir = os.path.join(args.data, 'validation')
|
eval_dir = os.path.join(args.data, 'validation')
|
||||||
@ -242,7 +251,7 @@ def main():
|
|||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
batch_size=4 * args.batch_size,
|
batch_size=4 * args.batch_size,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
use_prefetcher=True,
|
use_prefetcher=args.prefetcher,
|
||||||
interpolation=data_config['interpolation'],
|
interpolation=data_config['interpolation'],
|
||||||
mean=data_config['mean'],
|
mean=data_config['mean'],
|
||||||
std=data_config['std'],
|
std=data_config['std'],
|
||||||
@ -309,6 +318,10 @@ def train_epoch(
|
|||||||
epoch, model, loader, optimizer, loss_fn, args,
|
epoch, model, loader, optimizer, loss_fn, args,
|
||||||
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
|
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
|
||||||
|
|
||||||
|
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
|
||||||
|
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||||
|
loader.mixup_enabled = False
|
||||||
|
|
||||||
batch_time_m = AverageMeter()
|
batch_time_m = AverageMeter()
|
||||||
data_time_m = AverageMeter()
|
data_time_m = AverageMeter()
|
||||||
losses_m = AverageMeter()
|
losses_m = AverageMeter()
|
||||||
@ -321,13 +334,15 @@ def train_epoch(
|
|||||||
for batch_idx, (input, target) in enumerate(loader):
|
for batch_idx, (input, target) in enumerate(loader):
|
||||||
last_batch = batch_idx == last_idx
|
last_batch = batch_idx == last_idx
|
||||||
data_time_m.update(time.time() - end)
|
data_time_m.update(time.time() - end)
|
||||||
|
if not args.prefetcher:
|
||||||
if args.mixup > 0.:
|
input = input.cuda()
|
||||||
lam = 1.
|
target = target.cuda()
|
||||||
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
|
if args.mixup > 0.:
|
||||||
lam = np.random.beta(args.mixup, args.mixup)
|
lam = 1.
|
||||||
input.mul_(lam).add_(1 - lam, input.flip(0))
|
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
|
||||||
target = mixup_target(target, args.num_classes, lam, args.smoothing)
|
lam = np.random.beta(args.mixup, args.mixup)
|
||||||
|
input.mul_(lam).add_(1 - lam, input.flip(0))
|
||||||
|
target = mixup_target(target, args.num_classes, lam, args.smoothing)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
|
13
utils.py
13
utils.py
@ -140,19 +140,6 @@ def accuracy(output, target, topk=(1,)):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
|
||||||
x = x.long().view(-1, 1)
|
|
||||||
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
|
||||||
|
|
||||||
|
|
||||||
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
|
|
||||||
off_value = smoothing / num_classes
|
|
||||||
on_value = 1. - smoothing + off_value
|
|
||||||
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
|
|
||||||
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
|
|
||||||
return lam*y1 + (1. - lam)*y2
|
|
||||||
|
|
||||||
|
|
||||||
def get_outdir(path, *paths, inc=False):
|
def get_outdir(path, *paths, inc=False):
|
||||||
outdir = os.path.join(path, *paths)
|
outdir = os.path.join(path, *paths)
|
||||||
if not os.path.exists(outdir):
|
if not os.path.exists(outdir):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user