Update checkpoint save to fix old hard-link + fuse issue I ran into again... fix #340
parent
c4fb98f399
commit
deb9895600
|
@ -6,9 +6,10 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
"""
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -32,7 +33,8 @@ class CheckpointSaver:
|
|||
recovery_dir='',
|
||||
decreasing=False,
|
||||
max_history=10,
|
||||
unwrap_fn=unwrap_model):
|
||||
unwrap_fn=unwrap_model
|
||||
):
|
||||
|
||||
# objects to save state_dicts of
|
||||
self.model = model
|
||||
|
@ -46,7 +48,8 @@ class CheckpointSaver:
|
|||
self.best_epoch = None
|
||||
self.best_metric = None
|
||||
self.curr_recovery_file = ''
|
||||
self.last_recovery_file = ''
|
||||
self.prev_recovery_file = ''
|
||||
self.can_hardlink = True
|
||||
|
||||
# config
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
|
@ -60,41 +63,26 @@ class CheckpointSaver:
|
|||
self.unwrap_fn = unwrap_fn
|
||||
assert self.max_history >= 1
|
||||
|
||||
def save_checkpoint(self, epoch, metric=None):
|
||||
assert epoch >= 0
|
||||
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
||||
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
||||
self._save(tmp_save_path, epoch, metric)
|
||||
if os.path.exists(last_save_path):
|
||||
os.unlink(last_save_path) # required for Windows support.
|
||||
os.rename(tmp_save_path, last_save_path)
|
||||
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||
if (len(self.checkpoint_files) < self.max_history
|
||||
or metric is None or self.cmp(metric, worst_file[1])):
|
||||
if len(self.checkpoint_files) >= self.max_history:
|
||||
self._cleanup_checkpoints(1)
|
||||
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
|
||||
save_path = os.path.join(self.checkpoint_dir, filename)
|
||||
os.link(last_save_path, save_path)
|
||||
self.checkpoint_files.append((save_path, metric))
|
||||
self.checkpoint_files = sorted(
|
||||
self.checkpoint_files, key=lambda x: x[1],
|
||||
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
|
||||
def _replace(self, src, dst):
|
||||
if self.can_hardlink:
|
||||
try:
|
||||
if os.path.exists(dst):
|
||||
os.unlink(dst) # required for Windows support.
|
||||
except (OSError, NotImplementedError) as e:
|
||||
self.can_hardlink = False
|
||||
os.replace(src, dst)
|
||||
|
||||
checkpoints_str = "Current checkpoints:\n"
|
||||
for c in self.checkpoint_files:
|
||||
checkpoints_str += ' {}\n'.format(c)
|
||||
_logger.info(checkpoints_str)
|
||||
|
||||
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
|
||||
self.best_epoch = epoch
|
||||
self.best_metric = metric
|
||||
best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
|
||||
if os.path.exists(best_save_path):
|
||||
os.unlink(best_save_path)
|
||||
os.link(last_save_path, best_save_path)
|
||||
|
||||
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
||||
def _duplicate(self, src, dst):
|
||||
if self.can_hardlink:
|
||||
try:
|
||||
if os.path.exists(dst):
|
||||
# for Windows
|
||||
os.unlink(dst)
|
||||
os.link(src, dst)
|
||||
return
|
||||
except (OSError, NotImplementedError) as e:
|
||||
self.can_hardlink = False
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
def _save(self, save_path, epoch, metric=None):
|
||||
save_state = {
|
||||
|
@ -129,18 +117,61 @@ class CheckpointSaver:
|
|||
_logger.error("Exception '{}' while deleting checkpoint".format(e))
|
||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||
|
||||
def save_checkpoint(self, epoch, metric=None):
|
||||
assert epoch >= 0
|
||||
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
||||
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
||||
self._save(tmp_save_path, epoch, metric)
|
||||
self._replace(tmp_save_path, last_save_path)
|
||||
|
||||
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||
if (
|
||||
len(self.checkpoint_files) < self.max_history
|
||||
or metric is None
|
||||
or self.cmp(metric, worst_file[1])
|
||||
):
|
||||
if len(self.checkpoint_files) >= self.max_history:
|
||||
self._cleanup_checkpoints(1)
|
||||
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
|
||||
save_path = os.path.join(self.checkpoint_dir, filename)
|
||||
self._duplicate(last_save_path, save_path)
|
||||
|
||||
self.checkpoint_files.append((save_path, metric))
|
||||
self.checkpoint_files = sorted(
|
||||
self.checkpoint_files,
|
||||
key=lambda x: x[1],
|
||||
reverse=not self.decreasing # sort in descending order if a lower metric is not better
|
||||
)
|
||||
|
||||
checkpoints_str = "Current checkpoints:\n"
|
||||
for c in self.checkpoint_files:
|
||||
checkpoints_str += ' {}\n'.format(c)
|
||||
_logger.info(checkpoints_str)
|
||||
|
||||
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
|
||||
self.best_epoch = epoch
|
||||
self.best_metric = metric
|
||||
best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
|
||||
self._duplicate(last_save_path, best_save_path)
|
||||
|
||||
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
||||
|
||||
def save_recovery(self, epoch, batch_idx=0):
|
||||
assert epoch >= 0
|
||||
tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension)
|
||||
self._save(tmp_save_path, epoch)
|
||||
|
||||
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
|
||||
save_path = os.path.join(self.recovery_dir, filename)
|
||||
self._save(save_path, epoch)
|
||||
if os.path.exists(self.last_recovery_file):
|
||||
self._replace(tmp_save_path, save_path)
|
||||
|
||||
if os.path.exists(self.prev_recovery_file):
|
||||
try:
|
||||
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
||||
os.remove(self.last_recovery_file)
|
||||
_logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file))
|
||||
os.remove(self.prev_recovery_file)
|
||||
except Exception as e:
|
||||
_logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
|
||||
self.last_recovery_file = self.curr_recovery_file
|
||||
_logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file))
|
||||
self.prev_recovery_file = self.curr_recovery_file
|
||||
self.curr_recovery_file = save_path
|
||||
|
||||
def find_recovery(self):
|
||||
|
|
Loading…
Reference in New Issue