add remove_module_from_keys option to save_ckpt
parent
2fb120ecbc
commit
19e03d4097
|
@ -14,16 +14,17 @@ import torch.nn as nn
|
|||
from .iotools import mkdir_if_missing
|
||||
|
||||
|
||||
def save_checkpoint(state, save_dir, is_best=False):
|
||||
def save_checkpoint(state, save_dir, is_best=False, remove_module_from_keys=False):
|
||||
mkdir_if_missing(save_dir)
|
||||
# remove 'module.' in state_dict's keys if necessary
|
||||
state_dict = state['state_dict']
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:]
|
||||
new_state_dict[k] = v
|
||||
state['state_dict'] = new_state_dict
|
||||
if remove_module_from_keys:
|
||||
# remove 'module.' in state_dict's keys
|
||||
state_dict = state['state_dict']
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:]
|
||||
new_state_dict[k] = v
|
||||
state['state_dict'] = new_state_dict
|
||||
# save
|
||||
epoch = state['epoch']
|
||||
fpath = osp.join(save_dir, 'model.pth.tar-' + str(epoch))
|
||||
|
|
Loading…
Reference in New Issue