add max_keep_ckpts parameter (#227)

* add max_keep_ckpts parameters to save memory when saving models

* format

* format

* format

* fixed linting error

Co-authored-by: z-bingo <z-bingo@outlook.com>
pull/235/head
Bin Zhang 2020-04-17 16:01:30 +08:00 committed by GitHub
parent 732c379761
commit 8ac858b138
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 0 deletions

View File

@ -1,4 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os
from ..dist_utils import master_only
from .hook import HOOKS, Hook
@ -10,10 +12,12 @@ class CheckpointHook(Hook):
interval=-1,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
**kwargs):
self.interval = interval
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.args = kwargs
@master_only
@ -25,3 +29,15 @@ class CheckpointHook(Hook):
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
# remove other checkpoints
if self.max_keep_ckpts > 0:
filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth')
current_epoch = runner.epoch + 1
for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1):
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(epoch))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
else:
break