Fix #387 so that checkpoint saver works with max history of 1. Add checkpoint-hist arg to train.py.
parent
99b82ae5ab
commit
4203efa36d
|
@ -66,7 +66,7 @@ class CheckpointSaver:
|
|||
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
||||
self._save(tmp_save_path, epoch, metric)
|
||||
if os.path.exists(last_save_path):
|
||||
os.unlink(last_save_path) # required for Windows support.
|
||||
os.unlink(last_save_path) # required for Windows support.
|
||||
os.rename(tmp_save_path, last_save_path)
|
||||
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||
if (len(self.checkpoint_files) < self.max_history
|
||||
|
@ -118,7 +118,7 @@ class CheckpointSaver:
|
|||
def _cleanup_checkpoints(self, trim=0):
|
||||
trim = min(len(self.checkpoint_files), trim)
|
||||
delete_index = self.max_history - trim
|
||||
if delete_index <= 0 or len(self.checkpoint_files) <= delete_index:
|
||||
if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
|
||||
return
|
||||
to_delete = self.checkpoint_files[delete_index:]
|
||||
for d in to_delete:
|
||||
|
@ -147,7 +147,4 @@ class CheckpointSaver:
|
|||
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
|
||||
files = glob.glob(recovery_path + '*' + self.extension)
|
||||
files = sorted(files)
|
||||
if len(files):
|
||||
return files[0]
|
||||
else:
|
||||
return ''
|
||||
return files[0] if len(files) else ''
|
||||
|
|
4
train.py
4
train.py
|
@ -236,6 +236,8 @@ parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
|||
help='how many batches to wait before logging training status')
|
||||
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
|
||||
help='how many batches to wait before writing recovery checkpoint')
|
||||
parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
|
||||
help='number of checkpoints to keep (default: 10)')
|
||||
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
||||
help='how many training processes to use (default: 1)')
|
||||
parser.add_argument('--save-images', action='store_true', default=False,
|
||||
|
@ -547,7 +549,7 @@ def main():
|
|||
decreasing = True if eval_metric == 'loss' else False
|
||||
saver = CheckpointSaver(
|
||||
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
|
||||
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)
|
||||
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
|
||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||
f.write(args_text)
|
||||
|
||||
|
|
Loading…
Reference in New Issue