mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add parameter save_begin
for CheckpointHook (#1271)
This commit is contained in:
parent
3871881ef6
commit
68360e7ce8
@ -71,10 +71,11 @@ runner.train()
|
||||
- Save the best checkpoints
|
||||
- Specify the path to save the checkpoints
|
||||
- Make checkpoints for publish
|
||||
- Control the epoch number or iteration number at which checkpoint saving begins
|
||||
|
||||
For more features, please read the [CheckpointHook API documentation](mmengine.hooks.CheckpointHook).
|
||||
|
||||
The four features mentioned above are described below.
|
||||
The six features mentioned above are described below.
|
||||
|
||||
- Save checkpoints by interval, and support saving them by epoch or iteration
|
||||
|
||||
@ -129,6 +130,14 @@ The four features mentioned above are described below.
|
||||
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
|
||||
```
|
||||
|
||||
- Control the epoch number or iteration number at which checkpoint saving begins
|
||||
|
||||
If you want to set the number of epochs or iterations to control the start of saving weights, you can set the `save_begin` parameter, defaults to 0, which means saving checkpoints from the beginning of training. For example, if you train for a total of 10 epochs, and `save_begin` is set to 5, then the checkpoints for epochs 5, 6, 7, 8, 9, and 10 will be saved. If `interval=2`, only save checkpoints for epochs 5, 7 and 9.
|
||||
|
||||
```python
|
||||
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))
|
||||
```
|
||||
|
||||
[LoggerHook](mmengine.hooks.LoggerHook) collects logs from different components of `Runner` and write them to terminal, JSON file, tensorboard and wandb .etc.
|
||||
|
||||
If we want to output (or save) the logs every 20 iterations, we can set the `interval` parameter and configure it as follows.
|
||||
|
@ -72,10 +72,11 @@ runner.train()
|
||||
- 保存最优权重
|
||||
- 指定保存权重的路径
|
||||
- 制作发布用的权重
|
||||
- 设置开始保存权重的 epoch 数或者 iteration 数
|
||||
|
||||
如需了解其他功能,请阅读 [CheckpointHook API 文档](mmengine.hooks.CheckpointHook)。
|
||||
|
||||
下面介绍上面提到的 4 个功能。
|
||||
下面介绍上面提到的 6 个功能。
|
||||
|
||||
- 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重
|
||||
|
||||
@ -130,6 +131,14 @@ runner.train()
|
||||
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
|
||||
```
|
||||
|
||||
- 设置开始保存权重的 epoch 数或者 iteration 数
|
||||
|
||||
如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 `save_begin` 参数,默认为 0,表示从训练开始就保存权重。例如,如果总共训练 10 个 epoch,并且 `save_begin` 设置为 5,则将保存第 5、6、7、8、9 和 10 个 epoch 的权重。如果 `interval=2`,则仅保存第 5、7 和 9 个 epoch 的权重。
|
||||
|
||||
```python
|
||||
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))
|
||||
```
|
||||
|
||||
### LoggerHook
|
||||
|
||||
[LoggerHook](mmengine.hooks.LoggerHook) 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。
|
||||
|
@ -92,6 +92,11 @@ class CheckpointHook(Hook):
|
||||
publish model with keys in the list after training.
|
||||
Defaults to None.
|
||||
`New in version 0.7.1.`
|
||||
save_begin (int): Control the epoch number or iteration number
|
||||
at which checkpoint saving begins. Defaults to 0, which means
|
||||
saving at the beginning.
|
||||
`New in version 0.8.3.`
|
||||
|
||||
Examples:
|
||||
>>> # Save best based on single metric
|
||||
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
|
||||
@ -139,6 +144,7 @@ class CheckpointHook(Hook):
|
||||
filename_tmpl: Optional[str] = None,
|
||||
backend_args: Optional[dict] = None,
|
||||
published_keys: Union[str, List[str], None] = None,
|
||||
save_begin: int = 0,
|
||||
**kwargs) -> None:
|
||||
self.interval = interval
|
||||
self.by_epoch = by_epoch
|
||||
@ -244,6 +250,10 @@ class CheckpointHook(Hook):
|
||||
self.published_keys = published_keys
|
||||
|
||||
self.last_ckpt = None
|
||||
if save_begin < 0:
|
||||
raise ValueError(
|
||||
'save_begin should not be less than 0, but got {save_begin}')
|
||||
self.save_begin = save_begin
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Finish all operations, related to checkpoint.
|
||||
@ -326,9 +336,9 @@ class CheckpointHook(Hook):
|
||||
return
|
||||
|
||||
# save checkpoint for following cases:
|
||||
# 1. every ``self.interval`` epochs
|
||||
# 1. every ``self.interval`` epochs which start at ``self.save_begin``
|
||||
# 2. reach the last epoch of training
|
||||
if self.every_n_epochs(runner, self.interval) or (
|
||||
if self.every_n_epochs(runner, self.interval, self.save_begin) or (
|
||||
self.save_last and self.is_last_train_epoch(runner)):
|
||||
runner.logger.info(
|
||||
f'Saving checkpoint at {runner.epoch + 1} epochs')
|
||||
@ -644,8 +654,10 @@ class CheckpointHook(Hook):
|
||||
|
||||
# save checkpoint for following cases:
|
||||
# 1. every ``self.interval`` iterations
|
||||
# which start at ``self.save_begin``
|
||||
# 2. reach the last iteration of training
|
||||
if self.every_n_train_iters(runner, self.interval) or \
|
||||
if self.every_n_train_iters(runner, self.interval,
|
||||
self.save_begin) or \
|
||||
(self.save_last and
|
||||
self.is_last_train_iter(runner)):
|
||||
runner.logger.info(
|
||||
|
@ -335,18 +335,21 @@ class Hook:
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
|
||||
def every_n_epochs(self, runner, n: int) -> bool:
|
||||
def every_n_epochs(self, runner, n: int, start: int = 0) -> bool:
|
||||
"""Test whether current epoch can be evenly divided by n.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
n (int): Whether current epoch can be evenly divided by n.
|
||||
start (int): Starting from `start` to check the logic for
|
||||
every n epochs. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
bool: Whether current epoch can be evenly divided by n.
|
||||
"""
|
||||
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
||||
dividend = runner.epoch + 1 - start
|
||||
return dividend % n == 0 if dividend >= 0 and n > 0 else False
|
||||
|
||||
def every_n_inner_iters(self, batch_idx: int, n: int) -> bool:
|
||||
"""Test whether current inner iteration can be evenly divided by n.
|
||||
@ -363,19 +366,22 @@ class Hook:
|
||||
"""
|
||||
return (batch_idx + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_train_iters(self, runner, n: int) -> bool:
|
||||
def every_n_train_iters(self, runner, n: int, start: int = 0) -> bool:
|
||||
"""Test whether current training iteration can be evenly divided by n.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
n (int): Whether current iteration can be evenly divided by n.
|
||||
start (int): Starting from `start` to check the logic for
|
||||
every n iterations. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
bool: Return True if the current iteration can be evenly divided
|
||||
by n, otherwise False.
|
||||
"""
|
||||
return (runner.iter + 1) % n == 0 if n > 0 else False
|
||||
dividend = runner.iter + 1 - start
|
||||
return dividend % n == 0 if dividend >= 0 and n > 0 else False
|
||||
|
||||
def end_of_epoch(self, dataloader, batch_idx: int) -> bool:
|
||||
"""Check whether the current iteration reaches the last iteration of
|
||||
|
@ -622,7 +622,7 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
self.assertEqual(best_ckpt['meta']['epoch'], 0)
|
||||
self.assertEqual(best_ckpt['meta']['iter'], 5)
|
||||
|
||||
# test save published keys
|
||||
# Test save published keys
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
|
||||
runner = self.build_runner(cfg)
|
||||
@ -632,3 +632,59 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files))
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test save_begin with interval=2, save_begin=5
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.interval = 2
|
||||
cfg.default_hooks.checkpoint.save_begin = 5
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
|
||||
for i in range(5):
|
||||
self.assertFalse(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
for i in range(5, 11):
|
||||
if (i - 5) % 2 == 1:
|
||||
self.assertFalse(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
else:
|
||||
self.assertTrue(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test save_begin with interval=2, save_begin=0
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.interval = 2
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
|
||||
for i in range(1, 11):
|
||||
if i % 2 == 1:
|
||||
self.assertFalse(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
else:
|
||||
self.assertTrue(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test save_begin with interval=2, save_begin=1
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.interval = 2
|
||||
cfg.default_hooks.checkpoint.save_begin = 1
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
|
||||
for i in range(1, 11):
|
||||
if i % 2 == 1:
|
||||
self.assertTrue(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
else:
|
||||
self.assertFalse(
|
||||
osp.isfile(
|
||||
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
self.clear_work_dir()
|
||||
|
@ -63,6 +63,9 @@ class TestLoggerHook(RunnerTestCase):
|
||||
|
||||
def test_after_train_iter(self):
|
||||
# Test LoggerHook by iter.
|
||||
# Avoid to compare `Runner.iter` (MagicMock) with other integers.
|
||||
ori_every_n_train_iters = LoggerHook.every_n_train_iters
|
||||
LoggerHook.every_n_train_iters = MagicMock(return_value=True)
|
||||
runner = MagicMock()
|
||||
runner.log_processor.get_log_after_iter = MagicMock(
|
||||
return_value=(dict(), 'log_str'))
|
||||
@ -112,6 +115,7 @@ class TestLoggerHook(RunnerTestCase):
|
||||
logger_hook = LoggerHook()
|
||||
logger_hook.after_train_iter(runner, batch_idx=8)
|
||||
runner.log_processor.get_log_after_iter.assert_called()
|
||||
LoggerHook.every_n_train_iters = ori_every_n_train_iters
|
||||
|
||||
def test_after_val_epoch(self):
|
||||
logger_hook = LoggerHook()
|
||||
|
Loading…
x
Reference in New Issue
Block a user