Add support for new AMP checkpointing support w/ amp.state_dict
parent
ba3c97c3ad
commit
3d9c8a6489
|
@ -29,7 +29,7 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
|
|||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path):
|
||||
optimizer_state = None
|
||||
other_state = {}
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
@ -40,7 +40,9 @@ def resume_checkpoint(model, checkpoint_path):
|
|||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
if 'optimizer' in checkpoint:
|
||||
optimizer_state = checkpoint['optimizer']
|
||||
other_state['optimizer'] = checkpoint['optimizer']
|
||||
if 'amp' in checkpoint:
|
||||
other_state['amp'] = checkpoint['amp']
|
||||
if 'epoch' in checkpoint:
|
||||
resume_epoch = checkpoint['epoch']
|
||||
if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||
|
@ -49,7 +51,7 @@ def resume_checkpoint(model, checkpoint_path):
|
|||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
return optimizer_state, resume_epoch
|
||||
return other_state, resume_epoch
|
||||
else:
|
||||
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
|
|
@ -11,6 +11,12 @@ import operator
|
|||
import logging
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
try:
|
||||
from apex import amp
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
amp = None
|
||||
has_apex = False
|
||||
|
||||
from torch import distributed as dist
|
||||
|
||||
|
@ -50,7 +56,7 @@ class CheckpointSaver:
|
|||
self.max_history = max_history
|
||||
assert self.max_history >= 1
|
||||
|
||||
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
|
||||
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
|
||||
assert epoch >= 0
|
||||
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||
if (len(self.checkpoint_files) < self.max_history
|
||||
|
@ -59,7 +65,7 @@ class CheckpointSaver:
|
|||
self._cleanup_checkpoints(1)
|
||||
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
|
||||
save_path = os.path.join(self.checkpoint_dir, filename)
|
||||
self._save(save_path, model, optimizer, args, epoch, model_ema, metric)
|
||||
self._save(save_path, model, optimizer, args, epoch, model_ema, metric, use_amp)
|
||||
self.checkpoint_files.append((save_path, metric))
|
||||
self.checkpoint_files = sorted(
|
||||
self.checkpoint_files, key=lambda x: x[1],
|
||||
|
@ -77,7 +83,7 @@ class CheckpointSaver:
|
|||
|
||||
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
||||
|
||||
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
|
||||
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
|
||||
save_state = {
|
||||
'epoch': epoch,
|
||||
'arch': args.model,
|
||||
|
@ -86,6 +92,8 @@ class CheckpointSaver:
|
|||
'args': args,
|
||||
'version': 2, # version < 2 increments epoch before save
|
||||
}
|
||||
if use_amp and 'state_dict' in amp.__dict__:
|
||||
save_state['amp'] = amp.state_dict()
|
||||
if model_ema is not None:
|
||||
save_state['state_dict_ema'] = get_state_dict(model_ema)
|
||||
if metric is not None:
|
||||
|
@ -106,11 +114,11 @@ class CheckpointSaver:
|
|||
logging.error("Exception '{}' while deleting checkpoint".format(e))
|
||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||
|
||||
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
|
||||
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
|
||||
assert epoch >= 0
|
||||
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
||||
save_path = os.path.join(self.recovery_dir, filename)
|
||||
self._save(save_path, model, optimizer, args, epoch, model_ema)
|
||||
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
|
||||
if os.path.exists(self.last_recovery_file):
|
||||
try:
|
||||
logging.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
||||
|
|
30
train.py
30
train.py
|
@ -38,6 +38,8 @@ parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH'
|
|||
help='Initialize model from this checkpoint (default: none)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='Resume full model and optimizer state from checkpoint (default: none)')
|
||||
parser.add_argument('--no-resume-opt', action='store_true', default=False,
|
||||
help='prevent resume of optimizer state when resuming model')
|
||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||
help='number of label classes (default: 1000)')
|
||||
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
|
||||
|
@ -189,12 +191,6 @@ def main():
|
|||
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
optimizer_state = None
|
||||
resume_epoch = None
|
||||
if args.resume:
|
||||
optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
if args.amp:
|
||||
logging.warning(
|
||||
|
@ -205,8 +201,6 @@ def main():
|
|||
model.cuda()
|
||||
|
||||
optimizer = create_optimizer(args, model)
|
||||
if optimizer_state is not None:
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
use_amp = False
|
||||
if has_apex and args.amp:
|
||||
|
@ -216,6 +210,22 @@ def main():
|
|||
logging.info('NVIDIA APEX {}. AMP {}.'.format(
|
||||
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
resume_state = {}
|
||||
resume_epoch = None
|
||||
if args.resume:
|
||||
resume_state, resume_epoch = resume_checkpoint(model, args.resume)
|
||||
if resume_state and not args.no_resume_opt:
|
||||
if 'optimizer' in resume_state:
|
||||
if args.local_rank == 0:
|
||||
logging.info('Restoring Optimizer state from checkpoint')
|
||||
optimizer.load_state_dict(resume_state['optimizer'])
|
||||
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
|
||||
if args.local_rank == 0:
|
||||
logging.info('Restoring NVIDIA AMP state from checkpoint')
|
||||
amp.load_state_dict(resume_state['amp'])
|
||||
resume_state = None
|
||||
|
||||
model_ema = None
|
||||
if args.model_ema:
|
||||
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
||||
|
@ -363,7 +373,7 @@ def main():
|
|||
save_metric = eval_metrics[eval_metric]
|
||||
best_metric, best_epoch = saver.save_checkpoint(
|
||||
model, optimizer, args,
|
||||
epoch=epoch, model_ema=model_ema, metric=save_metric)
|
||||
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
@ -456,7 +466,7 @@ def train_epoch(
|
|||
if saver is not None and args.recovery_interval and (
|
||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||
saver.save_recovery(
|
||||
model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
|
||||
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||
|
|
Loading…
Reference in New Issue