mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Update type hint and unit tests of evaluator. (#110)
* update type hint and ut for evaluator * update doc * fix
This commit is contained in:
parent
cfccabc657
commit
be6f18988e
@ -85,7 +85,7 @@ validation_cfg=dict(
|
||||
|
||||
首先,自定义评测器类应继承自 `BaseEvaluator`,并应加入注册器 `EVALUATORS` (关于注册器的说明请参考[相关文档](docs\zh_cn\tutorials\registry.md))。
|
||||
|
||||
`process()` 方法有 2 个输入参数,分别是测试数据样本`data_samples`和模型预测结果 `predictions`。我们从中分别取出样本类别标签和分类预测结果,并存放在 `self.results` 中。
|
||||
`process()` 方法有 2 个输入参数,分别是一个批次的测试数据样本 `data_batch` 和模型预测结果 `predictions`。我们从中分别取出样本类别标签和分类预测结果,并存放在 `self.results` 中。
|
||||
|
||||
`compute_metrics()` 方法有 1 个输入参数 `results`,里面存放了所有批次测试数据经过 `process()` 方法处理后得到的结果。从中取出样本类别标签和分类预测结果,即可计算得到分类正确率 `acc`。最终,将计算得到的评测指标以字典的形式返回。
|
||||
|
||||
@ -111,14 +111,17 @@ class Accuracy(BaseEvaluator):
|
||||
|
||||
default_prefix = 'ACC'
|
||||
|
||||
def process(self, data_samples: Dict, predictions: Dict):
|
||||
def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]],
|
||||
predictions: Sequence[BaseDataSample]):
|
||||
"""Process one batch of data and predictions. The processed
|
||||
Results should be stored in `self.results`, which will be used
|
||||
to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_samples (dict): The data samples from the dataset.
|
||||
predictions (dict): The output of the model.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data
|
||||
from the dataloader.
|
||||
predictions (Sequence[BaseDataSample]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
|
||||
# 取出分类预测结果和类别标签
|
||||
|
@ -5,7 +5,7 @@ import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -62,14 +62,17 @@ class BaseEvaluator(metaclass=ABCMeta):
|
||||
self._dataset_meta = dataset_meta
|
||||
|
||||
@abstractmethod
|
||||
def process(self, data_samples: BaseDataSample, predictions: dict) -> None:
|
||||
def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]],
|
||||
predictions: Sequence[BaseDataSample]) -> None:
|
||||
"""Process one batch of data samples and predictions. The processed
|
||||
results should be stored in ``self.results``, which will be used to
|
||||
compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_samples (BaseDataSample): The data samples from the dataset.
|
||||
predictions (dict): The output of the model.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data
|
||||
from the dataloader.
|
||||
predictions (Sequence[BaseDataSample]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence, Union
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from .base import BaseEvaluator
|
||||
@ -32,16 +32,19 @@ class ComposedEvaluator:
|
||||
for evaluator in self.evaluators:
|
||||
evaluator.dataset_meta = dataset_meta
|
||||
|
||||
def process(self, data_samples: BaseDataSample, predictions: dict):
|
||||
def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]],
|
||||
predictions: Sequence[BaseDataSample]):
|
||||
"""Invoke process method of each wrapped evaluator.
|
||||
|
||||
Args:
|
||||
data_samples (BaseDataSample): The data samples from the dataset.
|
||||
predictions (dict): The output of the model.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data
|
||||
from the dataloader.
|
||||
predictions (Sequence[BaseDataSample]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
|
||||
for evalutor in self.evaluators:
|
||||
evalutor.process(data_samples, predictions)
|
||||
evalutor.process(data_batch, predictions)
|
||||
|
||||
def evaluate(self, size: int) -> dict:
|
||||
"""Invoke evaluate method of each wrapped evaluator and collect the
|
||||
|
@ -5,6 +5,7 @@ from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.evaluator import BaseEvaluator, build_evaluator
|
||||
from mmengine.registry import EVALUATORS
|
||||
|
||||
@ -36,17 +37,20 @@ class ToyEvaluator(BaseEvaluator):
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
self.dummy_metrics = dummy_metrics
|
||||
|
||||
def process(self, data_samples, predictions):
|
||||
result = {'pred': predictions['pred'], 'label': data_samples['label']}
|
||||
self.results.append(result)
|
||||
def process(self, data_batch, predictions):
|
||||
results = [{
|
||||
'pred': pred.pred,
|
||||
'label': data[1].label
|
||||
} for pred, data in zip(predictions, data_batch)]
|
||||
self.results.extend(results)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
if self.dummy_metrics is not None:
|
||||
assert isinstance(self.dummy_metrics, dict)
|
||||
return self.dummy_metrics.copy()
|
||||
|
||||
pred = np.concatenate([result['pred'] for result in results])
|
||||
label = np.concatenate([result['label'] for result in results])
|
||||
pred = np.array([result['pred'] for result in results])
|
||||
label = np.array([result['label'] for result in results])
|
||||
acc = (pred == label).sum() / pred.size
|
||||
|
||||
metrics = {
|
||||
@ -74,9 +78,11 @@ def generate_test_results(size, batch_size, pred, label):
|
||||
bs_residual = size % batch_size
|
||||
for i in range(num_batch):
|
||||
bs = bs_residual if i == num_batch - 1 else batch_size
|
||||
data_samples = {'label': np.full(bs, label)}
|
||||
predictions = {'pred': np.full(bs, pred)}
|
||||
yield (data_samples, predictions)
|
||||
data_batch = [(np.zeros(
|
||||
(3, 10, 10)), BaseDataSample(data={'label': label}))
|
||||
for _ in range(bs)]
|
||||
predictions = [BaseDataSample(data={'pred': pred}) for _ in range(bs)]
|
||||
yield (data_batch, predictions)
|
||||
|
||||
|
||||
class TestBaseEvaluator(TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user