mv save_ckpt to torchtools

pull/119/head
KaiyangZhou 2019-02-20 21:51:24 +00:00
parent 26a9400609
commit ae54ca0f6b
1 changed files with 3 additions and 22 deletions

View File

@ -4,10 +4,8 @@ import os
import os.path as osp
import errno
import json
import shutil
from collections import OrderedDict
import torch
import warnings
def mkdir_if_missing(directory):
@ -22,7 +20,7 @@ def mkdir_if_missing(directory):
def check_isfile(path):
isfile = osp.isfile(path)
if not isfile:
print('=> Warning: no file found at "{}" (ignored)'.format(path))
warnings.warn('No file found at "{}"'.format(path))
return isfile
@ -35,21 +33,4 @@ def read_json(fpath):
def write_json(obj, fpath):
mkdir_if_missing(osp.dirname(fpath))
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '))
def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'):
if len(osp.dirname(fpath)) != 0:
mkdir_if_missing(osp.dirname(fpath))
# remove 'module.' in state_dict's keys if necessary
state_dict = state['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
new_state_dict[k] = v
state['state_dict'] = new_state_dict
# save
torch.save(state, fpath)
if is_best:
shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))
json.dump(obj, f, indent=4, separators=(',', ': '))