commit
1ad455fb32
|
@ -156,8 +156,8 @@ def do_train(
|
|||
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
|
||||
timer = Timer(average=True)
|
||||
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict()})
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
|
||||
'optimizer': optimizer})
|
||||
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
||||
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
|
||||
|
||||
|
@ -235,10 +235,10 @@ def do_train_with_center(
|
|||
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
|
||||
timer = Timer(average=True)
|
||||
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'center_param': center_criterion.state_dict(),
|
||||
'optimizer_center': optimizer_center.state_dict()})
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
|
||||
'optimizer': optimizer,
|
||||
'center_param': center_criterion,
|
||||
'optimizer_center': optimizer_center})
|
||||
|
||||
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
||||
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
|
||||
|
@ -287,4 +287,4 @@ def do_train_with_center(
|
|||
for r in [1, 5, 10]:
|
||||
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
||||
|
||||
trainer.run(train_loader, max_epochs=epochs)
|
||||
trainer.run(train_loader, max_epochs=epochs)
|
||||
|
|
Loading…
Reference in New Issue