mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Save checkpoint again to update best_ckpt of ckpt (#1168)
This commit is contained in:
parent
9d9f2b761e
commit
19aa1eb780
@ -3,6 +3,7 @@ import hashlib
|
||||
import logging
|
||||
import os.path as osp
|
||||
import pickle
|
||||
from collections import deque
|
||||
from math import inf
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
@ -242,6 +243,8 @@ class CheckpointHook(Hook):
|
||||
'Find duplicate elements in "published_keys".')
|
||||
self.published_keys = published_keys
|
||||
|
||||
self.last_ckpt = None
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Finish all operations, related to checkpoint.
|
||||
|
||||
@ -294,6 +297,25 @@ class CheckpointHook(Hook):
|
||||
key_indicator] = runner.message_hub.get_info(
|
||||
best_ckpt_name)
|
||||
|
||||
if self.max_keep_ckpts > 0:
|
||||
keep_ckpt_ids = []
|
||||
if 'keep_ckpt_ids' in runner.message_hub.runtime_info:
|
||||
keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids')
|
||||
|
||||
while len(keep_ckpt_ids) > self.max_keep_ckpts:
|
||||
step = keep_ckpt_ids.pop(0)
|
||||
if is_main_process():
|
||||
path = self.file_backend.join_path(
|
||||
self.out_dir, self.filename_tmpl.format(step))
|
||||
if self.file_backend.isfile(path):
|
||||
self.file_backend.remove(path)
|
||||
elif self.file_backend.isdir(path):
|
||||
# checkpoints saved by deepspeed are directories
|
||||
self.file_backend.rmtree(path)
|
||||
|
||||
self.keep_ckpt_ids: deque = deque(keep_ckpt_ids,
|
||||
self.max_keep_ckpts)
|
||||
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||
|
||||
@ -337,14 +359,13 @@ class CheckpointHook(Hook):
|
||||
if self.published_keys is None:
|
||||
return
|
||||
|
||||
if self.save_last and 'last_ckpt' in runner.message_hub.runtime_info:
|
||||
last_ckpt = runner.message_hub.get_info('last_ckpt')
|
||||
self._publish_model(runner, last_ckpt)
|
||||
if self.save_last and self.last_ckpt is not None:
|
||||
self._publish_model(runner, self.last_ckpt)
|
||||
|
||||
if getattr(self, 'best_ckpt_path', None) is not None:
|
||||
self._publish_model(runner, str(self.best_ckpt_path))
|
||||
if getattr(self, 'best_ckpt_path_dict', None) is not None:
|
||||
for key, best_ckpt in self.best_ckpt_path_dict.items():
|
||||
for best_ckpt in self.best_ckpt_path_dict.values():
|
||||
self._publish_model(runner, best_ckpt)
|
||||
|
||||
@master_only
|
||||
@ -379,16 +400,35 @@ class CheckpointHook(Hook):
|
||||
f'{final_path}.',
|
||||
logger='current')
|
||||
|
||||
def _save_checkpoint(self, runner) -> None:
|
||||
"""Save the current checkpoint and delete outdated checkpoint.
|
||||
def _save_checkpoint_with_step(self, runner, step, meta):
|
||||
# remove other checkpoints before save checkpoint to make the
|
||||
# self.keep_ckpt_ids are saved as expected
|
||||
if self.max_keep_ckpts > 0:
|
||||
# _save_checkpoint and _save_best_checkpoint may call this
|
||||
# _save_checkpoint_with_step in one epoch
|
||||
if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step:
|
||||
pass
|
||||
else:
|
||||
if len(self.keep_ckpt_ids) == self.max_keep_ckpts:
|
||||
_step = self.keep_ckpt_ids.popleft()
|
||||
if is_main_process():
|
||||
ckpt_path = self.file_backend.join_path(
|
||||
self.out_dir, self.filename_tmpl.format(_step))
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.by_epoch:
|
||||
ckpt_filename = self.filename_tmpl.format(runner.epoch + 1)
|
||||
else:
|
||||
ckpt_filename = self.filename_tmpl.format(runner.iter + 1)
|
||||
if self.file_backend.isfile(ckpt_path):
|
||||
self.file_backend.remove(ckpt_path)
|
||||
elif self.file_backend.isdir(ckpt_path):
|
||||
# checkpoints saved by deepspeed are directories
|
||||
self.file_backend.rmtree(ckpt_path)
|
||||
|
||||
self.keep_ckpt_ids.append(step)
|
||||
runner.message_hub.update_info('keep_ckpt_ids',
|
||||
list(self.keep_ckpt_ids))
|
||||
|
||||
ckpt_filename = self.filename_tmpl.format(step)
|
||||
self.last_ckpt = self.file_backend.join_path(self.out_dir,
|
||||
ckpt_filename)
|
||||
runner.message_hub.update_info('last_ckpt', self.last_ckpt)
|
||||
|
||||
runner.save_checkpoint(
|
||||
self.out_dir,
|
||||
@ -396,6 +436,7 @@ class CheckpointHook(Hook):
|
||||
self.file_client_args,
|
||||
save_optimizer=self.save_optimizer,
|
||||
save_param_scheduler=self.save_param_scheduler,
|
||||
meta=meta,
|
||||
by_epoch=self.by_epoch,
|
||||
backend_args=self.backend_args,
|
||||
**self.args)
|
||||
@ -405,31 +446,24 @@ class CheckpointHook(Hook):
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
runner.message_hub.update_info(
|
||||
'last_ckpt',
|
||||
self.file_backend.join_path(self.out_dir, ckpt_filename))
|
||||
|
||||
# remove other checkpoints
|
||||
if self.max_keep_ckpts > 0:
|
||||
if self.by_epoch:
|
||||
current_ckpt = runner.epoch + 1
|
||||
else:
|
||||
current_ckpt = runner.iter + 1
|
||||
redundant_ckpts = range(
|
||||
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
||||
-self.interval)
|
||||
for _step in redundant_ckpts:
|
||||
ckpt_path = self.file_backend.join_path(
|
||||
self.out_dir, self.filename_tmpl.format(_step))
|
||||
if self.file_backend.isfile(ckpt_path):
|
||||
self.file_backend.remove(ckpt_path)
|
||||
else:
|
||||
break
|
||||
|
||||
save_file = osp.join(runner.work_dir, 'last_checkpoint')
|
||||
filepath = self.file_backend.join_path(self.out_dir, ckpt_filename)
|
||||
with open(save_file, 'w') as f:
|
||||
f.write(filepath)
|
||||
f.write(self.last_ckpt) # type: ignore
|
||||
|
||||
def _save_checkpoint(self, runner) -> None:
|
||||
"""Save the current checkpoint and delete outdated checkpoint.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.by_epoch:
|
||||
step = runner.epoch + 1
|
||||
meta = dict(epoch=step, iter=runner.iter)
|
||||
else:
|
||||
step = runner.iter + 1
|
||||
meta = dict(epoch=runner.epoch, iter=step)
|
||||
|
||||
self._save_checkpoint_with_step(runner, step, meta=meta)
|
||||
|
||||
def _save_best_checkpoint(self, runner, metrics) -> None:
|
||||
"""Save the current checkpoint and delete outdated checkpoint.
|
||||
@ -448,10 +482,13 @@ class CheckpointHook(Hook):
|
||||
ckpt_filename = self.filename_tmpl.format(runner.iter)
|
||||
cur_type, cur_time = 'iter', runner.iter
|
||||
|
||||
meta = dict(epoch=runner.epoch, iter=runner.iter)
|
||||
|
||||
# handle auto in self.key_indicators and self.rules before the loop
|
||||
if 'auto' in self.key_indicators:
|
||||
self._init_rule(self.rules, [list(metrics.keys())[0]])
|
||||
|
||||
best_ckpt_updated = False
|
||||
# save best logic
|
||||
# get score from messagehub
|
||||
for key_indicator, rule in zip(self.key_indicators, self.rules):
|
||||
@ -475,13 +512,18 @@ class CheckpointHook(Hook):
|
||||
key_score, best_score):
|
||||
continue
|
||||
|
||||
best_ckpt_updated = True
|
||||
|
||||
best_score = key_score
|
||||
runner.message_hub.update_info(best_score_key, best_score)
|
||||
|
||||
if best_ckpt_path and \
|
||||
self.file_backend.isfile(best_ckpt_path) and \
|
||||
is_main_process():
|
||||
self.file_backend.remove(best_ckpt_path)
|
||||
if best_ckpt_path and is_main_process():
|
||||
if self.file_backend.isfile(best_ckpt_path):
|
||||
self.file_backend.remove(best_ckpt_path)
|
||||
else:
|
||||
# checkpoints saved by deepspeed are directories
|
||||
self.file_backend.rmtree(best_ckpt_path)
|
||||
|
||||
runner.logger.info(
|
||||
f'The previous best checkpoint {best_ckpt_path} '
|
||||
'is removed')
|
||||
@ -507,12 +549,21 @@ class CheckpointHook(Hook):
|
||||
file_client_args=self.file_client_args,
|
||||
save_optimizer=False,
|
||||
save_param_scheduler=False,
|
||||
meta=meta,
|
||||
by_epoch=False,
|
||||
backend_args=self.backend_args)
|
||||
runner.logger.info(
|
||||
f'The best checkpoint with {best_score:0.4f} {key_indicator} '
|
||||
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
|
||||
|
||||
# save checkpoint again to update the best_score and best_ckpt stored
|
||||
# in message_hub because the checkpoint saved in `after_train_epoch`
|
||||
# or `after_train_iter` stage only keep the previous best checkpoint
|
||||
# not the current best checkpoint which causes the current best
|
||||
# checkpoint can not be removed when resuming training.
|
||||
if best_ckpt_updated and self.last_ckpt is not None:
|
||||
self._save_checkpoint_with_step(runner, cur_time, meta)
|
||||
|
||||
def _init_rule(self, rules, key_indicators) -> None:
|
||||
"""Initialize rule, key_indicator, comparison_func, and best score. If
|
||||
key_indicator is a list of string and rule is a string, all metric in
|
||||
|
@ -2090,7 +2090,7 @@ class Runner:
|
||||
file_client_args: Optional[dict] = None,
|
||||
save_optimizer: bool = True,
|
||||
save_param_scheduler: bool = True,
|
||||
meta: dict = None,
|
||||
meta: Optional[dict] = None,
|
||||
by_epoch: bool = True,
|
||||
backend_args: Optional[dict] = None,
|
||||
):
|
||||
@ -2112,8 +2112,8 @@ class Runner:
|
||||
to the checkpoint. Defaults to True.
|
||||
meta (dict, optional): The meta information to be saved in the
|
||||
checkpoint. Defaults to None.
|
||||
by_epoch (bool): Whether the scheduled momentum is updated by
|
||||
epochs. Defaults to True.
|
||||
by_epoch (bool): Decide the number of epoch or iteration saved in
|
||||
checkpoint. Defaults to True.
|
||||
backend_args (dict, optional): Arguments to instantiate the
|
||||
prefix of uri corresponding backend. Defaults to None.
|
||||
New in v0.2.0.
|
||||
@ -2129,9 +2129,11 @@ class Runner:
|
||||
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is
|
||||
# called by `after_train_epoch`` method of `CheckpointHook` so
|
||||
# `epoch` should be `self.epoch + 1`
|
||||
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
||||
meta.setdefault('epoch', self.epoch + 1)
|
||||
meta.setdefault('iter', self.iter)
|
||||
else:
|
||||
meta.update(epoch=self.epoch, iter=self.iter + 1)
|
||||
meta.setdefault('epoch', self.epoch)
|
||||
meta.setdefault('iter', self.iter + 1)
|
||||
|
||||
if file_client_args is not None:
|
||||
warnings.warn(
|
||||
|
@ -433,12 +433,13 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
setattr(common_cfg.train_cfg, f'max_{training_type}s', 11)
|
||||
checkpoint_cfg = dict(
|
||||
type='CheckpointHook',
|
||||
interval=2,
|
||||
interval=1,
|
||||
by_epoch=training_type == 'epoch')
|
||||
common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
|
||||
|
||||
# Test interval in epoch based training
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.interval = 2
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
|
||||
@ -502,13 +503,11 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test max_keep_ckpts
|
||||
# Test max_keep_ckpts=1
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.interval = 1
|
||||
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
print(os.listdir(cfg.work_dir))
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||
|
||||
@ -518,6 +517,45 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test max_keep_ckpts=3
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.max_keep_ckpts = 3
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth')))
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||
|
||||
for i in range(9):
|
||||
self.assertFalse(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
|
||||
|
||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
|
||||
self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'],
|
||||
[9, 10, 11])
|
||||
|
||||
# Test max_keep_ckpts when resuming traing
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
setattr(cfg.train_cfg, f'max_{training_type}s', 12)
|
||||
cfg.default_hooks.checkpoint.max_keep_ckpts = 2
|
||||
cfg.load_from = osp.join(cfg.work_dir, f'{training_type}_11.pth')
|
||||
cfg.resume = True
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
self.assertFalse(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth')))
|
||||
self.assertFalse(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
|
||||
self.assertTrue(
|
||||
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_12.pth')))
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test filename_tmpl
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
|
||||
@ -529,18 +567,61 @@ class TestCheckpointHook(RunnerTestCase):
|
||||
|
||||
# Test save_best
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.interval = 1
|
||||
cfg.default_hooks.checkpoint.save_best = 'test/acc'
|
||||
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
|
||||
cfg.train_cfg.val_interval = 1
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
best_ckpt = osp.join(cfg.work_dir,
|
||||
f'best_test_acc_{training_type}_5.pth')
|
||||
self.assertTrue(osp.isfile(best_ckpt))
|
||||
best_ckpt_path = osp.join(cfg.work_dir,
|
||||
f'best_test_acc_{training_type}_5.pth')
|
||||
best_ckpt = torch.load(best_ckpt_path)
|
||||
|
||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
|
||||
self.assertEqual(best_ckpt_path,
|
||||
ckpt['message_hub']['runtime_info']['best_ckpt'])
|
||||
|
||||
if training_type == 'epoch':
|
||||
self.assertEqual(ckpt['meta']['epoch'], 5)
|
||||
self.assertEqual(ckpt['meta']['iter'], 20)
|
||||
self.assertEqual(best_ckpt['meta']['epoch'], 5)
|
||||
self.assertEqual(best_ckpt['meta']['iter'], 20)
|
||||
else:
|
||||
self.assertEqual(ckpt['meta']['epoch'], 0)
|
||||
self.assertEqual(ckpt['meta']['iter'], 5)
|
||||
self.assertEqual(best_ckpt['meta']['epoch'], 0)
|
||||
self.assertEqual(best_ckpt['meta']['iter'], 5)
|
||||
|
||||
self.clear_work_dir()
|
||||
|
||||
# Test save_best with interval=2
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.save_best = 'test/acc'
|
||||
cfg.default_hooks.checkpoint.interval = 2
|
||||
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
|
||||
cfg.train_cfg.val_interval = 1
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
best_ckpt_path = osp.join(cfg.work_dir,
|
||||
f'best_test_acc_{training_type}_5.pth')
|
||||
best_ckpt = torch.load(best_ckpt_path)
|
||||
|
||||
# if the current ckpt is the best, the interval will be ignored the
|
||||
# the ckpt will also be saved
|
||||
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
|
||||
self.assertEqual(best_ckpt_path,
|
||||
ckpt['message_hub']['runtime_info']['best_ckpt'])
|
||||
|
||||
if training_type == 'epoch':
|
||||
self.assertEqual(ckpt['meta']['epoch'], 5)
|
||||
self.assertEqual(ckpt['meta']['iter'], 20)
|
||||
self.assertEqual(best_ckpt['meta']['epoch'], 5)
|
||||
self.assertEqual(best_ckpt['meta']['iter'], 20)
|
||||
else:
|
||||
self.assertEqual(ckpt['meta']['epoch'], 0)
|
||||
self.assertEqual(ckpt['meta']['iter'], 5)
|
||||
self.assertEqual(best_ckpt['meta']['epoch'], 0)
|
||||
self.assertEqual(best_ckpt['meta']['iter'], 5)
|
||||
|
||||
# test save published keys
|
||||
cfg = copy.deepcopy(common_cfg)
|
||||
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
|
||||
|
Loading…
x
Reference in New Issue
Block a user