[Feature]: Add checkpoint hook (#66)
* [Feature]: Add checkpoint hook * [Fix]: Fix lint * [Fix]: Delete redundant optional and give an example to our_dir * [Feature]: Add test the last_ckpt in UT * [Fix]: Fix docstring problem * [Fix]: Add patch to UT * [Feature]: Add Test case for by epochpull/57/head
parent
2448380339
commit
cf239a2b17
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .empty_cache_hook import EmptyCacheHook
|
||||
from .checkpoint_hook import CheckpointHook
|
||||
from .hook import Hook
|
||||
from .iter_timer_hook import IterTimerHook
|
||||
from .optimizer_hook import OptimizerHook
|
||||
|
@ -8,5 +9,5 @@ from .sampler_seed_hook import DistSamplerSeedHook
|
|||
|
||||
__all__ = [
|
||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||
'OptimizerHook', 'EmptyCacheHook'
|
||||
'OptimizerHook', 'EmptyCacheHook', 'CheckpointHook'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class CheckpointHook(Hook):
|
||||
"""Save checkpoints periodically.
|
||||
|
||||
Args:
|
||||
interval (int): The saving period. If ``by_epoch=True``, interval
|
||||
indicates epochs, otherwise it indicates iterations.
|
||||
Default: -1, which means "never".
|
||||
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
||||
Default: True.
|
||||
save_optimizer (bool): Whether to save optimizer state_dict in the
|
||||
checkpoint. It is usually used for resuming experiments.
|
||||
Default: True.
|
||||
out_dir (str, optional | Path): The root directory to save checkpoints.
|
||||
If not specified, ``runner.work_dir`` will be used by default. If
|
||||
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
||||
and the last level directory of ``runner.work_dir``. For example,
|
||||
if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
|
||||
``./work_dir/cur_exp``, then the ckpt will be saved in
|
||||
``./tmp/cur_exp``. Deafule to None.
|
||||
max_keep_ckpts (int): The maximum checkpoints to keep.
|
||||
In some cases we want only the latest few checkpoints and would
|
||||
like to delete old ones to save the disk space.
|
||||
Default: -1, which means unlimited.
|
||||
save_last (bool): Whether to force the last checkpoint to be
|
||||
saved regardless of interval. Default: True.
|
||||
sync_buffer (bool): Whether to synchronize buffers in
|
||||
different gpus. Default: False.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
interval: int = -1,
|
||||
by_epoch: bool = True,
|
||||
save_optimizer: bool = True,
|
||||
out_dir: Union[str, Path] = None,
|
||||
max_keep_ckpts: int = -1,
|
||||
save_last: bool = True,
|
||||
sync_buffer: bool = False,
|
||||
file_client_args: Optional[dict] = None,
|
||||
**kwargs) -> None:
|
||||
self.interval = interval
|
||||
self.by_epoch = by_epoch
|
||||
self.save_optimizer = save_optimizer
|
||||
self.out_dir = out_dir
|
||||
self.max_keep_ckpts = max_keep_ckpts
|
||||
self.save_last = save_last
|
||||
self.args = kwargs
|
||||
self.sync_buffer = sync_buffer
|
||||
self.file_client_args = file_client_args
|
||||
|
||||
def before_run(self, runner: object) -> None:
|
||||
"""Finish all operations, related to checkpoint.
|
||||
|
||||
This function will get the appropriate file client, and the directory
|
||||
to save these checkpoints of the model.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
"""
|
||||
if not self.out_dir:
|
||||
self.out_dir = runner.work_dir # type: ignore
|
||||
|
||||
self.file_client = FileClient.infer_client(self.file_client_args,
|
||||
self.out_dir)
|
||||
|
||||
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
|
||||
# `self.out_dir` is set so the final `self.out_dir` is the
|
||||
# concatenation of `self.out_dir` and the last level directory of
|
||||
# `runner.work_dir`
|
||||
if self.out_dir != runner.work_dir: # type: ignore
|
||||
basename = osp.basename(
|
||||
runner.work_dir.rstrip( # type: ignore
|
||||
osp.sep))
|
||||
self.out_dir = self.file_client.join_path(
|
||||
self.out_dir, # type: ignore
|
||||
basename)
|
||||
|
||||
runner.logger.info(( # type: ignore
|
||||
f'Checkpoints will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name}.'))
|
||||
|
||||
# disable the create_symlink option because some file backends do not
|
||||
# allow to create a symlink
|
||||
if 'create_symlink' in self.args:
|
||||
if self.args[
|
||||
'create_symlink'] and not self.file_client.allow_symlink:
|
||||
self.args['create_symlink'] = False
|
||||
warnings.warn(
|
||||
('create_symlink is set as True by the user but is changed'
|
||||
'to be False because creating symbolic link is not '
|
||||
f'allowed in {self.file_client.name}'))
|
||||
else:
|
||||
self.args['create_symlink'] = self.file_client.allow_symlink
|
||||
|
||||
def after_train_epoch(self, runner: object) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
"""
|
||||
if not self.by_epoch:
|
||||
return
|
||||
|
||||
# save checkpoint for following cases:
|
||||
# 1. every ``self.interval`` epochs
|
||||
# 2. reach the last epoch of training
|
||||
if self.every_n_epochs(
|
||||
runner, self.interval) or (self.save_last
|
||||
and self.is_last_epoch(runner)):
|
||||
runner.logger.info( # type: ignore
|
||||
f'Saving checkpoint at \
|
||||
{runner.epoch + 1} epochs') # type: ignore
|
||||
if self.sync_buffer:
|
||||
pass
|
||||
# TODO
|
||||
self._save_checkpoint(runner)
|
||||
|
||||
# TODO Add master_only decorator
|
||||
def _save_checkpoint(self, runner: object) -> None:
|
||||
"""Save the current checkpoint and delete outdated checkpoint.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
"""
|
||||
runner.save_checkpoint( # type: ignore
|
||||
self.out_dir,
|
||||
save_optimizer=self.save_optimizer,
|
||||
**self.args)
|
||||
if runner.meta is not None: # type: ignore
|
||||
if self.by_epoch:
|
||||
cur_ckpt_filename = self.args.get(
|
||||
'filename_tmpl',
|
||||
'epoch_{}.pth').format(runner.epoch + 1) # type: ignore
|
||||
else:
|
||||
cur_ckpt_filename = self.args.get(
|
||||
'filename_tmpl',
|
||||
'iter_{}.pth').format(runner.iter + 1) # type: ignore
|
||||
runner.meta.setdefault('hook_msgs', dict()) # type: ignore
|
||||
runner.meta['hook_msgs'][ # type: ignore
|
||||
'last_ckpt'] = self.file_client.join_path(
|
||||
self.out_dir, cur_ckpt_filename) # type: ignore
|
||||
# remove other checkpoints
|
||||
if self.max_keep_ckpts > 0:
|
||||
if self.by_epoch:
|
||||
name = 'epoch_{}.pth'
|
||||
current_ckpt = runner.epoch + 1 # type: ignore
|
||||
else:
|
||||
name = 'iter_{}.pth'
|
||||
current_ckpt = runner.iter + 1 # type: ignore
|
||||
redundant_ckpts = range(
|
||||
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
||||
-self.interval)
|
||||
filename_tmpl = self.args.get('filename_tmpl', name)
|
||||
for _step in redundant_ckpts:
|
||||
ckpt_path = self.file_client.join_path(
|
||||
self.out_dir, filename_tmpl.format(_step)) # type: ignore
|
||||
if self.file_client.isfile(ckpt_path):
|
||||
self.file_client.remove(ckpt_path)
|
||||
else:
|
||||
break
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if self.by_epoch:
|
||||
return
|
||||
|
||||
# save checkpoint for following cases:
|
||||
# 1. every ``self.interval`` iterations
|
||||
# 2. reach the last iteration of training
|
||||
if self.every_n_iters(
|
||||
runner, self.interval) or (self.save_last
|
||||
and self.is_last_iter(runner)):
|
||||
runner.logger.info( # type: ignore
|
||||
f'Saving checkpoint at \
|
||||
{runner.iter + 1} iterations') # type: ignore
|
||||
if self.sync_buffer:
|
||||
pass
|
||||
# TODO
|
||||
self._save_checkpoint(runner)
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import sys
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from mmengine.hooks import CheckpointHook
|
||||
|
||||
sys.modules['file_client'] = sys.modules['mmengine.fileio.file_client']
|
||||
|
||||
|
||||
class MockPetrel:
|
||||
|
||||
_allow_symlink = False
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def allow_symlink(self):
|
||||
return self._allow_symlink
|
||||
|
||||
|
||||
prefix_to_backends = {'s3': MockPetrel}
|
||||
|
||||
|
||||
class TestCheckpointHook:
|
||||
|
||||
@patch('file_client.FileClient._prefix_to_backends', prefix_to_backends)
|
||||
def test_before_run(self):
|
||||
runner = Mock()
|
||||
runner.work_dir = './tmp'
|
||||
|
||||
# the out_dir of the checkpoint hook is None
|
||||
checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
|
||||
checkpoint_hook.before_run(runner)
|
||||
assert checkpoint_hook.out_dir == runner.work_dir
|
||||
|
||||
# the out_dir of the checkpoint hook is not None
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=1, by_epoch=True, out_dir='test_dir')
|
||||
checkpoint_hook.before_run(runner)
|
||||
assert checkpoint_hook.out_dir == 'test_dir/tmp'
|
||||
|
||||
# create_symlink in args and create_symlink is True
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
|
||||
checkpoint_hook.before_run(runner)
|
||||
assert checkpoint_hook.args['create_symlink']
|
||||
|
||||
runner.work_dir = 's3://path/of/file'
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=1, by_epoch=True, create_symlink=True)
|
||||
checkpoint_hook.before_run(runner)
|
||||
assert not checkpoint_hook.args['create_symlink']
|
||||
|
||||
def test_after_train_epoch(self):
|
||||
runner = Mock()
|
||||
runner.work_dir = './tmp'
|
||||
runner.epoch = 9
|
||||
runner.meta = dict()
|
||||
|
||||
# by epoch is True
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||
checkpoint_hook.before_run(runner)
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert (runner.epoch + 1) % 2 == 0
|
||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
|
||||
|
||||
# epoch can not be evenly divided by 2
|
||||
runner.epoch = 10
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
|
||||
|
||||
# by epoch is False
|
||||
runner.epoch = 9
|
||||
runner.meta = dict()
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||
checkpoint_hook.before_run(runner)
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert runner.meta.get('hook_msgs', None) is None
|
||||
|
||||
# max_keep_ckpts > 0
|
||||
with TemporaryDirectory() as tempo_dir:
|
||||
runner.work_dir = tempo_dir
|
||||
os.system(f'touch {tempo_dir}/epoch_8.pth')
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=2, by_epoch=True, max_keep_ckpts=1)
|
||||
checkpoint_hook.before_run(runner)
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert (runner.epoch + 1) % 2 == 0
|
||||
assert not os.path.exists(f'{tempo_dir}/epoch_8.pth')
|
||||
|
||||
def test_after_train_iter(self):
|
||||
runner = Mock()
|
||||
runner.work_dir = './tmp'
|
||||
runner.iter = 9
|
||||
runner.meta = dict()
|
||||
|
||||
# by epoch is True
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||
checkpoint_hook.before_run(runner)
|
||||
checkpoint_hook.after_train_iter(runner)
|
||||
assert runner.meta.get('hook_msgs', None) is None
|
||||
|
||||
# by epoch is False
|
||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||
checkpoint_hook.before_run(runner)
|
||||
checkpoint_hook.after_train_iter(runner)
|
||||
assert (runner.iter + 1) % 2 == 0
|
||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
|
||||
|
||||
# epoch can not be evenly divided by 2
|
||||
runner.iter = 10
|
||||
checkpoint_hook.after_train_epoch(runner)
|
||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
|
||||
|
||||
# max_keep_ckpts > 0
|
||||
runner.iter = 9
|
||||
with TemporaryDirectory() as tempo_dir:
|
||||
runner.work_dir = tempo_dir
|
||||
os.system(f'touch {tempo_dir}/iter_8.pth')
|
||||
checkpoint_hook = CheckpointHook(
|
||||
interval=2, by_epoch=False, max_keep_ckpts=1)
|
||||
checkpoint_hook.before_run(runner)
|
||||
checkpoint_hook.after_train_iter(runner)
|
||||
assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
|
Loading…
Reference in New Issue