mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Support save best based on multi metrics (#349)
* support save best based on multi metrics * add unit test * resolve bugs after rebasing * revise docstring * revise docstring * fix as comment * revise as comment
This commit is contained in:
parent
6ebb7ed481
commit
08602a2385
@ -4,12 +4,13 @@ import warnings
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from math import inf
|
from math import inf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from mmengine.dist import master_only
|
from mmengine.dist import master_only
|
||||||
from mmengine.fileio import FileClient
|
from mmengine.fileio import FileClient
|
||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
from mmengine.utils import is_seq_of
|
from mmengine.utils import is_seq_of
|
||||||
|
from mmengine.utils.misc import is_list_of
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
DATA_BATCH = Optional[Sequence[dict]]
|
DATA_BATCH = Optional[Sequence[dict]]
|
||||||
@ -44,20 +45,27 @@ class CheckpointHook(Hook):
|
|||||||
Defaults to -1, which means unlimited.
|
Defaults to -1, which means unlimited.
|
||||||
save_last (bool): Whether to force the last checkpoint to be
|
save_last (bool): Whether to force the last checkpoint to be
|
||||||
saved regardless of interval. Defaults to True.
|
saved regardless of interval. Defaults to True.
|
||||||
save_best (str, optional): If a metric is specified, it would measure
|
save_best (str, List[str], optional): If a metric is specified, it
|
||||||
the best checkpoint during evaluation. The information about best
|
would measure the best checkpoint during evaluation. If a list of
|
||||||
checkpoint would be saved in ``runner.message_hub`` to keep
|
metrics is passed, it would measure a group of best checkpoints
|
||||||
|
corresponding to the passed metrics. The information about best
|
||||||
|
checkpoint(s) would be saved in ``runner.message_hub`` to keep
|
||||||
best score value and best checkpoint path, which will be also
|
best score value and best checkpoint path, which will be also
|
||||||
loaded when resuming checkpoint. Options are the evaluation metrics
|
loaded when resuming checkpoint. Options are the evaluation metrics
|
||||||
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
|
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
|
||||||
detection and instance segmentation. ``AR@100`` for proposal
|
detection and instance segmentation. ``AR@100`` for proposal
|
||||||
recall. If ``save_best`` is ``auto``, the first key of the returned
|
recall. If ``save_best`` is ``auto``, the first key of the returned
|
||||||
``OrderedDict`` result will be used. Defaults to None.
|
``OrderedDict`` result will be used. Defaults to None.
|
||||||
rule (str, optional): Comparison rule for best score. If set to
|
rule (str, List[str], optional): Comparison rule for best score. If
|
||||||
None, it will infer a reasonable rule. Keys such as 'acc', 'top'
|
set to None, it will infer a reasonable rule. Keys such as 'acc',
|
||||||
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
|
'top' .etc will be inferred by 'greater' rule. Keys contain 'loss'
|
||||||
be inferred by 'less' rule. Options are 'greater', 'less', None.
|
will be inferred by 'less' rule. If ``save_best`` is a list of
|
||||||
Defaults to None.
|
metrics and ``rule`` is a str, all metrics in ``save_best`` will
|
||||||
|
share the comparison rule. If ``save_best`` and ``rule`` are both
|
||||||
|
lists, their length must be the same, and metrics in ``save_best``
|
||||||
|
will use the corresponding comparison rule in ``rule``. Options
|
||||||
|
are 'greater', 'less', None and list which contains 'greater' and
|
||||||
|
'less'. Defaults to None.
|
||||||
greater_keys (List[str], optional): Metric keys that will be
|
greater_keys (List[str], optional): Metric keys that will be
|
||||||
inferred by 'greater' comparison rule. If ``None``,
|
inferred by 'greater' comparison rule. If ``None``,
|
||||||
_default_greater_keys will be used. Defaults to None.
|
_default_greater_keys will be used. Defaults to None.
|
||||||
@ -67,6 +75,17 @@ class CheckpointHook(Hook):
|
|||||||
file_client_args (dict, optional): Arguments to instantiate a
|
file_client_args (dict, optional): Arguments to instantiate a
|
||||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Save best based on single metric
|
||||||
|
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
|
||||||
|
>>> rule='less')
|
||||||
|
>>> # Save best based on multi metrics with the same comparison rule
|
||||||
|
>>> CheckpointHook(interval=2, by_epoch=True,
|
||||||
|
>>> save_best=['acc', 'mIoU'], rule='greater')
|
||||||
|
>>> # Save best based on multi metrics with different comparison rule
|
||||||
|
>>> CheckpointHook(interval=2, by_epoch=True,
|
||||||
|
>>> save_best=['FID', 'IS'], rule=['less', 'greater'])
|
||||||
"""
|
"""
|
||||||
out_dir: str
|
out_dir: str
|
||||||
|
|
||||||
@ -93,8 +112,8 @@ class CheckpointHook(Hook):
|
|||||||
out_dir: Optional[Union[str, Path]] = None,
|
out_dir: Optional[Union[str, Path]] = None,
|
||||||
max_keep_ckpts: int = -1,
|
max_keep_ckpts: int = -1,
|
||||||
save_last: bool = True,
|
save_last: bool = True,
|
||||||
save_best: Optional[str] = None,
|
save_best: Union[str, List[str], None] = None,
|
||||||
rule: Optional[str] = None,
|
rule: Union[str, List[str], None] = None,
|
||||||
greater_keys: Optional[Sequence[str]] = None,
|
greater_keys: Optional[Sequence[str]] = None,
|
||||||
less_keys: Optional[Sequence[str]] = None,
|
less_keys: Optional[Sequence[str]] = None,
|
||||||
file_client_args: Optional[dict] = None,
|
file_client_args: Optional[dict] = None,
|
||||||
@ -110,11 +129,39 @@ class CheckpointHook(Hook):
|
|||||||
self.file_client_args = file_client_args
|
self.file_client_args = file_client_args
|
||||||
|
|
||||||
# save best logic
|
# save best logic
|
||||||
assert isinstance(save_best, str) or save_best is None, \
|
assert (isinstance(save_best, str) or is_list_of(save_best, str)
|
||||||
'"save_best" should be a str or None ' \
|
or (save_best is None)), (
|
||||||
f'rather than {type(save_best)}'
|
'"save_best" should be a str or list of str or None, '
|
||||||
|
f'but got {type(save_best)}')
|
||||||
|
|
||||||
|
if isinstance(save_best, list):
|
||||||
|
if 'auto' in save_best:
|
||||||
|
assert len(save_best) == 1, (
|
||||||
|
'Only support one "auto" in "save_best" list.')
|
||||||
|
assert len(save_best) == len(
|
||||||
|
set(save_best)), ('Find duplicate element in "save_best".')
|
||||||
|
else:
|
||||||
|
# convert str to list[str]
|
||||||
|
if save_best is not None:
|
||||||
|
save_best = [save_best] # type: ignore # noqa: F401
|
||||||
self.save_best = save_best
|
self.save_best = save_best
|
||||||
|
|
||||||
|
# rule logic
|
||||||
|
assert (isinstance(rule, str) or is_list_of(rule, str)
|
||||||
|
or (rule is None)), (
|
||||||
|
'"rule" should be a str or list of str or None, '
|
||||||
|
f'but got {type(rule)}')
|
||||||
|
if isinstance(rule, list):
|
||||||
|
# check the length of rule list
|
||||||
|
assert len(rule) in [
|
||||||
|
1,
|
||||||
|
len(self.save_best) # type: ignore
|
||||||
|
], ('Number of "rule" must be 1 or the same as number of '
|
||||||
|
f'"save_best", but got {len(rule)}.')
|
||||||
|
else:
|
||||||
|
# convert str/None to list
|
||||||
|
rule = [rule] # type: ignore # noqa: F401
|
||||||
|
|
||||||
if greater_keys is None:
|
if greater_keys is None:
|
||||||
self.greater_keys = self._default_greater_keys
|
self.greater_keys = self._default_greater_keys
|
||||||
else:
|
else:
|
||||||
@ -132,8 +179,12 @@ class CheckpointHook(Hook):
|
|||||||
self.less_keys = less_keys # type: ignore
|
self.less_keys = less_keys # type: ignore
|
||||||
|
|
||||||
if self.save_best is not None:
|
if self.save_best is not None:
|
||||||
self.best_ckpt_path = None
|
self.is_better_than: Dict[str, Callable] = dict()
|
||||||
self._init_rule(rule, self.save_best)
|
self._init_rule(rule, self.save_best)
|
||||||
|
if len(self.key_indicators) == 1:
|
||||||
|
self.best_ckpt_path: Optional[str] = None
|
||||||
|
else:
|
||||||
|
self.best_ckpt_path_dict: Dict = dict()
|
||||||
|
|
||||||
def before_train(self, runner) -> None:
|
def before_train(self, runner) -> None:
|
||||||
"""Finish all operations, related to checkpoint.
|
"""Finish all operations, related to checkpoint.
|
||||||
@ -162,10 +213,21 @@ class CheckpointHook(Hook):
|
|||||||
f'{self.file_client.name}.')
|
f'{self.file_client.name}.')
|
||||||
|
|
||||||
if self.save_best is not None:
|
if self.save_best is not None:
|
||||||
if 'best_ckpt' not in runner.message_hub.runtime_info:
|
if len(self.key_indicators) == 1:
|
||||||
self.best_ckpt_path = None
|
if 'best_ckpt' not in runner.message_hub.runtime_info:
|
||||||
|
self.best_ckpt_path = None
|
||||||
|
else:
|
||||||
|
self.best_ckpt_path = runner.message_hub.get_info(
|
||||||
|
'best_ckpt')
|
||||||
else:
|
else:
|
||||||
self.best_ckpt_path = runner.message_hub.get_info('best_ckpt')
|
for key_indicator in self.key_indicators:
|
||||||
|
best_ckpt_name = f'best_ckpt_{key_indicator}'
|
||||||
|
if best_ckpt_name not in runner.message_hub.runtime_info:
|
||||||
|
self.best_ckpt_path_dict[key_indicator] = None
|
||||||
|
else:
|
||||||
|
self.best_ckpt_path_dict[
|
||||||
|
key_indicator] = runner.message_hub.get_info(
|
||||||
|
best_ckpt_name)
|
||||||
|
|
||||||
def after_train_epoch(self, runner) -> None:
|
def after_train_epoch(self, runner) -> None:
|
||||||
"""Save the checkpoint and synchronize buffers after each epoch.
|
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||||
@ -195,7 +257,7 @@ class CheckpointHook(Hook):
|
|||||||
"""
|
"""
|
||||||
self._save_best_checkpoint(runner, metrics)
|
self._save_best_checkpoint(runner, metrics)
|
||||||
|
|
||||||
def _get_metric_score(self, metrics):
|
def _get_metric_score(self, metrics, key_indicator):
|
||||||
eval_res = OrderedDict()
|
eval_res = OrderedDict()
|
||||||
if metrics is not None:
|
if metrics is not None:
|
||||||
eval_res.update(metrics)
|
eval_res.update(metrics)
|
||||||
@ -206,10 +268,7 @@ class CheckpointHook(Hook):
|
|||||||
'the best checkpoint will be skipped in this evaluation.')
|
'the best checkpoint will be skipped in this evaluation.')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.key_indicator == 'auto':
|
return eval_res[key_indicator]
|
||||||
self._init_rule(self.rule, list(eval_res.keys())[0])
|
|
||||||
|
|
||||||
return eval_res[self.key_indicator]
|
|
||||||
|
|
||||||
@master_only
|
@master_only
|
||||||
def _save_checkpoint(self, runner) -> None:
|
def _save_checkpoint(self, runner) -> None:
|
||||||
@ -264,6 +323,7 @@ class CheckpointHook(Hook):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
|
metrics (dict): Evaluation results of all metrics.
|
||||||
"""
|
"""
|
||||||
if not self.save_best:
|
if not self.save_best:
|
||||||
return
|
return
|
||||||
@ -277,91 +337,123 @@ class CheckpointHook(Hook):
|
|||||||
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
||||||
cur_type, cur_time = 'iter', runner.iter + 1
|
cur_type, cur_time = 'iter', runner.iter + 1
|
||||||
|
|
||||||
|
# handle auto in self.key_indicators and self.rules before the loop
|
||||||
|
if 'auto' in self.key_indicators:
|
||||||
|
self._init_rule(self.rules, [list(metrics.keys())[0]])
|
||||||
|
|
||||||
# save best logic
|
# save best logic
|
||||||
# get score from messagehub
|
# get score from messagehub
|
||||||
# notice `_get_metirc_score` helps to infer
|
for key_indicator, rule in zip(self.key_indicators, self.rules):
|
||||||
# self.rule when self.save_best is `auto`
|
key_score = self._get_metric_score(metrics, key_indicator)
|
||||||
key_score = self._get_metric_score(metrics)
|
|
||||||
if 'best_score' not in runner.message_hub.runtime_info:
|
|
||||||
best_score = self.init_value_map[self.rule]
|
|
||||||
else:
|
|
||||||
best_score = runner.message_hub.get_info('best_score')
|
|
||||||
|
|
||||||
if not key_score or not self.is_better_than(key_score, best_score):
|
if len(self.key_indicators) == 1:
|
||||||
return
|
best_score_key = 'best_score'
|
||||||
|
runtime_best_ckpt_key = 'best_ckpt'
|
||||||
|
best_ckpt_path = self.best_ckpt_path
|
||||||
|
else:
|
||||||
|
best_score_key = f'best_score_{key_indicator}'
|
||||||
|
runtime_best_ckpt_key = f'best_ckpt_{key_indicator}'
|
||||||
|
best_ckpt_path = self.best_ckpt_path_dict[key_indicator]
|
||||||
|
|
||||||
best_score = key_score
|
if best_score_key not in runner.message_hub.runtime_info:
|
||||||
runner.message_hub.update_info('best_score', best_score)
|
best_score = self.init_value_map[rule]
|
||||||
|
else:
|
||||||
|
best_score = runner.message_hub.get_info(best_score_key)
|
||||||
|
|
||||||
if self.best_ckpt_path and self.file_client.isfile(
|
if key_score is None or not self.is_better_than[key_indicator](
|
||||||
self.best_ckpt_path):
|
key_score, best_score):
|
||||||
self.file_client.remove(self.best_ckpt_path)
|
continue
|
||||||
|
|
||||||
|
best_score = key_score
|
||||||
|
runner.message_hub.update_info(best_score_key, best_score)
|
||||||
|
|
||||||
|
if best_ckpt_path and self.file_client.isfile(best_ckpt_path):
|
||||||
|
self.file_client.remove(best_ckpt_path)
|
||||||
|
runner.logger.info(
|
||||||
|
f'The previous best checkpoint {best_ckpt_path} '
|
||||||
|
'is removed')
|
||||||
|
|
||||||
|
best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}'
|
||||||
|
if len(self.key_indicators) == 1:
|
||||||
|
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
|
||||||
|
self.out_dir, best_ckpt_name)
|
||||||
|
runner.message_hub.update_info(runtime_best_ckpt_key,
|
||||||
|
self.best_ckpt_path)
|
||||||
|
else:
|
||||||
|
self.best_ckpt_path_dict[
|
||||||
|
key_indicator] = self.file_client.join_path( # type: ignore # noqa: E501
|
||||||
|
self.out_dir, best_ckpt_name)
|
||||||
|
runner.message_hub.update_info(
|
||||||
|
runtime_best_ckpt_key,
|
||||||
|
self.best_ckpt_path_dict[key_indicator])
|
||||||
|
runner.save_checkpoint(
|
||||||
|
self.out_dir,
|
||||||
|
filename=best_ckpt_name,
|
||||||
|
file_client_args=self.file_client_args,
|
||||||
|
save_optimizer=False,
|
||||||
|
save_param_scheduler=False,
|
||||||
|
by_epoch=False)
|
||||||
runner.logger.info(
|
runner.logger.info(
|
||||||
f'The previous best checkpoint {self.best_ckpt_path} '
|
f'The best checkpoint with {best_score:0.4f} {key_indicator} '
|
||||||
'is removed')
|
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
|
||||||
|
|
||||||
best_ckpt_name = f'best_{self.key_indicator}_{ckpt_filename}'
|
def _init_rule(self, rules, key_indicators) -> None:
|
||||||
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
|
"""Initialize rule, key_indicator, comparison_func, and best score. If
|
||||||
self.out_dir, best_ckpt_name)
|
key_indicator is a list of string and rule is a string, all metric in
|
||||||
runner.message_hub.update_info('best_ckpt', self.best_ckpt_path)
|
the key_indicator will share the same rule.
|
||||||
runner.save_checkpoint(
|
|
||||||
self.out_dir,
|
|
||||||
filename=best_ckpt_name,
|
|
||||||
file_client_args=self.file_client_args,
|
|
||||||
save_optimizer=False,
|
|
||||||
save_param_scheduler=False,
|
|
||||||
by_epoch=False)
|
|
||||||
runner.logger.info(
|
|
||||||
f'The best checkpoint with {best_score:0.4f} {self.key_indicator} '
|
|
||||||
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
|
|
||||||
|
|
||||||
def _init_rule(self, rule, key_indicator) -> None:
|
|
||||||
"""Initialize rule, key_indicator, comparison_func, and best score.
|
|
||||||
Here is the rule to determine which rule is used for key indicator when
|
Here is the rule to determine which rule is used for key indicator when
|
||||||
the rule is not specific (note that the key indicator matching is case-
|
the rule is not specific (note that the key indicator matching is case-
|
||||||
insensitive):
|
insensitive):
|
||||||
|
|
||||||
1. If the key indicator is in ``self.greater_keys``, the rule will be
|
1. If the key indicator is in ``self.greater_keys``, the rule
|
||||||
specified as 'greater'.
|
will be specified as 'greater'.
|
||||||
2. Or if the key indicator is in ``self.less_keys``, the rule will be
|
2. Or if the key indicator is in ``self.less_keys``, the rule
|
||||||
specified as 'less'.
|
will be specified as 'less'.
|
||||||
3. Or if any one item in ``self.greater_keys`` is a substring of
|
3. Or if any one item in ``self.greater_keys`` is a substring of
|
||||||
key_indicator , the rule will be specified as 'greater'.
|
key_indicator, the rule will be specified as 'greater'.
|
||||||
4. Or if any one item in ``self.less_keys`` is a substring of
|
4. Or if any one item in ``self.less_keys`` is a substring of
|
||||||
key_indicator , the rule will be specified as 'less'.
|
key_indicator, the rule will be specified as 'less'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rule (str | None): Comparison rule for best score.
|
rule (List[Optional[str]]): Comparison rule for best score.
|
||||||
key_indicator (str | None): Key indicator to determine the
|
key_indicator (List[str]): Key indicator to determine
|
||||||
comparison rule.
|
the comparison rule.
|
||||||
"""
|
"""
|
||||||
|
if len(rules) == 1:
|
||||||
|
rules = rules * len(key_indicators)
|
||||||
|
|
||||||
if rule not in self.rule_map and rule is not None:
|
self.rules = []
|
||||||
raise KeyError('rule must be greater, less or None, '
|
for rule, key_indicator in zip(rules, key_indicators):
|
||||||
f'but got {rule}.')
|
|
||||||
|
|
||||||
if rule is None and key_indicator != 'auto':
|
if rule not in self.rule_map and rule is not None:
|
||||||
# `_lc` here means we use the lower case of keys for
|
raise KeyError('rule must be greater, less or None, '
|
||||||
# case-insensitive matching
|
f'but got {rule}.')
|
||||||
key_indicator_lc = key_indicator.lower()
|
|
||||||
greater_keys = [key.lower() for key in self.greater_keys]
|
|
||||||
less_keys = [key.lower() for key in self.less_keys]
|
|
||||||
|
|
||||||
if key_indicator_lc in greater_keys:
|
if rule is None and key_indicator != 'auto':
|
||||||
rule = 'greater'
|
# `_lc` here means we use the lower case of keys for
|
||||||
elif key_indicator_lc in less_keys:
|
# case-insensitive matching
|
||||||
rule = 'less'
|
key_indicator_lc = key_indicator.lower()
|
||||||
elif any(key in key_indicator_lc for key in greater_keys):
|
greater_keys = {key.lower() for key in self.greater_keys}
|
||||||
rule = 'greater'
|
less_keys = {key.lower() for key in self.less_keys}
|
||||||
elif any(key in key_indicator_lc for key in less_keys):
|
|
||||||
rule = 'less'
|
if key_indicator_lc in greater_keys:
|
||||||
else:
|
rule = 'greater'
|
||||||
raise ValueError('Cannot infer the rule for key '
|
elif key_indicator_lc in less_keys:
|
||||||
f'{key_indicator}, thus a specific rule '
|
rule = 'less'
|
||||||
'must be specified.')
|
elif any(key in key_indicator_lc for key in greater_keys):
|
||||||
self.rule = rule
|
rule = 'greater'
|
||||||
self.key_indicator = key_indicator
|
elif any(key in key_indicator_lc for key in less_keys):
|
||||||
if self.rule is not None:
|
rule = 'less'
|
||||||
self.is_better_than = self.rule_map[self.rule]
|
else:
|
||||||
|
raise ValueError('Cannot infer the rule for key '
|
||||||
|
f'{key_indicator}, thus a specific rule '
|
||||||
|
'must be specified.')
|
||||||
|
if rule is not None:
|
||||||
|
self.is_better_than[key_indicator] = self.rule_map[rule]
|
||||||
|
self.rules.append(rule)
|
||||||
|
|
||||||
|
self.key_indicators = key_indicators
|
||||||
|
|
||||||
def after_train_iter(self,
|
def after_train_iter(self,
|
||||||
runner,
|
runner,
|
||||||
|
@ -49,6 +49,30 @@ class TestCheckpointHook:
|
|||||||
assert checkpoint_hook.out_dir == (
|
assert checkpoint_hook.out_dir == (
|
||||||
f'test_dir/{osp.basename(work_dir)}')
|
f'test_dir/{osp.basename(work_dir)}')
|
||||||
|
|
||||||
|
runner.message_hub = MessageHub.get_instance('test_before_train')
|
||||||
|
# no 'best_ckpt_path' in runtime_info
|
||||||
|
checkpoint_hook = CheckpointHook(interval=1, save_best=['acc', 'mIoU'])
|
||||||
|
checkpoint_hook.before_train(runner)
|
||||||
|
assert checkpoint_hook.best_ckpt_path_dict == dict(acc=None, mIoU=None)
|
||||||
|
assert not hasattr(checkpoint_hook, 'best_ckpt_path')
|
||||||
|
|
||||||
|
# only one 'best_ckpt_path' in runtime_info
|
||||||
|
runner.message_hub.update_info('best_ckpt_acc', 'best_acc')
|
||||||
|
checkpoint_hook.before_train(runner)
|
||||||
|
assert checkpoint_hook.best_ckpt_path_dict == dict(
|
||||||
|
acc='best_acc', mIoU=None)
|
||||||
|
|
||||||
|
# no 'best_ckpt_path' in runtime_info
|
||||||
|
checkpoint_hook = CheckpointHook(interval=1, save_best='acc')
|
||||||
|
checkpoint_hook.before_train(runner)
|
||||||
|
assert checkpoint_hook.best_ckpt_path is None
|
||||||
|
assert not hasattr(checkpoint_hook, 'best_ckpt_path_dict')
|
||||||
|
|
||||||
|
# 'best_ckpt_path' in runtime_info
|
||||||
|
runner.message_hub.update_info('best_ckpt', 'best_ckpt')
|
||||||
|
checkpoint_hook.before_train(runner)
|
||||||
|
assert checkpoint_hook.best_ckpt_path == 'best_ckpt'
|
||||||
|
|
||||||
def test_after_val_epoch(self, tmp_path):
|
def test_after_val_epoch(self, tmp_path):
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.work_dir = tmp_path
|
runner.work_dir = tmp_path
|
||||||
@ -69,7 +93,7 @@ class TestCheckpointHook:
|
|||||||
with pytest.warns(UserWarning) as record_warnings:
|
with pytest.warns(UserWarning) as record_warnings:
|
||||||
eval_hook = CheckpointHook(
|
eval_hook = CheckpointHook(
|
||||||
interval=2, by_epoch=True, save_best='auto')
|
interval=2, by_epoch=True, save_best='auto')
|
||||||
eval_hook._get_metric_score(None)
|
eval_hook._get_metric_score(None, None)
|
||||||
# Since there will be many warnings thrown, we just need to check
|
# Since there will be many warnings thrown, we just need to check
|
||||||
# if the expected exceptions are thrown
|
# if the expected exceptions are thrown
|
||||||
expected_message = (
|
expected_message = (
|
||||||
@ -82,6 +106,17 @@ class TestCheckpointHook:
|
|||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
# test error when number of rules and metrics are not same
|
||||||
|
with pytest.raises(AssertionError) as assert_error:
|
||||||
|
CheckpointHook(
|
||||||
|
interval=1,
|
||||||
|
save_best=['mIoU', 'acc'],
|
||||||
|
rule=['greater', 'greater', 'less'],
|
||||||
|
by_epoch=True)
|
||||||
|
error_message = ('Number of "rule" must be 1 or the same as number of '
|
||||||
|
'"save_best", but got 3.')
|
||||||
|
assert error_message in str(assert_error.value)
|
||||||
|
|
||||||
# if save_best is None,no best_ckpt meta should be stored
|
# if save_best is None,no best_ckpt meta should be stored
|
||||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
|
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
|
||||||
eval_hook.before_train(runner)
|
eval_hook.before_train(runner)
|
||||||
@ -97,8 +132,8 @@ class TestCheckpointHook:
|
|||||||
best_ckpt_name = 'best_acc_epoch_10.pth'
|
best_ckpt_name = 'best_acc_epoch_10.pth'
|
||||||
best_ckpt_path = eval_hook.file_client.join_path(
|
best_ckpt_path = eval_hook.file_client.join_path(
|
||||||
eval_hook.out_dir, best_ckpt_name)
|
eval_hook.out_dir, best_ckpt_name)
|
||||||
assert eval_hook.key_indicator == 'acc'
|
assert eval_hook.key_indicators == ['acc']
|
||||||
assert eval_hook.rule == 'greater'
|
assert eval_hook.rules == ['greater']
|
||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score') == 0.5
|
runner.message_hub.get_info('best_score') == 0.5
|
||||||
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
||||||
@ -142,6 +177,42 @@ class TestCheckpointHook:
|
|||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score') == 1.0
|
runner.message_hub.get_info('best_score') == 1.0
|
||||||
|
|
||||||
|
# test multi `save_best` with one rule
|
||||||
|
eval_hook = CheckpointHook(
|
||||||
|
interval=2, save_best=['acc', 'mIoU'], rule='greater')
|
||||||
|
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
||||||
|
assert eval_hook.rules == ['greater', 'greater']
|
||||||
|
|
||||||
|
# test multi `save_best` with multi rules
|
||||||
|
eval_hook = CheckpointHook(
|
||||||
|
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
|
||||||
|
assert eval_hook.key_indicators == ['FID', 'IS']
|
||||||
|
assert eval_hook.rules == ['less', 'greater']
|
||||||
|
|
||||||
|
# test multi `save_best` with default rule
|
||||||
|
eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
|
||||||
|
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
||||||
|
assert eval_hook.rules == ['greater', 'greater']
|
||||||
|
runner.message_hub = MessageHub.get_instance(
|
||||||
|
'test_after_val_epoch_save_multi_best')
|
||||||
|
eval_hook.before_train(runner)
|
||||||
|
metrics = dict(acc=0.5, mIoU=0.6)
|
||||||
|
eval_hook.after_val_epoch(runner, metrics)
|
||||||
|
best_acc_name = 'best_acc_epoch_10.pth'
|
||||||
|
best_acc_path = eval_hook.file_client.join_path(
|
||||||
|
eval_hook.out_dir, best_acc_name)
|
||||||
|
best_mIoU_name = 'best_mIoU_epoch_10.pth'
|
||||||
|
best_mIoU_path = eval_hook.file_client.join_path(
|
||||||
|
eval_hook.out_dir, best_mIoU_name)
|
||||||
|
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
||||||
|
runner.message_hub.get_info('best_score_acc') == 0.5
|
||||||
|
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
||||||
|
runner.message_hub.get_info('best_score_mIoU') == 0.6
|
||||||
|
assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \
|
||||||
|
runner.message_hub.get_info('best_ckpt_acc') == best_acc_path
|
||||||
|
assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \
|
||||||
|
runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path
|
||||||
|
|
||||||
# test behavior when by_epoch is False
|
# test behavior when by_epoch is False
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.work_dir = tmp_path
|
runner.work_dir = tmp_path
|
||||||
@ -156,8 +227,8 @@ class TestCheckpointHook:
|
|||||||
interval=2, by_epoch=False, save_best='acc', rule='greater')
|
interval=2, by_epoch=False, save_best='acc', rule='greater')
|
||||||
eval_hook.before_train(runner)
|
eval_hook.before_train(runner)
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
eval_hook.after_val_epoch(runner, metrics)
|
||||||
assert eval_hook.key_indicator == 'acc'
|
assert eval_hook.key_indicators == ['acc']
|
||||||
assert eval_hook.rule == 'greater'
|
assert eval_hook.rules == ['greater']
|
||||||
best_ckpt_name = 'best_acc_iter_10.pth'
|
best_ckpt_name = 'best_acc_iter_10.pth'
|
||||||
best_ckpt_path = eval_hook.file_client.join_path(
|
best_ckpt_path = eval_hook.file_client.join_path(
|
||||||
eval_hook.out_dir, best_ckpt_name)
|
eval_hook.out_dir, best_ckpt_name)
|
||||||
@ -176,6 +247,38 @@ class TestCheckpointHook:
|
|||||||
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score') == 0.666
|
runner.message_hub.get_info('best_score') == 0.666
|
||||||
|
# error when 'auto' in `save_best` list
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
CheckpointHook(interval=2, save_best=['auto', 'acc'])
|
||||||
|
# error when one `save_best` with multi `rule`
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
CheckpointHook(
|
||||||
|
interval=2, save_best='acc', rule=['greater', 'less'])
|
||||||
|
|
||||||
|
# check best checkpoint name with `by_epoch` is False
|
||||||
|
eval_hook = CheckpointHook(
|
||||||
|
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
|
||||||
|
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
||||||
|
assert eval_hook.rules == ['greater', 'greater']
|
||||||
|
runner.message_hub = MessageHub.get_instance(
|
||||||
|
'test_after_val_epoch_save_multi_best_by_epoch_is_false')
|
||||||
|
eval_hook.before_train(runner)
|
||||||
|
metrics = dict(acc=0.5, mIoU=0.6)
|
||||||
|
eval_hook.after_val_epoch(runner, metrics)
|
||||||
|
best_acc_name = 'best_acc_iter_10.pth'
|
||||||
|
best_acc_path = eval_hook.file_client.join_path(
|
||||||
|
eval_hook.out_dir, best_acc_name)
|
||||||
|
best_mIoU_name = 'best_mIoU_iter_10.pth'
|
||||||
|
best_mIoU_path = eval_hook.file_client.join_path(
|
||||||
|
eval_hook.out_dir, best_mIoU_name)
|
||||||
|
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
||||||
|
runner.message_hub.get_info('best_score_acc') == 0.5
|
||||||
|
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
||||||
|
runner.message_hub.get_info('best_score_mIoU') == 0.6
|
||||||
|
assert 'best_ckpt_acc' in runner.message_hub.runtime_info and \
|
||||||
|
runner.message_hub.get_info('best_ckpt_acc') == best_acc_path
|
||||||
|
assert 'best_ckpt_mIoU' in runner.message_hub.runtime_info and \
|
||||||
|
runner.message_hub.get_info('best_ckpt_mIoU') == best_mIoU_path
|
||||||
|
|
||||||
def test_after_train_epoch(self, tmp_path):
|
def test_after_train_epoch(self, tmp_path):
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user