mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
add support for native torch AMP in torch 1.6
This commit is contained in:
parent
470220b1f4
commit
d98967ed5d
39
train.py
39
train.py
@ -25,8 +25,11 @@ try:
|
|||||||
from apex.parallel import convert_syncbn_model
|
from apex.parallel import convert_syncbn_model
|
||||||
has_apex = True
|
has_apex = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
from torch.cuda import amp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||||
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
||||||
@ -327,6 +330,10 @@ def main():
|
|||||||
if has_apex and args.amp:
|
if has_apex and args.amp:
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||||
use_amp = True
|
use_amp = True
|
||||||
|
elif args.amp:
|
||||||
|
_logger.info('Using torch AMP. Install NVIDIA Apex for Apex AMP.')
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
use_amp = True
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info('NVIDIA APEX {}. AMP {}.'.format(
|
_logger.info('NVIDIA APEX {}. AMP {}.'.format(
|
||||||
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
||||||
@ -506,7 +513,8 @@ def main():
|
|||||||
train_metrics = train_epoch(
|
train_metrics = train_epoch(
|
||||||
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
||||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||||
use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn)
|
use_amp=use_amp, has_apex=has_apex, scaler = scaler,
|
||||||
|
model_ema=model_ema, mixup_fn=mixup_fn)
|
||||||
|
|
||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
@ -546,7 +554,8 @@ def main():
|
|||||||
|
|
||||||
def train_epoch(
|
def train_epoch(
|
||||||
epoch, model, loader, optimizer, loss_fn, args,
|
epoch, model, loader, optimizer, loss_fn, args,
|
||||||
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None):
|
lr_scheduler=None, saver=None, output_dir='', use_amp=False,
|
||||||
|
has_apex=False, scaler = None, model_ema=None, mixup_fn=None):
|
||||||
|
|
||||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||||
if args.prefetcher and loader.mixup_enabled:
|
if args.prefetcher and loader.mixup_enabled:
|
||||||
@ -570,20 +579,32 @@ def train_epoch(
|
|||||||
input, target = input.cuda(), target.cuda()
|
input, target = input.cuda(), target.cuda()
|
||||||
if mixup_fn is not None:
|
if mixup_fn is not None:
|
||||||
input, target = mixup_fn(input, target)
|
input, target = mixup_fn(input, target)
|
||||||
|
if not has_apex and use_amp:
|
||||||
output = model(input)
|
with torch.cuda.amp.autocast():
|
||||||
|
output = model(input)
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target)
|
||||||
|
else:
|
||||||
|
output = model(input)
|
||||||
|
loss = loss_fn(output, target)
|
||||||
|
|
||||||
if not args.distributed:
|
if not args.distributed:
|
||||||
losses_m.update(loss.item(), input.size(0))
|
losses_m.update(loss.item(), input.size(0))
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if use_amp:
|
if use_amp:
|
||||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
if has_apex:
|
||||||
scaled_loss.backward()
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
if not has_apex and use_amp:
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
if model_ema is not None:
|
if model_ema is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user