[Feature] Add parameter save_begin for CheckpointHook (#1271)

This commit is contained in:
KerwinKai 2023-07-25 19:21:21 +08:00 committed by GitHub
parent 3871881ef6
commit 68360e7ce8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 106 additions and 10 deletions

View File

@ -71,10 +71,11 @@ runner.train()
- Save the best checkpoints
- Specify the path to save the checkpoints
- Make checkpoints for publish
- Control the epoch number or iteration number at which checkpoint saving begins
For more features, please read the [CheckpointHook API documentation](mmengine.hooks.CheckpointHook).
The four features mentioned above are described below.
The six features mentioned above are described below.
- Save checkpoints by interval, and support saving them by epoch or iteration
@ -129,6 +130,14 @@ The four features mentioned above are described below.
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
```
- Control the epoch number or iteration number at which checkpoint saving begins
If you want to set the number of epochs or iterations to control the start of saving weights, you can set the `save_begin` parameter, defaults to 0, which means saving checkpoints from the beginning of training. For example, if you train for a total of 10 epochs, and `save_begin` is set to 5, then the checkpoints for epochs 5, 6, 7, 8, 9, and 10 will be saved. If `interval=2`, only save checkpoints for epochs 5, 7 and 9.
```python
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))
```
[LoggerHook](mmengine.hooks.LoggerHook) collects logs from different components of `Runner` and write them to terminal, JSON file, tensorboard and wandb .etc.
If we want to output (or save) the logs every 20 iterations, we can set the `interval` parameter and configure it as follows.

View File

@ -72,10 +72,11 @@ runner.train()
- 保存最优权重
- 指定保存权重的路径
- 制作发布用的权重
- 设置开始保存权重的 epoch 数或者 iteration 数
如需了解其他功能,请阅读 [CheckpointHook API 文档](mmengine.hooks.CheckpointHook)。
下面介绍上面提到的 4 个功能。
下面介绍上面提到的 6 个功能。
- 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重
@ -130,6 +131,14 @@ runner.train()
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
```
- 设置开始保存权重的 epoch 数或者 iteration 数
如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 `save_begin` 参数,默认为 0表示从训练开始就保存权重。例如如果总共训练 10 个 epoch并且 `save_begin` 设置为 5则将保存第 5、6、7、8、9 和 10 个 epoch 的权重。如果 `interval=2`,则仅保存第 5、7 和 9 个 epoch 的权重。
```python
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))
```
### LoggerHook
[LoggerHook](mmengine.hooks.LoggerHook) 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。

View File

