[Enhancement] Support save best based on multi metrics (#349)
* support save best based on multi metrics * add unit test * resolve bugs after rebasing * revise docstring * revise docstring * fix as comment * revise as commentpull/414/head
parent
6ebb7ed481
commit
08602a2385
|
@ -4,12 +4,13 @@ import warnings
|
|||
from collections import OrderedDict
|
||||
from math import inf
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Union
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils import is_seq_of
|
||||
from mmengine.utils.misc import is_list_of
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
@ -44,20 +45,27 @@ class CheckpointHook(Hook):
|
|||
Defaults to -1, which means unlimited.
|
||||
save_last (bool): Whether to force the last checkpoint to be
|
||||
saved regardless of interval. Defaults to True.
|
||||
save_best (str, optional): If a metric is specified, it would measure
|
||||
the best checkpoint during evaluation. The information about best
|
||||
checkpoint would be saved in ``runner.message_hub`` to keep
|
||||
save_best (str, List[str], optional): If a metric is specified, it
|
||||
would measure the best checkpoint during evaluation. If a list of
|
||||
metrics is passed, it would measure a group of best checkpoints
|
||||
corresponding to the passed metrics. The information about best
|
||||
checkpoint(s) would be saved in ``runner.message_hub`` to keep
|
||||
best score value and best checkpoint path, which will be also
|
||||
loaded when resuming checkpoint. Options are the evaluation metrics
|
||||
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
|
||||
detection and instance segmentation. ``AR@100`` for proposal
|
||||
recall. If ``save_best`` is ``auto``, the first key of the returned
|
||||
``OrderedDict`` result will be used. Defaults to None.
|
||||
rule (str, optional): Comparison rule for best score. If set to
|
||||
None, it will infer a reasonable rule. Keys such as 'acc', 'top'
|
||||
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
|
||||
be inferred by 'less' rule. Options are 'greater', 'less', None.
|
||||
Defaults to None.
|
||||
rule (str, List[str], optional): Comparison rule for best score. If
|
||||
set to None, it will infer a reasonable rule. Keys such as 'acc',
|
||||
'top' .etc will be inferred by 'greater' rule. Keys contain 'loss'
|
||||
will be inferred by 'less' rule. If ``save_best`` is a list of
|
||||
metrics and ``rule`` is a str, all metrics in ``save_best`` will
|
||||
share the comparison rule. If ``save_best`` and ``rule`` are both
|
||||
lists, their length must be the same, and metrics in ``save_best``
|
||||
will use the corresponding comparison rule in ``rule``. Options
|
||||
are 'greater', 'less', None and list which contains 'greater' and
|
||||
'less'. Defaults to None.
|
||||
greater_keys (List[str], optional): Metric keys that will be
|
||||
inferred by 'greater' comparison rule. If ``None``,
|
||||
_default_greater_keys will be used. Defaults to None.
|
||||
|
@ -67,6 +75,17 @@ class CheckpointHook(Hook):
|
|||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> # Save best based on single metric
|
||||
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
|
||||
>>> rule='less')
|
||||
>>> # Save best based on multi metrics with the same comparison rule
|
||||
>>> CheckpointHook(interval=2, by_epoch=True,
|
||||
>>> save_best=['acc', 'mIoU'], rule='greater')
|
||||
>>> # Save best based on multi metrics with different comparison rule
|
||||
>>> CheckpointHook(interval=2, by_epoch=True,
|
||||
>>> save_best=['FID', 'IS'], rule=['less', 'greater'])
|
||||
"""
|
||||
out_dir: str
|
||||
|
||||
|
@ -93,8 +112,8 @@ class CheckpointHook(Hook):
|
|||
out_dir: Optional[Union[str, Path]] = None,
|
||||
max_keep_ckpts: int = -1,
|
||||
save_last: bool = True,
|
||||
save_best: Optional[str] = None,
|
||||
rule: Optional[str] = None,
|
||||
save_best: Union[str, List[str], None] = None,
|
||||
rule: Union[str, List[str], None] = None,
|
||||
greater_keys: Optional[Sequence[str]] = None,
|
||||
less_keys: Optional[Sequence[str]] = None,
|
||||
file_client_args: Optional[dict] = None,
|
||||
|
@ -110,11 +129,39 @@ class CheckpointHook(Hook):
|
|||
self.file_client_args = file_client_args
|
||||
|
||||
# save best logic
|
||||
assert isinstance(save_best, str) or save_best is None, \
|
||||
'"save_best" should be a str or None ' \
|
||||
f'rather than {type(save_best)}'
|
||||
assert (isinstance(save_best, str) or is_list_of(save_best, str)
|
||||
or (save_best is None)), (
|
||||
'"save_best" should be a str or list of str or None, '
|
||||
f'but got {type(save_best)}')
|
||||
|
||||
if isinstance(save_best, list):
|
||||
if 'auto' in save_best:
|
||||
assert len(save_best) == 1, (
|
||||
'Only support one "auto" in "save_best" list.')
|
||||
assert len(save_best) == len(
|
||||
set(save_best)), ('Find duplicate element in "save_best".')
|
||||
else:
|
||||
# convert str to list[str]
|
||||
if save_best is not None:
|
||||
save_best = [save_best] # type: ignore # noqa: F401
|
||||
self.save_best = save_best
|
||||
|
||||
# rule logic
|
||||
assert (isinstance(rule, str) or is_list_of(rule, str)
|
||||
or (rule is None)), (
|
||||
'"rule" should be a str or list of str or None, '
|
||||
f'but got {type(rule)}')
|
||||
if isinstance(rule, list):
|
||||
# check the length of rule list
|
||||
assert len(rule) in [
|
||||
1,
|
||||
len(self.save_best) # type: ignore
|
||||
], ('Number of "rule" must be 1 or the same as number of '
|
||||
f'"save_best", but got {len(rule)}.')
|
||||
else:
|
||||
# convert str/None to list
|
||||
rule = [rule] # type: ignore # noqa: F401
|
||||
|
||||
if greater_keys is None:
|
||||
self.greater_keys = self._default_greater_keys
|
||||
else:
|
||||
|
@ -132,8 +179,12 @@ class CheckpointHook(Hook):
|
|||
self.less_keys = less_keys # type: ignore
|
||||
|
||||
if self.save_best is not None:
|
||||
self.best_ckpt_path = None
|
||||
self.is_better_than: Dict[str, Callable] = dict()
|
||||
self._init_rule(rule, self.save_best)
|
||||
if len(self.key_indicators) == 1:
|
||||
self.best_ckpt_path: Optional[str] = None
|
||||
else:
|
||||
self.best_ckpt_path_dict: Dict = dict()
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Finish all operations, related to checkpoint.
|
||||
|
@ -162,10 +213,21 @@ class CheckpointHook(Hook):
|
|||
f'{self.file_client.name}.')
|
||||
|
||||
if self.save_best is not None:
|
||||
if 'best_ckpt' not in runner.message_hub.runtime_info:
|
||||
self.best_ckpt_path = None
|
||||
if len(self.key_indicators) == 1:
|
||||
if 'best_ckpt' not in runner.message_hub.runtime_info:
|
||||
self.best_ckpt_path = None
|
||||
else:
|
||||
self.best_ckpt_path = runner.message_hub.get_info(
|
||||
'best_ckpt')
|
||||
else:
|
||||
self.best_ckpt_path = runner.message_hub.get_info('best_ckpt')
|
||||
for key_indicator in self.key_indicators:
|
||||
best_ckpt_name = f'best_ckpt_{key_indicator}'
|
||||
if best_ckpt_name not in runner.message_hub.runtime_info:
|
||||
self.best_ckpt_path_dict[key_indicator] = None
|
||||
else:
|
||||
self.best_ckpt_path_dict[
|
||||
key_indicator] = runner.message_hub.get_info(
|
||||
best_ckpt_name)
|
||||
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||
|
@ -195,7 +257,7 @@ class CheckpointHook(Hook):
|
|||
"""
|
||||
self._save_best_checkpoint(runner, metrics)
|
||||
|
||||
def _get_metric_score(self, metrics):
|
||||
def _get_metric_score(self, metrics, key_indicator):
|
||||
eval_res = OrderedDict()
|
||||
if metrics is not None:
|
||||
eval_res.update(metrics)
|
||||
|
@ -206,10 +268,7 @@ class CheckpointHook(Hook):
|
|||
'the best checkpoint will be skipped in this evaluation.')
|
||||
return None
|
||||
|
||||
if self.key_indicator == 'auto':
|
||||
self._init_rule(self.rule, list(eval_res.keys())[0])
|
||||
|
||||
return eval_res[self.key_indicator]
|
||||
return eval_res[key_indicator]
|
||||
|
||||
@master_only
|
||||
def _save_checkpoint(self, runner) -> None:
|
||||
|
@ -264,6 +323,7 @@ class CheckpointHook(Hook):
|
|||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
metrics (dict): Evaluation results of all metrics.
|
||||
"""
|
||||
if not self.save_best:
|
||||
return
|
||||
|
@ -277,91 +337,123 @@ class CheckpointHook(Hook):
|
|||
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
||||
cur_type, cur_time = 'iter', runner.iter + 1
|
||||
|
||||
# handle auto in self.key_indicators and self.rules before the loop
|
||||
if 'auto' in self.key_indicators:
|
||||
self._init_rule(self.rules, [list(metrics.keys())[0]])
|
||||
|
||||
# save best logic
|
||||
# get score from messagehub
|
||||
# notice `_get_metirc_score` helps to infer
|
||||
# self.rule when self.save_best is `auto`
|
||||
key_score = self._get_metric_score(metrics)
|
||||
if 'best_score' not in runner.message_hub.runtime_info:
|
||||
best_score = self.init_value_map[self.rule]
|
||||
else:
|
||||
best_score = runner.message_hub.get_info('best_score')
|
||||
for key_indicator, rule in zip(self.key_indicators, self.rules):
|
||||
key_score = self._get_metric_score(metrics, key_indicator)
|
||||
|
||||
if not key_score or not self.is_better_than(key_score, best_score):
|
||||
return
|
||||
if len(self.key_indicators) == 1:
|
||||
best_score_key = 'best_score'
|
||||
runtime_best_ckpt_key = 'best_ckpt'
|
||||
best_ckpt_path = self.best_ckpt_path
|
||||
else:
|
||||
best_score_key = f'best_score_{key_indicator}'
|
||||
runtime_best_ckpt_key = f'best_ckpt_{key_indicator}'
|
||||
best_ckpt_path = self.best_ckpt_path_dict[key_indicator]
|
||||
|
||||
best_score = key_score
|
||||
runner.message_hub.update_info('best_score', best_score)
|
||||
if best_score_key not in runner.message_hub.runtime_info:
|
||||
best_score = self.init_value_map[rule]
|
||||
else:
|
||||
best_score = runner.message_hub.get_info(best_score_key)
|
||||
|
||||
if self.best_ckpt_path and self.file_client.isfile(
|
||||
self.best_ckpt_path):
|
||||
self.file_client.remove(self.best_ckpt_path)
|
||||
if key_score is None or not self.is_better_than[key_indicator](
|
||||
key_score, best_score):
|
||||
continue
|
||||
|
||||
best_score = key_score
|
||||
runner.message_hub.update_info(best_score_key, best_score)
|
||||
|
||||
if best_ckpt_path and self.file_client.isfile(best_ckpt_path):
|
||||
self.file_client.remove(best_ckpt_path)
|
||||
runner.logger.info(
|
||||
f'The previous best checkpoint {best_ckpt_path} '
|
||||
'is removed')
|
||||
|
||||
best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}'
|
||||
if len(self.key_indicators) == 1:
|
||||
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
|
||||
self.out_dir, best_ckpt_name)
|
||||
runner.message_hub.update_info(runtime_best_ckpt_key,
|
||||
self.best_ckpt_path)
|
||||
else:
|
||||
self.best_ckpt_path_dict[
|
||||
key_indicator] = self.file_client.join_path( # type: ignore # noqa: E501
|
||||
self.out_dir, best_ckpt_name)
|
||||
runner.message_hub.update_info(
|
||||
runtime_best_ckpt_key,
|
||||
self.best_ckpt_path_dict[key_indicator])
|
||||
runner.save_checkpoint(
|
||||
self.out_dir,
|
||||
filename=best_ckpt_name,
|
||||
file_client_args=self.file_client_args,
|
||||
save_optimizer=False,
|
||||
save_param_scheduler=False,
|
||||
by_epoch=False)
|
||||
runner.logger.info(
|
||||
f'The previous best checkpoint {self.best_ckpt_path} '
|
||||
'is removed')
|
||||
f'The best checkpoint with {best_score:0.4f} {key_indicator} '
|
||||
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
|
||||
|
||||
best_ckpt_name = f'best_{self.key_indicator}_{ckpt_filename}'
|
||||
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
|
||||
self.out_dir, best_ckpt_name)
|
||||
runner.message_hub.update_info('best_ckpt', self.best_ckpt_path)
|
||||
runner.save_checkpoint(
|
||||
self.out_dir,
|
||||
filename=best_ckpt_name,
|
||||
file_client_args=self.file_client_args,
|
||||
save_optimizer=False,
|
||||
save_param_scheduler=False,
|
||||
by_epoch=False)
|
||||
runner.logger.info(
|
||||
f'The best checkpoint with {best_score:0.4f} {self.key_indicator} '
|
||||
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
|
||||
def _init_rule(self, rules, key_indicators) -> None:
|
||||
"""Initialize rule, key_indicator, comparison_func, and best score. If
|
||||
key_indicator is a list of string and rule is a string, all metric in
|
||||
the key_indicator will share the same rule.
|
||||
|
||||
def _init_rule(self, rule, key_indicator) -> None:
|
||||
"""Initialize rule, key_indicator, comparison_func, and best score.
|
||||
Here is the rule to determine which rule is used for key indicator when
|
||||
the rule is not specific (note that the key indicator matching is case-
|
||||
insensitive):
|
||||
|
||||
1. If the key indicator is in ``self.greater_keys``, the rule will be
|
||||
specified as 'greater'.
|
||||
2. Or if the key indicator is in ``self.less_keys``, the rule will be
|
||||
specified as 'less'.
|
||||
1. If the key indicator is in ``self.greater_keys``, the rule
|
||||
will be specified as 'greater'.
|
||||
2. Or if the key indicator is in ``self.less_keys``, the rule
|
||||
will be specified as 'less'.
|
||||
3. Or if any one item in ``self.greater_keys`` is a substring of
|
||||
key_indicator , the rule will be specified as 'greater'.
|
||||
key_indicator, the rule will be specified as 'greater'.
|
||||
4. Or if any one item in ``self.less_keys`` is a substring of
|
||||
key_indicator , the rule will be specified as 'less'.
|
||||
key_indicator, the rule will be specified as 'less'.
|
||||
|
||||
Args:
|
||||
rule (str | None): Comparison rule for best score.
|
||||
key_indicator (str | None): Key indicator to determine the
|
||||
comparison rule.
|
||||
rule (List[Optional[str]]): Comparison rule for best score.
|
||||
key_indicator (List[str]): Key indicator to determine
|
||||
the comparison rule.
|
||||
"""
|
||||
if len(rules) == 1:
|
||||
rules = rules * len(key_indicators)
|
||||
|
||||
if rule not in self.rule_map and rule is not None:
|
||||
raise KeyError('rule must be greater, less or None, '
|
||||
f'but got {rule}.')
|
||||
self.rules = []
|
||||
for rule, key_indicator in zip(rules, key_indicators):
|
||||
|
||||
if rule is None and key_indicator != 'auto':
|
||||
# `_lc` here means we use the lower case of keys for
|
||||
# case-insensitive matching
|
||||
key_indicator_lc = key_indicator.lower()
|
||||
greater_keys = [key.lower() for key in self.greater_keys]
|
||||
less_keys = [key.lower() for key in self.less_keys]
|
||||
if rule not in self.rule_map and rule is not None:
|
||||
raise KeyError('rule must be greater, less or None, '
|
||||
f'but got {rule}.')
|
||||
|
||||
if key_indicator_lc in greater_keys:
|
||||
rule = 'greater'
|
||||
elif key_indicator_lc in less_keys:
|
||||
rule = 'less'
|
||||
elif any(key in key_indicator_lc for key in greater_keys):
|
||||
rule = 'greater'
|
||||
elif any(key in key_indicator_lc for key in less_keys):
|
||||
rule = 'less'
|
||||
else:
|
||||
raise ValueError('Cannot infer the rule for key '
|
||||
f'{key_indicator}, thus a specific rule '
|
||||
'must be specified.')
|
||||
self.rule = rule
|
||||
self.key_indicator = key_indicator
|
||||
if self.rule is not None:
|
||||
self.is_better_than = self.rule_map[self.rule]
|
||||
if rule is None and key_indicator != 'auto':
|
||||
# `_lc` here means we use the lower case of keys for
|
||||
# case-insensitive matching
|
||||
key_indicator_lc = key_indicator.lower()
|
||||
greater_keys = {key.lower() for key in self.greater_keys}
|
||||
less_keys = {key.lower() for key in self.less_keys}
|
||||
|
||||
if key_indicator_lc in greater_keys:
|
||||
rule = 'greater'
|
||||
elif key_indicator_lc in less_keys:
|
||||
rule = 'less'
|
||||
elif any(key in key_indicator_lc for key in greater_keys):
|
||||
rule = 'greater'
|
||||
elif any(key in key_indicator_lc for key in less_keys):
|
||||
rule = 'less'
|
||||
else:
|
||||
raise ValueError('Cannot infer the rule for key '
|
||||
f'{key_indicator}, thus a specific rule '
|
||||
'must be specified.')
|
||||
if rule is not None:
|
||||
self.is_better_than[key_indicator] = self.rule_map[rule]
|
||||
self.rules.append(rule)
|
||||
|
||||
self.key_indicators = key_indicators
|
||||
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
|
|
|
@ -49,6 +49,30 @@ class TestCheckpointHook:
|
|||
assert checkpoint_hook.out_dir == (
|
||||
f'test_dir/{osp.basename(work_dir)}')
|
||||
|
||||
runner.message_hub = MessageHub.get_instance('test_before_train')
|
||||
# no 'best_ckpt_path' in runtime_info
|
||||
checkpoint_hook = CheckpointHook(interval=1, save_best=['acc', 'mIoU'])
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.best_ckpt_path_dict == dict(acc=None, mIoU=None)
|
||||
assert not hasattr(checkpoint_hook, 'best_ckpt_path')
|
||||
|
||||
# only one 'best_ckpt_path' in runtime_info
|
||||
runner.message_hub.update_info('best_ckpt_acc', 'best_acc')
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.best_ckpt_path_dict == dict(
|
||||
acc='best_acc', mIoU=None)
|
||||
|
||||
# no 'best_ckpt_path' in runtime_info
|
||||
checkpoint_hook = CheckpointHook(interval=1, save_best='acc')
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.best_ckpt_path is None
|
||||
assert not hasattr(checkpoint_hook, 'best_ckpt_path_dict')
|
||||
|
||||
# 'best_ckpt_path' in runtime_info
|
||||
runner.message_hub.update_info('best_ckpt', 'best_ckpt')
|
||||
checkpoint_hook.before_train(runner)
|
||||
assert checkpoint_hook.best_ckpt_path == 'best_ckpt'
|
||||
|
||||
def test_after_val_epoch(self, tmp_path):
|
||||
runner = Mock()
|
||||
runner.work_dir = tmp_path
|
||||
|
@ -69,7 +93,7 @@ class TestCheckpointHook:
|
|||
with pytest.warns(UserWarning) as record_warnings:
|
||||
eval_hook = CheckpointHook(
|
||||
interval=2, by_epoch=True, save_best='auto')
|
||||
eval_hook._get_metric_score(None)
|
||||
eval_hook._get_metric_score(None, None)
|
||||
# Since there will be many warnings thrown, we just need to check
|
||||
# if the expected exceptions are thrown
|
||||
expected_message = (
|
||||
|
@ -82,6 +106,17 @@ class TestCheckpointHook:
|
|||
else:
|
||||
assert False
|
||||
|
||||
# test error when number of rules and metrics are not same
|
||||
with pytest.raises(AssertionError) as assert_error:
|
||||
CheckpointHook(
|
||||
interval=1,
|
||||
save_best=['mIoU', 'acc'],
|
||||
rule=['greater', 'greater', 'less'],
|
||||
by_epoch=True)
|
||||
error_message = ('Number of "rule" must be 1 or the same as number of '
|
||||
'"save_best", but got 3.')
|
||||
assert error_message in str(assert_error.value)
|
||||
|
||||
# if save_best is None,no best_ckpt meta should be stored
|
||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
|
||||
eval_hook.before_train(runner)
|
||||
|
@ -97,8 +132,8 @@ class TestCheckpointHook:
|
|||
best_ckpt_name = 'best_acc_epoch_10.pth'
|
||||
best_ckpt_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_ckpt_name)
|
||||
assert eval_hook.key_indicator == 'acc'
|
||||
assert eval_hook.rule == 'greater'
|
||||
assert eval_hook.key_indicators == ['acc']
|
||||
assert eval_hook.rules == ['greater']
|
||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 0.5
|
||||
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
||||
|
@ -142,6 +177,42 @@ class TestCheckpointHook:
|
|||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 1.0
|
||||
|
||||
# test multi `save_best` with one rule
|
||||
eval_hook = CheckpointHook(
|
||||
interval=2, save_best=['acc', 'mIoU'], rule='greater')
|
||||
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
||||
assert eval_hook.rules == ['greater', 'greater']
|
||||
|
||||
# test multi `save_best` with multi rules
|
||||
eval_hook = CheckpointHook(
|
||||
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
|
||||
assert eval_hook.key_indicators == ['FID', 'IS']
|
||||
assert eval_hook.rules == ['less', 'greater']
|
||||
|
||||
# test multi `save_best` with default rule
|
||||
eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
|
||||
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
||||
assert eval_hook.rules == ['greater', 'greater']
|
||||
runner.message_hub = MessageHub.get_instance(
|
||||
'test_after_val_epoch_save_multi_best')
|
||||
eval_hook.before_train(runner)
|
||||
metrics = dict(acc=0.5, mIoU=0.6)
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
best_acc_name = 'best_acc_epoch_10.pth'
|
||||
best_acc_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_acc_name)
|
||||
best_mIoU_name = 'best_mIoU_epoch_10.pth'
|
||||
best_mIoU_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_mIoU_name)
|
||||
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score_acc') == 0.5
|
||||
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score_mIoU') == 0.6
|
||||
assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt_acc') == best_acc_path
|
||||
assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path
|
||||
|
||||
# test behavior when by_epoch is False
|
||||
runner = Mock()
|
||||
runner.work_dir = tmp_path
|
||||
|
@ -156,8 +227,8 @@ class TestCheckpointHook:
|
|||
interval=2, by_epoch=False, save_best='acc', rule='greater')
|
||||
eval_hook.before_train(runner)
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
assert eval_hook.key_indicator == 'acc'
|
||||
assert eval_hook.rule == 'greater'
|
||||
assert eval_hook.key_indicators == ['acc']
|
||||
assert eval_hook.rules == ['greater']
|
||||
best_ckpt_name = 'best_acc_iter_10.pth'
|
||||
best_ckpt_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_ckpt_name)
|
||||
|
@ -176,6 +247,38 @@ class TestCheckpointHook:
|
|||
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score') == 0.666
|
||||
# error when 'auto' in `save_best` list
|
||||
with pytest.raises(AssertionError):
|
||||
CheckpointHook(interval=2, save_best=['auto', 'acc'])
|
||||
# error when one `save_best` with multi `rule`
|
||||
with pytest.raises(AssertionError):
|
||||
CheckpointHook(
|
||||
interval=2, save_best='acc', rule=['greater', 'less'])
|
||||
|
||||
# check best checkpoint name with `by_epoch` is False
|
||||
eval_hook = CheckpointHook(
|
||||
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
|
||||
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
||||
assert eval_hook.rules == ['greater', 'greater']
|
||||
runner.message_hub = MessageHub.get_instance(
|
||||
'test_after_val_epoch_save_multi_best_by_epoch_is_false')
|
||||
eval_hook.before_train(runner)
|
||||
metrics = dict(acc=0.5, mIoU=0.6)
|
||||
eval_hook.after_val_epoch(runner, metrics)
|
||||
best_acc_name = 'best_acc_iter_10.pth'
|
||||
best_acc_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_acc_name)
|
||||
best_mIoU_name = 'best_mIoU_iter_10.pth'
|
||||
best_mIoU_path = eval_hook.file_client.join_path(
|
||||
eval_hook.out_dir, best_mIoU_name)
|
||||
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score_acc') == 0.5
|
||||
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_score_mIoU') == 0.6
|
||||
assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt_acc') == best_acc_path
|
||||
assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \
|
||||
runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path
|
||||
|
||||
def test_after_train_epoch(self, tmp_path):
|
||||
runner = Mock()
|
||||
|
|
Loading…
Reference in New Issue