mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature] Support support and class-wise evaluation results (#143)
* support support, support class-wise evaluation results and move eval_metrics.py * Fix docstring * change average to be non-optional * revise according to comments * add more unittest
This commit is contained in:
parent
4f8fc9cbf3
commit
8e990b5654
@ -1,8 +1,11 @@
|
|||||||
from .eval_hooks import DistEvalHook, EvalHook
|
from .eval_hooks import DistEvalHook, EvalHook
|
||||||
|
from .eval_metrics import (calculate_confusion_matrix, f1_score, precision,
|
||||||
|
recall, support)
|
||||||
from .mean_ap import average_precision, mAP
|
from .mean_ap import average_precision, mAP
|
||||||
from .multilabel_eval_metrics import average_performance
|
from .multilabel_eval_metrics import average_performance
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'DistEvalHook', 'EvalHook', 'average_precision', 'mAP',
|
'DistEvalHook', 'EvalHook', 'precision', 'recall', 'f1_score', 'support',
|
||||||
'average_performance'
|
'average_precision', 'mAP', 'average_performance',
|
||||||
|
'calculate_confusion_matrix'
|
||||||
]
|
]
|
||||||
|
138
mmcls/core/evaluation/eval_metrics.py
Normal file
138
mmcls/core/evaluation/eval_metrics.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_confusion_matrix(pred, target):
|
||||||
|
"""Calculate confusion matrix according to the prediction and target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor | np.array): The model prediction.
|
||||||
|
target (torch.Tensor | np.array): The target of each prediction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Confusion matrix with shape (C, C), where C is the number
|
||||||
|
of classes.
|
||||||
|
"""
|
||||||
|
if isinstance(pred, np.ndarray) and isinstance(target, np.ndarray):
|
||||||
|
pred = torch.from_numpy(pred)
|
||||||
|
target = torch.from_numpy(target)
|
||||||
|
elif not (isinstance(pred, torch.Tensor)
|
||||||
|
and isinstance(target, torch.Tensor)):
|
||||||
|
raise TypeError('pred and target should both be'
|
||||||
|
'torch.Tensor or np.ndarray')
|
||||||
|
_, pred_label = pred.topk(1, dim=1)
|
||||||
|
num_classes = pred.size(1)
|
||||||
|
pred_label = pred_label.view(-1)
|
||||||
|
target_label = target.view(-1)
|
||||||
|
assert len(pred_label) == len(target_label)
|
||||||
|
confusion_matrix = torch.zeros(num_classes, num_classes)
|
||||||
|
with torch.no_grad():
|
||||||
|
for t, p in zip(target_label, pred_label):
|
||||||
|
confusion_matrix[t.long(), p.long()] += 1
|
||||||
|
return confusion_matrix
|
||||||
|
|
||||||
|
|
||||||
|
def precision(pred, target, average='macro'):
|
||||||
|
"""Calculate precision according to the prediction and target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor | np.array): The model prediction.
|
||||||
|
target (torch.Tensor | np.array): The target of each prediction.
|
||||||
|
average (str): The type of averaging performed on the result.
|
||||||
|
Options are 'macro' and 'none'. Defaults to 'macro'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: Precision value with shape determined by average.
|
||||||
|
"""
|
||||||
|
confusion_matrix = calculate_confusion_matrix(pred, target)
|
||||||
|
with torch.no_grad():
|
||||||
|
res = confusion_matrix.diag() / torch.clamp(
|
||||||
|
confusion_matrix.sum(0), min=1) * 100
|
||||||
|
if average == 'macro':
|
||||||
|
res = res.mean().numpy()
|
||||||
|
elif average == 'none':
|
||||||
|
res = res.numpy()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupport type of averaging {average}.')
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def recall(pred, target, average='macro'):
|
||||||
|
"""Calculate recall according to the prediction and target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor | np.array): The model prediction.
|
||||||
|
target (torch.Tensor | np.array): The target of each prediction.
|
||||||
|
average (str): The type of averaging performed on the result.
|
||||||
|
Options are 'macro' and 'none'. Defaults to 'macro'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: Recall value with shape determined by average.
|
||||||
|
"""
|
||||||
|
confusion_matrix = calculate_confusion_matrix(pred, target)
|
||||||
|
with torch.no_grad():
|
||||||
|
res = confusion_matrix.diag() / torch.clamp(
|
||||||
|
confusion_matrix.sum(1), min=1) * 100
|
||||||
|
if average == 'macro':
|
||||||
|
res = res.mean().numpy()
|
||||||
|
elif average == 'none':
|
||||||
|
res = res.numpy()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupport type of averaging {average}.')
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def f1_score(pred, target, average='macro'):
|
||||||
|
"""Calculate F1 score according to the prediction and target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor | np.array): The model prediction.
|
||||||
|
target (torch.Tensor | np.array): The target of each prediction.
|
||||||
|
average (str): The type of averaging performed on the result.
|
||||||
|
Options are 'macro' and 'none'. Defaults to 'macro'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: F1 score with shape determined by average.
|
||||||
|
"""
|
||||||
|
confusion_matrix = calculate_confusion_matrix(pred, target)
|
||||||
|
with torch.no_grad():
|
||||||
|
precision = confusion_matrix.diag() / torch.clamp(
|
||||||
|
confusion_matrix.sum(1), min=1)
|
||||||
|
recall = confusion_matrix.diag() / torch.clamp(
|
||||||
|
confusion_matrix.sum(0), min=1)
|
||||||
|
res = 2 * precision * recall / torch.clamp(
|
||||||
|
precision + recall, min=1e-20) * 100
|
||||||
|
res = torch.where(torch.isnan(res), torch.full_like(res, 0), res)
|
||||||
|
if average == 'macro':
|
||||||
|
res = res.mean().numpy()
|
||||||
|
elif average == 'none':
|
||||||
|
res = res.numpy()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupport type of averaging {average}.')
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def support(pred, target, average='macro'):
|
||||||
|
"""Calculate the total number of occurrences of each label according to
|
||||||
|
the prediction and target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred (torch.Tensor | np.array): The model prediction.
|
||||||
|
target (torch.Tensor | np.array): The target of each prediction.
|
||||||
|
average (str): The type of reduction performed on the result.
|
||||||
|
Options are 'macro' and 'none'. 'macro' gives the sum and 'none'
|
||||||
|
gives class-wise results. Defaults to 'macro'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: Support with shape determined by average.
|
||||||
|
"""
|
||||||
|
confusion_matrix = calculate_confusion_matrix(pred, target)
|
||||||
|
with torch.no_grad():
|
||||||
|
res = confusion_matrix.sum(1)
|
||||||
|
if average == 'macro':
|
||||||
|
res = res.sum().numpy()
|
||||||
|
elif average == 'none':
|
||||||
|
res = res.numpy()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupport type of averaging {average}.')
|
||||||
|
return res
|
@ -5,7 +5,8 @@ import mmcv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from mmcls.models.losses import accuracy, f1_score, precision, recall
|
from mmcls.core.evaluation import f1_score, precision, recall, support
|
||||||
|
from mmcls.models.losses import accuracy
|
||||||
from .pipelines import Compose
|
from .pipelines import Compose
|
||||||
|
|
||||||
|
|
||||||
@ -122,6 +123,8 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||||||
results (list): Testing results of the dataset.
|
results (list): Testing results of the dataset.
|
||||||
metric (str | list[str]): Metrics to be evaluated.
|
metric (str | list[str]): Metrics to be evaluated.
|
||||||
Default value is `accuracy`.
|
Default value is `accuracy`.
|
||||||
|
metric_options (dict): Options for calculating metrics. Allowed
|
||||||
|
keys are 'topk' and 'average'.
|
||||||
logger (logging.Logger | None | str): Logger used for printing
|
logger (logging.Logger | None | str): Logger used for printing
|
||||||
related information during evaluation. Default: None.
|
related information during evaluation. Default: None.
|
||||||
Returns:
|
Returns:
|
||||||
@ -131,7 +134,9 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||||||
metrics = [metric]
|
metrics = [metric]
|
||||||
else:
|
else:
|
||||||
metrics = metric
|
metrics = metric
|
||||||
allowed_metrics = ['accuracy', 'precision', 'recall', 'f1_score']
|
allowed_metrics = [
|
||||||
|
'accuracy', 'precision', 'recall', 'f1_score', 'support'
|
||||||
|
]
|
||||||
eval_results = {}
|
eval_results = {}
|
||||||
results = np.vstack(results)
|
results = np.vstack(results)
|
||||||
gt_labels = self.get_gt_labels()
|
gt_labels = self.get_gt_labels()
|
||||||
@ -145,13 +150,28 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||||||
acc = accuracy(results, gt_labels, topk)
|
acc = accuracy(results, gt_labels, topk)
|
||||||
eval_result = {f'top-{k}': a.item() for k, a in zip(topk, acc)}
|
eval_result = {f'top-{k}': a.item() for k, a in zip(topk, acc)}
|
||||||
elif metric == 'precision':
|
elif metric == 'precision':
|
||||||
precision_value = precision(results, gt_labels)
|
precision_value = precision(
|
||||||
|
results,
|
||||||
|
gt_labels,
|
||||||
|
average=metric_options.get('average', 'macro'))
|
||||||
eval_result = {'precision': precision_value}
|
eval_result = {'precision': precision_value}
|
||||||
elif metric == 'recall':
|
elif metric == 'recall':
|
||||||
recall_value = recall(results, gt_labels)
|
recall_value = recall(
|
||||||
|
results,
|
||||||
|
gt_labels,
|
||||||
|
average=metric_options.get('average', 'macro'))
|
||||||
eval_result = {'recall': recall_value}
|
eval_result = {'recall': recall_value}
|
||||||
elif metric == 'f1_score':
|
elif metric == 'f1_score':
|
||||||
f1_score_value = f1_score(results, gt_labels)
|
f1_score_value = f1_score(
|
||||||
|
results,
|
||||||
|
gt_labels,
|
||||||
|
average=metric_options.get('average', 'macro'))
|
||||||
eval_result = {'f1_score': f1_score_value}
|
eval_result = {'f1_score': f1_score_value}
|
||||||
|
elif metric == 'support':
|
||||||
|
support_value = support(
|
||||||
|
results,
|
||||||
|
gt_labels,
|
||||||
|
average=metric_options.get('average', 'macro'))
|
||||||
|
eval_result = {'support': support_value}
|
||||||
eval_results.update(eval_result)
|
eval_results.update(eval_result)
|
||||||
return eval_results
|
return eval_results
|
||||||
|
@ -2,7 +2,6 @@ from .accuracy import Accuracy, accuracy
|
|||||||
from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
|
from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
|
||||||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||||
cross_entropy)
|
cross_entropy)
|
||||||
from .eval_metrics import f1_score, precision, recall
|
|
||||||
from .focal_loss import FocalLoss, sigmoid_focal_loss
|
from .focal_loss import FocalLoss, sigmoid_focal_loss
|
||||||
from .label_smooth_loss import LabelSmoothLoss, label_smooth
|
from .label_smooth_loss import LabelSmoothLoss, label_smooth
|
||||||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||||
@ -11,5 +10,5 @@ __all__ = [
|
|||||||
'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
|
'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
|
||||||
'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||||
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
|
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
|
||||||
'precision', 'recall', 'f1_score', 'FocalLoss', 'sigmoid_focal_loss'
|
'FocalLoss', 'sigmoid_focal_loss'
|
||||||
]
|
]
|
||||||
|
@ -1,81 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_confusion_matrix(pred, target):
|
|
||||||
if isinstance(pred, np.ndarray) and isinstance(target, np.ndarray):
|
|
||||||
pred = torch.from_numpy(pred)
|
|
||||||
target = torch.from_numpy(target)
|
|
||||||
elif not (isinstance(pred, torch.Tensor)
|
|
||||||
and isinstance(target, torch.Tensor)):
|
|
||||||
raise TypeError('pred and target should both be'
|
|
||||||
'torch.Tensor or np.ndarray')
|
|
||||||
_, pred_label = pred.topk(1, dim=1)
|
|
||||||
num_classes = pred.size(1)
|
|
||||||
pred_label = pred_label.view(-1)
|
|
||||||
target_label = target.view(-1)
|
|
||||||
assert len(pred_label) == len(target_label)
|
|
||||||
confusion_matrix = torch.zeros(num_classes, num_classes)
|
|
||||||
with torch.no_grad():
|
|
||||||
for t, p in zip(target_label, pred_label):
|
|
||||||
confusion_matrix[t.long(), p.long()] += 1
|
|
||||||
return confusion_matrix
|
|
||||||
|
|
||||||
|
|
||||||
def precision(pred, target):
|
|
||||||
"""Calculate macro-averaged precision according to the prediction and target
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pred (torch.Tensor | np.array): The model prediction.
|
|
||||||
target (torch.Tensor | np.array): The target of each prediction.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The function will return a single float as precision.
|
|
||||||
"""
|
|
||||||
confusion_matrix = calculate_confusion_matrix(pred, target)
|
|
||||||
with torch.no_grad():
|
|
||||||
res = confusion_matrix.diag() / torch.clamp(
|
|
||||||
confusion_matrix.sum(0), min=1)
|
|
||||||
res = res.mean().item() * 100
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def recall(pred, target):
|
|
||||||
"""Calculate macro-averaged recall according to the prediction and target
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pred (torch.Tensor | np.array): The model prediction.
|
|
||||||
target (torch.Tensor | np.array): The target of each prediction.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The function will return a single float as recall.
|
|
||||||
"""
|
|
||||||
confusion_matrix = calculate_confusion_matrix(pred, target)
|
|
||||||
with torch.no_grad():
|
|
||||||
res = confusion_matrix.diag() / torch.clamp(
|
|
||||||
confusion_matrix.sum(1), min=1)
|
|
||||||
res = res.mean().item() * 100
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def f1_score(pred, target):
|
|
||||||
"""Calculate macro-averaged F1 score according to the prediction and target
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pred (torch.Tensor | np.array): The model prediction.
|
|
||||||
target (torch.Tensor | np.array): The target of each prediction.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: The function will return a single float as F1 score.
|
|
||||||
"""
|
|
||||||
confusion_matrix = calculate_confusion_matrix(pred, target)
|
|
||||||
with torch.no_grad():
|
|
||||||
precision = confusion_matrix.diag() / torch.clamp(
|
|
||||||
confusion_matrix.sum(1), min=1)
|
|
||||||
recall = confusion_matrix.diag() / torch.clamp(
|
|
||||||
confusion_matrix.sum(0), min=1)
|
|
||||||
res = 2 * precision * recall / torch.clamp(
|
|
||||||
precision + recall, min=1e-20)
|
|
||||||
res = torch.where(torch.isnan(res), torch.full_like(res, 0), res)
|
|
||||||
res = res.mean().item() * 100
|
|
||||||
return res
|
|
@ -108,13 +108,44 @@ def test_dataset_evaluation():
|
|||||||
fake_results = np.array([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
|
fake_results = np.array([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
|
||||||
[0, 0, 1], [0, 0, 1]])
|
[0, 0, 1], [0, 0, 1]])
|
||||||
eval_results = dataset.evaluate(
|
eval_results = dataset.evaluate(
|
||||||
fake_results, metric=['precision', 'recall', 'f1_score'])
|
fake_results, metric=['precision', 'recall', 'f1_score', 'support'])
|
||||||
assert eval_results['precision'] == pytest.approx(
|
assert eval_results['precision'] == pytest.approx(
|
||||||
(1 + 1 + 1 / 3) / 3 * 100.0)
|
(1 + 1 + 1 / 3) / 3 * 100.0)
|
||||||
assert eval_results['recall'] == pytest.approx(
|
assert eval_results['recall'] == pytest.approx(
|
||||||
(2 / 3 + 1 / 2 + 1) / 3 * 100.0)
|
(2 / 3 + 1 / 2 + 1) / 3 * 100.0)
|
||||||
assert eval_results['f1_score'] == pytest.approx(
|
assert eval_results['f1_score'] == pytest.approx(
|
||||||
(4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0)
|
(4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0)
|
||||||
|
assert eval_results['support'] == 6
|
||||||
|
|
||||||
|
# test evaluation results for classes
|
||||||
|
eval_results = dataset.evaluate(
|
||||||
|
fake_results,
|
||||||
|
metric=['precision', 'recall', 'f1_score', 'support'],
|
||||||
|
metric_options={'average': 'none'})
|
||||||
|
assert eval_results['precision'].shape == (3, )
|
||||||
|
assert eval_results['recall'].shape == (3, )
|
||||||
|
assert eval_results['f1_score'].shape == (3, )
|
||||||
|
assert eval_results['support'].shape == (3, )
|
||||||
|
|
||||||
|
# the average method must be valid
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
eval_results = dataset.evaluate(
|
||||||
|
fake_results,
|
||||||
|
metric='precision',
|
||||||
|
metric_options={'average': 'micro'})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
eval_results = dataset.evaluate(
|
||||||
|
fake_results, metric='recall', metric_options={'average': 'micro'})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
eval_results = dataset.evaluate(
|
||||||
|
fake_results,
|
||||||
|
metric='f1_score',
|
||||||
|
metric_options={'average': 'micro'})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
eval_results = dataset.evaluate(
|
||||||
|
fake_results,
|
||||||
|
metric='support',
|
||||||
|
metric_options={'average': 'micro'})
|
||||||
|
|
||||||
# test multi-label evalutation
|
# test multi-label evalutation
|
||||||
dataset = MultiLabelDataset(data_prefix='', pipeline=[], test_mode=True)
|
dataset = MultiLabelDataset(data_prefix='', pipeline=[], test_mode=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user