[Fix] Save checkpoint again to update best_ckpt of ckpt (#1168)

This commit is contained in:
Zaida Zhou 2023-06-02 14:42:56 +08:00 committed by GitHub
parent 9d9f2b761e
commit 19aa1eb780
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 187 additions and 53 deletions

View File

@ -3,6 +3,7 @@ import hashlib
import logging
import os.path as osp
import pickle
from collections import deque
from math import inf
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union
@ -242,6 +243,8 @@ class CheckpointHook(Hook):
'Find duplicate elements in "published_keys".')
self.published_keys = published_keys
self.last_ckpt = None
def before_train(self, runner) -> None:
"""Finish all operations, related to checkpoint.
@ -294,6 +297,25 @@ class CheckpointHook(Hook):
key_indicator] = runner.message_hub.get_info(
best_ckpt_name)
if self.max_keep_ckpts > 0:
keep_ckpt_ids = []
if 'keep_ckpt_ids' in runner.message_hub.runtime_info:
keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids')
while len(keep_ckpt_ids) > self.max_keep_ckpts:
step = keep_ckpt_ids.pop(0)
if is_main_process():
path = self.file_backend.join_path(
self.out_dir, self.filename_tmpl.format(step))
if self.file_backend.isfile(path):
self.file_backend.remove(path)
elif self.file_backend.isdir(path):
# checkpoints saved by deepspeed are directories
self.file_backend.rmtree(path)
self.keep_ckpt_ids: deque = deque(keep_ckpt_ids,
self.max_keep_ckpts)
def after_train_epoch(self, runner) -> None:
"""Save the checkpoint and synchronize buffers after each epoch.
@ -337,14 +359,13 @@ class CheckpointHook(Hook):
if self.published_keys is None:
return
if self.save_last and 'last_ckpt' in runner.message_hub.runtime_info:
last_ckpt = runner.message_hub.get_info('last_ckpt')
self._publish_model(runner, last_ckpt)
if self.save_last and self.last_ckpt is not None:
self._publish_model(runner, self.last_ckpt)
if getattr(self, 'best_ckpt_path', None) is not None:
self._publish_model(runner, str(self.best_ckpt_path))
if getattr(self, 'best_ckpt_path_dict', None) is not None:
for key, best_ckpt in self.best_ckpt_path_dict.items():
for best_ckpt in self.best_ckpt_path_dict.values():
self._publish_model(runner, best_ckpt)
@master_only
@ -379,16 +400,35 @@ class CheckpointHook(Hook):
f'{final_path}.',
logger='current')
def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
def _save_checkpoint_with_step(self, runner, step, meta):
# remove other checkpoints before save checkpoint to make the
# self.keep_ckpt_ids are saved as expected
if self.max_keep_ckpts > 0:
# _save_checkpoint and _save_best_checkpoint may call this
# _save_checkpoint_with_step in one epoch
if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step:
pass
else:
if len(self.keep_ckpt_ids) == self.max_keep_ckpts:
_step = self.keep_ckpt_ids.popleft()
if is_main_process():
ckpt_path = self.file_backend.join_path(
self.out_dir, self.filename_tmpl.format(_step))
Args:
runner (Runner): The runner of the training process.
"""
if self.by_epoch:
ckpt_filename = self.filename_tmpl.format(runner.epoch + 1)
else:
ckpt_filename = self.filename_tmpl.format(runner.iter + 1)
if self.file_backend.isfile(ckpt_path):
self.file_backend.remove(ckpt_path)
elif self.file_backend.isdir(ckpt_path):
# checkpoints saved by deepspeed are directories
self.file_backend.rmtree(ckpt_path)
self.keep_ckpt_ids.append(step)
runner.message_hub.update_info('keep_ckpt_ids',
list(self.keep_ckpt_ids))
ckpt_filename = self.filename_tmpl.format(step)
self.last_ckpt = self.file_backend.join_path(self.out_dir,
ckpt_filename)
runner.message_hub.update_info('last_ckpt', self.last_ckpt)
runner.save_checkpoint(
self.out_dir,
@ -396,6 +436,7 @@ class CheckpointHook(Hook):
self.file_client_args,
save_optimizer=self.save_optimizer,
save_param_scheduler=self.save_param_scheduler,
meta=meta,
by_epoch=self.by_epoch,
backend_args=self.backend_args,
**self.args)
@ -405,31 +446,24 @@ class CheckpointHook(Hook):
if not is_main_process():
return
runner.message_hub.update_info(
'last_ckpt',
self.file_backend.join_path(self.out_dir, ckpt_filename))
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
current_ckpt = runner.epoch + 1
else:
current_ckpt = runner.iter + 1
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
for _step in redundant_ckpts:
ckpt_path = self.file_backend.join_path(
self.out_dir, self.filename_tmpl.format(_step))
if self.file_backend.isfile(ckpt_path):
self.file_backend.remove(ckpt_path)
else:
break
save_file = osp.join(runner.work_dir, 'last_checkpoint')
filepath = self.file_backend.join_path(self.out_dir, ckpt_filename)
with open(save_file, 'w') as f:
f.write(filepath)
f.write(self.last_ckpt) # type: ignore
def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
Args:
runner (Runner): The runner of the training process.
"""
if self.by_epoch:
step = runner.epoch + 1
meta = dict(epoch=step, iter=runner.iter)
else:
step = runner.iter + 1
meta = dict(epoch=runner.epoch, iter=step)
self._save_checkpoint_with_step(runner, step, meta=meta)
def _save_best_checkpoint(self, runner, metrics) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
@ -448,10 +482,13 @@ class CheckpointHook(Hook):
ckpt_filename = self.filename_tmpl.format(runner.iter)
cur_type, cur_time = 'iter', runner.iter
meta = dict(epoch=runner.epoch, iter=runner.iter)
# handle auto in self.key_indicators and self.rules before the loop
if 'auto' in self.key_indicators:
self._init_rule(self.rules, [list(metrics.keys())[0]])
best_ckpt_updated = False
# save best logic
# get score from messagehub
for key_indicator, rule in zip(self.key_indicators, self.rules):
@ -475,13 +512,18 @@ class CheckpointHook(Hook):
key_score, best_score):
continue
best_ckpt_updated = True
best_score = key_score
runner.message_hub.update_info(best_score_key, best_score)
if best_ckpt_path and \
self.file_backend.isfile(best_ckpt_path) and \
is_main_process():
self.file_backend.remove(best_ckpt_path)
if best_ckpt_path and is_main_process():
if self.file_backend.isfile(best_ckpt_path):
self.file_backend.remove(best_ckpt_path)
else:
# checkpoints saved by deepspeed are directories
self.file_backend.rmtree(best_ckpt_path)
runner.logger.info(
f'The previous best checkpoint {best_ckpt_path} '
'is removed')
@ -507,12 +549,21 @@ class CheckpointHook(Hook):
file_client_args=self.file_client_args,
save_optimizer=False,
save_param_scheduler=False,
meta=meta,
by_epoch=False,
backend_args=self.backend_args)
runner.logger.info(
f'The best checkpoint with {best_score:0.4f} {key_indicator} '
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
# save checkpoint again to update the best_score and best_ckpt stored
# in message_hub because the checkpoint saved in `after_train_epoch`
# or `after_train_iter` stage only keep the previous best checkpoint
# not the current best checkpoint which causes the current best
# checkpoint can not be removed when resuming training.
if best_ckpt_updated and self.last_ckpt is not None:
self._save_checkpoint_with_step(runner, cur_time, meta)
def _init_rule(self, rules, key_indicators) -> None:
"""Initialize rule, key_indicator, comparison_func, and best score. If
key_indicator is a list of string and rule is a string, all metric in

View File

@ -2090,7 +2090,7 @@ class Runner:
file_client_args: Optional[dict] = None,
save_optimizer: bool = True,
save_param_scheduler: bool = True,
meta: dict = None,
meta: Optional[dict] = None,
by_epoch: bool = True,
backend_args: Optional[dict] = None,
):
@ -2112,8 +2112,8 @@ class Runner:
to the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
by_epoch (bool): Whether the scheduled momentum is updated by
epochs. Defaults to True.
by_epoch (bool): Decide the number of epoch or iteration saved in
checkpoint. Defaults to True.
backend_args (dict, optional): Arguments to instantiate the
prefix of uri corresponding backend. Defaults to None.
New in v0.2.0.
@ -2129,9 +2129,11 @@ class Runner:
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is
# called by `after_train_epoch`` method of `CheckpointHook` so
# `epoch` should be `self.epoch + 1`
meta.update(epoch=self.epoch + 1, iter=self.iter)
meta.setdefault('epoch', self.epoch + 1)
meta.setdefault('iter', self.iter)
else:
meta.update(epoch=self.epoch, iter=self.iter + 1)
meta.setdefault('epoch', self.epoch)
meta.setdefault('iter', self.iter + 1)
if file_client_args is not None:
warnings.warn(

View File

@ -433,12 +433,13 @@ class TestCheckpointHook(RunnerTestCase):
setattr(common_cfg.train_cfg, f'max_{training_type}s', 11)
checkpoint_cfg = dict(
type='CheckpointHook',
interval=2,
interval=1,
by_epoch=training_type == 'epoch')
common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
# Test interval in epoch based training
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 2
runner = self.build_runner(cfg)
runner.train()
@ -502,13 +503,11 @@ class TestCheckpointHook(RunnerTestCase):
self.clear_work_dir()
# Test max_keep_ckpts
# Test max_keep_ckpts=1
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 1
cfg.default_hooks.checkpoint.max_keep_ckpts = 1
runner = self.build_runner(cfg)
runner.train()
print(os.listdir(cfg.work_dir))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
@ -518,6 +517,45 @@ class TestCheckpointHook(RunnerTestCase):
self.clear_work_dir()
# Test max_keep_ckpts=3
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.max_keep_ckpts = 3
runner = self.build_runner(cfg)
runner.train()
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth')))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
for i in range(9):
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'],
[9, 10, 11])
# Test max_keep_ckpts when resuming traing
cfg = copy.deepcopy(common_cfg)
setattr(cfg.train_cfg, f'max_{training_type}s', 12)
cfg.default_hooks.checkpoint.max_keep_ckpts = 2
cfg.load_from = osp.join(cfg.work_dir, f'{training_type}_11.pth')
cfg.resume = True
runner = self.build_runner(cfg)
runner.train()
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth')))
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_12.pth')))
self.clear_work_dir()
# Test filename_tmpl
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
@ -529,18 +567,61 @@ class TestCheckpointHook(RunnerTestCase):
# Test save_best
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 1
cfg.default_hooks.checkpoint.save_best = 'test/acc'
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
cfg.train_cfg.val_interval = 1
runner = self.build_runner(cfg)
runner.train()
best_ckpt = osp.join(cfg.work_dir,
f'best_test_acc_{training_type}_5.pth')
self.assertTrue(osp.isfile(best_ckpt))
best_ckpt_path = osp.join(cfg.work_dir,
f'best_test_acc_{training_type}_5.pth')
best_ckpt = torch.load(best_ckpt_path)
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
self.assertEqual(best_ckpt_path,
ckpt['message_hub']['runtime_info']['best_ckpt'])
if training_type == 'epoch':
self.assertEqual(ckpt['meta']['epoch'], 5)
self.assertEqual(ckpt['meta']['iter'], 20)
self.assertEqual(best_ckpt['meta']['epoch'], 5)
self.assertEqual(best_ckpt['meta']['iter'], 20)
else:
self.assertEqual(ckpt['meta']['epoch'], 0)
self.assertEqual(ckpt['meta']['iter'], 5)
self.assertEqual(best_ckpt['meta']['epoch'], 0)
self.assertEqual(best_ckpt['meta']['iter'], 5)
self.clear_work_dir()
# Test save_best with interval=2
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.save_best = 'test/acc'
cfg.default_hooks.checkpoint.interval = 2
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
cfg.train_cfg.val_interval = 1
runner = self.build_runner(cfg)
runner.train()
best_ckpt_path = osp.join(cfg.work_dir,
f'best_test_acc_{training_type}_5.pth')
best_ckpt = torch.load(best_ckpt_path)
# if the current ckpt is the best, the interval will be ignored the
# the ckpt will also be saved
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
self.assertEqual(best_ckpt_path,
ckpt['message_hub']['runtime_info']['best_ckpt'])
if training_type == 'epoch':
self.assertEqual(ckpt['meta']['epoch'], 5)
self.assertEqual(ckpt['meta']['iter'], 20)
self.assertEqual(best_ckpt['meta']['epoch'], 5)
self.assertEqual(best_ckpt['meta']['iter'], 20)
else:
self.assertEqual(ckpt['meta']['epoch'], 0)
self.assertEqual(ckpt['meta']['iter'], 5)
self.assertEqual(best_ckpt['meta']['epoch'], 0)
self.assertEqual(best_ckpt['meta']['iter'], 5)
# test save published keys
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']