mv save_ckpt to torchtools
parent
26a9400609
commit
ae54ca0f6b
|
@ -4,10 +4,8 @@ import os
|
|||
import os.path as osp
|
||||
import errno
|
||||
import json
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
|
||||
def mkdir_if_missing(directory):
|
||||
|
@ -22,7 +20,7 @@ def mkdir_if_missing(directory):
|
|||
def check_isfile(path):
|
||||
isfile = osp.isfile(path)
|
||||
if not isfile:
|
||||
print('=> Warning: no file found at "{}" (ignored)'.format(path))
|
||||
warnings.warn('No file found at "{}"'.format(path))
|
||||
return isfile
|
||||
|
||||
|
||||
|
@ -35,21 +33,4 @@ def read_json(fpath):
|
|||
def write_json(obj, fpath):
|
||||
mkdir_if_missing(osp.dirname(fpath))
|
||||
with open(fpath, 'w') as f:
|
||||
json.dump(obj, f, indent=4, separators=(',', ': '))
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'):
|
||||
if len(osp.dirname(fpath)) != 0:
|
||||
mkdir_if_missing(osp.dirname(fpath))
|
||||
# remove 'module.' in state_dict's keys if necessary
|
||||
state_dict = state['state_dict']
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:]
|
||||
new_state_dict[k] = v
|
||||
state['state_dict'] = new_state_dict
|
||||
# save
|
||||
torch.save(state, fpath)
|
||||
if is_best:
|
||||
shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))
|
||||
json.dump(obj, f, indent=4, separators=(',', ': '))
|
Loading…
Reference in New Issue