diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 53f982c1..ea762361 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -4,12 +4,13 @@ import warnings from collections import OrderedDict from math import inf from pathlib import Path -from typing import Optional, Sequence, Union +from typing import Callable, Dict, List, Optional, Sequence, Union from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.registry import HOOKS from mmengine.utils import is_seq_of +from mmengine.utils.misc import is_list_of from .hook import Hook DATA_BATCH = Optional[Sequence[dict]] @@ -44,20 +45,27 @@ class CheckpointHook(Hook): Defaults to -1, which means unlimited. save_last (bool): Whether to force the last checkpoint to be saved regardless of interval. Defaults to True. - save_best (str, optional): If a metric is specified, it would measure - the best checkpoint during evaluation. The information about best - checkpoint would be saved in ``runner.message_hub`` to keep + save_best (str, List[str], optional): If a metric is specified, it + would measure the best checkpoint during evaluation. If a list of + metrics is passed, it would measure a group of best checkpoints + corresponding to the passed metrics. The information about best + checkpoint(s) would be saved in ``runner.message_hub`` to keep best score value and best checkpoint path, which will be also loaded when resuming checkpoint. Options are the evaluation metrics on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance segmentation. ``AR@100`` for proposal recall. If ``save_best`` is ``auto``, the first key of the returned ``OrderedDict`` result will be used. Defaults to None. - rule (str, optional): Comparison rule for best score. If set to - None, it will infer a reasonable rule. Keys such as 'acc', 'top' - .etc will be inferred by 'greater' rule. Keys contain 'loss' will - be inferred by 'less' rule. Options are 'greater', 'less', None. - Defaults to None. + rule (str, List[str], optional): Comparison rule for best score. If + set to None, it will infer a reasonable rule. Keys such as 'acc', + 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' + will be inferred by 'less' rule. If ``save_best`` is a list of + metrics and ``rule`` is a str, all metrics in ``save_best`` will + share the comparison rule. If ``save_best`` and ``rule`` are both + lists, their length must be the same, and metrics in ``save_best`` + will use the corresponding comparison rule in ``rule``. Options + are 'greater', 'less', None and list which contains 'greater' and + 'less'. Defaults to None. greater_keys (List[str], optional): Metric keys that will be inferred by 'greater' comparison rule. If ``None``, _default_greater_keys will be used. Defaults to None. @@ -67,6 +75,17 @@ class CheckpointHook(Hook): file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. Defaults to None. + + Examples: + >>> # Save best based on single metric + >>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', + >>> rule='less') + >>> # Save best based on multi metrics with the same comparison rule + >>> CheckpointHook(interval=2, by_epoch=True, + >>> save_best=['acc', 'mIoU'], rule='greater') + >>> # Save best based on multi metrics with different comparison rule + >>> CheckpointHook(interval=2, by_epoch=True, + >>> save_best=['FID', 'IS'], rule=['less', 'greater']) """ out_dir: str @@ -93,8 +112,8 @@ class CheckpointHook(Hook): out_dir: Optional[Union[str, Path]] = None, max_keep_ckpts: int = -1, save_last: bool = True, - save_best: Optional[str] = None, - rule: Optional[str] = None, + save_best: Union[str, List[str], None] = None, + rule: Union[str, List[str], None] = None, greater_keys: Optional[Sequence[str]] = None, less_keys: Optional[Sequence[str]] = None, file_client_args: Optional[dict] = None, @@ -110,11 +129,39 @@ class CheckpointHook(Hook): self.file_client_args = file_client_args # save best logic - assert isinstance(save_best, str) or save_best is None, \ - '"save_best" should be a str or None ' \ - f'rather than {type(save_best)}' + assert (isinstance(save_best, str) or is_list_of(save_best, str) + or (save_best is None)), ( + '"save_best" should be a str or list of str or None, ' + f'but got {type(save_best)}') + + if isinstance(save_best, list): + if 'auto' in save_best: + assert len(save_best) == 1, ( + 'Only support one "auto" in "save_best" list.') + assert len(save_best) == len( + set(save_best)), ('Find duplicate element in "save_best".') + else: + # convert str to list[str] + if save_best is not None: + save_best = [save_best] # type: ignore # noqa: F401 self.save_best = save_best + # rule logic + assert (isinstance(rule, str) or is_list_of(rule, str) + or (rule is None)), ( + '"rule" should be a str or list of str or None, ' + f'but got {type(rule)}') + if isinstance(rule, list): + # check the length of rule list + assert len(rule) in [ + 1, + len(self.save_best) # type: ignore + ], ('Number of "rule" must be 1 or the same as number of ' + f'"save_best", but got {len(rule)}.') + else: + # convert str/None to list + rule = [rule] # type: ignore # noqa: F401 + if greater_keys is None: self.greater_keys = self._default_greater_keys else: @@ -132,8 +179,12 @@ class CheckpointHook(Hook): self.less_keys = less_keys # type: ignore if self.save_best is not None: - self.best_ckpt_path = None + self.is_better_than: Dict[str, Callable] = dict() self._init_rule(rule, self.save_best) + if len(self.key_indicators) == 1: + self.best_ckpt_path: Optional[str] = None + else: + self.best_ckpt_path_dict: Dict = dict() def before_train(self, runner) -> None: """Finish all operations, related to checkpoint. @@ -162,10 +213,21 @@ class CheckpointHook(Hook): f'{self.file_client.name}.') if self.save_best is not None: - if 'best_ckpt' not in runner.message_hub.runtime_info: - self.best_ckpt_path = None + if len(self.key_indicators) == 1: + if 'best_ckpt' not in runner.message_hub.runtime_info: + self.best_ckpt_path = None + else: + self.best_ckpt_path = runner.message_hub.get_info( + 'best_ckpt') else: - self.best_ckpt_path = runner.message_hub.get_info('best_ckpt') + for key_indicator in self.key_indicators: + best_ckpt_name = f'best_ckpt_{key_indicator}' + if best_ckpt_name not in runner.message_hub.runtime_info: + self.best_ckpt_path_dict[key_indicator] = None + else: + self.best_ckpt_path_dict[ + key_indicator] = runner.message_hub.get_info( + best_ckpt_name) def after_train_epoch(self, runner) -> None: """Save the checkpoint and synchronize buffers after each epoch. @@ -195,7 +257,7 @@ class CheckpointHook(Hook): """ self._save_best_checkpoint(runner, metrics) - def _get_metric_score(self, metrics): + def _get_metric_score(self, metrics, key_indicator): eval_res = OrderedDict() if metrics is not None: eval_res.update(metrics) @@ -206,10 +268,7 @@ class CheckpointHook(Hook): 'the best checkpoint will be skipped in this evaluation.') return None - if self.key_indicator == 'auto': - self._init_rule(self.rule, list(eval_res.keys())[0]) - - return eval_res[self.key_indicator] + return eval_res[key_indicator] @master_only def _save_checkpoint(self, runner) -> None: @@ -264,6 +323,7 @@ class CheckpointHook(Hook): Args: runner (Runner): The runner of the training process. + metrics (dict): Evaluation results of all metrics. """ if not self.save_best: return @@ -277,91 +337,123 @@ class CheckpointHook(Hook): 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) cur_type, cur_time = 'iter', runner.iter + 1 + # handle auto in self.key_indicators and self.rules before the loop + if 'auto' in self.key_indicators: + self._init_rule(self.rules, [list(metrics.keys())[0]]) + # save best logic # get score from messagehub - # notice `_get_metirc_score` helps to infer - # self.rule when self.save_best is `auto` - key_score = self._get_metric_score(metrics) - if 'best_score' not in runner.message_hub.runtime_info: - best_score = self.init_value_map[self.rule] - else: - best_score = runner.message_hub.get_info('best_score') + for key_indicator, rule in zip(self.key_indicators, self.rules): + key_score = self._get_metric_score(metrics, key_indicator) - if not key_score or not self.is_better_than(key_score, best_score): - return + if len(self.key_indicators) == 1: + best_score_key = 'best_score' + runtime_best_ckpt_key = 'best_ckpt' + best_ckpt_path = self.best_ckpt_path + else: + best_score_key = f'best_score_{key_indicator}' + runtime_best_ckpt_key = f'best_ckpt_{key_indicator}' + best_ckpt_path = self.best_ckpt_path_dict[key_indicator] - best_score = key_score - runner.message_hub.update_info('best_score', best_score) + if best_score_key not in runner.message_hub.runtime_info: + best_score = self.init_value_map[rule] + else: + best_score = runner.message_hub.get_info(best_score_key) - if self.best_ckpt_path and self.file_client.isfile( - self.best_ckpt_path): - self.file_client.remove(self.best_ckpt_path) + if key_score is None or not self.is_better_than[key_indicator]( + key_score, best_score): + continue + + best_score = key_score + runner.message_hub.update_info(best_score_key, best_score) + + if best_ckpt_path and self.file_client.isfile(best_ckpt_path): + self.file_client.remove(best_ckpt_path) + runner.logger.info( + f'The previous best checkpoint {best_ckpt_path} ' + 'is removed') + + best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}' + if len(self.key_indicators) == 1: + self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501 + self.out_dir, best_ckpt_name) + runner.message_hub.update_info(runtime_best_ckpt_key, + self.best_ckpt_path) + else: + self.best_ckpt_path_dict[ + key_indicator] = self.file_client.join_path( # type: ignore # noqa: E501 + self.out_dir, best_ckpt_name) + runner.message_hub.update_info( + runtime_best_ckpt_key, + self.best_ckpt_path_dict[key_indicator]) + runner.save_checkpoint( + self.out_dir, + filename=best_ckpt_name, + file_client_args=self.file_client_args, + save_optimizer=False, + save_param_scheduler=False, + by_epoch=False) runner.logger.info( - f'The previous best checkpoint {self.best_ckpt_path} ' - 'is removed') + f'The best checkpoint with {best_score:0.4f} {key_indicator} ' + f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') - best_ckpt_name = f'best_{self.key_indicator}_{ckpt_filename}' - self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501 - self.out_dir, best_ckpt_name) - runner.message_hub.update_info('best_ckpt', self.best_ckpt_path) - runner.save_checkpoint( - self.out_dir, - filename=best_ckpt_name, - file_client_args=self.file_client_args, - save_optimizer=False, - save_param_scheduler=False, - by_epoch=False) - runner.logger.info( - f'The best checkpoint with {best_score:0.4f} {self.key_indicator} ' - f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') + def _init_rule(self, rules, key_indicators) -> None: + """Initialize rule, key_indicator, comparison_func, and best score. If + key_indicator is a list of string and rule is a string, all metric in + the key_indicator will share the same rule. - def _init_rule(self, rule, key_indicator) -> None: - """Initialize rule, key_indicator, comparison_func, and best score. Here is the rule to determine which rule is used for key indicator when the rule is not specific (note that the key indicator matching is case- insensitive): - 1. If the key indicator is in ``self.greater_keys``, the rule will be - specified as 'greater'. - 2. Or if the key indicator is in ``self.less_keys``, the rule will be - specified as 'less'. + 1. If the key indicator is in ``self.greater_keys``, the rule + will be specified as 'greater'. + 2. Or if the key indicator is in ``self.less_keys``, the rule + will be specified as 'less'. 3. Or if any one item in ``self.greater_keys`` is a substring of - key_indicator , the rule will be specified as 'greater'. + key_indicator, the rule will be specified as 'greater'. 4. Or if any one item in ``self.less_keys`` is a substring of - key_indicator , the rule will be specified as 'less'. + key_indicator, the rule will be specified as 'less'. + Args: - rule (str | None): Comparison rule for best score. - key_indicator (str | None): Key indicator to determine the - comparison rule. + rule (List[Optional[str]]): Comparison rule for best score. + key_indicator (List[str]): Key indicator to determine + the comparison rule. """ + if len(rules) == 1: + rules = rules * len(key_indicators) - if rule not in self.rule_map and rule is not None: - raise KeyError('rule must be greater, less or None, ' - f'but got {rule}.') + self.rules = [] + for rule, key_indicator in zip(rules, key_indicators): - if rule is None and key_indicator != 'auto': - # `_lc` here means we use the lower case of keys for - # case-insensitive matching - key_indicator_lc = key_indicator.lower() - greater_keys = [key.lower() for key in self.greater_keys] - less_keys = [key.lower() for key in self.less_keys] + if rule not in self.rule_map and rule is not None: + raise KeyError('rule must be greater, less or None, ' + f'but got {rule}.') - if key_indicator_lc in greater_keys: - rule = 'greater' - elif key_indicator_lc in less_keys: - rule = 'less' - elif any(key in key_indicator_lc for key in greater_keys): - rule = 'greater' - elif any(key in key_indicator_lc for key in less_keys): - rule = 'less' - else: - raise ValueError('Cannot infer the rule for key ' - f'{key_indicator}, thus a specific rule ' - 'must be specified.') - self.rule = rule - self.key_indicator = key_indicator - if self.rule is not None: - self.is_better_than = self.rule_map[self.rule] + if rule is None and key_indicator != 'auto': + # `_lc` here means we use the lower case of keys for + # case-insensitive matching + key_indicator_lc = key_indicator.lower() + greater_keys = {key.lower() for key in self.greater_keys} + less_keys = {key.lower() for key in self.less_keys} + + if key_indicator_lc in greater_keys: + rule = 'greater' + elif key_indicator_lc in less_keys: + rule = 'less' + elif any(key in key_indicator_lc for key in greater_keys): + rule = 'greater' + elif any(key in key_indicator_lc for key in less_keys): + rule = 'less' + else: + raise ValueError('Cannot infer the rule for key ' + f'{key_indicator}, thus a specific rule ' + 'must be specified.') + if rule is not None: + self.is_better_than[key_indicator] = self.rule_map[rule] + self.rules.append(rule) + + self.key_indicators = key_indicators def after_train_iter(self, runner, diff --git a/tests/test_hook/test_checkpoint_hook.py b/tests/test_hook/test_checkpoint_hook.py index 56b6dec5..ee2a0dff 100644 --- a/tests/test_hook/test_checkpoint_hook.py +++ b/tests/test_hook/test_checkpoint_hook.py @@ -49,6 +49,30 @@ class TestCheckpointHook: assert checkpoint_hook.out_dir == ( f'test_dir/{osp.basename(work_dir)}') + runner.message_hub = MessageHub.get_instance('test_before_train') + # no 'best_ckpt_path' in runtime_info + checkpoint_hook = CheckpointHook(interval=1, save_best=['acc', 'mIoU']) + checkpoint_hook.before_train(runner) + assert checkpoint_hook.best_ckpt_path_dict == dict(acc=None, mIoU=None) + assert not hasattr(checkpoint_hook, 'best_ckpt_path') + + # only one 'best_ckpt_path' in runtime_info + runner.message_hub.update_info('best_ckpt_acc', 'best_acc') + checkpoint_hook.before_train(runner) + assert checkpoint_hook.best_ckpt_path_dict == dict( + acc='best_acc', mIoU=None) + + # no 'best_ckpt_path' in runtime_info + checkpoint_hook = CheckpointHook(interval=1, save_best='acc') + checkpoint_hook.before_train(runner) + assert checkpoint_hook.best_ckpt_path is None + assert not hasattr(checkpoint_hook, 'best_ckpt_path_dict') + + # 'best_ckpt_path' in runtime_info + runner.message_hub.update_info('best_ckpt', 'best_ckpt') + checkpoint_hook.before_train(runner) + assert checkpoint_hook.best_ckpt_path == 'best_ckpt' + def test_after_val_epoch(self, tmp_path): runner = Mock() runner.work_dir = tmp_path @@ -69,7 +93,7 @@ class TestCheckpointHook: with pytest.warns(UserWarning) as record_warnings: eval_hook = CheckpointHook( interval=2, by_epoch=True, save_best='auto') - eval_hook._get_metric_score(None) + eval_hook._get_metric_score(None, None) # Since there will be many warnings thrown, we just need to check # if the expected exceptions are thrown expected_message = ( @@ -82,6 +106,17 @@ class TestCheckpointHook: else: assert False + # test error when number of rules and metrics are not same + with pytest.raises(AssertionError) as assert_error: + CheckpointHook( + interval=1, + save_best=['mIoU', 'acc'], + rule=['greater', 'greater', 'less'], + by_epoch=True) + error_message = ('Number of "rule" must be 1 or the same as number of ' + '"save_best", but got 3.') + assert error_message in str(assert_error.value) + # if save_best is None,no best_ckpt meta should be stored eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None) eval_hook.before_train(runner) @@ -97,8 +132,8 @@ class TestCheckpointHook: best_ckpt_name = 'best_acc_epoch_10.pth' best_ckpt_path = eval_hook.file_client.join_path( eval_hook.out_dir, best_ckpt_name) - assert eval_hook.key_indicator == 'acc' - assert eval_hook.rule == 'greater' + assert eval_hook.key_indicators == ['acc'] + assert eval_hook.rules == ['greater'] assert 'best_score' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score') == 0.5 assert 'best_ckpt' in runner.message_hub.runtime_info and \ @@ -142,6 +177,42 @@ class TestCheckpointHook: assert 'best_score' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score') == 1.0 + # test multi `save_best` with one rule + eval_hook = CheckpointHook( + interval=2, save_best=['acc', 'mIoU'], rule='greater') + assert eval_hook.key_indicators == ['acc', 'mIoU'] + assert eval_hook.rules == ['greater', 'greater'] + + # test multi `save_best` with multi rules + eval_hook = CheckpointHook( + interval=2, save_best=['FID', 'IS'], rule=['less', 'greater']) + assert eval_hook.key_indicators == ['FID', 'IS'] + assert eval_hook.rules == ['less', 'greater'] + + # test multi `save_best` with default rule + eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU']) + assert eval_hook.key_indicators == ['acc', 'mIoU'] + assert eval_hook.rules == ['greater', 'greater'] + runner.message_hub = MessageHub.get_instance( + 'test_after_val_epoch_save_multi_best') + eval_hook.before_train(runner) + metrics = dict(acc=0.5, mIoU=0.6) + eval_hook.after_val_epoch(runner, metrics) + best_acc_name = 'best_acc_epoch_10.pth' + best_acc_path = eval_hook.file_client.join_path( + eval_hook.out_dir, best_acc_name) + best_mIoU_name = 'best_mIoU_epoch_10.pth' + best_mIoU_path = eval_hook.file_client.join_path( + eval_hook.out_dir, best_mIoU_name) + assert 'best_score_acc' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_score_acc') == 0.5 + assert 'best_score_mIoU' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_score_mIoU') == 0.6 + assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_ckpt_acc') == best_acc_path + assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path + # test behavior when by_epoch is False runner = Mock() runner.work_dir = tmp_path @@ -156,8 +227,8 @@ class TestCheckpointHook: interval=2, by_epoch=False, save_best='acc', rule='greater') eval_hook.before_train(runner) eval_hook.after_val_epoch(runner, metrics) - assert eval_hook.key_indicator == 'acc' - assert eval_hook.rule == 'greater' + assert eval_hook.key_indicators == ['acc'] + assert eval_hook.rules == ['greater'] best_ckpt_name = 'best_acc_iter_10.pth' best_ckpt_path = eval_hook.file_client.join_path( eval_hook.out_dir, best_ckpt_name) @@ -176,6 +247,38 @@ class TestCheckpointHook: runner.message_hub.get_info('best_ckpt') == best_ckpt_path assert 'best_score' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score') == 0.666 + # error when 'auto' in `save_best` list + with pytest.raises(AssertionError): + CheckpointHook(interval=2, save_best=['auto', 'acc']) + # error when one `save_best` with multi `rule` + with pytest.raises(AssertionError): + CheckpointHook( + interval=2, save_best='acc', rule=['greater', 'less']) + + # check best checkpoint name with `by_epoch` is False + eval_hook = CheckpointHook( + interval=2, by_epoch=False, save_best=['acc', 'mIoU']) + assert eval_hook.key_indicators == ['acc', 'mIoU'] + assert eval_hook.rules == ['greater', 'greater'] + runner.message_hub = MessageHub.get_instance( + 'test_after_val_epoch_save_multi_best_by_epoch_is_false') + eval_hook.before_train(runner) + metrics = dict(acc=0.5, mIoU=0.6) + eval_hook.after_val_epoch(runner, metrics) + best_acc_name = 'best_acc_iter_10.pth' + best_acc_path = eval_hook.file_client.join_path( + eval_hook.out_dir, best_acc_name) + best_mIoU_name = 'best_mIoU_iter_10.pth' + best_mIoU_path = eval_hook.file_client.join_path( + eval_hook.out_dir, best_mIoU_name) + assert 'best_score_acc' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_score_acc') == 0.5 + assert 'best_score_mIoU' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_score_mIoU') == 0.6 + assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_ckpt_acc') == best_acc_path + assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path def test_after_train_epoch(self, tmp_path): runner = Mock()