From adc7f1e849f61df54b1a30a3a9755664f2030e05 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 15 Jan 2021 10:52:41 +0100 Subject: [PATCH] Add loss_scaler to checkpoints (#49) Ensure bit-wise reproducibility when rescheduling jobs --- main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main.py b/main.py index 2c2cc56..2d93567 100644 --- a/main.py +++ b/main.py @@ -357,6 +357,8 @@ def main(args): args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate(data_loader_val, model, device) @@ -387,6 +389,7 @@ def main(args): 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), + 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path)