parent
cdc9919be8
commit
5398255f6f
|
@ -2,6 +2,7 @@
|
|||
import os.path as osp
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from math import inf
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
@ -296,6 +297,18 @@ class CheckpointHook(Hook):
|
|||
"""
|
||||
self._save_best_checkpoint(runner, metrics)
|
||||
|
||||
def _get_metric_score(self, metrics, key_indicator):
|
||||
eval_res = OrderedDict()
|
||||
if metrics is not None:
|
||||
eval_res.update(metrics)
|
||||
if len(eval_res) == 0:
|
||||
warnings.warn(
|
||||
'Since `eval_res` is an empty dict, the behavior to save '
|
||||
'the best checkpoint will be skipped in this evaluation.')
|
||||
return None
|
||||
|
||||
return eval_res[key_indicator]
|
||||
|
||||
def _save_checkpoint(self, runner) -> None:
|
||||
"""Save the current checkpoint and delete outdated checkpoint.
|
||||
|
||||
|
@ -372,7 +385,7 @@ class CheckpointHook(Hook):
|
|||
# save best logic
|
||||
# get score from messagehub
|
||||
for key_indicator, rule in zip(self.key_indicators, self.rules):
|
||||
key_score = metrics.get(key_indicator, None)
|
||||
key_score = self._get_metric_score(metrics, key_indicator)
|
||||
|
||||
if len(self.key_indicators) == 1:
|
||||
best_score_key = 'best_score'
|
||||
|
|
Loading…
Reference in New Issue