mirror of https://github.com/open-mmlab/mmcv.git
add max_keep_ckpts parameter (#227)
* add max_keep_ckpts parameters to save memory when saving models * format * format * format * fixed linting error Co-authored-by: z-bingo <z-bingo@outlook.com>pull/235/head
parent
732c379761
commit
8ac858b138
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import os
|
||||
|
||||
from ..dist_utils import master_only
|
||||
from .hook import HOOKS, Hook
|
||||
|
||||
|
@ -10,10 +12,12 @@ class CheckpointHook(Hook):
|
|||
interval=-1,
|
||||
save_optimizer=True,
|
||||
out_dir=None,
|
||||
max_keep_ckpts=-1,
|
||||
**kwargs):
|
||||
self.interval = interval
|
||||
self.save_optimizer = save_optimizer
|
||||
self.out_dir = out_dir
|
||||
self.max_keep_ckpts = max_keep_ckpts
|
||||
self.args = kwargs
|
||||
|
||||
@master_only
|
||||
|
@ -25,3 +29,15 @@ class CheckpointHook(Hook):
|
|||
self.out_dir = runner.work_dir
|
||||
runner.save_checkpoint(
|
||||
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
|
||||
|
||||
# remove other checkpoints
|
||||
if self.max_keep_ckpts > 0:
|
||||
filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth')
|
||||
current_epoch = runner.epoch + 1
|
||||
for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1):
|
||||
ckpt_path = os.path.join(self.out_dir,
|
||||
filename_tmpl.format(epoch))
|
||||
if os.path.exists(ckpt_path):
|
||||
os.remove(ckpt_path)
|
||||
else:
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue