Add multi label metrics
parent
62b046521e
commit
a9057e88c4
|
@ -22,35 +22,13 @@ test_pipeline = [
|
|||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=16,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='',
|
||||
ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/VOCdevkit/VOC2007/',
|
||||
ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/test.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/VOCdevkit/VOC2007/',
|
||||
ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/test.txt',
|
||||
pipeline=test_pipeline))
|
||||
evaluation = dict(
|
||||
interval=1, metric=['mAP', 'CP', 'OP', 'CR', 'OR', 'CF1', 'OF1'])
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/VOCdevkit/VOC2007/',
|
||||
# manually split the `trainval.txt` for standard training.
|
||||
ann_file='ImageSets/Main/trainval.txt',
|
||||
data_root='data/VOCdevkit/VOC2007',
|
||||
image_set_path='ImageSets/Layout/val.txt',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
|
@ -61,27 +39,28 @@ val_dataloader = dict(
|
|||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/VOCdevkit/VOC2007/',
|
||||
# manually split the `trainval.txt` for standard validation.
|
||||
ann_file='ImageSets/Main/test.txt',
|
||||
data_root='data/VOCdevkit/VOC2007',
|
||||
image_set_path='ImageSets/Layout/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(
|
||||
type='MultiLabelMetric',
|
||||
items=['mAP', 'CP', 'OP', 'CR', 'OR', 'CF1', 'OF1'])
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/VOCdevkit/VOC2007/',
|
||||
ann_file='ImageSets/Main/test.txt',
|
||||
data_prefix='val',
|
||||
data_root='data/VOCdevkit/VOC2007',
|
||||
image_set_path='ImageSets/Layout/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
# calculate precision_recall_f1 and mAP
|
||||
val_evaluator = [dict(type='MultiLabelMetric'), dict(type='AveragePrecision')]
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
from .single_label import Accuracy, SingleLabelMetric
|
||||
|
||||
__all__ = ['Accuracy', 'SingleLabelMetric']
|
||||
__all__ = [
|
||||
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,593 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import LabelData, MMLogger
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
from .single_label import _precision_recall_f1_support, to_tensor
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class MultiLabelMetric(BaseMetric):
|
||||
"""A collection of metrics for multi-label multi-class classification task
|
||||
based on confusion matrix.
|
||||
|
||||
It includes precision, recall, f1-score and support.
|
||||
|
||||
Args:
|
||||
thr (float, optional): Predictions with scores under the thresholds
|
||||
are considered as negative. Defaults to None.
|
||||
topk (int, optional): Predictions with the k-th highest scores are
|
||||
considered as positive. Defaults to None.
|
||||
items (Sequence[str]): The detailed metric items to evaluate. Here is
|
||||
the available options:
|
||||
|
||||
- `"precision"`: The ratio tp / (tp + fp) where tp is the
|
||||
number of true positives and fp the number of false
|
||||
positives.
|
||||
- `"recall"`: The ratio tp / (tp + fn) where tp is the number
|
||||
of true positives and fn the number of false negatives.
|
||||
- `"f1-score"`: The f1-score is the harmonic mean of the
|
||||
precision and recall.
|
||||
- `"support"`: The total number of positive of each category
|
||||
in the target.
|
||||
|
||||
Defaults to ('precision', 'recall', 'f1-score').
|
||||
average (str | None): The average method. It supports three average
|
||||
modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `"micro"`: Calculate metrics globally by counting the total
|
||||
true positives, false negatives and false positives.
|
||||
- `None`: Return scores of all categories.
|
||||
|
||||
Defaults to "macro".
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.metrics import MultiLabelMetric
|
||||
>>> # ------ The Basic Usage for category indices labels -------
|
||||
>>> y_pred = [[0], [1], [0, 1], [3]]
|
||||
>>> y_true = [[0, 3], [0, 2], [1], [3]]
|
||||
>>> # Output precision, recall, f1-score and support
|
||||
>>> MultiLabelMetric.calculate(
|
||||
... y_pred, y_true, pred_indices=True, target_indices=True, num_classes=4)
|
||||
(tensor(50.), tensor(50.), tensor(45.8333), tensor(6))
|
||||
>>> # ----------- The Basic Usage for one-hot labels -----------
|
||||
>>> y_pred = torch.tensor([[1, 1, 0, 0],
|
||||
... [1, 1, 0, 0],
|
||||
... [0, 0, 1, 0],
|
||||
... [0, 1, 0, 0],
|
||||
... [0, 1, 0, 0]])
|
||||
>>> y_true = torch.Tensor([[1, 1, 0, 0],
|
||||
... [0, 0, 1, 0],
|
||||
... [1, 1, 1, 0],
|
||||
... [1, 0, 0, 0],
|
||||
... [1, 0, 0, 0]])
|
||||
>>> MultiLabelMetric.calculate(y_pred, y_true)
|
||||
(tensor(43.7500), tensor(31.2500), tensor(33.3333), tensor(8))
|
||||
>>> # --------- The Basic Usage for one-hot pred scores ---------
|
||||
>>> y_pred = torch.rand(y_true.size())
|
||||
>>> y_pred
|
||||
tensor([[0.4575, 0.7335, 0.3934, 0.2572],
|
||||
[0.1318, 0.1004, 0.8248, 0.6448],
|
||||
[0.8349, 0.6294, 0.7896, 0.2061],
|
||||
[0.4037, 0.7308, 0.6713, 0.8374],
|
||||
[0.3779, 0.4836, 0.0313, 0.0067]])
|
||||
>>> # Calculate with different threshold.
|
||||
>>> MultiLabelMetric.calculate(y_pred, y_true, thr=0.1)
|
||||
(tensor(42.5000), tensor(75.), tensor(53.1746), tensor(8))
|
||||
>>> # Calculate with topk.
|
||||
>>> MultiLabelMetric.calculate(y_pred, y_true, topk=1)
|
||||
(tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8))
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.core import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> # The `data_batch` won't be used in this case, just use a fake.
|
||||
>>> data_batch = [
|
||||
... {'inputs': None, 'data_sample': ClsDataSample()}
|
||||
... for i in range(1000)]
|
||||
>>> pred = [
|
||||
... ClsDataSample().set_pred_score(torch.rand((5, ))).set_gt_score(torch.randint(2, size=(5, )))
|
||||
... for i in range(1000)]
|
||||
>>> evaluator = Evaluator(metrics=MultiLabelMetric(thrs=0.5))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'multi-label/precision': 50.72898037055408,
|
||||
'multi-label/recall': 50.06836461357571,
|
||||
'multi-label/f1-score': 50.384466955258475
|
||||
}
|
||||
>>> # Evaluate on each class by using topk strategy
|
||||
>>> evaluator = Evaluator(metrics=MultiLabelMetric(topk=1, average=None))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'multi-label/precision_top1_classwise': [48.22, 50.54, 50.99, 44.18, 52.5],
|
||||
'multi-label/recall_top1_classwise': [18.92, 19.22, 19.92, 20.0, 20.27],
|
||||
'multi-label/f1-score_top1_classwise': [27.18, 27.85, 28.65, 27.54, 29.25]
|
||||
}
|
||||
>>> # Evaluate by label data got from head
|
||||
>>> pred = [
|
||||
... ClsDataSample().set_pred_score(torch.rand((5, ))).set_pred_label(
|
||||
... torch.randint(2, size=(5, ))).set_gt_score(torch.randint(2, size=(5, )))
|
||||
... for i in range(1000)]
|
||||
>>> evaluator = Evaluator(metrics=MultiLabelMetric())
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'multi-label/precision': 20.28921606216292,
|
||||
'multi-label/recall': 38.628095855722314,
|
||||
'multi-label/f1-score': 26.603530359627918
|
||||
}
|
||||
""" # noqa: E501
|
||||
default_prefix: Optional[str] = 'multi-label'
|
||||
|
||||
def __init__(self,
|
||||
thr: Optional[float] = None,
|
||||
topk: Optional[int] = None,
|
||||
items: Sequence[str] = ('precision', 'recall', 'f1-score'),
|
||||
average: Optional[str] = 'macro',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
|
||||
logger = MMLogger.get_current_instance()
|
||||
if thr is None and topk is None:
|
||||
thr = 0.5
|
||||
logger.warning('Neither thr nor k is given, set thr as 0.5 by '
|
||||
'default.')
|
||||
elif thr is not None and topk is not None:
|
||||
logger.warning('Both thr and topk are given, '
|
||||
'use threshold in favor of top-k.')
|
||||
|
||||
self.thr = thr
|
||||
self.topk = topk
|
||||
self.average = average
|
||||
|
||||
for item in items:
|
||||
assert item in ['precision', 'recall', 'f1-score', 'support'], \
|
||||
f'The metric {item} is not supported by `SingleLabelMetric`,' \
|
||||
' please choose from "precision", "recall", "f1-score" and ' \
|
||||
'"support".'
|
||||
self.items = tuple(items)
|
||||
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for pred in predictions:
|
||||
result = dict()
|
||||
pred_label = pred['pred_label']
|
||||
gt_label = pred['gt_label']
|
||||
|
||||
result['pred_score'] = pred_label['score'].clone()
|
||||
num_classes = result['pred_score'].size()[-1]
|
||||
|
||||
if 'score' in gt_label:
|
||||
result['gt_score'] = gt_label['score'].clone()
|
||||
else:
|
||||
result['gt_score'] = LabelData.label_to_onehot(
|
||||
gt_label['label'], num_classes)
|
||||
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
# NOTICE: don't access `self.results` from the method. `self.results`
|
||||
# are a list of results from multiple batch, while the input `results`
|
||||
# are the collected results.
|
||||
metrics = {}
|
||||
|
||||
target = torch.stack([res['gt_score'] for res in results])
|
||||
pred = torch.stack([res['pred_score'] for res in results])
|
||||
|
||||
metric_res = self.calculate(
|
||||
pred,
|
||||
target,
|
||||
pred_indices=False,
|
||||
target_indices=False,
|
||||
average=self.average,
|
||||
thr=self.thr,
|
||||
topk=self.topk)
|
||||
|
||||
def pack_results(precision, recall, f1_score, support):
|
||||
single_metrics = {}
|
||||
if 'precision' in self.items:
|
||||
single_metrics['precision'] = precision
|
||||
if 'recall' in self.items:
|
||||
single_metrics['recall'] = recall
|
||||
if 'f1-score' in self.items:
|
||||
single_metrics['f1-score'] = f1_score
|
||||
if 'support' in self.items:
|
||||
single_metrics['support'] = support
|
||||
return single_metrics
|
||||
|
||||
if self.thr:
|
||||
suffix = '' if self.thr == 0.5 else f'_thr-{self.thr:.2f}'
|
||||
for k, v in pack_results(*metric_res).items():
|
||||
metrics[k + suffix] = v
|
||||
else:
|
||||
for k, v in pack_results(*metric_res).items():
|
||||
metrics[k + f'_top{self.topk}'] = v
|
||||
|
||||
result_metrics = dict()
|
||||
for k, v in metrics.items():
|
||||
if self.average is None:
|
||||
result_metrics[k + '_classwise'] = v.detach().cpu().tolist()
|
||||
elif self.average == 'macro':
|
||||
result_metrics[k] = v.item()
|
||||
else:
|
||||
result_metrics[k + f'_{self.average}'] = v.item()
|
||||
return result_metrics
|
||||
|
||||
@staticmethod
|
||||
def calculate(
|
||||
pred: Union[torch.Tensor, np.ndarray, Sequence],
|
||||
target: Union[torch.Tensor, np.ndarray, Sequence],
|
||||
pred_indices: bool = False,
|
||||
target_indices: bool = False,
|
||||
average: Optional[str] = 'macro',
|
||||
thr: Optional[float] = None,
|
||||
topk: Optional[int] = None,
|
||||
num_classes: Optional[int] = None
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Calculate the precision, recall, f1-score.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.ndarray | Sequence): The prediction
|
||||
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
|
||||
shape ``(N, num_classes)`` or a sequence of index/onehot
|
||||
format labels.
|
||||
target (torch.Tensor | np.ndarray | Sequence): The prediction
|
||||
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
|
||||
shape ``(N, num_classes)`` or a sequence of index/onehot
|
||||
format labels.
|
||||
pred_indices (bool): Whether the ``pred`` is a sequence of
|
||||
category index labels. If True, ``num_classes`` must be set.
|
||||
Defaults to False.
|
||||
target_indices (bool): Whether the ``target`` is a sequence of
|
||||
category index labels. If True, ``num_classes`` must be set.
|
||||
Defaults to False.
|
||||
average (str | None): The average method. It supports three average
|
||||
modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and
|
||||
calculate the mean value over all categories.
|
||||
- `"micro"`: Calculate metrics globally by counting the
|
||||
total true positives, false negatives and false
|
||||
positives.
|
||||
- `None`: Return scores of all categories.
|
||||
|
||||
Defaults to "macro".
|
||||
thr (float, optional): Predictions with scores under the thresholds
|
||||
are considered as negative. Defaults to None.
|
||||
topk (int, optional): Predictions with the k-th highest scores are
|
||||
considered as positive. Defaults to None.
|
||||
num_classes (Optional, int): The number of classes. If the ``pred``
|
||||
is indices instead of onehot, this argument is required.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple: The tuple contains precision, recall and f1-score.
|
||||
And the type of each item is:
|
||||
|
||||
- torch.Tensor: A tensor for each metric. The shape is (1, ) if
|
||||
``average`` is not None, and (C, ) if ``average`` is None.
|
||||
|
||||
Notes:
|
||||
If both ``thr`` and ``topk`` are set, use ``thr` to determine
|
||||
positive predictions. If neither is set, use ``thr=0.5`` as
|
||||
default.
|
||||
"""
|
||||
average_options = ['micro', 'macro', None]
|
||||
assert average in average_options, 'Invalid `average` argument, ' \
|
||||
f'please specicy from {average_options}.'
|
||||
|
||||
def _format_label(label, is_indices):
|
||||
"""format various label to torch.Tensor."""
|
||||
if isinstance(label, np.ndarray):
|
||||
assert label.ndim == 2, 'The shape `pred` and `target` ' \
|
||||
'array must be (N, num_classes).'
|
||||
label = torch.from_numpy(label)
|
||||
elif isinstance(label, torch.Tensor):
|
||||
assert label.ndim == 2, 'The shape `pred` and `target` ' \
|
||||
'tensor must be (N, num_classes).'
|
||||
elif isinstance(label, Sequence):
|
||||
if is_indices:
|
||||
assert num_classes is not None, 'For index-type labels, ' \
|
||||
'please specify `num_classes`.'
|
||||
label = torch.stack([
|
||||
LabelData.label_to_onehot(
|
||||
to_tensor(indices), num_classes)
|
||||
for indices in label
|
||||
])
|
||||
else:
|
||||
label = torch.stack(
|
||||
[to_tensor(onehot) for onehot in label])
|
||||
else:
|
||||
raise TypeError(
|
||||
'The `pred` and `target` must be type of torch.tensor or '
|
||||
f'np.ndarray or sequence but get {type(label)}.')
|
||||
return label
|
||||
|
||||
pred = _format_label(pred, pred_indices)
|
||||
target = _format_label(target, target_indices).long()
|
||||
|
||||
assert pred.shape == target.shape, \
|
||||
f"The size of pred ({pred.shape}) doesn't match "\
|
||||
f'the target ({target.shape}).'
|
||||
|
||||
if num_classes is not None:
|
||||
assert pred.size(1) == num_classes, \
|
||||
f'The shape of `pred` ({pred.shape}) '\
|
||||
f"doesn't match the num_classes ({num_classes})."
|
||||
num_classes = pred.size(1)
|
||||
|
||||
thr = 0.5 if (thr is None and topk is None) else thr
|
||||
|
||||
if thr is not None:
|
||||
# a label is predicted positive if larger than thr
|
||||
pos_inds = (pred >= thr).long()
|
||||
else:
|
||||
# top-k labels will be predicted positive for any example
|
||||
_, topk_indices = pred.topk(topk)
|
||||
pos_inds = torch.zeros_like(pred).scatter_(1, topk_indices, 1)
|
||||
pos_inds = pos_inds.long()
|
||||
|
||||
return _precision_recall_f1_support(pos_inds, target, average)
|
||||
|
||||
|
||||
def _average_precision(pred: torch.Tensor,
|
||||
target: torch.Tensor) -> torch.Tensor:
|
||||
r"""Calculate the average precision for a single class.
|
||||
|
||||
AP summarizes a precision-recall curve as the weighted mean of maximum
|
||||
precisions obtained for any r'>r, where r is the recall:
|
||||
|
||||
.. math::
|
||||
\text{AP} = \sum_n (R_n - R_{n-1}) P_n
|
||||
|
||||
Note that no approximation is involved since the curve is piecewise
|
||||
constant.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The model prediction with shape
|
||||
``(N, num_classes)``.
|
||||
target (torch.Tensor): The target of predictions with shape
|
||||
``(N, num_classes)``.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: average precision result.
|
||||
"""
|
||||
assert pred.shape == target.shape, \
|
||||
f"The size of pred ({pred.shape}) doesn't match "\
|
||||
f'the target ({target.shape}).'
|
||||
|
||||
# a small value for division by zero errors
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
|
||||
# sort examples
|
||||
sorted_pred_inds = torch.argsort(pred, dim=0, descending=True)
|
||||
sorted_target = target[sorted_pred_inds]
|
||||
|
||||
# get indexes when gt_true is positive
|
||||
pos_inds = sorted_target == 1
|
||||
|
||||
# Calculate cumulative tp case numbers
|
||||
tps = torch.cumsum(pos_inds, 0)
|
||||
total_pos = tps[-1].item() # the last of tensor may change later
|
||||
|
||||
# Calculate cumulative tp&fp(pred_poss) case numbers
|
||||
pred_pos_nums = torch.arange(1, len(sorted_target) + 1)
|
||||
pred_pos_nums[pred_pos_nums < eps] = eps
|
||||
|
||||
tps[torch.logical_not(pos_inds)] = 0
|
||||
precision = tps / pred_pos_nums
|
||||
ap = torch.sum(precision, 0) / max(total_pos, eps)
|
||||
return ap
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class AveragePrecision(BaseMetric):
|
||||
"""Calculate the average precision with respect of classes.
|
||||
|
||||
Args:
|
||||
average (str | None): The average method. It supports two modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and calculate
|
||||
the mean value over all categories.
|
||||
- `None`: Return scores of all categories.
|
||||
|
||||
Defaults to "macro".
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] `Wikipedia entry for the Average precision
|
||||
<https://en.wikipedia.org/w/index.php?title=Information_retrieval&
|
||||
oldid=793358396#Average_precision>`_
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from mmcls.metrics import AveragePrecision
|
||||
>>> # --------- The Basic Usage for one-hot pred scores ---------
|
||||
>>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2],
|
||||
... [0.1, 0.2, 0.2, 0.1],
|
||||
... [0.7, 0.5, 0.9, 0.3],
|
||||
... [0.8, 0.1, 0.1, 0.2]])
|
||||
>>> y_true = torch.Tensor([[1, 1, 0, 0],
|
||||
... [0, 1, 0, 0],
|
||||
... [0, 0, 1, 0],
|
||||
... [1, 0, 0, 0]])
|
||||
>>> AveragePrecision.calculate(y_pred, y_true)
|
||||
tensor(70.833)
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.core import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> # The `data_batch` won't be used in this case, just use a fake.
|
||||
>>> data_batch = [
|
||||
... {'inputs': None, 'data_sample': ClsDataSample()}
|
||||
... for i in range(4)]
|
||||
>>> pred = [
|
||||
... ClsDataSample().set_pred_score(i).set_gt_score(j)
|
||||
... for i, j in zip(y_pred, y_true)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=AveragePrecision())
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(5)
|
||||
{'multi-label/mAP': 70.83333587646484}
|
||||
>>> # Evaluate on each class
|
||||
>>> evaluator = Evaluator(metrics=AveragePrecision(average=None))
|
||||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(5)
|
||||
{'multi-label/AP_classwise': [100., 83.33, 100., 0.]}
|
||||
"""
|
||||
default_prefix: Optional[str] = 'multi-label'
|
||||
|
||||
def __init__(self,
|
||||
average: Optional[str] = 'macro',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
self.average = average
|
||||
|
||||
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
|
||||
for pred in predictions:
|
||||
result = dict()
|
||||
pred_label = pred['pred_label']
|
||||
gt_label = pred['gt_label']
|
||||
|
||||
result['pred_score'] = pred_label['score']
|
||||
num_classes = result['pred_score'].size()[-1]
|
||||
|
||||
if 'score' in gt_label:
|
||||
result['gt_score'] = gt_label['score']
|
||||
else:
|
||||
result['gt_score'] = LabelData.label_to_onehot(
|
||||
gt_label['label'], num_classes)
|
||||
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
# NOTICE: don't access `self.results` from the method. `self.results`
|
||||
# are a list of results from multiple batch, while the input `results`
|
||||
# are the collected results.
|
||||
|
||||
# concat
|
||||
target = torch.stack([res['gt_score'] for res in results])
|
||||
pred = torch.stack([res['pred_score'] for res in results])
|
||||
|
||||
ap = self.calculate(pred, target, self.average)
|
||||
|
||||
result_metrics = dict()
|
||||
|
||||
if self.average is None:
|
||||
result_metrics['AP_classwise'] = ap.detach().cpu().tolist()
|
||||
else:
|
||||
result_metrics['mAP'] = ap.item()
|
||||
|
||||
return result_metrics
|
||||
|
||||
@staticmethod
|
||||
def calculate(pred: Union[torch.Tensor, np.ndarray],
|
||||
target: Union[torch.Tensor, np.ndarray],
|
||||
average: Optional[str] = 'macro') -> torch.Tensor:
|
||||
r"""Calculate the average precision for a single class.
|
||||
|
||||
AP summarizes a precision-recall curve as the weighted mean of maximum
|
||||
precisions obtained for any r'>r, where r is the recall:
|
||||
|
||||
.. math::
|
||||
\text{AP} = \sum_n (R_n - R_{n-1}) P_n
|
||||
|
||||
Note that no approximation is involved since the curve is piecewise
|
||||
constant.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.ndarray): The model predictions with
|
||||
shape ``(N, num_classes)``.
|
||||
target (torch.Tensor | np.ndarray): The target of predictions
|
||||
with shape ``(N, num_classes)``.
|
||||
average (str | None): The average method. It supports two modes:
|
||||
|
||||
- `"macro"`: Calculate metrics for each category, and
|
||||
calculate the mean value over all categories.
|
||||
- `None`: Return scores of all categories.
|
||||
|
||||
Defaults to "macro".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the average precision of all classes.
|
||||
"""
|
||||
average_options = ['macro', None]
|
||||
assert average in average_options, 'Invalid `average` argument, ' \
|
||||
f'please specicy from {average_options}.'
|
||||
|
||||
pred = to_tensor(pred)
|
||||
target = to_tensor(target)
|
||||
assert pred.ndim == 2 and pred.shape == target.shape, \
|
||||
'Both `pred` and `target` should have shape `(N, num_classes)`.'
|
||||
|
||||
num_classes = pred.shape[1]
|
||||
ap = pred.new_zeros(num_classes)
|
||||
for k in range(num_classes):
|
||||
ap[k] = _average_precision(pred[:, k], target[:, k])
|
||||
if average == 'macro':
|
||||
return ap.mean() * 100.0
|
||||
else:
|
||||
return ap * 100
|
|
@ -21,6 +21,37 @@ def to_tensor(value):
|
|||
return value
|
||||
|
||||
|
||||
def _precision_recall_f1_support(pred_positive, gt_positive, average):
|
||||
"""calculate base classification task metrics, such as precision, recall,
|
||||
f1_score, support."""
|
||||
average_options = ['micro', 'macro', None]
|
||||
assert average in average_options, 'Invalid `average` argument, ' \
|
||||
f'please specicy from {average_options}.'
|
||||
|
||||
class_correct = (pred_positive & gt_positive)
|
||||
if average == 'micro':
|
||||
tp_sum = class_correct.sum()
|
||||
pred_sum = pred_positive.sum()
|
||||
gt_sum = gt_positive.sum()
|
||||
else:
|
||||
tp_sum = class_correct.sum(0)
|
||||
pred_sum = pred_positive.sum(0)
|
||||
gt_sum = gt_positive.sum(0)
|
||||
|
||||
precision = tp_sum / torch.clamp(pred_sum, min=1.) * 100
|
||||
recall = tp_sum / torch.clamp(gt_sum, min=1.) * 100
|
||||
f1_score = 2 * precision * recall / torch.clamp(
|
||||
precision + recall, min=torch.finfo(torch.float32).eps)
|
||||
if average in ['macro', 'micro']:
|
||||
precision = precision.mean(0)
|
||||
recall = recall.mean(0)
|
||||
f1_score = f1_score.mean(0)
|
||||
support = gt_sum.sum(0)
|
||||
else:
|
||||
support = gt_sum
|
||||
return precision, recall, f1_score, support
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class Accuracy(BaseMetric):
|
||||
"""Top-k accuracy evaluation metric.
|
||||
|
@ -327,9 +358,9 @@ class SingleLabelMetric(BaseMetric):
|
|||
>>> evaluator.process(data_batch, pred)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{
|
||||
'single-label/precision': [21.14, 18.69, 17.17, 19.42, 16.14],
|
||||
'single-label/recall': [18.5, 18.5, 17.0, 20.0, 18.0],
|
||||
'single-label/f1-score': [19.73, 18.59, 17.09, 19.70, 17.02]
|
||||
'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1],
|
||||
'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0],
|
||||
'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0]
|
||||
}
|
||||
"""
|
||||
default_prefix: Optional[str] = 'single-label'
|
||||
|
@ -438,13 +469,17 @@ class SingleLabelMetric(BaseMetric):
|
|||
num_classes=results[0]['num_classes'])
|
||||
metrics = pack_results(*res)
|
||||
|
||||
result_metrics = dict()
|
||||
for k, v in metrics.items():
|
||||
if self.average is not None:
|
||||
metrics[k] = v.item()
|
||||
else:
|
||||
metrics[k] = v.cpu().detach().tolist()
|
||||
|
||||
return metrics
|
||||
if self.average is None:
|
||||
result_metrics[k + '_classwise'] = v.cpu().detach().tolist()
|
||||
elif self.average == 'micro':
|
||||
result_metrics[k + f'_{self.average}'] = v.item()
|
||||
else:
|
||||
result_metrics[k] = v.item()
|
||||
|
||||
return result_metrics
|
||||
|
||||
@staticmethod
|
||||
def calculate(
|
||||
|
@ -503,38 +538,14 @@ class SingleLabelMetric(BaseMetric):
|
|||
f"The size of pred ({pred.size(0)}) doesn't match "\
|
||||
f'the target ({target.size(0)}).'
|
||||
|
||||
def _do_calculate(pred_positive, gt_positive):
|
||||
class_correct = (pred_positive & gt_positive)
|
||||
if average == 'micro':
|
||||
tp_sum = class_correct.sum()
|
||||
pred_sum = pred_positive.sum()
|
||||
gt_sum = gt_positive.sum()
|
||||
else:
|
||||
tp_sum = class_correct.sum(0)
|
||||
pred_sum = pred_positive.sum(0)
|
||||
gt_sum = gt_positive.sum(0)
|
||||
|
||||
precision = tp_sum / np.maximum(pred_sum, 1.) * 100
|
||||
recall = tp_sum / np.maximum(gt_sum, 1.) * 100
|
||||
f1_score = 2 * precision * recall / np.maximum(
|
||||
precision + recall,
|
||||
torch.finfo(torch.float32).eps)
|
||||
if average in ['macro', 'micro']:
|
||||
precision = precision.mean(0, keepdim=True)
|
||||
recall = recall.mean(0, keepdim=True)
|
||||
f1_score = f1_score.mean(0, keepdim=True)
|
||||
support = gt_sum.sum(0, keepdim=True)
|
||||
else:
|
||||
support = gt_sum
|
||||
return precision, recall, f1_score, support
|
||||
|
||||
if pred.ndim == 1:
|
||||
assert num_classes is not None, \
|
||||
'Please specicy the `num_classes` if the `pred` is labels ' \
|
||||
'intead of scores.'
|
||||
gt_positive = F.one_hot(target.flatten(), num_classes)
|
||||
pred_positive = F.one_hot(pred.to(torch.int64), num_classes)
|
||||
return _do_calculate(pred_positive, gt_positive)
|
||||
return _precision_recall_f1_support(pred_positive, gt_positive,
|
||||
average)
|
||||
else:
|
||||
# For pred score, calculate on all thresholds.
|
||||
num_classes = pred.size(1)
|
||||
|
@ -549,6 +560,8 @@ class SingleLabelMetric(BaseMetric):
|
|||
pred_positive = F.one_hot(pred_label, num_classes)
|
||||
if thr is not None:
|
||||
pred_positive[pred_score <= thr] = 0
|
||||
results.append(_do_calculate(pred_positive, gt_positive))
|
||||
results.append(
|
||||
_precision_recall_f1_support(pred_positive, gt_positive,
|
||||
average))
|
||||
|
||||
return results
|
||||
|
|
|
@ -4,5 +4,6 @@ interrogate
|
|||
isort==4.3.21
|
||||
mmdet
|
||||
pytest
|
||||
sklearn
|
||||
xdoctest >= 0.10.0
|
||||
yapf
|
||||
|
|
|
@ -0,0 +1,398 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import sklearn.metrics
|
||||
import torch
|
||||
from mmengine.evaluator import Evaluator
|
||||
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.metrics import AveragePrecision, MultiLabelMetric
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestMultiLabel(TestCase):
|
||||
|
||||
def test_calculate(self):
|
||||
"""Test using the metric from static method."""
|
||||
|
||||
y_true = [[0], [1, 3], [0, 1, 2], [3]]
|
||||
y_pred = [[0, 3], [0, 2], [1, 2], [2, 3]]
|
||||
y_true_binary = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 1],
|
||||
[1, 1, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
y_pred_binary = np.array([
|
||||
[1, 0, 0, 1],
|
||||
[1, 0, 1, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
])
|
||||
y_pred_score = np.array([
|
||||
[0.8, 0, 0, 0.6],
|
||||
[0.2, 0, 0.6, 0],
|
||||
[0, 0.9, 0.6, 0],
|
||||
[0, 0, 0.2, 0.3],
|
||||
])
|
||||
|
||||
# Test with sequence of category indexes
|
||||
res = MultiLabelMetric.calculate(
|
||||
y_pred,
|
||||
y_true,
|
||||
pred_indices=True,
|
||||
target_indices=True,
|
||||
num_classes=4)
|
||||
self.assertIsInstance(res, tuple)
|
||||
precision, recall, f1_score, support = res
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true_binary, y_pred_binary, average='macro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true_binary, y_pred_binary, average='macro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true_binary, y_pred_binary, average='macro') * 100
|
||||
self.assertTensorEqual(precision, expect_precision)
|
||||
self.assertTensorEqual(recall, expect_recall)
|
||||
self.assertTensorEqual(f1_score, expect_f1)
|
||||
self.assertTensorEqual(support, 7)
|
||||
|
||||
# Test with onehot input
|
||||
res = MultiLabelMetric.calculate(y_pred_binary,
|
||||
torch.from_numpy(y_true_binary))
|
||||
self.assertIsInstance(res, tuple)
|
||||
precision, recall, f1_score, support = res
|
||||
# Expected values come from sklearn
|
||||
self.assertTensorEqual(precision, expect_precision)
|
||||
self.assertTensorEqual(recall, expect_recall)
|
||||
self.assertTensorEqual(f1_score, expect_f1)
|
||||
self.assertTensorEqual(support, 7)
|
||||
|
||||
# Test with topk argument
|
||||
res = MultiLabelMetric.calculate(
|
||||
y_pred_score, y_true, target_indices=True, topk=1, num_classes=4)
|
||||
self.assertIsInstance(res, tuple)
|
||||
precision, recall, f1_score, support = res
|
||||
# Expected values come from sklearn
|
||||
top1_y_pred = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true_binary, top1_y_pred, average='macro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true_binary, top1_y_pred, average='macro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true_binary, top1_y_pred, average='macro') * 100
|
||||
self.assertTensorEqual(precision, expect_precision)
|
||||
self.assertTensorEqual(recall, expect_recall)
|
||||
self.assertTensorEqual(f1_score, expect_f1)
|
||||
self.assertTensorEqual(support, 7)
|
||||
|
||||
# Test with thr argument
|
||||
res = MultiLabelMetric.calculate(
|
||||
y_pred_score, y_true, target_indices=True, thr=0.25, num_classes=4)
|
||||
self.assertIsInstance(res, tuple)
|
||||
precision, recall, f1_score, support = res
|
||||
# Expected values come from sklearn
|
||||
thr_y_pred = np.array([
|
||||
[1, 0, 0, 1],
|
||||
[0, 0, 1, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true_binary, thr_y_pred, average='macro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true_binary, thr_y_pred, average='macro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true_binary, thr_y_pred, average='macro') * 100
|
||||
self.assertTensorEqual(precision, expect_precision)
|
||||
self.assertTensorEqual(recall, expect_recall)
|
||||
self.assertTensorEqual(f1_score, expect_f1)
|
||||
self.assertTensorEqual(support, 7)
|
||||
|
||||
# Test with invalid inputs
|
||||
with self.assertRaisesRegex(TypeError, "<class 'str'> is not"):
|
||||
MultiLabelMetric.calculate(y_pred, 'hi', num_classes=10)
|
||||
|
||||
# Test with invalid input
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'Invalid `average` argument,'):
|
||||
MultiLabelMetric.calculate(
|
||||
y_pred, y_true, average='m', num_classes=10)
|
||||
|
||||
y_true_binary = np.array([[1, 0, 0, 0], [0, 1, 0, 1]])
|
||||
y_pred_binary = np.array([[1, 0, 0, 1], [1, 0, 1, 0], [0, 1, 1, 0]])
|
||||
# Test with invalid inputs
|
||||
with self.assertRaisesRegex(AssertionError, 'The size of pred'):
|
||||
MultiLabelMetric.calculate(y_pred_binary, y_true_binary)
|
||||
|
||||
# Test with invalid inputs
|
||||
with self.assertRaisesRegex(TypeError, 'The `pred` and `target` must'):
|
||||
MultiLabelMetric.calculate(y_pred_binary, 5)
|
||||
|
||||
def test_evaluate(self):
|
||||
fake_data_batch = [{
|
||||
'inputs': None,
|
||||
'data_sample': ClsDataSample()
|
||||
} for _ in range(4)]
|
||||
|
||||
y_true = [[0], [1, 3], [0, 1, 2], [3]]
|
||||
y_true_binary = torch.tensor([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 1],
|
||||
[1, 1, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
y_pred_score = torch.tensor([
|
||||
[0.8, 0, 0, 0.6],
|
||||
[0.2, 0, 0.6, 0],
|
||||
[0, 0.9, 0.6, 0],
|
||||
[0, 0, 0.2, 0.3],
|
||||
])
|
||||
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
|
||||
for i, j in zip(y_pred_score, y_true)
|
||||
]
|
||||
|
||||
# Test with default argument
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric'))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
thr05_y_pred = np.array([
|
||||
[1, 0, 0, 1],
|
||||
[0, 0, 1, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 0, 0],
|
||||
])
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true_binary, thr05_y_pred, average='macro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true_binary, thr05_y_pred, average='macro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true_binary, thr05_y_pred, average='macro') * 100
|
||||
self.assertEqual(res['multi-label/precision'], expect_precision)
|
||||
self.assertEqual(res['multi-label/recall'], expect_recall)
|
||||
self.assertEqual(res['multi-label/f1-score'], expect_f1)
|
||||
|
||||
# Test with topk argument
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', topk=1))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
top1_y_pred = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true_binary, top1_y_pred, average='macro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true_binary, top1_y_pred, average='macro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true_binary, top1_y_pred, average='macro') * 100
|
||||
self.assertEqual(res['multi-label/precision_top1'], expect_precision)
|
||||
self.assertEqual(res['multi-label/recall_top1'], expect_recall)
|
||||
self.assertEqual(res['multi-label/f1-score_top1'], expect_f1)
|
||||
|
||||
# Test with both argument
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', thr=0.25, topk=1))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
# Expected values come from sklearn
|
||||
thr_y_pred = np.array([
|
||||
[1, 0, 0, 1],
|
||||
[0, 0, 1, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true_binary, thr_y_pred, average='macro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true_binary, thr_y_pred, average='macro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true_binary, thr_y_pred, average='macro') * 100
|
||||
self.assertEqual(res['multi-label/precision_thr-0.25'],
|
||||
expect_precision)
|
||||
self.assertEqual(res['multi-label/recall_thr-0.25'], expect_recall)
|
||||
self.assertEqual(res['multi-label/f1-score_thr-0.25'], expect_f1)
|
||||
|
||||
# Test with average micro
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', average='micro'))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
# Expected values come from sklearn
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true_binary, thr05_y_pred, average='micro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true_binary, thr05_y_pred, average='micro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true_binary, thr05_y_pred, average='micro') * 100
|
||||
self.assertAlmostEqual(
|
||||
res['multi-label/precision_micro'], expect_precision, places=4)
|
||||
self.assertAlmostEqual(
|
||||
res['multi-label/recall_micro'], expect_recall, places=4)
|
||||
self.assertAlmostEqual(
|
||||
res['multi-label/f1-score_micro'], expect_f1, places=4)
|
||||
|
||||
# Test with average None
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', average=None))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
# Expected values come from sklearn
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true_binary, thr05_y_pred, average=None) * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true_binary, thr05_y_pred, average=None) * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true_binary, thr05_y_pred, average=None) * 100
|
||||
np.testing.assert_allclose(res['multi-label/precision_classwise'],
|
||||
expect_precision)
|
||||
np.testing.assert_allclose(res['multi-label/recall_classwise'],
|
||||
expect_recall)
|
||||
np.testing.assert_allclose(res['multi-label/f1-score_classwise'],
|
||||
expect_f1)
|
||||
|
||||
# Test with gt_score
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
|
||||
for i, j in zip(y_pred_score, y_true_binary)
|
||||
]
|
||||
|
||||
evaluator = Evaluator(dict(type='MultiLabelMetric', items=['support']))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertEqual(res['multi-label/support'], 7)
|
||||
|
||||
def assertTensorEqual(self,
|
||||
tensor: torch.Tensor,
|
||||
value: float,
|
||||
msg=None,
|
||||
**kwarg):
|
||||
tensor = tensor.to(torch.float32)
|
||||
if tensor.dim() == 0:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
value = torch.FloatTensor([value])
|
||||
try:
|
||||
torch.testing.assert_allclose(tensor, value, **kwarg)
|
||||
except AssertionError as e:
|
||||
self.fail(self._formatMessage(msg, str(e) + str(tensor)))
|
||||
|
||||
|
||||
class TestAveragePrecision(TestCase):
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
y_pred = torch.tensor([
|
||||
[0.9, 0.8, 0.3, 0.2],
|
||||
[0.1, 0.2, 0.2, 0.1],
|
||||
[0.7, 0.5, 0.9, 0.3],
|
||||
[0.8, 0.1, 0.1, 0.2],
|
||||
])
|
||||
y_true = torch.tensor([
|
||||
[1, 1, 0, 0],
|
||||
[0, 1, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[1, 0, 0, 0],
|
||||
])
|
||||
|
||||
fake_data_batch = [{
|
||||
'inputs': None,
|
||||
'data_sample': ClsDataSample()
|
||||
} for _ in range(4)]
|
||||
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j)
|
||||
for i, j in zip(y_pred, y_true)
|
||||
]
|
||||
|
||||
# Test with default macro avergae
|
||||
evaluator = Evaluator(dict(type='AveragePrecision'))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
res = evaluator.evaluate(5)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertAlmostEqual(res['multi-label/mAP'], 70.83333, places=4)
|
||||
|
||||
# Test with average mode None
|
||||
evaluator = Evaluator(dict(type='AveragePrecision', average=None))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
res = evaluator.evaluate(5)
|
||||
self.assertIsInstance(res, dict)
|
||||
aps = res['multi-label/AP_classwise']
|
||||
self.assertAlmostEqual(aps[0], 100., places=4)
|
||||
self.assertAlmostEqual(aps[1], 83.3333, places=4)
|
||||
self.assertAlmostEqual(aps[2], 100, places=4)
|
||||
self.assertAlmostEqual(aps[3], 0, places=4)
|
||||
|
||||
# Test with gt_label without score
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
|
||||
for i, j in zip(y_pred, [[0, 1], [1], [2], [0]])
|
||||
]
|
||||
evaluator = Evaluator(dict(type='AveragePrecision'))
|
||||
evaluator.process(fake_data_batch, pred)
|
||||
res = evaluator.evaluate(5)
|
||||
self.assertAlmostEqual(res['multi-label/mAP'], 70.83333, places=4)
|
||||
|
||||
def test_calculate(self):
|
||||
"""Test using the metric from static method."""
|
||||
|
||||
y_true = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 1],
|
||||
[1, 1, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
y_pred = np.array([
|
||||
[0.9, 0.8, 0.3, 0.2],
|
||||
[0.1, 0.2, 0.2, 0.1],
|
||||
[0.7, 0.5, 0.9, 0.3],
|
||||
[0.8, 0.1, 0.1, 0.2],
|
||||
])
|
||||
|
||||
ap_score = AveragePrecision.calculate(y_pred, y_true)
|
||||
expect_ap = sklearn.metrics.average_precision_score(y_true,
|
||||
y_pred) * 100
|
||||
self.assertTensorEqual(ap_score, expect_ap)
|
||||
|
||||
# Test with invalid inputs
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'Invalid `average` argument,'):
|
||||
AveragePrecision.calculate(y_pred, y_true, average='m')
|
||||
|
||||
y_true = np.array([[1, 0, 0, 0], [0, 1, 0, 1]])
|
||||
y_pred = np.array([[1, 0, 0, 1], [1, 0, 1, 0], [0, 1, 1, 0]])
|
||||
# Test with invalid inputs
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'Both `pred` and `target`'):
|
||||
AveragePrecision.calculate(y_pred, y_true)
|
||||
|
||||
# Test with invalid inputs
|
||||
with self.assertRaisesRegex(TypeError, "<class 'int'> is not an"):
|
||||
AveragePrecision.calculate(y_pred, 5)
|
||||
|
||||
def assertTensorEqual(self,
|
||||
tensor: torch.Tensor,
|
||||
value: float,
|
||||
msg=None,
|
||||
**kwarg):
|
||||
tensor = tensor.to(torch.float32)
|
||||
if tensor.dim() == 0:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
value = torch.FloatTensor([value])
|
||||
try:
|
||||
torch.testing.assert_allclose(tensor, value, **kwarg)
|
||||
except AssertionError as e:
|
||||
self.fail(self._formatMessage(msg, str(e) + str(tensor)))
|
|
@ -183,10 +183,13 @@ class TestSingleLabel(TestCase):
|
|||
metric.process(data_batch, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertAlmostEqual(res['single-label/precision'], 66.666, places=2)
|
||||
self.assertAlmostEqual(res['single-label/recall'], 66.666, places=2)
|
||||
self.assertAlmostEqual(res['single-label/f1-score'], 66.666, places=2)
|
||||
self.assertEqual(res['single-label/support'], 6)
|
||||
self.assertAlmostEqual(
|
||||
res['single-label/precision_micro'], 66.666, places=2)
|
||||
self.assertAlmostEqual(
|
||||
res['single-label/recall_micro'], 66.666, places=2)
|
||||
self.assertAlmostEqual(
|
||||
res['single-label/f1-score_micro'], 66.666, places=2)
|
||||
self.assertEqual(res['single-label/support_micro'], 6)
|
||||
|
||||
# Test with average mode None
|
||||
metric = METRICS.build(
|
||||
|
@ -197,19 +200,19 @@ class TestSingleLabel(TestCase):
|
|||
metric.process(data_batch, pred)
|
||||
res = metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
precision = res['single-label/precision']
|
||||
precision = res['single-label/precision_classwise']
|
||||
self.assertAlmostEqual(precision[0], 100., places=4)
|
||||
self.assertAlmostEqual(precision[1], 100., places=4)
|
||||
self.assertAlmostEqual(precision[2], 1 / 3 * 100, places=4)
|
||||
recall = res['single-label/recall']
|
||||
recall = res['single-label/recall_classwise']
|
||||
self.assertAlmostEqual(recall[0], 2 / 3 * 100, places=4)
|
||||
self.assertAlmostEqual(recall[1], 50., places=4)
|
||||
self.assertAlmostEqual(recall[2], 100., places=4)
|
||||
f1_score = res['single-label/f1-score']
|
||||
f1_score = res['single-label/f1-score_classwise']
|
||||
self.assertAlmostEqual(f1_score[0], 80., places=4)
|
||||
self.assertAlmostEqual(f1_score[1], 2 / 3 * 100, places=4)
|
||||
self.assertAlmostEqual(f1_score[2], 50., places=4)
|
||||
self.assertEqual(res['single-label/support'], [3, 2, 1])
|
||||
self.assertEqual(res['single-label/support_classwise'], [3, 2, 1])
|
||||
|
||||
# Test with label, the thrs will be ignored
|
||||
pred_no_score = copy.deepcopy(pred)
|
||||
|
@ -293,7 +296,7 @@ class TestSingleLabel(TestCase):
|
|||
msg=None,
|
||||
**kwarg):
|
||||
tensor = tensor.to(torch.float32)
|
||||
value = torch.FloatTensor([value])
|
||||
value = torch.tensor(value).float()
|
||||
try:
|
||||
torch.testing.assert_allclose(tensor, value, **kwarg)
|
||||
except AssertionError as e:
|
||||
|
|
Loading…
Reference in New Issue