diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index c9aa1a35..91cf382c 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import warnings -from collections import OrderedDict from math import inf from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Union @@ -294,20 +293,13 @@ class CheckpointHook(Hook): runner (Runner): The runner of the training process. metrics (dict): Evaluation results of all metrics """ - self._save_best_checkpoint(runner, metrics) - - def _get_metric_score(self, metrics, key_indicator): - eval_res = OrderedDict() - if metrics is not None: - eval_res.update(metrics) - - if len(eval_res) == 0: - warnings.warn( - 'Since `eval_res` is an empty dict, the behavior to save ' + if len(metrics) == 0: + runner.logger.warning( + 'Since `metrics` is an empty dict, the behavior to save ' 'the best checkpoint will be skipped in this evaluation.') - return None + return - return eval_res[key_indicator] + self._save_best_checkpoint(runner, metrics) def _save_checkpoint(self, runner) -> None: """Save the current checkpoint and delete outdated checkpoint. @@ -385,7 +377,7 @@ class CheckpointHook(Hook): # save best logic # get score from messagehub for key_indicator, rule in zip(self.key_indicators, self.rules): - key_score = self._get_metric_score(metrics, key_indicator) + key_score = metrics[key_indicator] if len(self.key_indicators) == 1: best_score_key = 'best_score' diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index 8fbb1a56..7abbedd2 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -148,6 +148,7 @@ class TestCheckpointHook: runner.work_dir = tmp_path runner.epoch = 9 runner.model = Mock() + runner.logger.warning = Mock() runner.message_hub = MessageHub.get_instance('test_after_val_epoch') with pytest.raises(ValueError): @@ -159,22 +160,11 @@ class TestCheckpointHook: CheckpointHook( interval=2, by_epoch=True, save_best='auto', rule='unsupport') - # if eval_res is an empty dict, print a warning information - with pytest.warns(UserWarning) as record_warnings: - eval_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='auto') - eval_hook._get_metric_score(None, None) - # Since there will be many warnings thrown, we just need to check - # if the expected exceptions are thrown - expected_message = ( - 'Since `eval_res` is an empty dict, the behavior to ' - 'save the best checkpoint will be skipped in this ' - 'evaluation.') - for warning in record_warnings: - if str(warning.message) == expected_message: - break - else: - assert False + # if metrics is an empty dict, print a warning information + checkpoint_hook = CheckpointHook( + interval=2, by_epoch=True, save_best='auto') + checkpoint_hook.after_val_epoch(runner, {}) + runner.logger.warning.assert_called_once() # test error when number of rules and metrics are not same with pytest.raises(AssertionError) as assert_error: @@ -187,93 +177,97 @@ class TestCheckpointHook: '"save_best", but got 3.') assert error_message in str(assert_error.value) - # if save_best is None,no best_ckpt meta should be stored - eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None) - eval_hook.before_train(runner) - eval_hook.after_val_epoch(runner, None) + # if save_best is None, no best_ckpt meta should be stored + checkpoint_hook = CheckpointHook( + interval=2, by_epoch=True, save_best=None) + checkpoint_hook.before_train(runner) + checkpoint_hook.after_val_epoch(runner, {}) assert 'best_score' not in runner.message_hub.runtime_info assert 'best_ckpt' not in runner.message_hub.runtime_info # when `save_best` is set to `auto`, first metric will be used. metrics = {'acc': 0.5, 'map': 0.3} - eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto') - eval_hook.before_train(runner) - eval_hook.after_val_epoch(runner, metrics) + checkpoint_hook = CheckpointHook( + interval=2, by_epoch=True, save_best='auto') + checkpoint_hook.before_train(runner) + checkpoint_hook.after_val_epoch(runner, metrics) best_ckpt_name = 'best_acc_epoch_9.pth' - best_ckpt_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_ckpt_name) - assert eval_hook.key_indicators == ['acc'] - assert eval_hook.rules == ['greater'] + best_ckpt_path = checkpoint_hook.file_client.join_path( + checkpoint_hook.out_dir, best_ckpt_name) + assert checkpoint_hook.key_indicators == ['acc'] + assert checkpoint_hook.rules == ['greater'] assert 'best_score' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score') == 0.5 assert 'best_ckpt' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_ckpt') == best_ckpt_path # # when `save_best` is set to `acc`, it should update greater value - eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc') - eval_hook.before_train(runner) + checkpoint_hook = CheckpointHook( + interval=2, by_epoch=True, save_best='acc') + checkpoint_hook.before_train(runner) metrics['acc'] = 0.8 - eval_hook.after_val_epoch(runner, metrics) + checkpoint_hook.after_val_epoch(runner, metrics) assert 'best_score' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score') == 0.8 # # when `save_best` is set to `loss`, it should update less value - eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss') - eval_hook.before_train(runner) + checkpoint_hook = CheckpointHook( + interval=2, by_epoch=True, save_best='loss') + checkpoint_hook.before_train(runner) metrics['loss'] = 0.8 - eval_hook.after_val_epoch(runner, metrics) + checkpoint_hook.after_val_epoch(runner, metrics) metrics['loss'] = 0.5 - eval_hook.after_val_epoch(runner, metrics) + checkpoint_hook.after_val_epoch(runner, metrics) assert 'best_score' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score') == 0.5 # when `rule` is set to `less`,then it should update less value # no matter what `save_best` is - eval_hook = CheckpointHook( + checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, save_best='acc', rule='less') - eval_hook.before_train(runner) + checkpoint_hook.before_train(runner) metrics['acc'] = 0.3 - eval_hook.after_val_epoch(runner, metrics) + checkpoint_hook.after_val_epoch(runner, metrics) assert 'best_score' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score') == 0.3 # # when `rule` is set to `greater`,then it should update greater value # # no matter what `save_best` is - eval_hook = CheckpointHook( + checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, save_best='loss', rule='greater') - eval_hook.before_train(runner) + checkpoint_hook.before_train(runner) metrics['loss'] = 1.0 - eval_hook.after_val_epoch(runner, metrics) + checkpoint_hook.after_val_epoch(runner, metrics) assert 'best_score' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score') == 1.0 # test multi `save_best` with one rule - eval_hook = CheckpointHook( + checkpoint_hook = CheckpointHook( interval=2, save_best=['acc', 'mIoU'], rule='greater') - assert eval_hook.key_indicators == ['acc', 'mIoU'] - assert eval_hook.rules == ['greater', 'greater'] + assert checkpoint_hook.key_indicators == ['acc', 'mIoU'] + assert checkpoint_hook.rules == ['greater', 'greater'] # test multi `save_best` with multi rules - eval_hook = CheckpointHook( + checkpoint_hook = CheckpointHook( interval=2, save_best=['FID', 'IS'], rule=['less', 'greater']) - assert eval_hook.key_indicators == ['FID', 'IS'] - assert eval_hook.rules == ['less', 'greater'] + assert checkpoint_hook.key_indicators == ['FID', 'IS'] + assert checkpoint_hook.rules == ['less', 'greater'] # test multi `save_best` with default rule - eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU']) - assert eval_hook.key_indicators == ['acc', 'mIoU'] - assert eval_hook.rules == ['greater', 'greater'] + checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU']) + assert checkpoint_hook.key_indicators == ['acc', 'mIoU'] + assert checkpoint_hook.rules == ['greater', 'greater'] runner.message_hub = MessageHub.get_instance( 'test_after_val_epoch_save_multi_best') - eval_hook.before_train(runner) + checkpoint_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) - eval_hook.after_val_epoch(runner, metrics) + checkpoint_hook.after_val_epoch(runner, metrics) best_acc_name = 'best_acc_epoch_9.pth' - best_acc_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_acc_name) + best_acc_path = checkpoint_hook.file_client.join_path( + checkpoint_hook.out_dir, best_acc_name) best_mIoU_name = 'best_mIoU_epoch_9.pth' - best_mIoU_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_mIoU_name) + best_mIoU_path = checkpoint_hook.file_client.join_path( + checkpoint_hook.out_dir, best_mIoU_name) assert 'best_score_acc' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score_acc') == 0.5 assert 'best_score_mIoU' in runner.message_hub.runtime_info and \ @@ -293,15 +287,15 @@ class TestCheckpointHook: # check best ckpt name and best score metrics = {'acc': 0.5, 'map': 0.3} - eval_hook = CheckpointHook( + checkpoint_hook = CheckpointHook( interval=2, by_epoch=False, save_best='acc', rule='greater') - eval_hook.before_train(runner) - eval_hook.after_val_epoch(runner, metrics) - assert eval_hook.key_indicators == ['acc'] - assert eval_hook.rules == ['greater'] + checkpoint_hook.before_train(runner) + checkpoint_hook.after_val_epoch(runner, metrics) + assert checkpoint_hook.key_indicators == ['acc'] + assert checkpoint_hook.rules == ['greater'] best_ckpt_name = 'best_acc_iter_9.pth' - best_ckpt_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_ckpt_name) + best_ckpt_path = checkpoint_hook.file_client.join_path( + checkpoint_hook.out_dir, best_ckpt_name) assert 'best_ckpt' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_ckpt') == best_ckpt_path assert 'best_score' in runner.message_hub.runtime_info and \ @@ -309,10 +303,10 @@ class TestCheckpointHook: # check best score updating metrics['acc'] = 0.666 - eval_hook.after_val_epoch(runner, metrics) + checkpoint_hook.after_val_epoch(runner, metrics) best_ckpt_name = 'best_acc_iter_9.pth' - best_ckpt_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_ckpt_name) + best_ckpt_path = checkpoint_hook.file_client.join_path( + checkpoint_hook.out_dir, best_ckpt_name) assert 'best_ckpt' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_ckpt') == best_ckpt_path assert 'best_score' in runner.message_hub.runtime_info and \ @@ -326,21 +320,21 @@ class TestCheckpointHook: interval=2, save_best='acc', rule=['greater', 'less']) # check best checkpoint name with `by_epoch` is False - eval_hook = CheckpointHook( + checkpoint_hook = CheckpointHook( interval=2, by_epoch=False, save_best=['acc', 'mIoU']) - assert eval_hook.key_indicators == ['acc', 'mIoU'] - assert eval_hook.rules == ['greater', 'greater'] + assert checkpoint_hook.key_indicators == ['acc', 'mIoU'] + assert checkpoint_hook.rules == ['greater', 'greater'] runner.message_hub = MessageHub.get_instance( 'test_after_val_epoch_save_multi_best_by_epoch_is_false') - eval_hook.before_train(runner) + checkpoint_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) - eval_hook.after_val_epoch(runner, metrics) + checkpoint_hook.after_val_epoch(runner, metrics) best_acc_name = 'best_acc_iter_9.pth' - best_acc_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_acc_name) + best_acc_path = checkpoint_hook.file_client.join_path( + checkpoint_hook.out_dir, best_acc_name) best_mIoU_name = 'best_mIoU_iter_9.pth' - best_mIoU_path = eval_hook.file_client.join_path( - eval_hook.out_dir, best_mIoU_name) + best_mIoU_path = checkpoint_hook.file_client.join_path( + checkpoint_hook.out_dir, best_mIoU_name) assert 'best_score_acc' in runner.message_hub.runtime_info and \ runner.message_hub.get_info('best_score_acc') == 0.5 assert 'best_score_mIoU' in runner.message_hub.runtime_info and \