mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Update evaluator prefix (#114)
* update evaluator prefix * update docstring and comments * update doc
This commit is contained in:
parent
3e0c064f49
commit
61fecabea6
@ -56,7 +56,7 @@ validation_cfg=dict(
|
||||
],
|
||||
# 指定使用前缀为 COCO 的 AP 为主要评测指标
|
||||
# 在没有重名指标歧义的情况下,此处可以不写前缀,只写评测指标名
|
||||
main_metric='COCO.AP',
|
||||
main_metric='COCO/AP',
|
||||
interval=10,
|
||||
by_epoch=True,
|
||||
)
|
||||
@ -106,7 +106,7 @@ class Accuracy(BaseEvaluator):
|
||||
Default prefix: ACC
|
||||
|
||||
Metrics:
|
||||
- accuracy: classification accuracy
|
||||
- accuracy (float): classification accuracy
|
||||
"""
|
||||
|
||||
default_prefix = 'ACC'
|
||||
|
@ -2,5 +2,8 @@
|
||||
from .base import BaseEvaluator
|
||||
from .builder import build_evaluator
|
||||
from .composed_evaluator import ComposedEvaluator
|
||||
from .utils import get_metric_value
|
||||
|
||||
__all__ = ['BaseEvaluator', 'ComposedEvaluator', 'build_evaluator']
|
||||
__all__ = [
|
||||
'BaseEvaluator', 'ComposedEvaluator', 'build_evaluator', 'get_metric_value'
|
||||
]
|
||||
|
@ -120,7 +120,7 @@ class BaseEvaluator(metaclass=ABCMeta):
|
||||
# Add prefix to metric names
|
||||
if self.prefix:
|
||||
metrics = {
|
||||
'.'.join((self.prefix, k)): v
|
||||
'/'.join((self.prefix, k)): v
|
||||
for k, v in metrics.items()
|
||||
}
|
||||
metrics = [metrics] # type: ignore
|
||||
|
37
mmengine/evaluator/utils.py
Normal file
37
mmengine/evaluator/utils.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def get_metric_value(indicator: str, metrics: Dict) -> Any:
|
||||
"""Get the metric value specified by an indicator, which can be either a
|
||||
metric name or a full name with evaluator prefix.
|
||||
|
||||
Args:
|
||||
indicator (str): The metric indicator, which can be the metric name
|
||||
(e.g. 'AP') or the full name with prefix (e.g. 'COCO/AP')
|
||||
metrics (dict): The evaluation results output by the evaluator
|
||||
|
||||
Returns:
|
||||
Any: The specified metric value
|
||||
"""
|
||||
|
||||
if '/' in indicator:
|
||||
# The indicator is a full name
|
||||
if indicator in metrics:
|
||||
return metrics[indicator]
|
||||
else:
|
||||
raise ValueError(
|
||||
f'The indicator "{indicator}" can not match any metric in '
|
||||
f'{list(metrics.keys())}')
|
||||
else:
|
||||
# The indicator is metric name without prefix
|
||||
matched = [k for k in metrics.keys() if k.split('/')[-1] == indicator]
|
||||
|
||||
if not matched:
|
||||
raise ValueError(
|
||||
f'The indicator {indicator} can not match any metric in '
|
||||
f'{list(metrics.keys())}')
|
||||
elif len(matched) > 1:
|
||||
raise ValueError(f'The indicator "{indicator}" matches multiple '
|
||||
f'metrics {matched}')
|
||||
else:
|
||||
return metrics[matched[0]]
|
@ -6,7 +6,7 @@ from unittest import TestCase
|
||||
import numpy as np
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.evaluator import BaseEvaluator, build_evaluator
|
||||
from mmengine.evaluator import BaseEvaluator, build_evaluator, get_metric_value
|
||||
from mmengine.registry import EVALUATORS
|
||||
|
||||
|
||||
@ -62,7 +62,7 @@ class ToyEvaluator(BaseEvaluator):
|
||||
|
||||
|
||||
@EVALUATORS.register_module()
|
||||
class UnprefixedEvaluator(BaseEvaluator):
|
||||
class NonPrefixedEvaluator(BaseEvaluator):
|
||||
"""Evaluator with unassigned `default_prefix` to test the warning
|
||||
information."""
|
||||
|
||||
@ -100,8 +100,8 @@ class TestBaseEvaluator(TestCase):
|
||||
evaluator.process(data_samples, predictions)
|
||||
|
||||
metrics = evaluator.evaluate(size=size)
|
||||
self.assertAlmostEqual(metrics['Toy.accuracy'], 1.0)
|
||||
self.assertEqual(metrics['Toy.size'], size)
|
||||
self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0)
|
||||
self.assertEqual(metrics['Toy/size'], size)
|
||||
|
||||
# Test empty results
|
||||
cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0))
|
||||
@ -126,9 +126,9 @@ class TestBaseEvaluator(TestCase):
|
||||
|
||||
metrics = evaluator.evaluate(size=size)
|
||||
|
||||
self.assertAlmostEqual(metrics['Toy.accuracy'], 1.0)
|
||||
self.assertAlmostEqual(metrics['Toy.mAP'], 0.0)
|
||||
self.assertEqual(metrics['Toy.size'], size)
|
||||
self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0)
|
||||
self.assertAlmostEqual(metrics['Toy/mAP'], 0.0)
|
||||
self.assertEqual(metrics['Toy/size'], size)
|
||||
|
||||
def test_ambiguate_metric(self):
|
||||
|
||||
@ -167,6 +167,45 @@ class TestBaseEvaluator(TestCase):
|
||||
self.assertDictEqual(_evaluator.dataset_meta, dataset_meta)
|
||||
|
||||
def test_prefix(self):
|
||||
cfg = dict(type='UnprefixedEvaluator')
|
||||
cfg = dict(type='NonPrefixedEvaluator')
|
||||
with self.assertWarnsRegex(UserWarning, 'The prefix is not set'):
|
||||
_ = build_evaluator(cfg)
|
||||
|
||||
def test_get_metric_value(self):
|
||||
|
||||
metrics = {
|
||||
'prefix_0/metric_0': 0,
|
||||
'prefix_1/metric_0': 1,
|
||||
'prefix_1/metric_1': 2,
|
||||
'nonprefixed': 3,
|
||||
}
|
||||
|
||||
# Test indicator with prefix
|
||||
indicator = 'prefix_0/metric_0' # correct indicator
|
||||
self.assertEqual(get_metric_value(indicator, metrics), 0)
|
||||
|
||||
indicator = 'prefix_1/metric_0' # correct indicator
|
||||
self.assertEqual(get_metric_value(indicator, metrics), 1)
|
||||
|
||||
indicator = 'prefix_0/metric_1' # unmatched indicator (wrong metric)
|
||||
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
|
||||
_ = get_metric_value(indicator, metrics)
|
||||
|
||||
indicator = 'prefix_2/metric' # unmatched indicator (wrong prefix)
|
||||
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
|
||||
_ = get_metric_value(indicator, metrics)
|
||||
|
||||
# Test indicator without prefix
|
||||
indicator = 'metric_1' # correct indicator (prefixed metric)
|
||||
self.assertEqual(get_metric_value(indicator, metrics), 2)
|
||||
|
||||
indicator = 'nonprefixed' # correct indicator (non-prefixed metric)
|
||||
self.assertEqual(get_metric_value(indicator, metrics), 3)
|
||||
|
||||
indicator = 'metric_0' # ambiguous indicator
|
||||
with self.assertRaisesRegex(ValueError, 'matches multiple metrics'):
|
||||
_ = get_metric_value(indicator, metrics)
|
||||
|
||||
indicator = 'metric_2' # unmatched indicator
|
||||
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
|
||||
_ = get_metric_value(indicator, metrics)
|
||||
|
Loading…
x
Reference in New Issue
Block a user