mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature]: Add checkpoint hook (#66)
* [Feature]: Add checkpoint hook * [Fix]: Fix lint * [Fix]: Delete redundant optional and give an example to our_dir * [Feature]: Add test the last_ckpt in UT * [Fix]: Fix docstring problem * [Fix]: Add patch to UT * [Feature]: Add Test case for by epoch
This commit is contained in:
parent
2448380339
commit
cf239a2b17
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .empty_cache_hook import EmptyCacheHook
|
from .empty_cache_hook import EmptyCacheHook
|
||||||
|
from .checkpoint_hook import CheckpointHook
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
from .iter_timer_hook import IterTimerHook
|
from .iter_timer_hook import IterTimerHook
|
||||||
from .optimizer_hook import OptimizerHook
|
from .optimizer_hook import OptimizerHook
|
||||||
@ -8,5 +9,5 @@ from .sampler_seed_hook import DistSamplerSeedHook
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||||
'OptimizerHook', 'EmptyCacheHook'
|
'OptimizerHook', 'EmptyCacheHook', 'CheckpointHook'
|
||||||
]
|
]
|
||||||
|
206
mmengine/hooks/checkpoint_hook.py
Normal file
206
mmengine/hooks/checkpoint_hook.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataSample
|
||||||
|
from mmengine.fileio import FileClient
|
||||||
|
from mmengine.registry import HOOKS
|
||||||
|
from .hook import Hook
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class CheckpointHook(Hook):
|
||||||
|
"""Save checkpoints periodically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interval (int): The saving period. If ``by_epoch=True``, interval
|
||||||
|
indicates epochs, otherwise it indicates iterations.
|
||||||
|
Default: -1, which means "never".
|
||||||
|
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
||||||
|
Default: True.
|
||||||
|
save_optimizer (bool): Whether to save optimizer state_dict in the
|
||||||
|
checkpoint. It is usually used for resuming experiments.
|
||||||
|
Default: True.
|
||||||
|
out_dir (str, optional | Path): The root directory to save checkpoints.
|
||||||
|
If not specified, ``runner.work_dir`` will be used by default. If
|
||||||
|
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
||||||
|
and the last level directory of ``runner.work_dir``. For example,
|
||||||
|
if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
|
||||||
|
``./work_dir/cur_exp``, then the ckpt will be saved in
|
||||||
|
``./tmp/cur_exp``. Deafule to None.
|
||||||
|
max_keep_ckpts (int): The maximum checkpoints to keep.
|
||||||
|
In some cases we want only the latest few checkpoints and would
|
||||||
|
like to delete old ones to save the disk space.
|
||||||
|
Default: -1, which means unlimited.
|
||||||
|
save_last (bool): Whether to force the last checkpoint to be
|
||||||
|
saved regardless of interval. Default: True.
|
||||||
|
sync_buffer (bool): Whether to synchronize buffers in
|
||||||
|
different gpus. Default: False.
|
||||||
|
file_client_args (dict, optional): Arguments to instantiate a
|
||||||
|
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||||
|
Default: None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
interval: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
save_optimizer: bool = True,
|
||||||
|
out_dir: Union[str, Path] = None,
|
||||||
|
max_keep_ckpts: int = -1,
|
||||||
|
save_last: bool = True,
|
||||||
|
sync_buffer: bool = False,
|
||||||
|
file_client_args: Optional[dict] = None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
self.interval = interval
|
||||||
|
self.by_epoch = by_epoch
|
||||||
|
self.save_optimizer = save_optimizer
|
||||||
|
self.out_dir = out_dir
|
||||||
|
self.max_keep_ckpts = max_keep_ckpts
|
||||||
|
self.save_last = save_last
|
||||||
|
self.args = kwargs
|
||||||
|
self.sync_buffer = sync_buffer
|
||||||
|
self.file_client_args = file_client_args
|
||||||
|
|
||||||
|
def before_run(self, runner: object) -> None:
|
||||||
|
"""Finish all operations, related to checkpoint.
|
||||||
|
|
||||||
|
This function will get the appropriate file client, and the directory
|
||||||
|
to save these checkpoints of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
"""
|
||||||
|
if not self.out_dir:
|
||||||
|
self.out_dir = runner.work_dir # type: ignore
|
||||||
|
|
||||||
|
self.file_client = FileClient.infer_client(self.file_client_args,
|
||||||
|
self.out_dir)
|
||||||
|
|
||||||
|
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
|
||||||
|
# `self.out_dir` is set so the final `self.out_dir` is the
|
||||||
|
# concatenation of `self.out_dir` and the last level directory of
|
||||||
|
# `runner.work_dir`
|
||||||
|
if self.out_dir != runner.work_dir: # type: ignore
|
||||||
|
basename = osp.basename(
|
||||||
|
runner.work_dir.rstrip( # type: ignore
|
||||||
|
osp.sep))
|
||||||
|
self.out_dir = self.file_client.join_path(
|
||||||
|
self.out_dir, # type: ignore
|
||||||
|
basename)
|
||||||
|
|
||||||
|
runner.logger.info(( # type: ignore
|
||||||
|
f'Checkpoints will be saved to {self.out_dir} by '
|
||||||
|
f'{self.file_client.name}.'))
|
||||||
|
|
||||||
|
# disable the create_symlink option because some file backends do not
|
||||||
|
# allow to create a symlink
|
||||||
|
if 'create_symlink' in self.args:
|
||||||
|
if self.args[
|
||||||
|
'create_symlink'] and not self.file_client.allow_symlink:
|
||||||
|
self.args['create_symlink'] = False
|
||||||
|
warnings.warn(
|
||||||
|
('create_symlink is set as True by the user but is changed'
|
||||||
|
'to be False because creating symbolic link is not '
|
||||||
|
f'allowed in {self.file_client.name}'))
|
||||||
|
else:
|
||||||
|
self.args['create_symlink'] = self.file_client.allow_symlink
|
||||||
|
|
||||||
|
def after_train_epoch(self, runner: object) -> None:
|
||||||
|
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
"""
|
||||||
|
if not self.by_epoch:
|
||||||
|
return
|
||||||
|
|
||||||
|
# save checkpoint for following cases:
|
||||||
|
# 1. every ``self.interval`` epochs
|
||||||
|
# 2. reach the last epoch of training
|
||||||
|
if self.every_n_epochs(
|
||||||
|
runner, self.interval) or (self.save_last
|
||||||
|
and self.is_last_epoch(runner)):
|
||||||
|
runner.logger.info( # type: ignore
|
||||||
|
f'Saving checkpoint at \
|
||||||
|
{runner.epoch + 1} epochs') # type: ignore
|
||||||
|
if self.sync_buffer:
|
||||||
|
pass
|
||||||
|
# TODO
|
||||||
|
self._save_checkpoint(runner)
|
||||||
|
|
||||||
|
# TODO Add master_only decorator
|
||||||
|
def _save_checkpoint(self, runner: object) -> None:
|
||||||
|
"""Save the current checkpoint and delete outdated checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
"""
|
||||||
|
runner.save_checkpoint( # type: ignore
|
||||||
|
self.out_dir,
|
||||||
|
save_optimizer=self.save_optimizer,
|
||||||
|
**self.args)
|
||||||
|
if runner.meta is not None: # type: ignore
|
||||||
|
if self.by_epoch:
|
||||||
|
cur_ckpt_filename = self.args.get(
|
||||||
|
'filename_tmpl',
|
||||||
|
'epoch_{}.pth').format(runner.epoch + 1) # type: ignore
|
||||||
|
else:
|
||||||
|
cur_ckpt_filename = self.args.get(
|
||||||
|
'filename_tmpl',
|
||||||
|
'iter_{}.pth').format(runner.iter + 1) # type: ignore
|
||||||
|
runner.meta.setdefault('hook_msgs', dict()) # type: ignore
|
||||||
|
runner.meta['hook_msgs'][ # type: ignore
|
||||||
|
'last_ckpt'] = self.file_client.join_path(
|
||||||
|
self.out_dir, cur_ckpt_filename) # type: ignore
|
||||||
|
# remove other checkpoints
|
||||||
|
if self.max_keep_ckpts > 0:
|
||||||
|
if self.by_epoch:
|
||||||
|
name = 'epoch_{}.pth'
|
||||||
|
current_ckpt = runner.epoch + 1 # type: ignore
|
||||||
|
else:
|
||||||
|
name = 'iter_{}.pth'
|
||||||
|
current_ckpt = runner.iter + 1 # type: ignore
|
||||||
|
redundant_ckpts = range(
|
||||||
|
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
||||||
|
-self.interval)
|
||||||
|
filename_tmpl = self.args.get('filename_tmpl', name)
|
||||||
|
for _step in redundant_ckpts:
|
||||||
|
ckpt_path = self.file_client.join_path(
|
||||||
|
self.out_dir, filename_tmpl.format(_step)) # type: ignore
|
||||||
|
if self.file_client.isfile(ckpt_path):
|
||||||
|
self.file_client.remove(ckpt_path)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def after_train_iter(
|
||||||
|
self,
|
||||||
|
runner: object,
|
||||||
|
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||||
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
|
"""Save the checkpoint and synchronize buffers after each iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||||
|
Defaults to None.
|
||||||
|
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
if self.by_epoch:
|
||||||
|
return
|
||||||
|
|
||||||
|
# save checkpoint for following cases:
|
||||||
|
# 1. every ``self.interval`` iterations
|
||||||
|
# 2. reach the last iteration of training
|
||||||
|
if self.every_n_iters(
|
||||||
|
runner, self.interval) or (self.save_last
|
||||||
|
and self.is_last_iter(runner)):
|
||||||
|
runner.logger.info( # type: ignore
|
||||||
|
f'Saving checkpoint at \
|
||||||
|
{runner.iter + 1} iterations') # type: ignore
|
||||||
|
if self.sync_buffer:
|
||||||
|
pass
|
||||||
|
# TODO
|
||||||
|
self._save_checkpoint(runner)
|
131
tests/test_hook/test_checkpoint_hook.py
Normal file
131
tests/test_hook/test_checkpoint_hook.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from mmengine.hooks import CheckpointHook
|
||||||
|
|
||||||
|
sys.modules['file_client'] = sys.modules['mmengine.fileio.file_client']
|
||||||
|
|
||||||
|
|
||||||
|
class MockPetrel:
|
||||||
|
|
||||||
|
_allow_symlink = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_symlink(self):
|
||||||
|
return self._allow_symlink
|
||||||
|
|
||||||
|
|
||||||
|
prefix_to_backends = {'s3': MockPetrel}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointHook:
|
||||||
|
|
||||||
|
@patch('file_client.FileClient._prefix_to_backends', prefix_to_backends)
|
||||||
|
def test_before_run(self):
|
||||||
|
runner = Mock()
|
||||||
|
runner.work_dir = './tmp'
|
||||||
|
|
||||||
|
# the out_dir of the checkpoint hook is None
|
||||||
|
checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
assert checkpoint_hook.out_dir == runner.work_dir
|
||||||
|
|
||||||
|
# the out_dir of the checkpoint hook is not None
|
||||||
|
checkpoint_hook = CheckpointHook(
|
||||||
|
interval=1, by_epoch=True, out_dir='test_dir')
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
assert checkpoint_hook.out_dir == 'test_dir/tmp'
|
||||||
|
|
||||||
|
# create_symlink in args and create_symlink is True
|
||||||
|
checkpoint_hook = CheckpointHook(
|
||||||
|
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
assert checkpoint_hook.args['create_symlink']
|
||||||
|
|
||||||
|
runner.work_dir = 's3://path/of/file'
|
||||||
|
checkpoint_hook = CheckpointHook(
|
||||||
|
interval=1, by_epoch=True, create_symlink=True)
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
assert not checkpoint_hook.args['create_symlink']
|
||||||
|
|
||||||
|
def test_after_train_epoch(self):
|
||||||
|
runner = Mock()
|
||||||
|
runner.work_dir = './tmp'
|
||||||
|
runner.epoch = 9
|
||||||
|
runner.meta = dict()
|
||||||
|
|
||||||
|
# by epoch is True
|
||||||
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
|
assert (runner.epoch + 1) % 2 == 0
|
||||||
|
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
|
||||||
|
|
||||||
|
# epoch can not be evenly divided by 2
|
||||||
|
runner.epoch = 10
|
||||||
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
|
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
|
||||||
|
|
||||||
|
# by epoch is False
|
||||||
|
runner.epoch = 9
|
||||||
|
runner.meta = dict()
|
||||||
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
|
assert runner.meta.get('hook_msgs', None) is None
|
||||||
|
|
||||||
|
# max_keep_ckpts > 0
|
||||||
|
with TemporaryDirectory() as tempo_dir:
|
||||||
|
runner.work_dir = tempo_dir
|
||||||
|
os.system(f'touch {tempo_dir}/epoch_8.pth')
|
||||||
|
checkpoint_hook = CheckpointHook(
|
||||||
|
interval=2, by_epoch=True, max_keep_ckpts=1)
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
|
assert (runner.epoch + 1) % 2 == 0
|
||||||
|
assert not os.path.exists(f'{tempo_dir}/epoch_8.pth')
|
||||||
|
|
||||||
|
def test_after_train_iter(self):
|
||||||
|
runner = Mock()
|
||||||
|
runner.work_dir = './tmp'
|
||||||
|
runner.iter = 9
|
||||||
|
runner.meta = dict()
|
||||||
|
|
||||||
|
# by epoch is True
|
||||||
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
checkpoint_hook.after_train_iter(runner)
|
||||||
|
assert runner.meta.get('hook_msgs', None) is None
|
||||||
|
|
||||||
|
# by epoch is False
|
||||||
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
checkpoint_hook.after_train_iter(runner)
|
||||||
|
assert (runner.iter + 1) % 2 == 0
|
||||||
|
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
|
||||||
|
|
||||||
|
# epoch can not be evenly divided by 2
|
||||||
|
runner.iter = 10
|
||||||
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
|
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
|
||||||
|
|
||||||
|
# max_keep_ckpts > 0
|
||||||
|
runner.iter = 9
|
||||||
|
with TemporaryDirectory() as tempo_dir:
|
||||||
|
runner.work_dir = tempo_dir
|
||||||
|
os.system(f'touch {tempo_dir}/iter_8.pth')
|
||||||
|
checkpoint_hook = CheckpointHook(
|
||||||
|
interval=2, by_epoch=False, max_keep_ckpts=1)
|
||||||
|
checkpoint_hook.before_run(runner)
|
||||||
|
checkpoint_hook.after_train_iter(runner)
|
||||||
|
assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
|
Loading…
x
Reference in New Issue
Block a user