From da3f2fea1a183d125f3009f9269759a8bd474830 Mon Sep 17 00:00:00 2001 From: shaoniangu Date: Fri, 5 Jul 2019 01:40:45 +0800 Subject: [PATCH] fix breakpoint reloading bug --- engine/trainer.py | 1 + tools/train.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/engine/trainer.py b/engine/trainer.py index c0b506c..abc9e90 100644 --- a/engine/trainer.py +++ b/engine/trainer.py @@ -237,6 +237,7 @@ def do_train_with_center( trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), + 'center_param': center_criterion.state_dict(), 'optimizer_center': optimizer_center.state_dict()}) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, diff --git a/tools/train.py b/tools/train.py index 53d4c12..3bb7835 100644 --- a/tools/train.py +++ b/tools/train.py @@ -82,10 +82,13 @@ def train(cfg): print('Start epoch:', start_epoch) path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') print('Path to the checkpoint of optimizer:', path_to_optimizer) + path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param') + print('Path to the checkpoint of center_param:', path_to_center_param) path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center') print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center) model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) optimizer.load_state_dict(torch.load(path_to_optimizer)) + center_criterion.load_state_dict(torch.load(path_to_center_param)) optimizer_center.load_state_dict(torch.load(path_to_optimizer_center)) scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)