Add loss_scaler to checkpoints (#49)

Ensure bit-wise reproducibility when rescheduling jobs
pull/52/head
Francisco Massa 2021-01-15 10:52:41 +01:00 committed by GitHub
parent a8e90967a3
commit adc7f1e849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 0 deletions

View File

@ -357,6 +357,8 @@ def main(args):
args.start_epoch = checkpoint['epoch'] + 1
if args.model_ema:
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
if args.eval:
test_stats = evaluate(data_loader_val, model, device)
@ -387,6 +389,7 @@ def main(args):
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema),
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)