[Refactor] Refactor interface of checkpointhook (#127)
* [Refactor] Refactor interface of checkpointhook * fix print format * minor ifxliukuikun/visualizer-show-enhance
parent
fff4742e0b
commit
72cf410969
|
@ -16,7 +16,7 @@ class BaseEvaluator(metaclass=ABCMeta):
|
|||
Then it collects all results together from all ranks if distributed
|
||||
training is used. Finally, it computes the metrics of the entire dataset.
|
||||
|
||||
A subclass of class:`BaseEvaluator` should assign a meanful value to the
|
||||
A subclass of class:`BaseEvaluator` should assign a meaningful value to the
|
||||
class attribute `default_prefix`. See the argument `prefix` for details.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -25,6 +25,9 @@ class CheckpointHook(Hook):
|
|||
save_optimizer (bool): Whether to save optimizer state_dict in the
|
||||
checkpoint. It is usually used for resuming experiments.
|
||||
Default: True.
|
||||
save_param_scheduler (bool): Whether to save param_scheduler state_dict
|
||||
in the checkpoint. It is usually used for resuming experiments.
|
||||
Default: True.
|
||||
out_dir (str, optional | Path): The root directory to save checkpoints.
|
||||
If not specified, ``runner.work_dir`` will be used by default. If
|
||||
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
||||
|
@ -44,6 +47,7 @@ class CheckpointHook(Hook):
|
|||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
"""
|
||||
out_dir: str
|
||||
|
||||
priority = 'VERY_LOW'
|
||||
|
||||
|
@ -51,7 +55,8 @@ class CheckpointHook(Hook):
|
|||
interval: int = -1,
|
||||
by_epoch: bool = True,
|
||||
save_optimizer: bool = True,
|
||||
out_dir: Union[str, Path] = None,
|
||||
save_param_scheduler: bool = True,
|
||||
out_dir: Optional[Union[str, Path]] = None,
|
||||
max_keep_ckpts: int = -1,
|
||||
save_last: bool = True,
|
||||
sync_buffer: bool = False,
|
||||
|
@ -60,7 +65,8 @@ class CheckpointHook(Hook):
|
|||
self.interval = interval
|
||||
self.by_epoch = by_epoch
|
||||
self.save_optimizer = save_optimizer
|
||||
self.out_dir = out_dir
|
||||
self.save_param_scheduler = save_param_scheduler
|
||||
self.out_dir = out_dir # type: ignore
|
||||
self.max_keep_ckpts = max_keep_ckpts
|
||||
self.save_last = save_last
|
||||
self.args = kwargs
|
||||
|
@ -121,8 +127,8 @@ class CheckpointHook(Hook):
|
|||
# 2. reach the last epoch of training
|
||||
if self.every_n_epochs(runner, self.interval) or (
|
||||
self.save_last and self.is_last_train_epoch(runner)):
|
||||
runner.logger.info(f'Saving checkpoint at \
|
||||
{runner.epoch + 1} epochs')
|
||||
runner.logger.info(
|
||||
f'Saving checkpoint at {runner.epoch + 1} epochs')
|
||||
if self.sync_buffer:
|
||||
pass
|
||||
# TODO
|
||||
|
@ -135,18 +141,26 @@ class CheckpointHook(Hook):
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
runner.save_checkpoint(
|
||||
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
|
||||
if runner.meta is not None:
|
||||
if self.by_epoch:
|
||||
cur_ckpt_filename = self.args.get(
|
||||
ckpt_filename = self.args.get(
|
||||
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
|
||||
else:
|
||||
cur_ckpt_filename = self.args.get(
|
||||
ckpt_filename = self.args.get(
|
||||
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
||||
|
||||
runner.save_checkpoint(
|
||||
self.out_dir,
|
||||
filename=ckpt_filename,
|
||||
save_optimizer=self.save_optimizer,
|
||||
save_param_scheduler=self.save_param_scheduler,
|
||||
by_epoch=self.by_epoch,
|
||||
**self.args)
|
||||
|
||||
if runner.meta is not None:
|
||||
runner.meta.setdefault('hook_msgs', dict())
|
||||
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
|
||||
self.out_dir, cur_ckpt_filename) # type: ignore
|
||||
self.out_dir, ckpt_filename)
|
||||
|
||||
# remove other checkpoints
|
||||
if self.max_keep_ckpts > 0:
|
||||
if self.by_epoch:
|
||||
|
@ -161,7 +175,7 @@ class CheckpointHook(Hook):
|
|||
filename_tmpl = self.args.get('filename_tmpl', name)
|
||||
for _step in redundant_ckpts:
|
||||
ckpt_path = self.file_client.join_path(
|
||||
self.out_dir, filename_tmpl.format(_step)) # type: ignore
|
||||
self.out_dir, filename_tmpl.format(_step))
|
||||
if self.file_client.isfile(ckpt_path):
|
||||
self.file_client.remove(ckpt_path)
|
||||
else:
|
||||
|
@ -188,8 +202,8 @@ class CheckpointHook(Hook):
|
|||
# 2. reach the last iteration of training
|
||||
if self.every_n_iters(runner, self.interval) or \
|
||||
(self.save_last and self.is_last_iter(runner, mode='train')):
|
||||
runner.logger.info(f'Saving checkpoint at \
|
||||
{runner.iter + 1} iterations')
|
||||
runner.logger.info(
|
||||
f'Saving checkpoint at {runner.iter + 1} iterations')
|
||||
if self.sync_buffer:
|
||||
pass
|
||||
# TODO
|
||||
|
|
Loading…
Reference in New Issue