fix breakpoint reloading bug
parent
f9e80cba49
commit
da3f2fea1a
|
@ -237,6 +237,7 @@ def do_train_with_center(
|
|||
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'center_param': center_criterion.state_dict(),
|
||||
'optimizer_center': optimizer_center.state_dict()})
|
||||
|
||||
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
||||
|
|
|
@ -82,10 +82,13 @@ def train(cfg):
|
|||
print('Start epoch:', start_epoch)
|
||||
path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer')
|
||||
print('Path to the checkpoint of optimizer:', path_to_optimizer)
|
||||
path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param')
|
||||
print('Path to the checkpoint of center_param:', path_to_center_param)
|
||||
path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center')
|
||||
print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center)
|
||||
model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH))
|
||||
optimizer.load_state_dict(torch.load(path_to_optimizer))
|
||||
center_criterion.load_state_dict(torch.load(path_to_center_param))
|
||||
optimizer_center.load_state_dict(torch.load(path_to_optimizer_center))
|
||||
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
||||
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
|
||||
|
|
Loading…
Reference in New Issue