mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor interface of checkpointhook (#127)
* [Refactor] Refactor interface of checkpointhook * fix print format * minor ifx
This commit is contained in:
parent
fff4742e0b
commit
72cf410969
@ -16,7 +16,7 @@ class BaseEvaluator(metaclass=ABCMeta):
|
|||||||
Then it collects all results together from all ranks if distributed
|
Then it collects all results together from all ranks if distributed
|
||||||
training is used. Finally, it computes the metrics of the entire dataset.
|
training is used. Finally, it computes the metrics of the entire dataset.
|
||||||
|
|
||||||
A subclass of class:`BaseEvaluator` should assign a meanful value to the
|
A subclass of class:`BaseEvaluator` should assign a meaningful value to the
|
||||||
class attribute `default_prefix`. See the argument `prefix` for details.
|
class attribute `default_prefix`. See the argument `prefix` for details.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -25,6 +25,9 @@ class CheckpointHook(Hook):
|
|||||||
save_optimizer (bool): Whether to save optimizer state_dict in the
|
save_optimizer (bool): Whether to save optimizer state_dict in the
|
||||||
checkpoint. It is usually used for resuming experiments.
|
checkpoint. It is usually used for resuming experiments.
|
||||||
Default: True.
|
Default: True.
|
||||||
|
save_param_scheduler (bool): Whether to save param_scheduler state_dict
|
||||||
|
in the checkpoint. It is usually used for resuming experiments.
|
||||||
|
Default: True.
|
||||||
out_dir (str, optional | Path): The root directory to save checkpoints.
|
out_dir (str, optional | Path): The root directory to save checkpoints.
|
||||||
If not specified, ``runner.work_dir`` will be used by default. If
|
If not specified, ``runner.work_dir`` will be used by default. If
|
||||||
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
||||||
@ -44,6 +47,7 @@ class CheckpointHook(Hook):
|
|||||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||||
Default: None.
|
Default: None.
|
||||||
"""
|
"""
|
||||||
|
out_dir: str
|
||||||
|
|
||||||
priority = 'VERY_LOW'
|
priority = 'VERY_LOW'
|
||||||
|
|
||||||
@ -51,7 +55,8 @@ class CheckpointHook(Hook):
|
|||||||
interval: int = -1,
|
interval: int = -1,
|
||||||
by_epoch: bool = True,
|
by_epoch: bool = True,
|
||||||
save_optimizer: bool = True,
|
save_optimizer: bool = True,
|
||||||
out_dir: Union[str, Path] = None,
|
save_param_scheduler: bool = True,
|
||||||
|
out_dir: Optional[Union[str, Path]] = None,
|
||||||
max_keep_ckpts: int = -1,
|
max_keep_ckpts: int = -1,
|
||||||
save_last: bool = True,
|
save_last: bool = True,
|
||||||
sync_buffer: bool = False,
|
sync_buffer: bool = False,
|
||||||
@ -60,7 +65,8 @@ class CheckpointHook(Hook):
|
|||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.by_epoch = by_epoch
|
self.by_epoch = by_epoch
|
||||||
self.save_optimizer = save_optimizer
|
self.save_optimizer = save_optimizer
|
||||||
self.out_dir = out_dir
|
self.save_param_scheduler = save_param_scheduler
|
||||||
|
self.out_dir = out_dir # type: ignore
|
||||||
self.max_keep_ckpts = max_keep_ckpts
|
self.max_keep_ckpts = max_keep_ckpts
|
||||||
self.save_last = save_last
|
self.save_last = save_last
|
||||||
self.args = kwargs
|
self.args = kwargs
|
||||||
@ -121,8 +127,8 @@ class CheckpointHook(Hook):
|
|||||||
# 2. reach the last epoch of training
|
# 2. reach the last epoch of training
|
||||||
if self.every_n_epochs(runner, self.interval) or (
|
if self.every_n_epochs(runner, self.interval) or (
|
||||||
self.save_last and self.is_last_train_epoch(runner)):
|
self.save_last and self.is_last_train_epoch(runner)):
|
||||||
runner.logger.info(f'Saving checkpoint at \
|
runner.logger.info(
|
||||||
{runner.epoch + 1} epochs')
|
f'Saving checkpoint at {runner.epoch + 1} epochs')
|
||||||
if self.sync_buffer:
|
if self.sync_buffer:
|
||||||
pass
|
pass
|
||||||
# TODO
|
# TODO
|
||||||
@ -135,18 +141,26 @@ class CheckpointHook(Hook):
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
runner.save_checkpoint(
|
|
||||||
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
|
|
||||||
if runner.meta is not None:
|
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
cur_ckpt_filename = self.args.get(
|
ckpt_filename = self.args.get(
|
||||||
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
|
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
|
||||||
else:
|
else:
|
||||||
cur_ckpt_filename = self.args.get(
|
ckpt_filename = self.args.get(
|
||||||
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
||||||
|
|
||||||
|
runner.save_checkpoint(
|
||||||
|
self.out_dir,
|
||||||
|
filename=ckpt_filename,
|
||||||
|
save_optimizer=self.save_optimizer,
|
||||||
|
save_param_scheduler=self.save_param_scheduler,
|
||||||
|
by_epoch=self.by_epoch,
|
||||||
|
**self.args)
|
||||||
|
|
||||||
|
if runner.meta is not None:
|
||||||
runner.meta.setdefault('hook_msgs', dict())
|
runner.meta.setdefault('hook_msgs', dict())
|
||||||
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
|
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
|
||||||
self.out_dir, cur_ckpt_filename) # type: ignore
|
self.out_dir, ckpt_filename)
|
||||||
|
|
||||||
# remove other checkpoints
|
# remove other checkpoints
|
||||||
if self.max_keep_ckpts > 0:
|
if self.max_keep_ckpts > 0:
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
@ -161,7 +175,7 @@ class CheckpointHook(Hook):
|
|||||||
filename_tmpl = self.args.get('filename_tmpl', name)
|
filename_tmpl = self.args.get('filename_tmpl', name)
|
||||||
for _step in redundant_ckpts:
|
for _step in redundant_ckpts:
|
||||||
ckpt_path = self.file_client.join_path(
|
ckpt_path = self.file_client.join_path(
|
||||||
self.out_dir, filename_tmpl.format(_step)) # type: ignore
|
self.out_dir, filename_tmpl.format(_step))
|
||||||
if self.file_client.isfile(ckpt_path):
|
if self.file_client.isfile(ckpt_path):
|
||||||
self.file_client.remove(ckpt_path)
|
self.file_client.remove(ckpt_path)
|
||||||
else:
|
else:
|
||||||
@ -188,8 +202,8 @@ class CheckpointHook(Hook):
|
|||||||
# 2. reach the last iteration of training
|
# 2. reach the last iteration of training
|
||||||
if self.every_n_iters(runner, self.interval) or \
|
if self.every_n_iters(runner, self.interval) or \
|
||||||
(self.save_last and self.is_last_iter(runner, mode='train')):
|
(self.save_last and self.is_last_iter(runner, mode='train')):
|
||||||
runner.logger.info(f'Saving checkpoint at \
|
runner.logger.info(
|
||||||
{runner.iter + 1} iterations')
|
f'Saving checkpoint at {runner.iter + 1} iterations')
|
||||||
if self.sync_buffer:
|
if self.sync_buffer:
|
||||||
pass
|
pass
|
||||||
# TODO
|
# TODO
|
||||||
|
Loading…
x
Reference in New Issue
Block a user