From 19aa1eb7800888425ce2a24d5f93b58304821821 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Fri, 2 Jun 2023 14:42:56 +0800 Subject: [PATCH] [Fix] Save checkpoint again to update best_ckpt of ckpt (#1168) --- mmengine/hooks/checkpoint_hook.py | 131 ++++++++++++++++------- mmengine/runner/runner.py | 12 ++- tests/test_hooks/test_checkpoint_hook.py | 97 +++++++++++++++-- 3 files changed, 187 insertions(+), 53 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 23f4a8ab..6b74c52a 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -3,6 +3,7 @@ import hashlib import logging import os.path as osp import pickle +from collections import deque from math import inf from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Union @@ -242,6 +243,8 @@ class CheckpointHook(Hook): 'Find duplicate elements in "published_keys".') self.published_keys = published_keys + self.last_ckpt = None + def before_train(self, runner) -> None: """Finish all operations, related to checkpoint. @@ -294,6 +297,25 @@ class CheckpointHook(Hook): key_indicator] = runner.message_hub.get_info( best_ckpt_name) + if self.max_keep_ckpts > 0: + keep_ckpt_ids = [] + if 'keep_ckpt_ids' in runner.message_hub.runtime_info: + keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids') + + while len(keep_ckpt_ids) > self.max_keep_ckpts: + step = keep_ckpt_ids.pop(0) + if is_main_process(): + path = self.file_backend.join_path( + self.out_dir, self.filename_tmpl.format(step)) + if self.file_backend.isfile(path): + self.file_backend.remove(path) + elif self.file_backend.isdir(path): + # checkpoints saved by deepspeed are directories + self.file_backend.rmtree(path) + + self.keep_ckpt_ids: deque = deque(keep_ckpt_ids, + self.max_keep_ckpts) + def after_train_epoch(self, runner) -> None: """Save the checkpoint and synchronize buffers after each epoch. @@ -337,14 +359,13 @@ class CheckpointHook(Hook): if self.published_keys is None: return - if self.save_last and 'last_ckpt' in runner.message_hub.runtime_info: - last_ckpt = runner.message_hub.get_info('last_ckpt') - self._publish_model(runner, last_ckpt) + if self.save_last and self.last_ckpt is not None: + self._publish_model(runner, self.last_ckpt) if getattr(self, 'best_ckpt_path', None) is not None: self._publish_model(runner, str(self.best_ckpt_path)) if getattr(self, 'best_ckpt_path_dict', None) is not None: - for key, best_ckpt in self.best_ckpt_path_dict.items(): + for best_ckpt in self.best_ckpt_path_dict.values(): self._publish_model(runner, best_ckpt) @master_only @@ -379,16 +400,35 @@ class CheckpointHook(Hook): f'{final_path}.', logger='current') - def _save_checkpoint(self, runner) -> None: - """Save the current checkpoint and delete outdated checkpoint. + def _save_checkpoint_with_step(self, runner, step, meta): + # remove other checkpoints before save checkpoint to make the + # self.keep_ckpt_ids are saved as expected + if self.max_keep_ckpts > 0: + # _save_checkpoint and _save_best_checkpoint may call this + # _save_checkpoint_with_step in one epoch + if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step: + pass + else: + if len(self.keep_ckpt_ids) == self.max_keep_ckpts: + _step = self.keep_ckpt_ids.popleft() + if is_main_process(): + ckpt_path = self.file_backend.join_path( + self.out_dir, self.filename_tmpl.format(_step)) - Args: - runner (Runner): The runner of the training process. - """ - if self.by_epoch: - ckpt_filename = self.filename_tmpl.format(runner.epoch + 1) - else: - ckpt_filename = self.filename_tmpl.format(runner.iter + 1) + if self.file_backend.isfile(ckpt_path): + self.file_backend.remove(ckpt_path) + elif self.file_backend.isdir(ckpt_path): + # checkpoints saved by deepspeed are directories + self.file_backend.rmtree(ckpt_path) + + self.keep_ckpt_ids.append(step) + runner.message_hub.update_info('keep_ckpt_ids', + list(self.keep_ckpt_ids)) + + ckpt_filename = self.filename_tmpl.format(step) + self.last_ckpt = self.file_backend.join_path(self.out_dir, + ckpt_filename) + runner.message_hub.update_info('last_ckpt', self.last_ckpt) runner.save_checkpoint( self.out_dir, @@ -396,6 +436,7 @@ class CheckpointHook(Hook): self.file_client_args, save_optimizer=self.save_optimizer, save_param_scheduler=self.save_param_scheduler, + meta=meta, by_epoch=self.by_epoch, backend_args=self.backend_args, **self.args) @@ -405,31 +446,24 @@ class CheckpointHook(Hook): if not is_main_process(): return - runner.message_hub.update_info( - 'last_ckpt', - self.file_backend.join_path(self.out_dir, ckpt_filename)) - - # remove other checkpoints - if self.max_keep_ckpts > 0: - if self.by_epoch: - current_ckpt = runner.epoch + 1 - else: - current_ckpt = runner.iter + 1 - redundant_ckpts = range( - current_ckpt - self.max_keep_ckpts * self.interval, 0, - -self.interval) - for _step in redundant_ckpts: - ckpt_path = self.file_backend.join_path( - self.out_dir, self.filename_tmpl.format(_step)) - if self.file_backend.isfile(ckpt_path): - self.file_backend.remove(ckpt_path) - else: - break - save_file = osp.join(runner.work_dir, 'last_checkpoint') - filepath = self.file_backend.join_path(self.out_dir, ckpt_filename) with open(save_file, 'w') as f: - f.write(filepath) + f.write(self.last_ckpt) # type: ignore + + def _save_checkpoint(self, runner) -> None: + """Save the current checkpoint and delete outdated checkpoint. + + Args: + runner (Runner): The runner of the training process. + """ + if self.by_epoch: + step = runner.epoch + 1 + meta = dict(epoch=step, iter=runner.iter) + else: + step = runner.iter + 1 + meta = dict(epoch=runner.epoch, iter=step) + + self._save_checkpoint_with_step(runner, step, meta=meta) def _save_best_checkpoint(self, runner, metrics) -> None: """Save the current checkpoint and delete outdated checkpoint. @@ -448,10 +482,13 @@ class CheckpointHook(Hook): ckpt_filename = self.filename_tmpl.format(runner.iter) cur_type, cur_time = 'iter', runner.iter + meta = dict(epoch=runner.epoch, iter=runner.iter) + # handle auto in self.key_indicators and self.rules before the loop if 'auto' in self.key_indicators: self._init_rule(self.rules, [list(metrics.keys())[0]]) + best_ckpt_updated = False # save best logic # get score from messagehub for key_indicator, rule in zip(self.key_indicators, self.rules): @@ -475,13 +512,18 @@ class CheckpointHook(Hook): key_score, best_score): continue + best_ckpt_updated = True + best_score = key_score runner.message_hub.update_info(best_score_key, best_score) - if best_ckpt_path and \ - self.file_backend.isfile(best_ckpt_path) and \ - is_main_process(): - self.file_backend.remove(best_ckpt_path) + if best_ckpt_path and is_main_process(): + if self.file_backend.isfile(best_ckpt_path): + self.file_backend.remove(best_ckpt_path) + else: + # checkpoints saved by deepspeed are directories + self.file_backend.rmtree(best_ckpt_path) + runner.logger.info( f'The previous best checkpoint {best_ckpt_path} ' 'is removed') @@ -507,12 +549,21 @@ class CheckpointHook(Hook): file_client_args=self.file_client_args, save_optimizer=False, save_param_scheduler=False, + meta=meta, by_epoch=False, backend_args=self.backend_args) runner.logger.info( f'The best checkpoint with {best_score:0.4f} {key_indicator} ' f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') + # save checkpoint again to update the best_score and best_ckpt stored + # in message_hub because the checkpoint saved in `after_train_epoch` + # or `after_train_iter` stage only keep the previous best checkpoint + # not the current best checkpoint which causes the current best + # checkpoint can not be removed when resuming training. + if best_ckpt_updated and self.last_ckpt is not None: + self._save_checkpoint_with_step(runner, cur_time, meta) + def _init_rule(self, rules, key_indicators) -> None: """Initialize rule, key_indicator, comparison_func, and best score. If key_indicator is a list of string and rule is a string, all metric in diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 31fac2c0..1e13d345 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -2090,7 +2090,7 @@ class Runner: file_client_args: Optional[dict] = None, save_optimizer: bool = True, save_param_scheduler: bool = True, - meta: dict = None, + meta: Optional[dict] = None, by_epoch: bool = True, backend_args: Optional[dict] = None, ): @@ -2112,8 +2112,8 @@ class Runner: to the checkpoint. Defaults to True. meta (dict, optional): The meta information to be saved in the checkpoint. Defaults to None. - by_epoch (bool): Whether the scheduled momentum is updated by - epochs. Defaults to True. + by_epoch (bool): Decide the number of epoch or iteration saved in + checkpoint. Defaults to True. backend_args (dict, optional): Arguments to instantiate the prefix of uri corresponding backend. Defaults to None. New in v0.2.0. @@ -2129,9 +2129,11 @@ class Runner: # `self.call_hook('after_train_epoch)` but `save_checkpoint` is # called by `after_train_epoch`` method of `CheckpointHook` so # `epoch` should be `self.epoch + 1` - meta.update(epoch=self.epoch + 1, iter=self.iter) + meta.setdefault('epoch', self.epoch + 1) + meta.setdefault('iter', self.iter) else: - meta.update(epoch=self.epoch, iter=self.iter + 1) + meta.setdefault('epoch', self.epoch) + meta.setdefault('iter', self.iter + 1) if file_client_args is not None: warnings.warn( diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index e6469bb3..e4e98272 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -433,12 +433,13 @@ class TestCheckpointHook(RunnerTestCase): setattr(common_cfg.train_cfg, f'max_{training_type}s', 11) checkpoint_cfg = dict( type='CheckpointHook', - interval=2, + interval=1, by_epoch=training_type == 'epoch') common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg) # Test interval in epoch based training cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.checkpoint.interval = 2 runner = self.build_runner(cfg) runner.train() @@ -502,13 +503,11 @@ class TestCheckpointHook(RunnerTestCase): self.clear_work_dir() - # Test max_keep_ckpts + # Test max_keep_ckpts=1 cfg = copy.deepcopy(common_cfg) - cfg.default_hooks.checkpoint.interval = 1 cfg.default_hooks.checkpoint.max_keep_ckpts = 1 runner = self.build_runner(cfg) runner.train() - print(os.listdir(cfg.work_dir)) self.assertTrue( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) @@ -518,6 +517,45 @@ class TestCheckpointHook(RunnerTestCase): self.clear_work_dir() + # Test max_keep_ckpts=3 + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.checkpoint.max_keep_ckpts = 3 + runner = self.build_runner(cfg) + runner.train() + self.assertTrue( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth'))) + self.assertTrue( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth'))) + self.assertTrue( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) + + for i in range(9): + self.assertFalse( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) + + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'], + [9, 10, 11]) + + # Test max_keep_ckpts when resuming traing + cfg = copy.deepcopy(common_cfg) + setattr(cfg.train_cfg, f'max_{training_type}s', 12) + cfg.default_hooks.checkpoint.max_keep_ckpts = 2 + cfg.load_from = osp.join(cfg.work_dir, f'{training_type}_11.pth') + cfg.resume = True + runner = self.build_runner(cfg) + runner.train() + self.assertFalse( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_9.pth'))) + self.assertFalse( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth'))) + self.assertTrue( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) + self.assertTrue( + osp.isfile(osp.join(cfg.work_dir, f'{training_type}_12.pth'))) + + self.clear_work_dir() + # Test filename_tmpl cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth' @@ -529,18 +567,61 @@ class TestCheckpointHook(RunnerTestCase): # Test save_best cfg = copy.deepcopy(common_cfg) - cfg.default_hooks.checkpoint.interval = 1 cfg.default_hooks.checkpoint.save_best = 'test/acc' cfg.val_evaluator = dict(type='TriangleMetric', length=11) cfg.train_cfg.val_interval = 1 runner = self.build_runner(cfg) runner.train() - best_ckpt = osp.join(cfg.work_dir, - f'best_test_acc_{training_type}_5.pth') - self.assertTrue(osp.isfile(best_ckpt)) + best_ckpt_path = osp.join(cfg.work_dir, + f'best_test_acc_{training_type}_5.pth') + best_ckpt = torch.load(best_ckpt_path) + + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + self.assertEqual(best_ckpt_path, + ckpt['message_hub']['runtime_info']['best_ckpt']) + + if training_type == 'epoch': + self.assertEqual(ckpt['meta']['epoch'], 5) + self.assertEqual(ckpt['meta']['iter'], 20) + self.assertEqual(best_ckpt['meta']['epoch'], 5) + self.assertEqual(best_ckpt['meta']['iter'], 20) + else: + self.assertEqual(ckpt['meta']['epoch'], 0) + self.assertEqual(ckpt['meta']['iter'], 5) + self.assertEqual(best_ckpt['meta']['epoch'], 0) + self.assertEqual(best_ckpt['meta']['iter'], 5) self.clear_work_dir() + # Test save_best with interval=2 + cfg = copy.deepcopy(common_cfg) + cfg.default_hooks.checkpoint.save_best = 'test/acc' + cfg.default_hooks.checkpoint.interval = 2 + cfg.val_evaluator = dict(type='TriangleMetric', length=11) + cfg.train_cfg.val_interval = 1 + runner = self.build_runner(cfg) + runner.train() + best_ckpt_path = osp.join(cfg.work_dir, + f'best_test_acc_{training_type}_5.pth') + best_ckpt = torch.load(best_ckpt_path) + + # if the current ckpt is the best, the interval will be ignored the + # the ckpt will also be saved + ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + self.assertEqual(best_ckpt_path, + ckpt['message_hub']['runtime_info']['best_ckpt']) + + if training_type == 'epoch': + self.assertEqual(ckpt['meta']['epoch'], 5) + self.assertEqual(ckpt['meta']['iter'], 20) + self.assertEqual(best_ckpt['meta']['epoch'], 5) + self.assertEqual(best_ckpt['meta']['iter'], 20) + else: + self.assertEqual(ckpt['meta']['epoch'], 0) + self.assertEqual(ckpt['meta']['iter'], 5) + self.assertEqual(best_ckpt['meta']['epoch'], 0) + self.assertEqual(best_ckpt['meta']['iter'], 5) + # test save published keys cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']