mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Evaluator] MultiDatasetEvaluator (#1250)
* multi datasets evalutor * fix comment * fix typo
This commit is contained in:
parent
1978075577
commit
bcc245efd3
@ -69,3 +69,4 @@ class ConcatDataset(MMENGINE_CONCATDATASET):
|
|||||||
self._fully_initialized = False
|
self._fully_initialized = False
|
||||||
if not lazy_init:
|
if not lazy_init:
|
||||||
self.full_init()
|
self.full_init()
|
||||||
|
self._metainfo.update(dict(cumulative_sizes=self.cumulative_sizes))
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .evaluator import * # NOQA
|
||||||
from .metrics import * # NOQA
|
from .metrics import * # NOQA
|
||||||
|
4
mmocr/evaluation/evaluator/__init__.py
Normal file
4
mmocr/evaluation/evaluator/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .multi_datasets_evaluator import MultiDatasetsEvaluator
|
||||||
|
|
||||||
|
__all__ = ['MultiDatasetsEvaluator']
|
100
mmocr/evaluation/evaluator/multi_datasets_evaluator.py
Normal file
100
mmocr/evaluation/evaluator/multi_datasets_evaluator.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import warnings
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from mmengine.dist import (broadcast_object_list, collect_results,
|
||||||
|
is_main_process)
|
||||||
|
from mmengine.evaluator import BaseMetric, Evaluator
|
||||||
|
from mmengine.evaluator.metric import _to_cpu
|
||||||
|
|
||||||
|
from mmocr.registry import EVALUATOR
|
||||||
|
from mmocr.utils.typing import ConfigType
|
||||||
|
|
||||||
|
|
||||||
|
@EVALUATOR.register_module()
|
||||||
|
class MultiDatasetsEvaluator(Evaluator):
|
||||||
|
"""Wrapper class to compose class: `ConcatDataset` and multiple
|
||||||
|
:class:`BaseMetric` instances.
|
||||||
|
The metrics will be evaluated on each dataset slice separately. The name of
|
||||||
|
the each metric is the concatenation of the dataset prefix, the metric
|
||||||
|
prefix and the key of metric - e.g.
|
||||||
|
`dataset_prefix/metric_prefix/accuracy`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics (dict or BaseMetric or Sequence): The config of metrics.
|
||||||
|
dataset_prefixes (Sequence[str]): The prefix of each dataset. The
|
||||||
|
length of this sequence should be the same as the length of the
|
||||||
|
datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, metrics: Union[ConfigType, BaseMetric, Sequence],
|
||||||
|
dataset_prefixes: Sequence[str]) -> None:
|
||||||
|
super().__init__(metrics)
|
||||||
|
self.dataset_prefixes = dataset_prefixes
|
||||||
|
|
||||||
|
def evaluate(self, size: int) -> dict:
|
||||||
|
"""Invoke ``evaluate`` method of each metric and collect the metrics
|
||||||
|
dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (int): Length of the entire validation dataset. When batch
|
||||||
|
size > 1, the dataloader may pad some data samples to make
|
||||||
|
sure all ranks have the same length of dataset slice. The
|
||||||
|
``collect_results`` function will drop the padded data based on
|
||||||
|
this size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Evaluation results of all metrics. The keys are the names
|
||||||
|
of the metrics, and the values are corresponding results.
|
||||||
|
"""
|
||||||
|
metrics_results = OrderedDict()
|
||||||
|
dataset_slices = self.dataset_meta.get('cumulative_sizes', [size])
|
||||||
|
assert len(dataset_slices) == len(self.dataset_prefixes)
|
||||||
|
for metric in self.metrics:
|
||||||
|
if len(metric.results) == 0:
|
||||||
|
warnings.warn(
|
||||||
|
f'{metric.__class__.__name__} got empty `self.results`.'
|
||||||
|
'Please ensure that the processed results are properly '
|
||||||
|
'added into `self.results` in `process` method.')
|
||||||
|
|
||||||
|
results = collect_results(metric.results, size,
|
||||||
|
metric.collect_device)
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
# cast all tensors in results list to cpu
|
||||||
|
results = _to_cpu(results)
|
||||||
|
for start, end, dataset_prefix in zip([0] +
|
||||||
|
dataset_slices[:-1],
|
||||||
|
dataset_slices,
|
||||||
|
self.dataset_prefixes):
|
||||||
|
metric_results = metric.compute_metrics(
|
||||||
|
results[start:end]) # type: ignore
|
||||||
|
# Add prefix to metric names
|
||||||
|
|
||||||
|
if metric.prefix:
|
||||||
|
final_prefix = '/'.join(
|
||||||
|
(dataset_prefix, metric.prefix))
|
||||||
|
else:
|
||||||
|
final_prefix = dataset_prefix
|
||||||
|
metric_results = {
|
||||||
|
'/'.join((final_prefix, k)): v
|
||||||
|
for k, v in metric_results.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
metric.results.clear()
|
||||||
|
# Check metric name conflicts
|
||||||
|
for name in metric_results.keys():
|
||||||
|
if name in metrics_results:
|
||||||
|
raise ValueError(
|
||||||
|
'There are multiple evaluation results with '
|
||||||
|
f'the same metric name {name}. Please make '
|
||||||
|
'sure all metrics have different prefixes.')
|
||||||
|
metrics_results.update(metric_results)
|
||||||
|
if is_main_process():
|
||||||
|
metrics_results = [metrics_results]
|
||||||
|
else:
|
||||||
|
metrics_results = [None] # type: ignore
|
||||||
|
broadcast_object_list(metrics_results)
|
||||||
|
|
||||||
|
return metrics_results[0]
|
@ -0,0 +1,126 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mmengine import METRICS, BaseDataElement, DefaultScope
|
||||||
|
from mmengine.evaluator import BaseMetric
|
||||||
|
|
||||||
|
from mmocr.evaluation import MultiDatasetsEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
@METRICS.register_module()
|
||||||
|
class ToyMetric(BaseMetric):
|
||||||
|
"""Evaluator that calculates the metric `accuracy` from predictions and
|
||||||
|
labels. Alternatively, this evaluator can return arbitrary dummy metrics
|
||||||
|
set in the config.
|
||||||
|
|
||||||
|
Default prefix: Toy
|
||||||
|
|
||||||
|
Metrics:
|
||||||
|
- accuracy (float): The classification accuracy. Only when
|
||||||
|
`dummy_metrics` is None.
|
||||||
|
- size (int): The number of test samples. Only when `dummy_metrics`
|
||||||
|
is None.
|
||||||
|
|
||||||
|
If `dummy_metrics` is set as a dict in the config, it will be
|
||||||
|
returned as the metrics and override `accuracy` and `size`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_prefix = None
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
collect_device: str = 'cpu',
|
||||||
|
prefix: Optional[str] = 'Toy',
|
||||||
|
dummy_metrics: Optional[Dict] = None):
|
||||||
|
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||||
|
self.dummy_metrics = dummy_metrics
|
||||||
|
|
||||||
|
def process(self, data_batch, predictions):
|
||||||
|
results = [{
|
||||||
|
'pred': pred.get('pred'),
|
||||||
|
'label': data['data_sample'].get('label')
|
||||||
|
} for pred, data in zip(predictions, data_batch)]
|
||||||
|
self.results.extend(results)
|
||||||
|
|
||||||
|
def compute_metrics(self, results: List):
|
||||||
|
if self.dummy_metrics is not None:
|
||||||
|
assert isinstance(self.dummy_metrics, dict)
|
||||||
|
return self.dummy_metrics.copy()
|
||||||
|
|
||||||
|
pred = np.array([result['pred'] for result in results])
|
||||||
|
label = np.array([result['label'] for result in results])
|
||||||
|
acc = (pred == label).sum() / pred.size
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
'accuracy': acc,
|
||||||
|
'size': pred.size, # To check the number of testing samples
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_results(size, batch_size, pred, label):
|
||||||
|
num_batch = math.ceil(size / batch_size)
|
||||||
|
bs_residual = size % batch_size
|
||||||
|
for i in range(num_batch):
|
||||||
|
bs = bs_residual if i == num_batch - 1 else batch_size
|
||||||
|
data_batch = [
|
||||||
|
dict(
|
||||||
|
inputs=np.zeros((3, 10, 10)),
|
||||||
|
data_sample=BaseDataElement(label=label)) for _ in range(bs)
|
||||||
|
]
|
||||||
|
predictions = [BaseDataElement(pred=pred) for _ in range(bs)]
|
||||||
|
yield (data_batch, predictions)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiDatasetsEvaluator(TestCase):
|
||||||
|
|
||||||
|
def test_composed_metrics(self):
|
||||||
|
DefaultScope.get_instance('mmocr', scope_name='mmocr')
|
||||||
|
cfg = [
|
||||||
|
dict(type='ToyMetric'),
|
||||||
|
dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
|
||||||
|
]
|
||||||
|
|
||||||
|
evaluator = MultiDatasetsEvaluator(cfg, dataset_prefixes=['Fake'])
|
||||||
|
evaluator.dataset_meta = {}
|
||||||
|
size = 10
|
||||||
|
batch_size = 4
|
||||||
|
|
||||||
|
for data_samples, predictions in generate_test_results(
|
||||||
|
size, batch_size, pred=1, label=1):
|
||||||
|
evaluator.process(data_samples, predictions)
|
||||||
|
|
||||||
|
metrics = evaluator.evaluate(size=size)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(metrics['Fake/Toy/accuracy'], 1.0)
|
||||||
|
self.assertAlmostEqual(metrics['Fake/Toy/mAP'], 0.0)
|
||||||
|
self.assertEqual(metrics['Fake/Toy/size'], size)
|
||||||
|
with self.assertWarns(Warning):
|
||||||
|
evaluator.evaluate(size=0)
|
||||||
|
|
||||||
|
cfg = [dict(type='ToyMetric'), dict(type='ToyMetric')]
|
||||||
|
|
||||||
|
evaluator = MultiDatasetsEvaluator(cfg, dataset_prefixes=['Fake'])
|
||||||
|
evaluator.dataset_meta = {}
|
||||||
|
|
||||||
|
for data_samples, predictions in generate_test_results(
|
||||||
|
size, batch_size, pred=1, label=1):
|
||||||
|
evaluator.process(data_samples, predictions)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
evaluator.evaluate(size=size)
|
||||||
|
|
||||||
|
cfg = [dict(type='ToyMetric'), dict(type='ToyMetric', prefix=None)]
|
||||||
|
|
||||||
|
evaluator = MultiDatasetsEvaluator(cfg, dataset_prefixes=['Fake'])
|
||||||
|
evaluator.dataset_meta = {}
|
||||||
|
|
||||||
|
for data_samples, predictions in generate_test_results(
|
||||||
|
size, batch_size, pred=1, label=1):
|
||||||
|
evaluator.process(data_samples, predictions)
|
||||||
|
metrics = evaluator.evaluate(size=size)
|
||||||
|
self.assertIn('Fake/Toy/accuracy', metrics)
|
||||||
|
self.assertIn('Fake/accuracy', metrics)
|
Loading…
x
Reference in New Issue
Block a user