fix save_checkpoint bug with native amp
parent
d98967ed5d
commit
5f563ca4df
5
train.py
5
train.py
|
@ -544,7 +544,7 @@ def main():
|
|||
save_metric = eval_metrics[eval_metric]
|
||||
best_metric, best_epoch = saver.save_checkpoint(
|
||||
model, optimizer, args,
|
||||
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
|
||||
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
@ -647,8 +647,9 @@ def train_epoch(
|
|||
|
||||
if saver is not None and args.recovery_interval and (
|
||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||
|
||||
saver.save_recovery(
|
||||
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
|
||||
model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||
|
|
Loading…
Reference in New Issue