[Feature] Add metrics for single-label classification.
parent
93a27c8324
commit
6ad75f0076
mmcls/metrics
tests/test_metrics
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .single_label import Accuracy, SingleLabelMetric
|
||||
|
||||
__all__ = ['Accuracy', 'SingleLabelMetric']
|
|
@ -0,0 +1,554 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
|
||||
|
||||
def to_tensor(value):
|
||||
"""Convert value to torch.Tensor."""
|
||||
if isinstance(value, np.ndarray):
|
||||
value = torch.from_numpy(value)
|
||||
elif isinstance(value, Sequence) and not mmengine.is_str(value):
|
||||
value = torch.tensor(value)
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
raise TypeError(f'{type(value)} is not an available argument.')
|
||||
return value
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class Accuracy(BaseMetric):
|
||||
"""Top-k accuracy evaluation metric.
|
||||
|
||||
Args:
|
||||
topk (int | Sequence[int]): If the predictions in ``topk``
|
||||
matches the target, the predictions will be regarded as
|
||||
correct ones. Defaults to 1.
|
||||
thrs (Sequence[float | None] | float | None): Predictions with scores
|
||||
under the thresholds are considered negative. None means no
|
||||
thresholds. Default to 0.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.metrics import Accuracy
|
||||
>>> # -------------------- The Basic Usage --------------------
|
||||
>>> y_pred = [0, 2, 1, 3]
|
||||
>>> y_true = [0, 1, 2, 3]
|
||||
>>> Accuracy.calculate(y_pred, y_true)
|
||||
tensor([50.])
|
||||
>>> # Calculate the top1 and top5 accuracy.
|
||||
>>> y_score = torch.rand((1000, 10))
|
||||
>>> y_true = torch.zeros((1000, ))
|
||||
>>> Accuracy.calculate(y_score, y_true, topk=(1, 5))
|
||||
[[tensor([9.9000])], [tensor([51.5000])]]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.core import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_batch = [{
|
||||
... 'inputs': None, # In this example, the `inputs` is not used.
|
||||
... 'data_sample': ClsDataSample().set_gt_label(0)
|
||||
... } for i in range(1000)]
|
||||
>>> pred = [
|
||||
... ClsDataSample().set_pred_score(torch.rand(10))
|
||||
... for i in range(1000)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=Accuracy(topk=(1, 5)))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'accuracy/top1': 9.300000190734863,
|
||||
'accuracy/top5': 51.20000076293945
|
||||
}
|
||||
"""
|
||||
default_prefix: Optional[str] = 'accuracy'
|
||||
|
||||
def __init__(self,
|
||||
topk: Union[int, Sequence[int]] = (1, ),
|
||||
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
if isinstance(topk, int):
|
||||
self.topk = (topk, )
|
||||
else:
|
||||
self.topk = tuple(topk)
|
||||
|
||||
if isinstance(thrs, float) or thrs is None:
|
||||
self.thrs = (thrs, )
|
||||
else:
|
||||
self.thrs = tuple(thrs)
|
||||
|
||||
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
|
||||
for data, pred in zip(data_batch, predictions):
|
||||
result = dict()
|
||||
pred_label = pred['pred_label']
|
||||
# Use gt_label in the pred dict preferentially.
|
||||
gt_label = pred.get('gt_label', data['data_sample']['gt_label'])
|
||||
if 'score' in pred_label:
|
||||
result['pred_score'] = pred_label['score'].cpu()
|
||||
else:
|
||||
result['pred_label'] = pred_label['label'].cpu()
|
||||
result['gt_label'] = gt_label['label'].cpu()
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (dict): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
# NOTICE: don't access `self.results` from the method.
|
||||
metrics = {}
|
||||
|
||||
# concat
|
||||
target = torch.cat([res['gt_label'] for res in results])
|
||||
if 'pred_score' in results[0]:
|
||||
pred = torch.stack([res['pred_score'] for res in results])
|
||||
|
||||
try:
|
||||
acc = self.calculate(pred, target, self.topk, self.thrs)
|
||||
except ValueError as e:
|
||||
# If the topk is invalid.
|
||||
raise ValueError(
|
||||
str(e) + ' Please check the `val_evaluator` and '
|
||||
'`test_evaluator` fields in your config file.')
|
||||
|
||||
multi_thrs = len(self.thrs) > 1
|
||||
for i, k in enumerate(self.topk):
|
||||
for j, thr in enumerate(self.thrs):
|
||||
name = f'top{k}'
|
||||
if multi_thrs:
|
||||
name += '_no-thr' if thr is None else f'_thr-{thr:.2f}'
|
||||
metrics[name] = acc[i][j].item()
|
||||
else:
|
||||
# If only label in the `pred_label`.
|
||||
pred = torch.cat([res['pred_label'] for res in results])
|
||||
acc = self.calculate(pred, target, self.topk, self.thrs)
|
||||
metrics['top1'] = acc.item()
|
||||
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def calculate(
|
||||
pred: Union[torch.Tensor, np.ndarray, Sequence],
|
||||
target: Union[torch.Tensor, np.ndarray, Sequence],
|
||||
topk: Sequence[int] = (1, ),
|
||||
thrs: Sequence[Union[float, None]] = (0., ),
|
||||
) -> Union[torch.Tensor, List[List[torch.Tensor]]]:
|
||||
"""Calculate the accuracy.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.ndarray | Sequence): The prediction
|
||||
results. It can be labels (N, ), or scores of every
|
||||
class (N, C).
|
||||
target (torch.Tensor | np.ndarray | Sequence): The target of
|
||||
each prediction with shape (N, ).
|
||||
thrs (Sequence[float | None]): Predictions with scores under
|
||||
the thresholds are considered negative. It's only used
|
||||
when ``pred`` is scores. None means no thresholds.
|
||||
Default to (0., ).
|
||||
thrs (Sequence[float]): Predictions with scores under
|
||||
the thresholds are considered negative. It's only used
|
||||
when ``pred`` is scores. Default to (0., ).
|
||||
|
||||
Returns:
|
||||
torch.Tensor | List[List[torch.Tensor]]: Accuracy.
|
||||
|
||||
- torch.Tensor: If the ``pred`` is a sequence of label instead of
|
||||
score (number of dimensions is 1). Only return a top-1 accuracy
|
||||
tensor, and ignore the argument ``topk` and ``thrs``.
|
||||
- List[List[torch.Tensor]]: If the ``pred`` is a sequence of score
|
||||
(number of dimensions is 2). Return the accuracy on each ``topk``
|
||||
and ``thrs``. And the first dim is ``topk``, the second dim is
|
||||
``thrs``.
|
||||
"""
|
||||
|
||||
pred = to_tensor(pred)
|
||||
target = to_tensor(target).to(torch.int64)
|
||||
num = pred.size(0)
|
||||
assert pred.size(0) == target.size(0), \
|
||||
f"The size of pred ({pred.size(0)}) doesn't match "\
|
||||
f'the target ({target.size(0)}).'
|
||||
|
||||
if pred.ndim == 1:
|
||||
# For pred label, ignore topk and acc
|
||||
pred_label = pred.int()
|
||||
correct = pred.eq(target).float().sum(0, keepdim=True)
|
||||
acc = correct.mul_(100. / num)
|
||||
return acc
|
||||
else:
|
||||
# For pred score, calculate on all topk and thresholds.
|
||||
pred = pred.float()
|
||||
maxk = max(topk)
|
||||
|
||||
if maxk > pred.size(1):
|
||||
raise ValueError(
|
||||
f'Top-{maxk} accuracy is unavailable since the number of '
|
||||
f'categories is {pred.size(1)}.')
|
||||
|
||||
pred_score, pred_label = pred.topk(maxk, dim=1)
|
||||
pred_label = pred_label.t()
|
||||
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
|
||||
results = []
|
||||
for k in topk:
|
||||
results.append([])
|
||||
for thr in thrs:
|
||||
# Only prediction values larger than thr are counted
|
||||
# as correct
|
||||
_correct = correct
|
||||
if thr is not None:
|
||||
_correct = _correct & (pred_score.t() > thr)
|
||||
correct_k = _correct[:k].reshape(-1).float().sum(
|
||||
0, keepdim=True)
|
||||
acc = correct_k.mul_(100. / num)
|
||||
results[-1].append(acc)
|
||||
return results
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class SingleLabelMetric(BaseMetric):
|
||||
"""A collection of metrics for single-label multi-class classification task
|
||||
based on confusion matrix.
|
||||
|
||||
It includes precision, recall, f1-score and support. Comparing with
|
||||
:class:`Accuracy`, these metrics doesn't support topk, but supports
|
||||
various average mode.
|
||||
|
||||
Args:
|
||||
thrs (Sequence[float | None] | float | None): Predictions with scores
|
||||
under the thresholds are considered negative. None means no
|
||||
thresholds. Default to 0.
|
||||
items (Sequence[str]): The detailed metric items to evaluate. Here is
|
||||
the available options:
|
||||
|
||||
- `"precision"`: The ratio tp / (tp + fp) where tp is the
|
||||
number of true positives and fp the number of false
|
||||
positives.
|
||||
- `"recall"`: The ratio tp / (tp + fn) where tp is the number
|
||||
of true positives and fn the number of false negatives.
|
||||
- `"f1-score"`: The f1-score is the harmonic mean of the
|
||||
precision and recall.
|
||||
- `"support"`: The total number of occurrences of each category
|
||||
in the target.
|
||||
|
||||
Defaults to ('precision', 'recall', 'f1-score').
|
||||
average (str, optional): The average method. If None, the scores
|
||||
for each class are returned. And it supports two average modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `"micro"`: Calculate metrics globally by counting the total
|
||||
true positives, false negatives and false positives.
|
||||
|
||||
Defaults to "macro".
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.metrics import SingleLabelMetric
|
||||
>>> # -------------------- The Basic Usage --------------------
|
||||
>>> y_pred = [0, 1, 1, 3]
|
||||
>>> y_true = [0, 2, 1, 3]
|
||||
>>> # Output precision, recall, f1-score and support.
|
||||
>>> SingleLabelMetric.calculate(y_pred, y_true, num_classes=4)
|
||||
(tensor(62.5000, dtype=torch.float64),
|
||||
tensor(75., dtype=torch.float64),
|
||||
tensor(66.6667, dtype=torch.float64),
|
||||
tensor(4))
|
||||
>>> # Calculate with different thresholds.
|
||||
>>> y_score = torch.rand((1000, 10))
|
||||
>>> y_true = torch.zeros((1000, ))
|
||||
>>> SingleLabelMetric.calculate(y_score, y_true, thrs=(0., 0.9))
|
||||
[(tensor(10., dtype=torch.float64),
|
||||
tensor(1.2100, dtype=torch.float64),
|
||||
tensor(2.1588, dtype=torch.float64),
|
||||
tensor(1000)),
|
||||
(tensor(10., dtype=torch.float64),
|
||||
tensor(0.8200, dtype=torch.float64),
|
||||
tensor(1.5157, dtype=torch.float64),
|
||||
tensor(1000))]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.core import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_batch = [{
|
||||
... 'inputs': None, # In this example, the `inputs` is not used.
|
||||
... 'data_sample': ClsDataSample().set_gt_label(i%5)
|
||||
... } for i in range(1000)]
|
||||
>>> pred = [
|
||||
... ClsDataSample().set_pred_score(torch.rand(5))
|
||||
... for i in range(1000)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=SingleLabelMetric())
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'single-label/precision': 10.0,
|
||||
'single-label/recall': 0.96,
|
||||
'single-label/f1-score': 1.7518248175182483
|
||||
}
|
||||
>>> # Evaluate on each class
|
||||
>>> evaluator = Evaluator(metrics=SingleLabelMetric(average=None))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'single-label/precision': [21.14, 18.69, 17.17, 19.42, 16.14],
|
||||
'single-label/recall': [18.5, 18.5, 17.0, 20.0, 18.0],
|
||||
'single-label/f1-score': [19.73, 18.59, 17.09, 19.70, 17.02]
|
||||
}
|
||||
"""
|
||||
default_prefix: Optional[str] = 'single-label'
|
||||
|
||||
def __init__(self,
|
||||
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
|
||||
items: Sequence[str] = ('precision', 'recall', 'f1-score'),
|
||||
average: Optional[str] = 'macro',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
if isinstance(thrs, float) or thrs is None:
|
||||
self.thrs = (thrs, )
|
||||
else:
|
||||
self.thrs = tuple(thrs)
|
||||
|
||||
for item in items:
|
||||
assert item in ['precision', 'recall', 'f1-score', 'support'], \
|
||||
f'The metric {item} is not supported by `SingleLabelMetric`,' \
|
||||
' please specicy from "precision", "recall", "f1-score" and ' \
|
||||
'"support".'
|
||||
self.items = tuple(items)
|
||||
self.average = average
|
||||
|
||||
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
|
||||
for data, pred in zip(data_batch, predictions):
|
||||
result = dict()
|
||||
pred_label = pred['pred_label']
|
||||
# Use gt_label in the pred dict preferentially.
|
||||
gt_label = pred.get('gt_label', data['data_sample']['gt_label'])
|
||||
if 'score' in pred_label:
|
||||
result['pred_score'] = pred_label['score'].cpu()
|
||||
elif ('num_classes' in pred_label
|
||||
or 'num_classes' in data['data_sample']):
|
||||
result['pred_label'] = pred_label['label'].cpu()
|
||||
result['num_classes'] = pred_label.get(
|
||||
'num_classes', None) or data['data_sample']['num_classes']
|
||||
else:
|
||||
raise ValueError('The `pred_label` in predictions do not '
|
||||
'have neither `score` nor `num_classes`.')
|
||||
result['gt_label'] = gt_label['label'].cpu()
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
# NOTICE: don't access `self.results` from the method. `self.results`
|
||||
# are a list of results from multiple batch, while the input `results`
|
||||
# are the collected results.
|
||||
metrics = {}
|
||||
|
||||
def pack_results(precision, recall, f1_score, support):
|
||||
single_metrics = {}
|
||||
if 'precision' in self.items:
|
||||
single_metrics['precision'] = precision
|
||||
if 'recall' in self.items:
|
||||
single_metrics['recall'] = recall
|
||||
if 'f1-score' in self.items:
|
||||
single_metrics['f1-score'] = f1_score
|
||||
if 'support' in self.items:
|
||||
single_metrics['support'] = support
|
||||
return single_metrics
|
||||
|
||||
# concat
|
||||
target = torch.cat([res['gt_label'] for res in results])
|
||||
if 'pred_score' in results[0]:
|
||||
pred = torch.stack([res['pred_score'] for res in results])
|
||||
metrics_list = self.calculate(
|
||||
pred, target, thrs=self.thrs, average=self.average)
|
||||
|
||||
multi_thrs = len(self.thrs) > 1
|
||||
for i, thr in enumerate(self.thrs):
|
||||
if multi_thrs:
|
||||
suffix = '_no-thr' if thr is None else f'_thr-{thr:.2f}'
|
||||
else:
|
||||
suffix = ''
|
||||
|
||||
for k, v in pack_results(*metrics_list[i]).items():
|
||||
metrics[k + suffix] = v
|
||||
else:
|
||||
# If only label in the `pred_label`.
|
||||
pred = torch.cat([res['pred_label'] for res in results])
|
||||
res = self.calculate(
|
||||
pred,
|
||||
target,
|
||||
average=self.average,
|
||||
num_classes=results[0]['num_classes'])
|
||||
metrics = pack_results(*res)
|
||||
|
||||
for k, v in metrics.items():
|
||||
if self.average is not None:
|
||||
metrics[k] = v.item()
|
||||
else:
|
||||
metrics[k] = v.cpu().detach().tolist()
|
||||
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def calculate(
|
||||
pred: Union[torch.Tensor, np.ndarray, Sequence],
|
||||
target: Union[torch.Tensor, np.ndarray, Sequence],
|
||||
thrs: Sequence[Union[float, None]] = (0., ),
|
||||
average: Optional[str] = 'macro',
|
||||
num_classes: Optional[int] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Calculate the precision, recall, f1-score and support.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.ndarray | Sequence): The prediction
|
||||
results. It can be labels (N, ), or scores of every
|
||||
class (N, C).
|
||||
target (torch.Tensor | np.ndarray | Sequence): The target of
|
||||
each prediction with shape (N, ).
|
||||
thrs (Sequence[float | None]): Predictions with scores under
|
||||
the thresholds are considered negative. It's only used
|
||||
when ``pred`` is scores. None means no thresholds.
|
||||
Default to (0., ).
|
||||
average (str, optional): The average method. If None, the scores
|
||||
for each class are returned. And it supports two average modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and
|
||||
calculate the mean value over all categories.
|
||||
- `"micro"`: Calculate metrics globally by counting the
|
||||
total true positives, false negatives and false
|
||||
positives.
|
||||
|
||||
Defaults to "macro".
|
||||
num_classes (Optional, int): The number of classes. If the ``pred``
|
||||
is label instead of scores, this argument is required.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple: The tuple contains precision, recall and f1-score.
|
||||
And the type of each item is:
|
||||
|
||||
- torch.Tensor: If the ``pred`` is a sequence of label instead of
|
||||
score (number of dimensions is 1). Only returns a tensor for
|
||||
each metric. The shape is (1, ) if ``classwise`` is False, and
|
||||
(C, ) if ``classwise`` is True.
|
||||
- List[torch.Tensor]: If the ``pred`` is a sequence of score
|
||||
(number of dimensions is 2). Return the metrics on each ``thrs``.
|
||||
The shape of tensor is (1, ) if ``classwise`` is False, and (C, )
|
||||
if ``classwise`` is True.
|
||||
"""
|
||||
average_options = ['micro', 'macro', None]
|
||||
assert average in average_options, 'Invalid `average` argument, ' \
|
||||
f'please specicy from {average_options}.'
|
||||
|
||||
pred = to_tensor(pred)
|
||||
target = to_tensor(target).to(torch.int64)
|
||||
assert pred.size(0) == target.size(0), \
|
||||
f"The size of pred ({pred.size(0)}) doesn't match "\
|
||||
f'the target ({target.size(0)}).'
|
||||
|
||||
def _do_calculate(pred_positive, gt_positive):
|
||||
class_correct = (pred_positive & gt_positive)
|
||||
if average == 'micro':
|
||||
tp_sum = class_correct.sum()
|
||||
pred_sum = pred_positive.sum()
|
||||
gt_sum = gt_positive.sum()
|
||||
else:
|
||||
tp_sum = class_correct.sum(0)
|
||||
pred_sum = pred_positive.sum(0)
|
||||
gt_sum = gt_positive.sum(0)
|
||||
|
||||
precision = tp_sum / np.maximum(pred_sum, 1.) * 100
|
||||
recall = tp_sum / np.maximum(gt_sum, 1.) * 100
|
||||
f1_score = 2 * precision * recall / np.maximum(
|
||||
precision + recall,
|
||||
torch.finfo(torch.float32).eps)
|
||||
if average in ['macro', 'micro']:
|
||||
precision = precision.mean(0, keepdim=True)
|
||||
recall = recall.mean(0, keepdim=True)
|
||||
f1_score = f1_score.mean(0, keepdim=True)
|
||||
support = gt_sum.sum(0, keepdim=True)
|
||||
else:
|
||||
support = gt_sum
|
||||
return precision, recall, f1_score, support
|
||||
|
||||
if pred.ndim == 1:
|
||||
assert num_classes is not None, \
|
||||
'Please specicy the `num_classes` if the `pred` is labels ' \
|
||||
'intead of scores.'
|
||||
gt_positive = F.one_hot(target.flatten(), num_classes)
|
||||
pred_positive = F.one_hot(pred.to(torch.int64), num_classes)
|
||||
return _do_calculate(pred_positive, gt_positive)
|
||||
else:
|
||||
# For pred score, calculate on all thresholds.
|
||||
num_classes = pred.size(1)
|
||||
pred_score, pred_label = torch.topk(pred, k=1)
|
||||
pred_score = pred_score.flatten()
|
||||
pred_label = pred_label.flatten()
|
||||
|
||||
gt_positive = F.one_hot(target.flatten(), num_classes)
|
||||
|
||||
results = []
|
||||
for thr in thrs:
|
||||
pred_positive = F.one_hot(pred_label, num_classes)
|
||||
if thr is not None:
|
||||
pred_positive[pred_score <= thr] = 0
|
||||
results.append(_do_calculate(pred_positive, gt_positive))
|
||||
|
||||
return results
|
|
@ -0,0 +1,300 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.metrics import Accuracy, SingleLabelMetric
|
||||
from mmcls.registry import METRICS
|
||||
|
||||
|
||||
class TestAccuracy(TestCase):
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
data_batch = [{
|
||||
'data_sample': ClsDataSample().set_gt_label(i).to_dict()
|
||||
} for i in [0, 0, 1, 2, 1, 0]]
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).to_dict()
|
||||
for i, j in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
torch.tensor([0.4, 0.5, 0.1]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
], [0, 0, 1, 2, 2, 2])
|
||||
]
|
||||
|
||||
# Test with score (use score instead of label if score exists)
|
||||
metric = METRICS.build(dict(type='Accuracy', thrs=0.6))
|
||||
metric.process(data_batch, pred)
|
||||
acc = metric.evaluate(6)
|
||||
self.assertIsInstance(acc, dict)
|
||||
self.assertAlmostEqual(acc['accuracy/top1'], 2 / 6 * 100, places=4)
|
||||
|
||||
# Test with multiple thrs
|
||||
metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None)))
|
||||
metric.process(data_batch, pred)
|
||||
acc = metric.evaluate(6)
|
||||
self.assertSetEqual(
|
||||
set(acc.keys()), {
|
||||
'accuracy/top1_thr-0.00', 'accuracy/top1_thr-0.60',
|
||||
'accuracy/top1_no-thr'
|
||||
})
|
||||
|
||||
# Test with invalid topk
|
||||
with self.assertRaisesRegex(ValueError, 'check the `val_evaluator`'):
|
||||
metric = METRICS.build(dict(type='Accuracy', topk=(1, 5)))
|
||||
metric.process(data_batch, pred)
|
||||
metric.evaluate(6)
|
||||
|
||||
# Test with label
|
||||
for sample in pred:
|
||||
del sample['pred_label']['score']
|
||||
metric = METRICS.build(dict(type='Accuracy', thrs=(0., 0.6, None)))
|
||||
metric.process(data_batch, pred)
|
||||
acc = metric.evaluate(6)
|
||||
self.assertIsInstance(acc, dict)
|
||||
self.assertAlmostEqual(acc['accuracy/top1'], 4 / 6 * 100, places=4)
|
||||
|
||||
# Test initialization
|
||||
metric = METRICS.build(dict(type='Accuracy', thrs=0.6))
|
||||
self.assertTupleEqual(metric.thrs, (0.6, ))
|
||||
metric = METRICS.build(dict(type='Accuracy', thrs=[0.6]))
|
||||
self.assertTupleEqual(metric.thrs, (0.6, ))
|
||||
metric = METRICS.build(dict(type='Accuracy', topk=5))
|
||||
self.assertTupleEqual(metric.topk, (5, ))
|
||||
metric = METRICS.build(dict(type='Accuracy', topk=[5]))
|
||||
self.assertTupleEqual(metric.topk, (5, ))
|
||||
|
||||
def test_calculate(self):
|
||||
"""Test using the metric from static method."""
|
||||
|
||||
# Test with score
|
||||
y_true = np.array([0, 0, 1, 2, 1, 0])
|
||||
y_label = torch.tensor([0, 0, 1, 2, 2, 2])
|
||||
y_score = [
|
||||
[0.7, 0.0, 0.3],
|
||||
[0.5, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.1],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
]
|
||||
|
||||
# Test with score
|
||||
acc = Accuracy.calculate(y_score, y_true, thrs=(0.6, ))
|
||||
self.assertIsInstance(acc, list)
|
||||
self.assertIsInstance(acc[0], list)
|
||||
self.assertIsInstance(acc[0][0], torch.Tensor)
|
||||
self.assertTensorEqual(acc[0][0], 2 / 6 * 100)
|
||||
|
||||
# Test with label
|
||||
acc = Accuracy.calculate(y_label, y_true, thrs=(0.6, ))
|
||||
self.assertIsInstance(acc, torch.Tensor)
|
||||
# the thrs will be ignored
|
||||
self.assertTensorEqual(acc, 4 / 6 * 100)
|
||||
|
||||
# Test with invalid inputs
|
||||
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
|
||||
Accuracy.calculate(y_label, 'hi')
|
||||
|
||||
# Test with invalid topk
|
||||
with self.assertRaisesRegex(ValueError, 'Top-5 accuracy .* is 3'):
|
||||
Accuracy.calculate(y_score, y_true, topk=(1, 5))
|
||||
|
||||
def assertTensorEqual(self,
|
||||
tensor: torch.Tensor,
|
||||
value: float,
|
||||
msg=None,
|
||||
**kwarg):
|
||||
tensor = tensor.to(torch.float32)
|
||||
value = torch.FloatTensor([value])
|
||||
try:
|
||||
torch.testing.assert_allclose(tensor, value, **kwarg)
|
||||
except AssertionError as e:
|
||||
self.fail(self._formatMessage(msg, str(e)))
|
||||
|
||||
|
||||
class TestSingleLabel(TestCase):
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
data_batch = [{
|
||||
'data_sample': ClsDataSample().set_gt_label(i).to_dict()
|
||||
} for i in [0, 0, 1, 2, 1, 0]]
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_pred_label(j).to_dict()
|
||||
for i, j in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
torch.tensor([0.4, 0.5, 0.1]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
], [0, 0, 1, 2, 2, 2])
|
||||
]
|
||||
|
||||
# Test with score (use score instead of label if score exists)
|
||||
metric = METRICS.build(
|
||||
dict(
|
||||
type='SingleLabelMetric',
|
||||
thrs=0.6,
|
||||
items=('precision', 'recall', 'f1-score', 'support')))
|
||||
metric.process(data_batch, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertAlmostEqual(
|
||||
res['single-label/precision'], (1 + 0 + 1 / 3) / 3 * 100, places=4)
|
||||
self.assertAlmostEqual(
|
||||
res['single-label/recall'], (1 / 3 + 0 + 1) / 3 * 100, places=4)
|
||||
self.assertAlmostEqual(
|
||||
res['single-label/f1-score'], (1 / 2 + 0 + 1 / 2) / 3 * 100,
|
||||
places=4)
|
||||
self.assertEqual(res['single-label/support'], 6)
|
||||
|
||||
# Test with multiple thrs
|
||||
metric = METRICS.build(
|
||||
dict(type='SingleLabelMetric', thrs=(0., 0.6, None)))
|
||||
metric.process(data_batch, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertSetEqual(
|
||||
set(res.keys()), {
|
||||
'single-label/precision_thr-0.00',
|
||||
'single-label/recall_thr-0.00',
|
||||
'single-label/f1-score_thr-0.00',
|
||||
'single-label/precision_thr-0.60',
|
||||
'single-label/recall_thr-0.60',
|
||||
'single-label/f1-score_thr-0.60',
|
||||
'single-label/precision_no-thr', 'single-label/recall_no-thr',
|
||||
'single-label/f1-score_no-thr'
|
||||
})
|
||||
|
||||
# Test with average mode "micro"
|
||||
metric = METRICS.build(
|
||||
dict(
|
||||
type='SingleLabelMetric',
|
||||
average='micro',
|
||||
items=('precision', 'recall', 'f1-score', 'support')))
|
||||
metric.process(data_batch, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertAlmostEqual(res['single-label/precision'], 66.666, places=2)
|
||||
self.assertAlmostEqual(res['single-label/recall'], 66.666, places=2)
|
||||
self.assertAlmostEqual(res['single-label/f1-score'], 66.666, places=2)
|
||||
self.assertEqual(res['single-label/support'], 6)
|
||||
|
||||
# Test with average mode None
|
||||
metric = METRICS.build(
|
||||
dict(
|
||||
type='SingleLabelMetric',
|
||||
average=None,
|
||||
items=('precision', 'recall', 'f1-score', 'support')))
|
||||
metric.process(data_batch, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
precision = res['single-label/precision']
|
||||
self.assertAlmostEqual(precision[0], 100., places=4)
|
||||
self.assertAlmostEqual(precision[1], 100., places=4)
|
||||
self.assertAlmostEqual(precision[2], 1 / 3 * 100, places=4)
|
||||
recall = res['single-label/recall']
|
||||
self.assertAlmostEqual(recall[0], 2 / 3 * 100, places=4)
|
||||
self.assertAlmostEqual(recall[1], 50., places=4)
|
||||
self.assertAlmostEqual(recall[2], 100., places=4)
|
||||
f1_score = res['single-label/f1-score']
|
||||
self.assertAlmostEqual(f1_score[0], 80., places=4)
|
||||
self.assertAlmostEqual(f1_score[1], 2 / 3 * 100, places=4)
|
||||
self.assertAlmostEqual(f1_score[2], 50., places=4)
|
||||
self.assertEqual(res['single-label/support'], [3, 2, 1])
|
||||
|
||||
# Test with label, the thrs will be ignored
|
||||
pred_no_score = copy.deepcopy(pred)
|
||||
for sample in pred_no_score:
|
||||
del sample['pred_label']['score']
|
||||
metric = METRICS.build(dict(type='SingleLabelMetric', thrs=(0., 0.6)))
|
||||
metric.process(data_batch, pred_no_score)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
# Expected values come from sklearn
|
||||
self.assertAlmostEqual(res['single-label/precision'], 77.777, places=2)
|
||||
self.assertAlmostEqual(res['single-label/recall'], 72.222, places=2)
|
||||
self.assertAlmostEqual(res['single-label/f1-score'], 65.555, places=2)
|
||||
|
||||
pred_no_num_classes = copy.deepcopy(pred_no_score)
|
||||
for sample in pred_no_num_classes:
|
||||
del sample['pred_label']['num_classes']
|
||||
with self.assertRaisesRegex(ValueError, 'neither `score` nor'):
|
||||
metric.process(data_batch, pred_no_num_classes)
|
||||
|
||||
# Test with empty items
|
||||
metric = METRICS.build(dict(type='SingleLabelMetric', items=tuple()))
|
||||
metric.process(data_batch, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertEqual(len(res), 0)
|
||||
|
||||
metric.process(data_batch, pred_no_score)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertEqual(len(res), 0)
|
||||
|
||||
# Test initialization
|
||||
metric = METRICS.build(dict(type='SingleLabelMetric', thrs=0.6))
|
||||
self.assertTupleEqual(metric.thrs, (0.6, ))
|
||||
metric = METRICS.build(dict(type='SingleLabelMetric', thrs=[0.6]))
|
||||
self.assertTupleEqual(metric.thrs, (0.6, ))
|
||||
|
||||
def test_calculate(self):
|
||||
"""Test using the metric from static method."""
|
||||
|
||||
# Test with score
|
||||
y_true = np.array([0, 0, 1, 2, 1, 0])
|
||||
y_label = torch.tensor([0, 0, 1, 2, 2, 2])
|
||||
y_score = [
|
||||
[0.7, 0.0, 0.3],
|
||||
[0.5, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.1],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
]
|
||||
|
||||
# Test with score
|
||||
res = SingleLabelMetric.calculate(y_score, y_true, thrs=(0.6, ))
|
||||
self.assertIsInstance(res, list)
|
||||
self.assertIsInstance(res[0], tuple)
|
||||
precision, recall, f1_score, support = res[0]
|
||||
self.assertTensorEqual(precision, (1 + 0 + 1 / 3) / 3 * 100)
|
||||
self.assertTensorEqual(recall, (1 / 3 + 0 + 1) / 3 * 100)
|
||||
self.assertTensorEqual(f1_score, (1 / 2 + 0 + 1 / 2) / 3 * 100)
|
||||
self.assertTensorEqual(support, 6)
|
||||
|
||||
# Test with label
|
||||
res = SingleLabelMetric.calculate(y_label, y_true, num_classes=3)
|
||||
self.assertIsInstance(res, tuple)
|
||||
precision, recall, f1_score, support = res
|
||||
# Expected values come from sklearn
|
||||
self.assertTensorEqual(precision, 77.7777)
|
||||
self.assertTensorEqual(recall, 72.2222)
|
||||
self.assertTensorEqual(f1_score, 65.5555)
|
||||
self.assertTensorEqual(support, 6)
|
||||
|
||||
# Test with invalid inputs
|
||||
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
|
||||
SingleLabelMetric.calculate(y_label, 'hi')
|
||||
|
||||
def assertTensorEqual(self,
|
||||
tensor: torch.Tensor,
|
||||
value: float,
|
||||
msg=None,
|
||||
**kwarg):
|
||||
tensor = tensor.to(torch.float32)
|
||||
value = torch.FloatTensor([value])
|
||||
try:
|
||||
torch.testing.assert_allclose(tensor, value, **kwarg)
|
||||
except AssertionError as e:
|
||||
self.fail(self._formatMessage(msg, str(e)))
|
Loading…
Reference in New Issue