fix breakpoint reloading bug

pull/51/head
shaoniangu 2019-07-05 01:40:45 +08:00
parent f9e80cba49
commit da3f2fea1a
2 changed files with 4 additions and 0 deletions

View File

@ -237,6 +237,7 @@ def do_train_with_center(
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'center_param': center_criterion.state_dict(),
'optimizer_center': optimizer_center.state_dict()})
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,

View File

@ -82,10 +82,13 @@ def train(cfg):
print('Start epoch:', start_epoch)
path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
print('Path to the checkpoint of optimizer:', path_to_optimizer)
path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param')
print('Path to the checkpoint of center_param:', path_to_center_param)
path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center')
print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center)
model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
optimizer.load_state_dict(torch.load(path_to_optimizer))
center_criterion.load_state_dict(torch.load(path_to_center_param))
optimizer_center.load_state_dict(torch.load(path_to_optimizer_center))
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)