[Fix] Save checkpoint again to update best_ckpt of ckpt (#1168)

This commit is contained in:
Zaida Zhou 2023-06-02 14:42:56 +08:00 committed by GitHub
parent 9d9f2b761e
commit 19aa1eb780
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 187 additions and 53 deletions

View File

@ -3,6 +3,7 @@ import hashlib
import logging import logging
import os.path as osp import os.path as osp
import pickle import pickle
from collections import deque
from math import inf from math import inf
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union from typing import Callable, Dict, List, Optional, Sequence, Union
@ -242,6 +243,8 @@ class CheckpointHook(Hook):
'Find duplicate elements in "published_keys".') 'Find duplicate elements in "published_keys".')
self.published_keys = published_keys self.published_keys = published_keys
self.last_ckpt = None
def before_train(self, runner) -> None: def before_train(self, runner) -> None:
"""Finish all operations, related to checkpoint. """Finish all operations, related to checkpoint.
@ -294,6 +297,25 @@ class CheckpointHook(Hook):
key_indicator] = runner.message_hub.get_info( key_indicator] = runner.message_hub.get_info(
best_ckpt_name) best_ckpt_name)
if self.max_keep_ckpts > 0:
keep_ckpt_ids = []
if 'keep_ckpt_ids' in runner.message_hub.runtime_info:
keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids')
while len(keep_ckpt_ids) > self.max_keep_ckpts:
step = keep_ckpt_ids.pop(0)
if is_main_process():
path = self.file_backend.join_path(
self.out_dir, self.filename_tmpl.format(step))
if self.file_backend.isfile(path):
self.file_backend.remove(path)
elif self.file_backend.isdir(path):
# checkpoints saved by deepspeed are directories
self.file_backend.rmtree(path)
self.keep_ckpt_ids: deque = deque(keep_ckpt_ids,
self.max_keep_ckpts)
def after_train_epoch(self, runner) -> None: def after_train_epoch(self, runner) -> None:
"""Save the checkpoint and synchronize buffers after each epoch. """Save the checkpoint and synchronize buffers after each epoch.
@ -337,14 +359,13 @@ class CheckpointHook(Hook):
if self.published_keys is None: if self.published_keys is None:
return return
if self.save_last and 'last_ckpt' in runner.message_hub.runtime_info: if self.save_last and self.last_ckpt is not None:
last_ckpt = runner.message_hub.get_info('last_ckpt') self._publish_model(runner, self.last_ckpt)
self._publish_model(runner, last_ckpt)
if getattr(self, 'best_ckpt_path', None) is not None: if getattr(self, 'best_ckpt_path', None) is not None:
self._publish_model(runner, str(self.best_ckpt_path)) self._publish_model(runner, str(self.best_ckpt_path))
if getattr(self, 'best_ckpt_path_dict', None) is not None: if getattr(self, 'best_ckpt_path_dict', None) is not None:
for key, best_ckpt in self.best_ckpt_path_dict.items(): for best_ckpt in self.best_ckpt_path_dict.values():
self._publish_model(runner, best_ckpt) self._publish_model(runner, best_ckpt)
@master_only @master_only
@ -379,16 +400,35 @@ class CheckpointHook(Hook):
f'{final_path}.', f'{final_path}.',
logger='current') logger='current')
def _save_checkpoint(self, runner) -> None: def _save_checkpoint_with_step(self, runner, step, meta):
"""Save the current checkpoint and delete outdated checkpoint. # remove other checkpoints before save checkpoint to make the
# self.keep_ckpt_ids are saved as expected
if self.max_keep_ckpts > 0:
# _save_checkpoint and _save_best_checkpoint may call this
# _save_checkpoint_with_step in one epoch
if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step:
pass
else:
if len(self.keep_ckpt_ids) == self.max_keep_ckpts:
_step = self.keep_ckpt_ids.popleft()
if is_main_process():
ckpt_path = self.file_backend.join_path(
self.out_dir, self.filename_tmpl.format(_step))
Args: if self.file_backend.isfile(ckpt_path):
runner (Runner): The runner of the training process. self.file_backend.remove(ckpt_path)
""" elif self.file_backend.isdir(ckpt_path):
if self.by_epoch: # checkpoints saved by deepspeed are directories
ckpt_filename = self.filename_tmpl.format(runner.epoch + 1) self.file_backend.rmtree(ckpt_path)
else:
ckpt_filename = self.filename_tmpl.format(runner.iter + 1) self.keep_ckpt_ids.append(step)
runner.message_hub.update_info('keep_ckpt_ids',
list(self.keep_ckpt_ids))
ckpt_filename = self.filename_tmpl.format(step)
self.last_ckpt = self.file_backend.join_path(self.out_dir,
ckpt_filename)
runner.message_hub.update_info('last_ckpt', self.last_ckpt)
runner.save_checkpoint( runner.save_checkpoint(
self.out_dir, self.out_dir,
@ -396,6 +436,7 @@ class CheckpointHook(Hook):
self.file_client_args, self.file_client_args,
save_optimizer=self.save_optimizer, save_optimizer=self.save_optimizer,
save_param_scheduler=self.save_param_scheduler, save_param_scheduler=self.save_param_scheduler,
meta=meta,
by_epoch=self.by_epoch, by_epoch=self.by_epoch,
backend_args=self.backend_args, backend_args=self.backend_args,
**self.args) **self.args)
@ -405,31 +446,24 @@ class CheckpointHook(Hook):
if not is_main_process(): if not is_main_process():
return return
runner.message_hub.update_info(
'last_ckpt',
self.file_backend.join_path(self.out_dir, ckpt_filename))
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
current_ckpt = runner.epoch + 1
else:
current_ckpt = runner.iter + 1
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
for _step in redundant_ckpts:
ckpt_path = self.file_backend.join_path(
self.out_dir, self.filename_tmpl.format(_step))
if self.file_backend.isfile(ckpt_path):
self.file_backend.remove(ckpt_path)
else:
break
save_file = osp.join(runner.work_dir, 'last_checkpoint') save_file = osp.join(runner.work_dir, 'last_checkpoint')
filepath = self.file_backend.join_path(self.out_dir, ckpt_filename)
with open(save_file, 'w') as f: with open(save_file, 'w') as f:
f.write(filepath) f.write(self.last_ckpt) # type: ignore
def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
Args:
runner (Runner): The runner of the training process.
"""
if self.by_epoch:
step = runner.epoch + 1
meta = dict(epoch=step, iter=runner.iter)
else:
step = runner.iter + 1
meta = dict(epoch=runner.epoch, iter=step)
self._save_checkpoint_with_step(runner, step, meta=meta)
def _save_best_checkpoint(self, runner, metrics) -> None: def _save_best_checkpoint(self, runner, metrics) -> None:
"""Save the current checkpoint and delete outdated checkpoint. """Save the current checkpoint and delete outdated checkpoint.
@ -448,10 +482,13 @@ class CheckpointHook(Hook):
ckpt_filename = self.filename_tmpl.format(runner.iter) ckpt_filename = self.filename_tmpl.format(runner.iter)
cur_type, cur_time = 'iter', runner.iter cur_type, cur_time = 'iter', runner.iter
meta = dict(epoch=runner.epoch, iter=runner.iter)
# handle auto in self.key_indicators and self.rules before the loop # handle auto in self.key_indicators and self.rules before the loop
if 'auto' in self.key_indicators: if 'auto' in self.key_indicators:
self._init_rule(self.rules, [list(metrics.keys())[0]]) self._init_rule(self.rules, [list(metrics.keys())[0]])
best_ckpt_updated = False
# save best logic # save best logic
# get score from messagehub # get score from messagehub
for key_indicator, rule in zip(self.key_indicators, self.rules): for key_indicator, rule in zip(self.key_indicators, self.rules):
@ -475,13 +512,18 @@ class CheckpointHook(Hook):
key_score, best_score): key_score, best_score):
continue continue
best_ckpt_updated = True
best_score = key_score best_score = key_score
runner.message_hub.update_info(best_score_key, best_score) runner.message_hub.update_info(best_score_key, best_score)
if best_ckpt_path and \ if best_ckpt_path and is_main_process():
self.file_backend.isfile(best_ckpt_path) and \ if self.file_backend.isfile(best_ckpt_path):
is_main_process(): self.file_backend.remove(best_ckpt_path)
self.file_backend.remove(best_ckpt_path) else:
# checkpoints saved by deepspeed are directories
self.file_backend.rmtree(best_ckpt_path)
runner.logger.info( runner.logger.info(
f'The previous best checkpoint {best_ckpt_path} ' f'The previous best checkpoint {best_ckpt_path} '
'is removed') 'is removed')
@ -507,12 +549,21 @@ class CheckpointHook(Hook):
file_client_args=self.file_client_args, file_client_args=self.file_client_args,
save_optimizer=False, save_optimizer=False,
save_param_scheduler=False, save_param_scheduler=False,
meta=meta,
by_epoch=False, by_epoch=False,
backend_args=self.backend_args) backend_args=self.backend_args)
runner.logger.info( runner.logger.info(
f'The best checkpoint with {best_score:0.4f} {key_indicator} ' f'The best checkpoint with {best_score:0.4f} {key_indicator} '
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
# save checkpoint again to update the best_score and best_ckpt stored
# in message_hub because the checkpoint saved in `after_train_epoch`
# or `after_train_iter` stage only keep the previous best checkpoint
# not the current best checkpoint which causes the current best
# checkpoint can not be removed when resuming training.
if best_ckpt_updated and self.last_ckpt is not None:
self._save_checkpoint_with_step(runner, cur_time, meta)
def _init_rule(self, rules, key_indicators) -> None: def _init_rule(self, rules, key_indicators) -> None:
"""Initialize rule, key_indicator, comparison_func, and best score. If """Initialize rule, key_indicator, comparison_func, and best score. If
key_indicator is a list of string and rule is a string, all metric in key_indicator is a list of string and rule is a string, all metric in

View File

@ -2090,7 +2090,7 @@ class Runner:
file_client_args: Optional[dict] = None, file_client_args: Optional[dict] = None,
save_optimizer: bool = True, save_optimizer: bool = True,
save_param_scheduler: bool = True, save_param_scheduler: bool = True,
meta: dict = None, meta: Optional[dict] = None,
by_epoch: bool = True, by_epoch: bool = True,
backend_args: Optional[dict] = None, backend_args: Optional[dict] = None,
): ):
@ -2112,8 +2112,8 @@ class Runner:
to the checkpoint. Defaults to True. to the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None. checkpoint. Defaults to None.
by_epoch (bool): Whether the scheduled momentum is updated by by_epoch (bool): Decide the number of epoch or iteration saved in
epochs. Defaults to True. checkpoint. Defaults to True.
backend_args (dict, optional): Arguments to instantiate the backend_args (dict, optional): Arguments to instantiate the
prefix of uri corresponding backend. Defaults to None. prefix of uri corresponding backend. Defaults to None.
New in v0.2.0. New in v0.2.0.
@ -2129,9 +2129,11 @@ class Runner:
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is # `self.call_hook('after_train_epoch)` but `save_checkpoint` is
# called by `after_train_epoch`` method of `CheckpointHook` so # called by `after_train_epoch`` method of `CheckpointHook` so
# `epoch` should be `self.epoch + 1` # `epoch` should be `self.epoch + 1`
meta.update(epoch=self.epoch + 1, iter=self.iter) meta.setdefault('epoch', self.epoch + 1)
meta.setdefault('iter', self.iter)
else: else:
meta.update(epoch=self.epoch, iter=self.iter + 1) meta.setdefault('epoch', self.epoch)
meta.setdefault('iter', self.iter + 1)
if file_client_args is not None: if file_client_args is not None:
warnings.warn( warnings.warn(

View File

@ -433,12 +433,13 @@ class TestCheckpointHook(RunnerTestCase):
setattr(common_cfg.train_cfg, f'max_{training_type}s', 11) setattr(common_cfg.train_cfg, f'max_{training_type}s', 11)
checkpoint_cfg = dict( checkpoint_cfg = dict(
type='CheckpointHook', type='CheckpointHook',
interval=2, interval=1,
by_epoch=training_type == 'epoch') by_epoch=training_type == 'epoch')
common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg) common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
# Test interval in epoch based training # Test interval in epoch based training
cfg = copy.deepcopy(common_cfg) cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 2
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
runner.train() runner.train()
@ -502,13 +503,11 @@ class TestCheckpointHook(RunnerTestCase):
self.clear_work_dir() self.clear_work_dir()
# Test max_keep_ckpts # Test max_keep_ckpts=1
cfg = copy.deepcopy(common_cfg) cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 1
cfg.default_hooks.checkpoint.max_keep_ckpts = 1 cfg.default_hooks.checkpoint.max_keep_ckpts = 1
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
runner.train() runner.train()
print(os.listdir(cfg.work_dir))
self.assertTrue( self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
@ -518,6 +517,45 @@ class TestCheckpointHook(RunnerTestCase):
self.clear_work_dir() self.clear_work_dir()
# Test max_keep_ckpts=3
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.max_keep_ckpts = 3
runner = self.build_runner(cfg)
runner.train()
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth')))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
for i in range(9):
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'],
[9, 10, 11])
# Test max_keep_ckpts when resuming traing
cfg = copy.deepcopy(common_cfg)
setattr(cfg.train_cfg, f'max_{training_type}s', 12)
cfg.default_hooks.checkpoint.max_keep_ckpts = 2
cfg.load_from = osp.join(cfg.work_dir, f'{training_type}_11.pth')
cfg.resume = True
runner = self.build_runner(cfg)
runner.train()
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth')))
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth')))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_12.pth')))
self.clear_work_dir()
# Test filename_tmpl # Test filename_tmpl
cfg = copy.deepcopy(common_cfg) cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth' cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
@ -529,18 +567,61 @@ class TestCheckpointHook(RunnerTestCase):
# Test save_best # Test save_best
cfg = copy.deepcopy(common_cfg) cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 1
cfg.default_hooks.checkpoint.save_best = 'test/acc' cfg.default_hooks.checkpoint.save_best = 'test/acc'
cfg.val_evaluator = dict(type='TriangleMetric', length=11) cfg.val_evaluator = dict(type='TriangleMetric', length=11)
cfg.train_cfg.val_interval = 1 cfg.train_cfg.val_interval = 1
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
runner.train() runner.train()
best_ckpt = osp.join(cfg.work_dir, best_ckpt_path = osp.join(cfg.work_dir,
f'best_test_acc_{training_type}_5.pth') f'best_test_acc_{training_type}_5.pth')
self.assertTrue(osp.isfile(best_ckpt)) best_ckpt = torch.load(best_ckpt_path)
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
self.assertEqual(best_ckpt_path,
ckpt['message_hub']['runtime_info']['best_ckpt'])
if training_type == 'epoch':
self.assertEqual(ckpt['meta']['epoch'], 5)
self.assertEqual(ckpt['meta']['iter'], 20)
self.assertEqual(best_ckpt['meta']['epoch'], 5)
self.assertEqual(best_ckpt['meta']['iter'], 20)
else:
self.assertEqual(ckpt['meta']['epoch'], 0)
self.assertEqual(ckpt['meta']['iter'], 5)
self.assertEqual(best_ckpt['meta']['epoch'], 0)
self.assertEqual(best_ckpt['meta']['iter'], 5)
self.clear_work_dir() self.clear_work_dir()
# Test save_best with interval=2
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.save_best = 'test/acc'
cfg.default_hooks.checkpoint.interval = 2
cfg.val_evaluator = dict(type='TriangleMetric', length=11)
cfg.train_cfg.val_interval = 1
runner = self.build_runner(cfg)
runner.train()
best_ckpt_path = osp.join(cfg.work_dir,
f'best_test_acc_{training_type}_5.pth')
best_ckpt = torch.load(best_ckpt_path)
# if the current ckpt is the best, the interval will be ignored the
# the ckpt will also be saved
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
self.assertEqual(best_ckpt_path,
ckpt['message_hub']['runtime_info']['best_ckpt'])
if training_type == 'epoch':
self.assertEqual(ckpt['meta']['epoch'], 5)
self.assertEqual(ckpt['meta']['iter'], 20)
self.assertEqual(best_ckpt['meta']['epoch'], 5)
self.assertEqual(best_ckpt['meta']['iter'], 20)
else:
self.assertEqual(ckpt['meta']['epoch'], 0)
self.assertEqual(ckpt['meta']['iter'], 5)
self.assertEqual(best_ckpt['meta']['epoch'], 0)
self.assertEqual(best_ckpt['meta']['iter'], 5)
# test save published keys # test save published keys
cfg = copy.deepcopy(common_cfg) cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict'] cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']