[Enhance] Ensure metrics is not empty when saving best ckpts (#849)

* [Enhance] Ensure metrics is not empty when saving best ckpts

* fix warn to warning

* delete a unnecessary method
This commit is contained in:
Zaida Zhou 2022-12-28 11:34:08 +08:00 committed by GitHub
parent a9b6753fbe
commit 646927f62f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 88 deletions

View File

@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
import warnings import warnings
from collections import OrderedDict
from math import inf from math import inf
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union from typing import Callable, Dict, List, Optional, Sequence, Union
@ -294,20 +293,13 @@ class CheckpointHook(Hook):
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics metrics (dict): Evaluation results of all metrics
""" """
self._save_best_checkpoint(runner, metrics) if len(metrics) == 0:
runner.logger.warning(
def _get_metric_score(self, metrics, key_indicator): 'Since `metrics` is an empty dict, the behavior to save '
eval_res = OrderedDict()
if metrics is not None:
eval_res.update(metrics)
if len(eval_res) == 0:
warnings.warn(
'Since `eval_res` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.') 'the best checkpoint will be skipped in this evaluation.')
return None return
return eval_res[key_indicator] self._save_best_checkpoint(runner, metrics)
def _save_checkpoint(self, runner) -> None: def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint. """Save the current checkpoint and delete outdated checkpoint.
@ -385,7 +377,7 @@ class CheckpointHook(Hook):
# save best logic # save best logic
# get score from messagehub # get score from messagehub
for key_indicator, rule in zip(self.key_indicators, self.rules): for key_indicator, rule in zip(self.key_indicators, self.rules):
key_score = self._get_metric_score(metrics, key_indicator) key_score = metrics[key_indicator]
if len(self.key_indicators) == 1: if len(self.key_indicators) == 1:
best_score_key = 'best_score' best_score_key = 'best_score'

View File

@ -148,6 +148,7 @@ class TestCheckpointHook:
runner.work_dir = tmp_path runner.work_dir = tmp_path
runner.epoch = 9 runner.epoch = 9
runner.model = Mock() runner.model = Mock()
runner.logger.warning = Mock()
runner.message_hub = MessageHub.get_instance('test_after_val_epoch') runner.message_hub = MessageHub.get_instance('test_after_val_epoch')
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -159,22 +160,11 @@ class TestCheckpointHook:
CheckpointHook( CheckpointHook(
interval=2, by_epoch=True, save_best='auto', rule='unsupport') interval=2, by_epoch=True, save_best='auto', rule='unsupport')
# if eval_res is an empty dict, print a warning information # if metrics is an empty dict, print a warning information
with pytest.warns(UserWarning) as record_warnings: checkpoint_hook = CheckpointHook(
eval_hook = CheckpointHook( interval=2, by_epoch=True, save_best='auto')
interval=2, by_epoch=True, save_best='auto') checkpoint_hook.after_val_epoch(runner, {})
eval_hook._get_metric_score(None, None) runner.logger.warning.assert_called_once()
# Since there will be many warnings thrown, we just need to check
# if the expected exceptions are thrown
expected_message = (
'Since `eval_res` is an empty dict, the behavior to '
'save the best checkpoint will be skipped in this '
'evaluation.')
for warning in record_warnings:
if str(warning.message) == expected_message:
break
else:
assert False
# test error when number of rules and metrics are not same # test error when number of rules and metrics are not same
with pytest.raises(AssertionError) as assert_error: with pytest.raises(AssertionError) as assert_error:
@ -187,93 +177,97 @@ class TestCheckpointHook:
'"save_best", but got 3.') '"save_best", but got 3.')
assert error_message in str(assert_error.value) assert error_message in str(assert_error.value)
# if save_best is None,no best_ckpt meta should be stored # if save_best is None, no best_ckpt meta should be stored
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None) checkpoint_hook = CheckpointHook(
eval_hook.before_train(runner) interval=2, by_epoch=True, save_best=None)
eval_hook.after_val_epoch(runner, None) checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, {})
assert 'best_score' not in runner.message_hub.runtime_info assert 'best_score' not in runner.message_hub.runtime_info
assert 'best_ckpt' not in runner.message_hub.runtime_info assert 'best_ckpt' not in runner.message_hub.runtime_info
# when `save_best` is set to `auto`, first metric will be used. # when `save_best` is set to `auto`, first metric will be used.
metrics = {'acc': 0.5, 'map': 0.3} metrics = {'acc': 0.5, 'map': 0.3}
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto') checkpoint_hook = CheckpointHook(
eval_hook.before_train(runner) interval=2, by_epoch=True, save_best='auto')
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
best_ckpt_name = 'best_acc_epoch_9.pth' best_ckpt_name = 'best_acc_epoch_9.pth'
best_ckpt_path = eval_hook.file_client.join_path( best_ckpt_path = checkpoint_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name) checkpoint_hook.out_dir, best_ckpt_name)
assert eval_hook.key_indicators == ['acc'] assert checkpoint_hook.key_indicators == ['acc']
assert eval_hook.rules == ['greater'] assert checkpoint_hook.rules == ['greater']
assert 'best_score' in runner.message_hub.runtime_info and \ assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5 runner.message_hub.get_info('best_score') == 0.5
assert 'best_ckpt' in runner.message_hub.runtime_info and \ assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path runner.message_hub.get_info('best_ckpt') == best_ckpt_path
# # when `save_best` is set to `acc`, it should update greater value # # when `save_best` is set to `acc`, it should update greater value
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc') checkpoint_hook = CheckpointHook(
eval_hook.before_train(runner) interval=2, by_epoch=True, save_best='acc')
checkpoint_hook.before_train(runner)
metrics['acc'] = 0.8 metrics['acc'] = 0.8
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \ assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.8 runner.message_hub.get_info('best_score') == 0.8
# # when `save_best` is set to `loss`, it should update less value # # when `save_best` is set to `loss`, it should update less value
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss') checkpoint_hook = CheckpointHook(
eval_hook.before_train(runner) interval=2, by_epoch=True, save_best='loss')
checkpoint_hook.before_train(runner)
metrics['loss'] = 0.8 metrics['loss'] = 0.8
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
metrics['loss'] = 0.5 metrics['loss'] = 0.5
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \ assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5 runner.message_hub.get_info('best_score') == 0.5
# when `rule` is set to `less`,then it should update less value # when `rule` is set to `less`,then it should update less value
# no matter what `save_best` is # no matter what `save_best` is
eval_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='acc', rule='less') interval=2, by_epoch=True, save_best='acc', rule='less')
eval_hook.before_train(runner) checkpoint_hook.before_train(runner)
metrics['acc'] = 0.3 metrics['acc'] = 0.3
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \ assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.3 runner.message_hub.get_info('best_score') == 0.3
# # when `rule` is set to `greater`,then it should update greater value # # when `rule` is set to `greater`,then it should update greater value
# # no matter what `save_best` is # # no matter what `save_best` is
eval_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='loss', rule='greater') interval=2, by_epoch=True, save_best='loss', rule='greater')
eval_hook.before_train(runner) checkpoint_hook.before_train(runner)
metrics['loss'] = 1.0 metrics['loss'] = 1.0
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \ assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 1.0 runner.message_hub.get_info('best_score') == 1.0
# test multi `save_best` with one rule # test multi `save_best` with one rule
eval_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, save_best=['acc', 'mIoU'], rule='greater') interval=2, save_best=['acc', 'mIoU'], rule='greater')
assert eval_hook.key_indicators == ['acc', 'mIoU'] assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater'] assert checkpoint_hook.rules == ['greater', 'greater']
# test multi `save_best` with multi rules # test multi `save_best` with multi rules
eval_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater']) interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
assert eval_hook.key_indicators == ['FID', 'IS'] assert checkpoint_hook.key_indicators == ['FID', 'IS']
assert eval_hook.rules == ['less', 'greater'] assert checkpoint_hook.rules == ['less', 'greater']
# test multi `save_best` with default rule # test multi `save_best` with default rule
eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU']) checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
assert eval_hook.key_indicators == ['acc', 'mIoU'] assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater'] assert checkpoint_hook.rules == ['greater', 'greater']
runner.message_hub = MessageHub.get_instance( runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best') 'test_after_val_epoch_save_multi_best')
eval_hook.before_train(runner) checkpoint_hook.before_train(runner)
metrics = dict(acc=0.5, mIoU=0.6) metrics = dict(acc=0.5, mIoU=0.6)
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
best_acc_name = 'best_acc_epoch_9.pth' best_acc_name = 'best_acc_epoch_9.pth'
best_acc_path = eval_hook.file_client.join_path( best_acc_path = checkpoint_hook.file_client.join_path(
eval_hook.out_dir, best_acc_name) checkpoint_hook.out_dir, best_acc_name)
best_mIoU_name = 'best_mIoU_epoch_9.pth' best_mIoU_name = 'best_mIoU_epoch_9.pth'
best_mIoU_path = eval_hook.file_client.join_path( best_mIoU_path = checkpoint_hook.file_client.join_path(
eval_hook.out_dir, best_mIoU_name) checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \ assert 'best_score_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_acc') == 0.5 runner.message_hub.get_info('best_score_acc') == 0.5
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \ assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
@ -293,15 +287,15 @@ class TestCheckpointHook:
# check best ckpt name and best score # check best ckpt name and best score
metrics = {'acc': 0.5, 'map': 0.3} metrics = {'acc': 0.5, 'map': 0.3}
eval_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, save_best='acc', rule='greater') interval=2, by_epoch=False, save_best='acc', rule='greater')
eval_hook.before_train(runner) checkpoint_hook.before_train(runner)
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
assert eval_hook.key_indicators == ['acc'] assert checkpoint_hook.key_indicators == ['acc']
assert eval_hook.rules == ['greater'] assert checkpoint_hook.rules == ['greater']
best_ckpt_name = 'best_acc_iter_9.pth' best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = eval_hook.file_client.join_path( best_ckpt_path = checkpoint_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name) checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \ assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \ assert 'best_score' in runner.message_hub.runtime_info and \
@ -309,10 +303,10 @@ class TestCheckpointHook:
# check best score updating # check best score updating
metrics['acc'] = 0.666 metrics['acc'] = 0.666
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
best_ckpt_name = 'best_acc_iter_9.pth' best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = eval_hook.file_client.join_path( best_ckpt_path = checkpoint_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name) checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \ assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \ assert 'best_score' in runner.message_hub.runtime_info and \
@ -326,21 +320,21 @@ class TestCheckpointHook:
interval=2, save_best='acc', rule=['greater', 'less']) interval=2, save_best='acc', rule=['greater', 'less'])
# check best checkpoint name with `by_epoch` is False # check best checkpoint name with `by_epoch` is False
eval_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, save_best=['acc', 'mIoU']) interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
assert eval_hook.key_indicators == ['acc', 'mIoU'] assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater'] assert checkpoint_hook.rules == ['greater', 'greater']
runner.message_hub = MessageHub.get_instance( runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best_by_epoch_is_false') 'test_after_val_epoch_save_multi_best_by_epoch_is_false')
eval_hook.before_train(runner) checkpoint_hook.before_train(runner)
metrics = dict(acc=0.5, mIoU=0.6) metrics = dict(acc=0.5, mIoU=0.6)
eval_hook.after_val_epoch(runner, metrics) checkpoint_hook.after_val_epoch(runner, metrics)
best_acc_name = 'best_acc_iter_9.pth' best_acc_name = 'best_acc_iter_9.pth'
best_acc_path = eval_hook.file_client.join_path( best_acc_path = checkpoint_hook.file_client.join_path(
eval_hook.out_dir, best_acc_name) checkpoint_hook.out_dir, best_acc_name)
best_mIoU_name = 'best_mIoU_iter_9.pth' best_mIoU_name = 'best_mIoU_iter_9.pth'
best_mIoU_path = eval_hook.file_client.join_path( best_mIoU_path = checkpoint_hook.file_client.join_path(
eval_hook.out_dir, best_mIoU_name) checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \ assert 'best_score_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_acc') == 0.5 runner.message_hub.get_info('best_score_acc') == 0.5
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \ assert 'best_score_mIoU' in runner.message_hub.runtime_info and \