[Refactor] Refactor interface of checkpointhook (#127)

* [Refactor] Refactor interface of checkpointhook

* fix print format

* minor ifx
liukuikun/visualizer-show-enhance
Zaida Zhou 2022-03-13 23:39:28 +08:00 committed by GitHub
parent fff4742e0b
commit 72cf410969
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 16 deletions

View File

@ -16,7 +16,7 @@ class BaseEvaluator(metaclass=ABCMeta):
Then it collects all results together from all ranks if distributed
training is used. Finally, it computes the metrics of the entire dataset.
A subclass of class:`BaseEvaluator` should assign a meanful value to the
A subclass of class:`BaseEvaluator` should assign a meaningful value to the
class attribute `default_prefix`. See the argument `prefix` for details.
Args:

View File

@ -25,6 +25,9 @@ class CheckpointHook(Hook):
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
save_param_scheduler (bool): Whether to save param_scheduler state_dict
in the checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional | Path): The root directory to save checkpoints.
If not specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir``
@ -44,6 +47,7 @@ class CheckpointHook(Hook):
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
"""
out_dir: str
priority = 'VERY_LOW'
@ -51,7 +55,8 @@ class CheckpointHook(Hook):
interval: int = -1,
by_epoch: bool = True,
save_optimizer: bool = True,
out_dir: Union[str, Path] = None,
save_param_scheduler: bool = True,
out_dir: Optional[Union[str, Path]] = None,
max_keep_ckpts: int = -1,
save_last: bool = True,
sync_buffer: bool = False,
@ -60,7 +65,8 @@ class CheckpointHook(Hook):
self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.save_param_scheduler = save_param_scheduler
self.out_dir = out_dir # type: ignore
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
@ -121,8 +127,8 @@ class CheckpointHook(Hook):
# 2. reach the last epoch of training
if self.every_n_epochs(runner, self.interval) or (
self.save_last and self.is_last_train_epoch(runner)):
runner.logger.info(f'Saving checkpoint at \
{runner.epoch + 1} epochs')
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
pass
# TODO
@ -135,18 +141,26 @@ class CheckpointHook(Hook):
Args:
runner (Runner): The runner of the training process.
"""
if self.by_epoch:
ckpt_filename = self.args.get(
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else:
ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
self.out_dir,
filename=ckpt_filename,
save_optimizer=self.save_optimizer,
save_param_scheduler=self.save_param_scheduler,
by_epoch=self.by_epoch,
**self.args)
if runner.meta is not None:
if self.by_epoch:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename) # type: ignore
self.out_dir, ckpt_filename)
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
@ -161,7 +175,7 @@ class CheckpointHook(Hook):
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = self.file_client.join_path(
self.out_dir, filename_tmpl.format(_step)) # type: ignore
self.out_dir, filename_tmpl.format(_step))
if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path)
else:
@ -188,8 +202,8 @@ class CheckpointHook(Hook):
# 2. reach the last iteration of training
if self.every_n_iters(runner, self.interval) or \
(self.save_last and self.is_last_iter(runner, mode='train')):
runner.logger.info(f'Saving checkpoint at \
{runner.iter + 1} iterations')
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
pass
# TODO