diff --git a/torchreid/utils/torchtools.py b/torchreid/utils/torchtools.py index cc7697b..5fc51f4 100644 --- a/torchreid/utils/torchtools.py +++ b/torchreid/utils/torchtools.py @@ -14,16 +14,17 @@ import torch.nn as nn from .iotools import mkdir_if_missing -def save_checkpoint(state, save_dir, is_best=False): +def save_checkpoint(state, save_dir, is_best=False, remove_module_from_keys=False): mkdir_if_missing(save_dir) - # remove 'module.' in state_dict's keys if necessary - state_dict = state['state_dict'] - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - if k.startswith('module.'): - k = k[7:] - new_state_dict[k] = v - state['state_dict'] = new_state_dict + if remove_module_from_keys: + # remove 'module.' in state_dict's keys + state_dict = state['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('module.'): + k = k[7:] + new_state_dict[k] = v + state['state_dict'] = new_state_dict # save epoch = state['epoch'] fpath = osp.join(save_dir, 'model.pth.tar-' + str(epoch))