mirror of https://github.com/facebookresearch/deit
Add loss_scaler to checkpoints (#49)
Ensure bit-wise reproducibility when rescheduling jobspull/52/head
parent
a8e90967a3
commit
adc7f1e849
3
main.py
3
main.py
|
@ -357,6 +357,8 @@ def main(args):
|
|||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
if args.model_ema:
|
||||
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
|
||||
if 'scaler' in checkpoint:
|
||||
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||
|
||||
if args.eval:
|
||||
test_stats = evaluate(data_loader_val, model, device)
|
||||
|
@ -387,6 +389,7 @@ def main(args):
|
|||
'lr_scheduler': lr_scheduler.state_dict(),
|
||||
'epoch': epoch,
|
||||
'model_ema': get_state_dict(model_ema),
|
||||
'scaler': loss_scaler.state_dict(),
|
||||
'args': args,
|
||||
}, checkpoint_path)
|
||||
|
||||
|
|
Loading…
Reference in New Issue