[Feature] Update evaluator prefix (#114)

* update evaluator prefix

* update docstring and comments

* update doc
This commit is contained in:
Yining Li 2022-03-10 17:25:20 +08:00 committed by GitHub
parent 3e0c064f49
commit 61fecabea6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 12 deletions

View File

@ -56,7 +56,7 @@ validation_cfg=dict(
],
# 指定使用前缀为 COCO 的 AP 为主要评测指标
# 在没有重名指标歧义的情况下,此处可以不写前缀,只写评测指标名
main_metric='COCO.AP',
main_metric='COCO/AP',
interval=10,
by_epoch=True,
)
@ -106,7 +106,7 @@ class Accuracy(BaseEvaluator):
Default prefix: ACC
Metrics:
- accuracy: classification accuracy
- accuracy (float): classification accuracy
"""
default_prefix = 'ACC'

View File

@ -2,5 +2,8 @@
from .base import BaseEvaluator
from .builder import build_evaluator
from .composed_evaluator import ComposedEvaluator
from .utils import get_metric_value
__all__ = ['BaseEvaluator', 'ComposedEvaluator', 'build_evaluator']
__all__ = [
'BaseEvaluator', 'ComposedEvaluator', 'build_evaluator', 'get_metric_value'
]

View File

@ -120,7 +120,7 @@ class BaseEvaluator(metaclass=ABCMeta):
# Add prefix to metric names
if self.prefix:
metrics = {
'.'.join((self.prefix, k)): v
'/'.join((self.prefix, k)): v
for k, v in metrics.items()
}
metrics = [metrics] # type: ignore

View File

@ -0,0 +1,37 @@
from typing import Any, Dict
def get_metric_value(indicator: str, metrics: Dict) -> Any:
"""Get the metric value specified by an indicator, which can be either a
metric name or a full name with evaluator prefix.
Args:
indicator (str): The metric indicator, which can be the metric name
(e.g. 'AP') or the full name with prefix (e.g. 'COCO/AP')
metrics (dict): The evaluation results output by the evaluator
Returns:
Any: The specified metric value
"""
if '/' in indicator:
# The indicator is a full name
if indicator in metrics:
return metrics[indicator]
else:
raise ValueError(
f'The indicator "{indicator}" can not match any metric in '
f'{list(metrics.keys())}')
else:
# The indicator is metric name without prefix
matched = [k for k in metrics.keys() if k.split('/')[-1] == indicator]
if not matched:
raise ValueError(
f'The indicator {indicator} can not match any metric in '
f'{list(metrics.keys())}')
elif len(matched) > 1:
raise ValueError(f'The indicator "{indicator}" matches multiple '
f'metrics {matched}')
else:
return metrics[matched[0]]

View File

@ -6,7 +6,7 @@ from unittest import TestCase
import numpy as np
from mmengine.data import BaseDataSample
from mmengine.evaluator import BaseEvaluator, build_evaluator
from mmengine.evaluator import BaseEvaluator, build_evaluator, get_metric_value
from mmengine.registry import EVALUATORS
@ -62,7 +62,7 @@ class ToyEvaluator(BaseEvaluator):
@EVALUATORS.register_module()
class UnprefixedEvaluator(BaseEvaluator):
class NonPrefixedEvaluator(BaseEvaluator):
"""Evaluator with unassigned `default_prefix` to test the warning
information."""
@ -100,8 +100,8 @@ class TestBaseEvaluator(TestCase):
evaluator.process(data_samples, predictions)
metrics = evaluator.evaluate(size=size)
self.assertAlmostEqual(metrics['Toy.accuracy'], 1.0)
self.assertEqual(metrics['Toy.size'], size)
self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0)
self.assertEqual(metrics['Toy/size'], size)
# Test empty results
cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0))
@ -126,9 +126,9 @@ class TestBaseEvaluator(TestCase):
metrics = evaluator.evaluate(size=size)
self.assertAlmostEqual(metrics['Toy.accuracy'], 1.0)
self.assertAlmostEqual(metrics['Toy.mAP'], 0.0)
self.assertEqual(metrics['Toy.size'], size)
self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0)
self.assertAlmostEqual(metrics['Toy/mAP'], 0.0)
self.assertEqual(metrics['Toy/size'], size)
def test_ambiguate_metric(self):
@ -167,6 +167,45 @@ class TestBaseEvaluator(TestCase):
self.assertDictEqual(_evaluator.dataset_meta, dataset_meta)
def test_prefix(self):
cfg = dict(type='UnprefixedEvaluator')
cfg = dict(type='NonPrefixedEvaluator')
with self.assertWarnsRegex(UserWarning, 'The prefix is not set'):
_ = build_evaluator(cfg)
def test_get_metric_value(self):
metrics = {
'prefix_0/metric_0': 0,
'prefix_1/metric_0': 1,
'prefix_1/metric_1': 2,
'nonprefixed': 3,
}
# Test indicator with prefix
indicator = 'prefix_0/metric_0' # correct indicator
self.assertEqual(get_metric_value(indicator, metrics), 0)
indicator = 'prefix_1/metric_0' # correct indicator
self.assertEqual(get_metric_value(indicator, metrics), 1)
indicator = 'prefix_0/metric_1' # unmatched indicator (wrong metric)
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
_ = get_metric_value(indicator, metrics)
indicator = 'prefix_2/metric' # unmatched indicator (wrong prefix)
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
_ = get_metric_value(indicator, metrics)
# Test indicator without prefix
indicator = 'metric_1' # correct indicator (prefixed metric)
self.assertEqual(get_metric_value(indicator, metrics), 2)
indicator = 'nonprefixed' # correct indicator (non-prefixed metric)
self.assertEqual(get_metric_value(indicator, metrics), 3)
indicator = 'metric_0' # ambiguous indicator
with self.assertRaisesRegex(ValueError, 'matches multiple metrics'):
_ = get_metric_value(indicator, metrics)
indicator = 'metric_2' # unmatched indicator
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
_ = get_metric_value(indicator, metrics)