mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Ensure metrics is not empty when saving best ckpts (#849)
* [Enhance] Ensure metrics is not empty when saving best ckpts * fix warn to warning * delete a unnecessary method
This commit is contained in:
parent
a9b6753fbe
commit
646927f62f
@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
|
||||||
from math import inf
|
from math import inf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||||
@ -294,20 +293,13 @@ class CheckpointHook(Hook):
|
|||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
metrics (dict): Evaluation results of all metrics
|
metrics (dict): Evaluation results of all metrics
|
||||||
"""
|
"""
|
||||||
self._save_best_checkpoint(runner, metrics)
|
if len(metrics) == 0:
|
||||||
|
runner.logger.warning(
|
||||||
def _get_metric_score(self, metrics, key_indicator):
|
'Since `metrics` is an empty dict, the behavior to save '
|
||||||
eval_res = OrderedDict()
|
|
||||||
if metrics is not None:
|
|
||||||
eval_res.update(metrics)
|
|
||||||
|
|
||||||
if len(eval_res) == 0:
|
|
||||||
warnings.warn(
|
|
||||||
'Since `eval_res` is an empty dict, the behavior to save '
|
|
||||||
'the best checkpoint will be skipped in this evaluation.')
|
'the best checkpoint will be skipped in this evaluation.')
|
||||||
return None
|
return
|
||||||
|
|
||||||
return eval_res[key_indicator]
|
self._save_best_checkpoint(runner, metrics)
|
||||||
|
|
||||||
def _save_checkpoint(self, runner) -> None:
|
def _save_checkpoint(self, runner) -> None:
|
||||||
"""Save the current checkpoint and delete outdated checkpoint.
|
"""Save the current checkpoint and delete outdated checkpoint.
|
||||||
@ -385,7 +377,7 @@ class CheckpointHook(Hook):
|
|||||||
# save best logic
|
# save best logic
|
||||||
# get score from messagehub
|
# get score from messagehub
|
||||||
for key_indicator, rule in zip(self.key_indicators, self.rules):
|
for key_indicator, rule in zip(self.key_indicators, self.rules):
|
||||||
key_score = self._get_metric_score(metrics, key_indicator)
|
key_score = metrics[key_indicator]
|
||||||
|
|
||||||
if len(self.key_indicators) == 1:
|
if len(self.key_indicators) == 1:
|
||||||
best_score_key = 'best_score'
|
best_score_key = 'best_score'
|
||||||
|
@ -148,6 +148,7 @@ class TestCheckpointHook:
|
|||||||
runner.work_dir = tmp_path
|
runner.work_dir = tmp_path
|
||||||
runner.epoch = 9
|
runner.epoch = 9
|
||||||
runner.model = Mock()
|
runner.model = Mock()
|
||||||
|
runner.logger.warning = Mock()
|
||||||
runner.message_hub = MessageHub.get_instance('test_after_val_epoch')
|
runner.message_hub = MessageHub.get_instance('test_after_val_epoch')
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -159,22 +160,11 @@ class TestCheckpointHook:
|
|||||||
CheckpointHook(
|
CheckpointHook(
|
||||||
interval=2, by_epoch=True, save_best='auto', rule='unsupport')
|
interval=2, by_epoch=True, save_best='auto', rule='unsupport')
|
||||||
|
|
||||||
# if eval_res is an empty dict, print a warning information
|
# if metrics is an empty dict, print a warning information
|
||||||
with pytest.warns(UserWarning) as record_warnings:
|
checkpoint_hook = CheckpointHook(
|
||||||
eval_hook = CheckpointHook(
|
interval=2, by_epoch=True, save_best='auto')
|
||||||
interval=2, by_epoch=True, save_best='auto')
|
checkpoint_hook.after_val_epoch(runner, {})
|
||||||
eval_hook._get_metric_score(None, None)
|
runner.logger.warning.assert_called_once()
|
||||||
# Since there will be many warnings thrown, we just need to check
|
|
||||||
# if the expected exceptions are thrown
|
|
||||||
expected_message = (
|
|
||||||
'Since `eval_res` is an empty dict, the behavior to '
|
|
||||||
'save the best checkpoint will be skipped in this '
|
|
||||||
'evaluation.')
|
|
||||||
for warning in record_warnings:
|
|
||||||
if str(warning.message) == expected_message:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
# test error when number of rules and metrics are not same
|
# test error when number of rules and metrics are not same
|
||||||
with pytest.raises(AssertionError) as assert_error:
|
with pytest.raises(AssertionError) as assert_error:
|
||||||
@ -187,93 +177,97 @@ class TestCheckpointHook:
|
|||||||
'"save_best", but got 3.')
|
'"save_best", but got 3.')
|
||||||
assert error_message in str(assert_error.value)
|
assert error_message in str(assert_error.value)
|
||||||
|
|
||||||
# if save_best is None,no best_ckpt meta should be stored
|
# if save_best is None, no best_ckpt meta should be stored
|
||||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
|
checkpoint_hook = CheckpointHook(
|
||||||
eval_hook.before_train(runner)
|
interval=2, by_epoch=True, save_best=None)
|
||||||
eval_hook.after_val_epoch(runner, None)
|
checkpoint_hook.before_train(runner)
|
||||||
|
checkpoint_hook.after_val_epoch(runner, {})
|
||||||
assert 'best_score' not in runner.message_hub.runtime_info
|
assert 'best_score' not in runner.message_hub.runtime_info
|
||||||
assert 'best_ckpt' not in runner.message_hub.runtime_info
|
assert 'best_ckpt' not in runner.message_hub.runtime_info
|
||||||
|
|
||||||
# when `save_best` is set to `auto`, first metric will be used.
|
# when `save_best` is set to `auto`, first metric will be used.
|
||||||
metrics = {'acc': 0.5, 'map': 0.3}
|
metrics = {'acc': 0.5, 'map': 0.3}
|
||||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto')
|
checkpoint_hook = CheckpointHook(
|
||||||
eval_hook.before_train(runner)
|
interval=2, by_epoch=True, save_best='auto')
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.before_train(runner)
|
||||||
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
best_ckpt_name = 'best_acc_epoch_9.pth'
|
best_ckpt_name = 'best_acc_epoch_9.pth'
|
||||||
best_ckpt_path = eval_hook.file_client.join_path(
|
best_ckpt_path = checkpoint_hook.file_client.join_path(
|
||||||
eval_hook.out_dir, best_ckpt_name)
|
checkpoint_hook.out_dir, best_ckpt_name)
|
||||||
assert eval_hook.key_indicators == ['acc']
|
assert checkpoint_hook.key_indicators == ['acc']
|
||||||
assert eval_hook.rules == ['greater']
|
assert checkpoint_hook.rules == ['greater']
|
||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score') == 0.5
|
runner.message_hub.get_info('best_score') == 0.5
|
||||||
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
||||||
|
|
||||||
# # when `save_best` is set to `acc`, it should update greater value
|
# # when `save_best` is set to `acc`, it should update greater value
|
||||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc')
|
checkpoint_hook = CheckpointHook(
|
||||||
eval_hook.before_train(runner)
|
interval=2, by_epoch=True, save_best='acc')
|
||||||
|
checkpoint_hook.before_train(runner)
|
||||||
metrics['acc'] = 0.8
|
metrics['acc'] = 0.8
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score') == 0.8
|
runner.message_hub.get_info('best_score') == 0.8
|
||||||
|
|
||||||
# # when `save_best` is set to `loss`, it should update less value
|
# # when `save_best` is set to `loss`, it should update less value
|
||||||
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss')
|
checkpoint_hook = CheckpointHook(
|
||||||
eval_hook.before_train(runner)
|
interval=2, by_epoch=True, save_best='loss')
|
||||||
|
checkpoint_hook.before_train(runner)
|
||||||
metrics['loss'] = 0.8
|
metrics['loss'] = 0.8
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
metrics['loss'] = 0.5
|
metrics['loss'] = 0.5
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score') == 0.5
|
runner.message_hub.get_info('best_score') == 0.5
|
||||||
|
|
||||||
# when `rule` is set to `less`,then it should update less value
|
# when `rule` is set to `less`,then it should update less value
|
||||||
# no matter what `save_best` is
|
# no matter what `save_best` is
|
||||||
eval_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=2, by_epoch=True, save_best='acc', rule='less')
|
interval=2, by_epoch=True, save_best='acc', rule='less')
|
||||||
eval_hook.before_train(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
metrics['acc'] = 0.3
|
metrics['acc'] = 0.3
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score') == 0.3
|
runner.message_hub.get_info('best_score') == 0.3
|
||||||
|
|
||||||
# # when `rule` is set to `greater`,then it should update greater value
|
# # when `rule` is set to `greater`,then it should update greater value
|
||||||
# # no matter what `save_best` is
|
# # no matter what `save_best` is
|
||||||
eval_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=2, by_epoch=True, save_best='loss', rule='greater')
|
interval=2, by_epoch=True, save_best='loss', rule='greater')
|
||||||
eval_hook.before_train(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
metrics['loss'] = 1.0
|
metrics['loss'] = 1.0
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score') == 1.0
|
runner.message_hub.get_info('best_score') == 1.0
|
||||||
|
|
||||||
# test multi `save_best` with one rule
|
# test multi `save_best` with one rule
|
||||||
eval_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=2, save_best=['acc', 'mIoU'], rule='greater')
|
interval=2, save_best=['acc', 'mIoU'], rule='greater')
|
||||||
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
|
||||||
assert eval_hook.rules == ['greater', 'greater']
|
assert checkpoint_hook.rules == ['greater', 'greater']
|
||||||
|
|
||||||
# test multi `save_best` with multi rules
|
# test multi `save_best` with multi rules
|
||||||
eval_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
|
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
|
||||||
assert eval_hook.key_indicators == ['FID', 'IS']
|
assert checkpoint_hook.key_indicators == ['FID', 'IS']
|
||||||
assert eval_hook.rules == ['less', 'greater']
|
assert checkpoint_hook.rules == ['less', 'greater']
|
||||||
|
|
||||||
# test multi `save_best` with default rule
|
# test multi `save_best` with default rule
|
||||||
eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
|
checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
|
||||||
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
|
||||||
assert eval_hook.rules == ['greater', 'greater']
|
assert checkpoint_hook.rules == ['greater', 'greater']
|
||||||
runner.message_hub = MessageHub.get_instance(
|
runner.message_hub = MessageHub.get_instance(
|
||||||
'test_after_val_epoch_save_multi_best')
|
'test_after_val_epoch_save_multi_best')
|
||||||
eval_hook.before_train(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
metrics = dict(acc=0.5, mIoU=0.6)
|
metrics = dict(acc=0.5, mIoU=0.6)
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
best_acc_name = 'best_acc_epoch_9.pth'
|
best_acc_name = 'best_acc_epoch_9.pth'
|
||||||
best_acc_path = eval_hook.file_client.join_path(
|
best_acc_path = checkpoint_hook.file_client.join_path(
|
||||||
eval_hook.out_dir, best_acc_name)
|
checkpoint_hook.out_dir, best_acc_name)
|
||||||
best_mIoU_name = 'best_mIoU_epoch_9.pth'
|
best_mIoU_name = 'best_mIoU_epoch_9.pth'
|
||||||
best_mIoU_path = eval_hook.file_client.join_path(
|
best_mIoU_path = checkpoint_hook.file_client.join_path(
|
||||||
eval_hook.out_dir, best_mIoU_name)
|
checkpoint_hook.out_dir, best_mIoU_name)
|
||||||
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score_acc') == 0.5
|
runner.message_hub.get_info('best_score_acc') == 0.5
|
||||||
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
||||||
@ -293,15 +287,15 @@ class TestCheckpointHook:
|
|||||||
|
|
||||||
# check best ckpt name and best score
|
# check best ckpt name and best score
|
||||||
metrics = {'acc': 0.5, 'map': 0.3}
|
metrics = {'acc': 0.5, 'map': 0.3}
|
||||||
eval_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=2, by_epoch=False, save_best='acc', rule='greater')
|
interval=2, by_epoch=False, save_best='acc', rule='greater')
|
||||||
eval_hook.before_train(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
assert eval_hook.key_indicators == ['acc']
|
assert checkpoint_hook.key_indicators == ['acc']
|
||||||
assert eval_hook.rules == ['greater']
|
assert checkpoint_hook.rules == ['greater']
|
||||||
best_ckpt_name = 'best_acc_iter_9.pth'
|
best_ckpt_name = 'best_acc_iter_9.pth'
|
||||||
best_ckpt_path = eval_hook.file_client.join_path(
|
best_ckpt_path = checkpoint_hook.file_client.join_path(
|
||||||
eval_hook.out_dir, best_ckpt_name)
|
checkpoint_hook.out_dir, best_ckpt_name)
|
||||||
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
@ -309,10 +303,10 @@ class TestCheckpointHook:
|
|||||||
|
|
||||||
# check best score updating
|
# check best score updating
|
||||||
metrics['acc'] = 0.666
|
metrics['acc'] = 0.666
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
best_ckpt_name = 'best_acc_iter_9.pth'
|
best_ckpt_name = 'best_acc_iter_9.pth'
|
||||||
best_ckpt_path = eval_hook.file_client.join_path(
|
best_ckpt_path = checkpoint_hook.file_client.join_path(
|
||||||
eval_hook.out_dir, best_ckpt_name)
|
checkpoint_hook.out_dir, best_ckpt_name)
|
||||||
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
assert 'best_ckpt' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
|
||||||
assert 'best_score' in runner.message_hub.runtime_info and \
|
assert 'best_score' in runner.message_hub.runtime_info and \
|
||||||
@ -326,21 +320,21 @@ class TestCheckpointHook:
|
|||||||
interval=2, save_best='acc', rule=['greater', 'less'])
|
interval=2, save_best='acc', rule=['greater', 'less'])
|
||||||
|
|
||||||
# check best checkpoint name with `by_epoch` is False
|
# check best checkpoint name with `by_epoch` is False
|
||||||
eval_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
|
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
|
||||||
assert eval_hook.key_indicators == ['acc', 'mIoU']
|
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
|
||||||
assert eval_hook.rules == ['greater', 'greater']
|
assert checkpoint_hook.rules == ['greater', 'greater']
|
||||||
runner.message_hub = MessageHub.get_instance(
|
runner.message_hub = MessageHub.get_instance(
|
||||||
'test_after_val_epoch_save_multi_best_by_epoch_is_false')
|
'test_after_val_epoch_save_multi_best_by_epoch_is_false')
|
||||||
eval_hook.before_train(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
metrics = dict(acc=0.5, mIoU=0.6)
|
metrics = dict(acc=0.5, mIoU=0.6)
|
||||||
eval_hook.after_val_epoch(runner, metrics)
|
checkpoint_hook.after_val_epoch(runner, metrics)
|
||||||
best_acc_name = 'best_acc_iter_9.pth'
|
best_acc_name = 'best_acc_iter_9.pth'
|
||||||
best_acc_path = eval_hook.file_client.join_path(
|
best_acc_path = checkpoint_hook.file_client.join_path(
|
||||||
eval_hook.out_dir, best_acc_name)
|
checkpoint_hook.out_dir, best_acc_name)
|
||||||
best_mIoU_name = 'best_mIoU_iter_9.pth'
|
best_mIoU_name = 'best_mIoU_iter_9.pth'
|
||||||
best_mIoU_path = eval_hook.file_client.join_path(
|
best_mIoU_path = checkpoint_hook.file_client.join_path(
|
||||||
eval_hook.out_dir, best_mIoU_name)
|
checkpoint_hook.out_dir, best_mIoU_name)
|
||||||
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
assert 'best_score_acc' in runner.message_hub.runtime_info and \
|
||||||
runner.message_hub.get_info('best_score_acc') == 0.5
|
runner.message_hub.get_info('best_score_acc') == 0.5
|
||||||
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user