mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix save_checkpoint bug with native amp
This commit is contained in:
parent
d98967ed5d
commit
5f563ca4df
5
train.py
5
train.py
@ -544,7 +544,7 @@ def main():
|
|||||||
save_metric = eval_metrics[eval_metric]
|
save_metric = eval_metrics[eval_metric]
|
||||||
best_metric, best_epoch = saver.save_checkpoint(
|
best_metric, best_epoch = saver.save_checkpoint(
|
||||||
model, optimizer, args,
|
model, optimizer, args,
|
||||||
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
|
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
@ -647,8 +647,9 @@ def train_epoch(
|
|||||||
|
|
||||||
if saver is not None and args.recovery_interval and (
|
if saver is not None and args.recovery_interval and (
|
||||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||||
|
|
||||||
saver.save_recovery(
|
saver.save_recovery(
|
||||||
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
|
model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx)
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user