@ -92,6 +92,11 @@ class CheckpointHook(Hook):
publish model with keys in the list after training.
Defaults to None.
`New in version 0.7.1.`
save_begin (int): Control the epoch number or iteration number
at which checkpoint saving begins. Defaults to 0, which means
saving at the beginning.
`New in version 0.8.3.`
Examples:
>>> # Save best based on single metric
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
@ -139,6 +144,7 @@ class CheckpointHook(Hook):
filename_tmpl: Optional[str] = None,
backend_args: Optional[dict] = None,
published_keys: Union[str, List[str], None] = None,
save_begin: int = 0,
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
@ -244,6 +250,10 @@ class CheckpointHook(Hook):
self.published_keys = published_keys
self.last_ckpt = None
if save_begin < 0:
raise ValueError(
'save_begin should not be less than 0, but got {save_begin}')
self.save_begin = save_begin
def before_train(self, runner) -> None:
"""Finish all operations, related to checkpoint.
@ -326,9 +336,9 @@ class CheckpointHook(Hook):
return
# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 1. every ``self.interval`` epochs which start at ``self.save_begin``
# 2. reach the last epoch of training
if self.every_n_epochs(runner, self.interval) or (
if self.every_n_epochs(runner, self.interval, self.save_begin) or (
self.save_last and self.is_last_train_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
@ -644,8 +654,10 @@ class CheckpointHook(Hook):
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# which start at ``self.save_begin``
# 2. reach the last iteration of training
if self.every_n_train_iters(runner, self.interval) or \
if self.every_n_train_iters(runner, self.interval,
self.save_begin) or \
(self.save_last and
self.is_last_train_iter(runner)):
runner.logger.info(

View File

@ -335,18 +335,21 @@ class Hook:
mode (str): Current mode of runner. Defaults to 'train'.
"""
def every_n_epochs(self, runner, n: int) -> bool:
def every_n_epochs(self, runner, n: int, start: int = 0) -> bool:
"""Test whether current epoch can be evenly divided by n.
Args:
runner (Runner): The runner of the training, validation or testing
process.
n (int): Whether current epoch can be evenly divided by n.
start (int): Starting from `start` to check the logic for
every n epochs. Defaults to 0.
Returns:
bool: Whether current epoch can be evenly divided by n.
"""
return (runner.epoch + 1) % n == 0 if n > 0 else False
dividend = runner.epoch + 1 - start
return dividend % n == 0 if dividend >= 0 and n > 0 else False
def every_n_inner_iters(self, batch_idx: int, n: int) -> bool:
"""Test whether current inner iteration can be evenly divided by n.
@ -363,19 +366,22 @@ class Hook:
"""
return (batch_idx + 1) % n == 0 if n > 0 else False
def every_n_train_iters(self, runner, n: int) -> bool:
def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool:
"""Test whether current training iteration can be evenly divided by n.
Args:
runner (Runner): The runner of the training, validation or testing
process.
n (int): Whether current iteration can be evenly divided by n.
start (int): Starting from `start` to check the logic for
every n iterations. Defaults to 0.
Returns:
bool: Return True if the current iteration can be evenly divided
by n, otherwise False.
"""
return (runner.iter + 1) % n == 0 if n > 0 else False
dividend = runner.iter + 1 - start
return dividend % n == 0 if dividend >= 0 and n > 0 else False
def end_of_epoch(self, dataloader, batch_idx: int) -> bool:
"""Check whether the current iteration reaches the last iteration of

View File

@ -622,7 +622,7 @@ class TestCheckpointHook(RunnerTestCase):
self.assertEqual(best_ckpt['meta']['epoch'], 0)
self.assertEqual(best_ckpt['meta']['iter'], 5)
# test save published keys
# Test save published keys
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
runner = self.build_runner(cfg)
@ -632,3 +632,59 @@ class TestCheckpointHook(RunnerTestCase):
any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files))
self.clear_work_dir()
# Test save_begin with interval=2, save_begin=5
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 2
cfg.default_hooks.checkpoint.save_begin = 5
runner = self.build_runner(cfg)
runner.train()
for i in range(5):
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
for i in range(5, 11):
if (i - 5) % 2 == 1:
self.assertFalse(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
else:
self.assertTrue(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
self.clear_work_dir()
# Test save_begin with interval=2, save_begin=0
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 2
runner = self.build_runner(cfg)
runner.train()
for i in range(1, 11):
if i % 2 == 1:
self.assertFalse(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
else:
self.assertTrue(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
self.clear_work_dir()
# Test save_begin with interval=2, save_begin=1
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 2
cfg.default_hooks.checkpoint.save_begin = 1
runner = self.build_runner(cfg)
runner.train()
for i in range(1, 11):
if i % 2 == 1:
self.assertTrue(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
else:
self.assertFalse(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
self.clear_work_dir()

View File

@ -63,6 +63,9 @@ class TestLoggerHook(RunnerTestCase):
def test_after_train_iter(self):
# Test LoggerHook by iter.
# Avoid to compare `Runner.iter` (MagicMock) with other integers.
ori_every_n_train_iters = LoggerHook.every_n_train_iters
LoggerHook.every_n_train_iters = MagicMock(return_value=True)
runner = MagicMock()
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
@ -112,6 +115,7 @@ class TestLoggerHook(RunnerTestCase):
logger_hook = LoggerHook()
logger_hook.after_train_iter(runner, batch_idx=8)
runner.log_processor.get_log_after_iter.assert_called()
LoggerHook.every_n_train_iters = ori_every_n_train_iters
def test_after_val_epoch(self):
logger_hook = LoggerHook()