mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improve torch amp support and add channels_last support for train/validate scripts
This commit is contained in:
parent
1d34a0a851
commit
c2cd1a332e
@ -49,7 +49,8 @@ class CheckpointSaver:
|
|||||||
checkpoint_dir='',
|
checkpoint_dir='',
|
||||||
recovery_dir='',
|
recovery_dir='',
|
||||||
decreasing=False,
|
decreasing=False,
|
||||||
max_history=10):
|
max_history=10,
|
||||||
|
save_amp=False):
|
||||||
|
|
||||||
# state
|
# state
|
||||||
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
|
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
|
||||||
@ -67,13 +68,14 @@ class CheckpointSaver:
|
|||||||
self.decreasing = decreasing # a lower metric is better if True
|
self.decreasing = decreasing # a lower metric is better if True
|
||||||
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
|
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
|
||||||
self.max_history = max_history
|
self.max_history = max_history
|
||||||
|
self.save_apex_amp = save_amp # save APEX amp state
|
||||||
assert self.max_history >= 1
|
assert self.max_history >= 1
|
||||||
|
|
||||||
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
|
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
|
||||||
assert epoch >= 0
|
assert epoch >= 0
|
||||||
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
||||||
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
||||||
self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric, use_amp)
|
self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric)
|
||||||
if os.path.exists(last_save_path):
|
if os.path.exists(last_save_path):
|
||||||
os.unlink(last_save_path) # required for Windows support.
|
os.unlink(last_save_path) # required for Windows support.
|
||||||
os.rename(tmp_save_path, last_save_path)
|
os.rename(tmp_save_path, last_save_path)
|
||||||
@ -105,7 +107,7 @@ class CheckpointSaver:
|
|||||||
|
|
||||||
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
||||||
|
|
||||||
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
|
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
|
||||||
save_state = {
|
save_state = {
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
'arch': args.model,
|
'arch': args.model,
|
||||||
@ -114,7 +116,7 @@ class CheckpointSaver:
|
|||||||
'args': args,
|
'args': args,
|
||||||
'version': 2, # version < 2 increments epoch before save
|
'version': 2, # version < 2 increments epoch before save
|
||||||
}
|
}
|
||||||
if use_amp and 'state_dict' in amp.__dict__:
|
if self.save_apex_amp and 'state_dict' in amp.__dict__:
|
||||||
save_state['amp'] = amp.state_dict()
|
save_state['amp'] = amp.state_dict()
|
||||||
if model_ema is not None:
|
if model_ema is not None:
|
||||||
save_state['state_dict_ema'] = get_state_dict(model_ema)
|
save_state['state_dict_ema'] = get_state_dict(model_ema)
|
||||||
@ -136,11 +138,11 @@ class CheckpointSaver:
|
|||||||
_logger.error("Exception '{}' while deleting checkpoint".format(e))
|
_logger.error("Exception '{}' while deleting checkpoint".format(e))
|
||||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||||
|
|
||||||
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
|
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
|
||||||
assert epoch >= 0
|
assert epoch >= 0
|
||||||
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
||||||
save_path = os.path.join(self.recovery_dir, filename)
|
save_path = os.path.join(self.recovery_dir, filename)
|
||||||
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
|
self._save(save_path, model, optimizer, args, epoch, model_ema)
|
||||||
if os.path.exists(self.last_recovery_file):
|
if os.path.exists(self.last_recovery_file):
|
||||||
try:
|
try:
|
||||||
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
||||||
|
171
train.py
171
train.py
@ -18,18 +18,12 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
import yaml
|
import yaml
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
try:
|
import torch
|
||||||
from apex import amp
|
import torch.nn as nn
|
||||||
from apex.parallel import DistributedDataParallel as DDP
|
import torchvision.utils
|
||||||
from apex.parallel import convert_syncbn_model
|
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||||
has_apex = True
|
|
||||||
except ImportError:
|
|
||||||
from torch.cuda import amp
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
has_apex = False
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||||
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
||||||
@ -38,14 +32,24 @@ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCro
|
|||||||
from timm.optim import create_optimizer
|
from timm.optim import create_optimizer
|
||||||
from timm.scheduler import create_scheduler
|
from timm.scheduler import create_scheduler
|
||||||
|
|
||||||
import torch
|
try:
|
||||||
import torch.nn as nn
|
from apex import amp
|
||||||
import torchvision.utils
|
from apex.parallel import DistributedDataParallel as ApexDDP
|
||||||
|
from apex.parallel import convert_syncbn_model
|
||||||
|
has_apex = True
|
||||||
|
except ImportError:
|
||||||
|
has_apex = False
|
||||||
|
|
||||||
|
has_native_amp = False
|
||||||
|
try:
|
||||||
|
if getattr(torch.cuda.amp, 'autocast') is not None:
|
||||||
|
has_native_amp = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
_logger = logging.getLogger('train')
|
_logger = logging.getLogger('train')
|
||||||
|
|
||||||
|
|
||||||
# The first arg parser parses out only the --config argument, this argument is used to
|
# The first arg parser parses out only the --config argument, this argument is used to
|
||||||
# load a yaml file containing key-values that override the defaults for the main parser below
|
# load a yaml file containing key-values that override the defaults for the main parser below
|
||||||
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
|
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
|
||||||
@ -221,7 +225,13 @@ parser.add_argument('--num-gpu', type=int, default=1,
|
|||||||
parser.add_argument('--save-images', action='store_true', default=False,
|
parser.add_argument('--save-images', action='store_true', default=False,
|
||||||
help='save images of input bathes every log interval for debugging')
|
help='save images of input bathes every log interval for debugging')
|
||||||
parser.add_argument('--amp', action='store_true', default=False,
|
parser.add_argument('--amp', action='store_true', default=False,
|
||||||
help='use NVIDIA amp for mixed precision training')
|
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
|
||||||
|
parser.add_argument('--apex-amp', action='store_true', default=False,
|
||||||
|
help='Use NVIDIA Apex AMP mixed precision')
|
||||||
|
parser.add_argument('--native-amp', action='store_true', default=False,
|
||||||
|
help='Use Native Torch AMP mixed precision')
|
||||||
|
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||||
|
help='Use channels_last memory layout')
|
||||||
parser.add_argument('--pin-mem', action='store_true', default=False,
|
parser.add_argument('--pin-mem', action='store_true', default=False,
|
||||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||||
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||||
@ -254,6 +264,23 @@ def _parse_args():
|
|||||||
return args, args_text
|
return args, args_text
|
||||||
|
|
||||||
|
|
||||||
|
class ApexScaler:
|
||||||
|
def __call__(self, loss, optimizer):
|
||||||
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
class NativeScaler:
|
||||||
|
def __init__(self):
|
||||||
|
self._scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
def __call__(self, loss, optimizer):
|
||||||
|
self._scaler.scale(loss).backward()
|
||||||
|
self._scaler.step(optimizer)
|
||||||
|
self._scaler.update()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
setup_default_logging()
|
setup_default_logging()
|
||||||
args, args_text = _parse_args()
|
args, args_text = _parse_args()
|
||||||
@ -263,7 +290,8 @@ def main():
|
|||||||
if 'WORLD_SIZE' in os.environ:
|
if 'WORLD_SIZE' in os.environ:
|
||||||
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||||
if args.distributed and args.num_gpu > 1:
|
if args.distributed and args.num_gpu > 1:
|
||||||
_logger.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
|
_logger.warning(
|
||||||
|
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
|
||||||
args.num_gpu = 1
|
args.num_gpu = 1
|
||||||
|
|
||||||
args.device = 'cuda:0'
|
args.device = 'cuda:0'
|
||||||
@ -315,28 +343,50 @@ def main():
|
|||||||
assert num_aug_splits > 1 or args.resplit
|
assert num_aug_splits > 1 or args.resplit
|
||||||
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||||
|
|
||||||
if args.num_gpu > 1:
|
use_amp = None
|
||||||
if args.amp:
|
if args.amp:
|
||||||
|
# for backwards compat, `--amp` arg tries apex before native amp
|
||||||
|
if has_apex:
|
||||||
|
args.apex_amp = True
|
||||||
|
elif has_native_amp:
|
||||||
|
args.native_amp = True
|
||||||
|
if args.apex_amp and has_apex:
|
||||||
|
use_amp = 'apex'
|
||||||
|
elif args.native_amp and has_native_amp:
|
||||||
|
use_amp = 'native'
|
||||||
|
elif args.apex_amp or args.native_amp:
|
||||||
|
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
||||||
|
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
||||||
|
|
||||||
|
if args.num_gpu > 1:
|
||||||
|
if use_amp == 'apex':
|
||||||
_logger.warning(
|
_logger.warning(
|
||||||
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
|
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
|
||||||
args.amp = False
|
use_amp = None
|
||||||
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||||
|
assert not args.channels_last, "Channels last not supported with DP, use DDP."
|
||||||
else:
|
else:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
if args.channels_last:
|
||||||
|
model = model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
optimizer = create_optimizer(args, model)
|
optimizer = create_optimizer(args, model)
|
||||||
|
|
||||||
use_amp = False
|
amp_autocast = suppress # do nothing
|
||||||
if has_apex and args.amp:
|
loss_scaler = None
|
||||||
|
if use_amp == 'apex':
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||||
use_amp = True
|
loss_scaler = ApexScaler()
|
||||||
elif args.amp:
|
|
||||||
_logger.info('Using torch AMP. Install NVIDIA Apex for Apex AMP.')
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
|
||||||
use_amp = True
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info('NVIDIA APEX {}. AMP {}.'.format(
|
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
||||||
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
elif use_amp == 'native':
|
||||||
|
amp_autocast = torch.cuda.amp.autocast
|
||||||
|
loss_scaler = NativeScaler()
|
||||||
|
if args.local_rank == 0:
|
||||||
|
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||||
|
else:
|
||||||
|
if args.local_rank == 0:
|
||||||
|
_logger.info('AMP not enabled. Training in float32.')
|
||||||
|
|
||||||
# optionally resume from a checkpoint
|
# optionally resume from a checkpoint
|
||||||
resume_state = {}
|
resume_state = {}
|
||||||
@ -346,7 +396,7 @@ def main():
|
|||||||
if resume_state and not args.no_resume_opt:
|
if resume_state and not args.no_resume_opt:
|
||||||
if 'optimizer' in resume_state:
|
if 'optimizer' in resume_state:
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info('Restoring Optimizer state from checkpoint')
|
_logger.info('Restoring optimizer state from checkpoint')
|
||||||
optimizer.load_state_dict(resume_state['optimizer'])
|
optimizer.load_state_dict(resume_state['optimizer'])
|
||||||
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
|
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
@ -367,7 +417,8 @@ def main():
|
|||||||
if args.sync_bn:
|
if args.sync_bn:
|
||||||
assert not args.split_bn
|
assert not args.split_bn
|
||||||
try:
|
try:
|
||||||
if has_apex:
|
if has_apex and use_amp != 'native':
|
||||||
|
# Apex SyncBN preferred unless native amp is activated
|
||||||
model = convert_syncbn_model(model)
|
model = convert_syncbn_model(model)
|
||||||
else:
|
else:
|
||||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
@ -377,12 +428,15 @@ def main():
|
|||||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
|
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
|
||||||
if has_apex:
|
if has_apex and use_amp != 'native':
|
||||||
model = DDP(model, delay_allreduce=True)
|
# Apex DDP preferred unless native amp is activated
|
||||||
|
if args.local_rank == 0:
|
||||||
|
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
|
||||||
|
model = ApexDDP(model, delay_allreduce=True)
|
||||||
else:
|
else:
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
|
_logger.info("Using native Torch DistributedDataParallel.")
|
||||||
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
||||||
# NOTE: EMA model does not need to be wrapped by DDP
|
# NOTE: EMA model does not need to be wrapped by DDP
|
||||||
|
|
||||||
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
||||||
@ -501,7 +555,7 @@ def main():
|
|||||||
])
|
])
|
||||||
output_dir = get_outdir(output_base, 'train', exp_name)
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
||||||
decreasing = True if eval_metric == 'loss' else False
|
decreasing = True if eval_metric == 'loss' else False
|
||||||
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing, save_amp=use_amp == 'apex')
|
||||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||||
f.write(args_text)
|
f.write(args_text)
|
||||||
|
|
||||||
@ -513,22 +567,20 @@ def main():
|
|||||||
train_metrics = train_epoch(
|
train_metrics = train_epoch(
|
||||||
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
||||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||||
use_amp=use_amp, has_apex=has_apex, scaler = scaler,
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
||||||
model_ema=model_ema, mixup_fn=mixup_fn)
|
|
||||||
|
|
||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info("Distributing BatchNorm running means and vars")
|
_logger.info("Distributing BatchNorm running means and vars")
|
||||||
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||||
|
|
||||||
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
|
||||||
|
|
||||||
if model_ema is not None and not args.model_ema_force_cpu:
|
if model_ema is not None and not args.model_ema_force_cpu:
|
||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||||
|
|
||||||
ema_eval_metrics = validate(
|
ema_eval_metrics = validate(
|
||||||
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
|
model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
|
||||||
eval_metrics = ema_eval_metrics
|
eval_metrics = ema_eval_metrics
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
@ -543,8 +595,7 @@ def main():
|
|||||||
# save proper checkpoint with eval metric
|
# save proper checkpoint with eval metric
|
||||||
save_metric = eval_metrics[eval_metric]
|
save_metric = eval_metrics[eval_metric]
|
||||||
best_metric, best_epoch = saver.save_checkpoint(
|
best_metric, best_epoch = saver.save_checkpoint(
|
||||||
model, optimizer, args,
|
model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric)
|
||||||
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
@ -554,8 +605,8 @@ def main():
|
|||||||
|
|
||||||
def train_epoch(
|
def train_epoch(
|
||||||
epoch, model, loader, optimizer, loss_fn, args,
|
epoch, model, loader, optimizer, loss_fn, args,
|
||||||
lr_scheduler=None, saver=None, output_dir='', use_amp=False,
|
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
|
||||||
has_apex=False, scaler = None, model_ema=None, mixup_fn=None):
|
loss_scaler=None, model_ema=None, mixup_fn=None):
|
||||||
|
|
||||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||||
if args.prefetcher and loader.mixup_enabled:
|
if args.prefetcher and loader.mixup_enabled:
|
||||||
@ -579,11 +630,10 @@ def train_epoch(
|
|||||||
input, target = input.cuda(), target.cuda()
|
input, target = input.cuda(), target.cuda()
|
||||||
if mixup_fn is not None:
|
if mixup_fn is not None:
|
||||||
input, target = mixup_fn(input, target)
|
input, target = mixup_fn(input, target)
|
||||||
if not has_apex and use_amp:
|
if args.channels_last:
|
||||||
with torch.cuda.amp.autocast():
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
output = model(input)
|
|
||||||
loss = loss_fn(output, target)
|
with amp_autocast():
|
||||||
else:
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target)
|
||||||
|
|
||||||
@ -591,19 +641,10 @@ def train_epoch(
|
|||||||
losses_m.update(loss.item(), input.size(0))
|
losses_m.update(loss.item(), input.size(0))
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if use_amp:
|
if loss_scaler is not None:
|
||||||
if has_apex:
|
loss_scaler(loss, optimizer)
|
||||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
||||||
scaled_loss.backward()
|
|
||||||
else:
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if not has_apex and use_amp:
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
else:
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -648,8 +689,7 @@ def train_epoch(
|
|||||||
if saver is not None and args.recovery_interval and (
|
if saver is not None and args.recovery_interval and (
|
||||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||||
|
|
||||||
saver.save_recovery(
|
saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
|
||||||
model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx)
|
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||||
@ -663,7 +703,7 @@ def train_epoch(
|
|||||||
return OrderedDict([('loss', losses_m.avg)])
|
return OrderedDict([('loss', losses_m.avg)])
|
||||||
|
|
||||||
|
|
||||||
def validate(model, loader, loss_fn, args, log_suffix=''):
|
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
|
||||||
batch_time_m = AverageMeter()
|
batch_time_m = AverageMeter()
|
||||||
losses_m = AverageMeter()
|
losses_m = AverageMeter()
|
||||||
top1_m = AverageMeter()
|
top1_m = AverageMeter()
|
||||||
@ -679,7 +719,10 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
|
|||||||
if not args.prefetcher:
|
if not args.prefetcher:
|
||||||
input = input.cuda()
|
input = input.cuda()
|
||||||
target = target.cuda()
|
target = target.cuda()
|
||||||
|
if args.channels_last:
|
||||||
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
with amp_autocast():
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if isinstance(output, (tuple, list)):
|
if isinstance(output, (tuple, list)):
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
59
validate.py
59
validate.py
@ -17,17 +17,26 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from contextlib import suppress
|
||||||
try:
|
|
||||||
from apex import amp
|
|
||||||
has_apex = True
|
|
||||||
except ImportError:
|
|
||||||
has_apex = False
|
|
||||||
|
|
||||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
|
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
|
||||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||||
|
|
||||||
|
has_apex = False
|
||||||
|
try:
|
||||||
|
from apex import amp
|
||||||
|
has_apex = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
has_native_amp = False
|
||||||
|
try:
|
||||||
|
if getattr(torch.cuda.amp, 'autocast') is not None:
|
||||||
|
has_native_amp = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
_logger = logging.getLogger('validate')
|
_logger = logging.getLogger('validate')
|
||||||
|
|
||||||
@ -69,8 +78,14 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|||||||
help='disable fast prefetcher')
|
help='disable fast prefetcher')
|
||||||
parser.add_argument('--pin-mem', action='store_true', default=False,
|
parser.add_argument('--pin-mem', action='store_true', default=False,
|
||||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||||
|
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||||
|
help='Use channels_last memory layout')
|
||||||
parser.add_argument('--amp', action='store_true', default=False,
|
parser.add_argument('--amp', action='store_true', default=False,
|
||||||
help='Use AMP mixed precision')
|
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
|
||||||
|
parser.add_argument('--apex-amp', action='store_true', default=False,
|
||||||
|
help='Use NVIDIA Apex AMP mixed precision')
|
||||||
|
parser.add_argument('--native-amp', action='store_true', default=False,
|
||||||
|
help='Use Native Torch AMP mixed precision')
|
||||||
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
||||||
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||||
@ -104,6 +119,18 @@ def validate(args):
|
|||||||
# might as well try to validate something
|
# might as well try to validate something
|
||||||
args.pretrained = args.pretrained or not args.checkpoint
|
args.pretrained = args.pretrained or not args.checkpoint
|
||||||
args.prefetcher = not args.no_prefetcher
|
args.prefetcher = not args.no_prefetcher
|
||||||
|
amp_autocast = suppress # do nothing
|
||||||
|
if args.amp:
|
||||||
|
if has_apex:
|
||||||
|
args.apex_amp = True
|
||||||
|
elif has_native_amp:
|
||||||
|
args.native_amp = True
|
||||||
|
else:
|
||||||
|
_logger.warning("Neither APEX or Native Torch AMP is available, using FP32.")
|
||||||
|
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
|
||||||
|
if args.native_amp:
|
||||||
|
amp_autocast = torch.cuda.amp.autocast
|
||||||
|
|
||||||
if args.legacy_jit:
|
if args.legacy_jit:
|
||||||
set_jit_legacy()
|
set_jit_legacy()
|
||||||
|
|
||||||
@ -128,10 +155,12 @@ def validate(args):
|
|||||||
torch.jit.optimized_execution(True)
|
torch.jit.optimized_execution(True)
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
if args.amp:
|
|
||||||
model = amp.initialize(model.cuda(), opt_level='O1')
|
|
||||||
else:
|
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
if args.apex_amp:
|
||||||
|
model = amp.initialize(model, opt_level='O1')
|
||||||
|
|
||||||
|
if args.channels_last:
|
||||||
|
model = model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
if args.num_gpu > 1:
|
if args.num_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
||||||
@ -178,17 +207,21 @@ def validate(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
||||||
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
|
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
|
||||||
|
if args.channels_last:
|
||||||
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
model(input)
|
model(input)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
for batch_idx, (input, target) in enumerate(loader):
|
for batch_idx, (input, target) in enumerate(loader):
|
||||||
if args.no_prefetcher:
|
if args.no_prefetcher:
|
||||||
target = target.cuda()
|
target = target.cuda()
|
||||||
input = input.cuda()
|
input = input.cuda()
|
||||||
if args.fp16:
|
if args.channels_last:
|
||||||
input = input.half()
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
# compute output
|
# compute output
|
||||||
|
with amp_autocast():
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
if valid_labels is not None:
|
if valid_labels is not None:
|
||||||
output = output[:, valid_labels]
|
output = output[:, valid_labels]
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
@ -197,7 +230,7 @@ def validate(args):
|
|||||||
real_labels.add_result(output)
|
real_labels.add_result(output)
|
||||||
|
|
||||||
# measure accuracy and record loss
|
# measure accuracy and record loss
|
||||||
acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
|
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
|
||||||
losses.update(loss.item(), input.size(0))
|
losses.update(loss.item(), input.size(0))
|
||||||
top1.update(acc1.item(), input.size(0))
|
top1.update(acc1.item(), input.size(0))
|
||||||
top5.update(acc5.item(), input.size(0))
|
top5.update(acc5.item(), input.size(0))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user