diff --git a/engine/trainer.py b/engine/trainer.py index abc9e90..2ccc6bc 100644 --- a/engine/trainer.py +++ b/engine/trainer.py @@ -156,8 +156,8 @@ def do_train( checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) timer = Timer(average=True) - trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(), - 'optimizer': optimizer.state_dict()}) + trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model(), + 'optimizer': optimizer()}) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @@ -235,10 +235,10 @@ def do_train_with_center( checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) timer = Timer(average=True) - trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'center_param': center_criterion.state_dict(), - 'optimizer_center': optimizer_center.state_dict()}) + trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model, + 'optimizer': optimizer, + 'center_param': center_criterion, + 'optimizer_center': optimizer_center}) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @@ -287,4 +287,4 @@ def do_train_with_center( for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) - trainer.run(train_loader, max_epochs=epochs) \ No newline at end of file + trainer.run(train_loader, max_epochs=epochs)