From 5398255f6fb3dac8341f7d808f0d7d09350fcaae Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Wed, 7 Dec 2022 13:02:50 +0800 Subject: [PATCH] Revert "remove _get_metric_scope" This reverts commit eeb7a8c5ed2766bf773a9ed28f731fddacd10ac1. --- mmengine/hooks/checkpoint_hook.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 7cdae836..51732bad 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -2,6 +2,7 @@ import os.path as osp import re import warnings +from collections import OrderedDict from math import inf from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Union @@ -296,6 +297,18 @@ class CheckpointHook(Hook): """ self._save_best_checkpoint(runner, metrics) + def _get_metric_score(self, metrics, key_indicator): + eval_res = OrderedDict() + if metrics is not None: + eval_res.update(metrics) + if len(eval_res) == 0: + warnings.warn( + 'Since `eval_res` is an empty dict, the behavior to save ' + 'the best checkpoint will be skipped in this evaluation.') + return None + + return eval_res[key_indicator] + def _save_checkpoint(self, runner) -> None: """Save the current checkpoint and delete outdated checkpoint. @@ -372,7 +385,7 @@ class CheckpointHook(Hook): # save best logic # get score from messagehub for key_indicator, rule in zip(self.key_indicators, self.rules): - key_score = metrics.get(key_indicator, None) + key_score = self._get_metric_score(metrics, key_indicator) if len(self.key_indicators) == 1: best_score_key = 'best_score'