add remove_module_from_keys option to save_ckpt

pull/119/head
KaiyangZhou 2019-02-20 22:13:26 +00:00
parent 2fb120ecbc
commit 19e03d4097
1 changed files with 10 additions and 9 deletions

View File

@ -14,16 +14,17 @@ import torch.nn as nn
from .iotools import mkdir_if_missing
def save_checkpoint(state, save_dir, is_best=False):
def save_checkpoint(state, save_dir, is_best=False, remove_module_from_keys=False):
mkdir_if_missing(save_dir)
# remove 'module.' in state_dict's keys if necessary
state_dict = state['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
new_state_dict[k] = v
state['state_dict'] = new_state_dict
if remove_module_from_keys:
# remove 'module.' in state_dict's keys
state_dict = state['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
new_state_dict[k] = v
state['state_dict'] = new_state_dict
# save
epoch = state['epoch']
fpath = osp.join(save_dir, 'model.pth.tar-' + str(epoch))