fast-reid/utils/serialization.py
2018-10-18 19:04:28 +08:00

36 lines
780 B
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import errno
import os
import shutil
import sys
import os.path as osp
import torch
def mkdir_if_missing(dir_path):
try:
os.makedirs(dir_path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def save_checkpoint(state, is_best, save_dir, filename='checkpoint.pth.tar'):
fpath = '_'.join((str(state['epoch']), filename))
fpath = osp.join(save_dir, fpath)
mkdir_if_missing(save_dir)
torch.save(state, fpath)
if is_best:
shutil.copy(fpath, osp.join(save_dir, 'model_best.pth.tar'))