mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
add macro-averaged precision,recall,f1 options in evaluation (#93)
* add macro-averaged precision,recall,f1 options in evaluation * remove unnecessary comments * Revise according to comments * Revise according to comments
This commit is contained in:
parent
e75f2b7c35
commit
21fd5019fb
@ -5,7 +5,7 @@ import mmcv
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmcls.models.losses import accuracy
|
||||
from mmcls.models.losses import accuracy, f1_score, precision, recall
|
||||
from .pipelines import Compose
|
||||
|
||||
|
||||
@ -127,20 +127,31 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
||||
Returns:
|
||||
dict: evaluation results
|
||||
"""
|
||||
if not isinstance(metric, str):
|
||||
assert len(metric) == 1
|
||||
metric = metric[0]
|
||||
allowed_metrics = ['accuracy']
|
||||
if metric not in allowed_metrics:
|
||||
raise KeyError(f'metric {metric} is not supported')
|
||||
|
||||
if isinstance(metric, str):
|
||||
metrics = [metric]
|
||||
else:
|
||||
metrics = metric
|
||||
allowed_metrics = ['accuracy', 'precision', 'recall', 'f1_score']
|
||||
eval_results = {}
|
||||
if metric == 'accuracy':
|
||||
topk = metric_options.get('topk')
|
||||
for metric in metrics:
|
||||
if metric not in allowed_metrics:
|
||||
raise KeyError(f'metric {metric} is not supported.')
|
||||
results = np.vstack(results)
|
||||
gt_labels = self.get_gt_labels()
|
||||
num_imgs = len(results)
|
||||
assert len(gt_labels) == num_imgs
|
||||
acc = accuracy(results, gt_labels, topk)
|
||||
eval_results = {f'top-{k}': a.item() for k, a in zip(topk, acc)}
|
||||
if metric == 'accuracy':
|
||||
topk = metric_options.get('topk')
|
||||
acc = accuracy(results, gt_labels, topk)
|
||||
eval_result = {f'top-{k}': a.item() for k, a in zip(topk, acc)}
|
||||
elif metric == 'precision':
|
||||
precision_value = precision(results, gt_labels)
|
||||
eval_result = {'precision': precision_value}
|
||||
elif metric == 'recall':
|
||||
recall_value = recall(results, gt_labels)
|
||||
eval_result = {'recall': recall_value}
|
||||
elif metric == 'f1_score':
|
||||
f1_score_value = f1_score(results, gt_labels)
|
||||
eval_result = {'f1_score': f1_score_value}
|
||||
eval_results.update(eval_result)
|
||||
return eval_results
|
||||
|
@ -1,9 +1,11 @@
|
||||
from .accuracy import Accuracy, accuracy
|
||||
from .cross_entropy_loss import CrossEntropyLoss, cross_entropy
|
||||
from .eval_metrics import f1_score, precision, recall
|
||||
from .label_smooth_loss import LabelSmoothLoss, label_smooth
|
||||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||
|
||||
__all__ = [
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss'
|
||||
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
|
||||
'precision', 'recall', 'f1_score'
|
||||
]
|
||||
|
80
mmcls/models/losses/eval_metrics.py
Normal file
80
mmcls/models/losses/eval_metrics.py
Normal file
@ -0,0 +1,80 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def calculate_confusion_matrix(pred, target):
|
||||
if isinstance(pred, np.ndarray) and isinstance(target, np.ndarray):
|
||||
pred = torch.from_numpy(pred)
|
||||
target = torch.from_numpy(target)
|
||||
elif not (isinstance(pred, torch.Tensor)
|
||||
and isinstance(target, torch.Tensor)):
|
||||
raise TypeError('pred and target should both be'
|
||||
'torch.Tensor or np.ndarray')
|
||||
_, pred_label = pred.topk(1, dim=1)
|
||||
num_classes = pred.size(1)
|
||||
pred_label = pred_label.view(-1)
|
||||
target_label = target.view(-1)
|
||||
assert len(pred_label) == len(target_label)
|
||||
confusion_matrix = torch.zeros(num_classes, num_classes)
|
||||
with torch.no_grad():
|
||||
confusion_matrix[target_label.long(), pred_label.long()] += 1
|
||||
return confusion_matrix
|
||||
|
||||
|
||||
def precision(pred, target):
|
||||
"""Calculate macro-averaged precision according to the prediction and target
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.array): The model prediction.
|
||||
target (torch.Tensor | np.array): The target of each prediction.
|
||||
|
||||
Returns:
|
||||
float: The function will return a single float as precision.
|
||||
"""
|
||||
confusion_matrix = calculate_confusion_matrix(pred, target)
|
||||
with torch.no_grad():
|
||||
res = confusion_matrix.diag() / torch.clamp(
|
||||
confusion_matrix.sum(1), min=1)
|
||||
res = res.mean().item() * 100
|
||||
return res
|
||||
|
||||
|
||||
def recall(pred, target):
|
||||
"""Calculate macro-averaged recall according to the prediction and target
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.array): The model prediction.
|
||||
target (torch.Tensor | np.array): The target of each prediction.
|
||||
|
||||
Returns:
|
||||
float: The function will return a single float as recall.
|
||||
"""
|
||||
confusion_matrix = calculate_confusion_matrix(pred, target)
|
||||
with torch.no_grad():
|
||||
res = confusion_matrix.diag() / torch.clamp(
|
||||
confusion_matrix.sum(0), min=1)
|
||||
res = res.mean().item() * 100
|
||||
return res
|
||||
|
||||
|
||||
def f1_score(pred, target):
|
||||
"""Calculate macro-averaged F1 score according to the prediction and target
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.array): The model prediction.
|
||||
target (torch.Tensor | np.array): The target of each prediction.
|
||||
|
||||
Returns:
|
||||
float: The function will return a single float as F1 score.
|
||||
"""
|
||||
confusion_matrix = calculate_confusion_matrix(pred, target)
|
||||
with torch.no_grad():
|
||||
precision = confusion_matrix.diag() / torch.clamp(
|
||||
confusion_matrix.sum(1), min=1)
|
||||
recall = confusion_matrix.diag() / torch.clamp(
|
||||
confusion_matrix.sum(0), min=1)
|
||||
res = 2 * precision * recall / torch.clamp(
|
||||
precision + recall, min=1e-20)
|
||||
res = torch.where(torch.isnan(res), torch.full_like(res, 0), res)
|
||||
res = res.mean().item() * 100
|
||||
return res
|
@ -61,6 +61,25 @@ def test_datasets_override_default(dataset_name):
|
||||
assert dataset.CLASSES == original_classes
|
||||
|
||||
|
||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||
def test_dataset_evaluation():
|
||||
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
||||
dataset.data_infos = [
|
||||
dict(gt_label=0),
|
||||
dict(gt_label=1),
|
||||
dict(gt_label=2),
|
||||
dict(gt_label=1)
|
||||
]
|
||||
fake_results = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]])
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results, metric=['precision', 'recall', 'f1_score'])
|
||||
assert eval_results['precision'] == pytest.approx(
|
||||
(1 + 1 + 1 / 2) / 3 * 100.0)
|
||||
assert eval_results['recall'] == pytest.approx((1 + 1 / 2 + 1) / 3 * 100.0)
|
||||
assert eval_results['f1_score'] == pytest.approx(
|
||||
(1 + 2 / 3 + 2 / 3) / 3 * 100.0)
|
||||
|
||||
|
||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||
def test_dataset_wrapper():
|
||||
BaseDataset.CLASSES = ('foo', 'bar')
|
||||
|
Loading…
x
Reference in New Issue
Block a user