Merge pull request #82 from xjinchuan/master

A bug about saving the model
revert-82-master
Hao Luo 2019-09-11 12:11:13 +08:00 committed by GitHub
commit 1ad455fb32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 7 deletions

View File

@ -156,8 +156,8 @@ def do_train(
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
timer = Timer(average=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
'optimizer': optimizer.state_dict()})
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
'optimizer': optimizer})
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
@ -235,10 +235,10 @@ def do_train_with_center(
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
timer = Timer(average=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'center_param': center_criterion.state_dict(),
'optimizer_center': optimizer_center.state_dict()})
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
'optimizer': optimizer,
'center_param': center_criterion,
'optimizer_center': optimizer_center})
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
@ -287,4 +287,4 @@ def do_train_with_center(
for r in [1, 5, 10]:
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
trainer.run(train_loader, max_epochs=epochs)
trainer.run(train_loader, max_epochs=epochs)