[Feature] Support support and class-wise evaluation results (#143)

* support support, support class-wise evaluation results and move eval_metrics.py

* Fix docstring

* change average to be non-optional

* revise according to comments

* add more unittest
pull/144/head
LXXXXR 2021-01-19 16:42:16 +08:00 committed by GitHub
parent 4f8fc9cbf3
commit 8e990b5654
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 201 additions and 91 deletions

View File

@ -1,8 +1,11 @@
from .eval_hooks import DistEvalHook, EvalHook
from .eval_metrics import (calculate_confusion_matrix, f1_score, precision,
recall, support)
from .mean_ap import average_precision, mAP
from .multilabel_eval_metrics import average_performance
__all__ = [
'DistEvalHook', 'EvalHook', 'average_precision', 'mAP',
'average_performance'
'DistEvalHook', 'EvalHook', 'precision', 'recall', 'f1_score', 'support',
'average_precision', 'mAP', 'average_performance',
'calculate_confusion_matrix'
]

View File

@ -0,0 +1,138 @@
import numpy as np
import torch
def calculate_confusion_matrix(pred, target):
"""Calculate confusion matrix according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction.
Returns:
torch.Tensor: Confusion matrix with shape (C, C), where C is the number
of classes.
"""
if isinstance(pred, np.ndarray) and isinstance(target, np.ndarray):
pred = torch.from_numpy(pred)
target = torch.from_numpy(target)
elif not (isinstance(pred, torch.Tensor)
and isinstance(target, torch.Tensor)):
raise TypeError('pred and target should both be'
'torch.Tensor or np.ndarray')
_, pred_label = pred.topk(1, dim=1)
num_classes = pred.size(1)
pred_label = pred_label.view(-1)
target_label = target.view(-1)
assert len(pred_label) == len(target_label)
confusion_matrix = torch.zeros(num_classes, num_classes)
with torch.no_grad():
for t, p in zip(target_label, pred_label):
confusion_matrix[t.long(), p.long()] += 1
return confusion_matrix
def precision(pred, target, average='macro'):
"""Calculate precision according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction.
average (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. Defaults to 'macro'.
Returns:
np.array: Precision value with shape determined by average.
"""
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
res = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(0), min=1) * 100
if average == 'macro':
res = res.mean().numpy()
elif average == 'none':
res = res.numpy()
else:
raise ValueError(f'Unsupport type of averaging {average}.')
return res
def recall(pred, target, average='macro'):
"""Calculate recall according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction.
average (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. Defaults to 'macro'.
Returns:
np.array: Recall value with shape determined by average.
"""
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
res = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(1), min=1) * 100
if average == 'macro':
res = res.mean().numpy()
elif average == 'none':
res = res.numpy()
else:
raise ValueError(f'Unsupport type of averaging {average}.')
return res
def f1_score(pred, target, average='macro'):
"""Calculate F1 score according to the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction.
average (str): The type of averaging performed on the result.
Options are 'macro' and 'none'. Defaults to 'macro'.
Returns:
np.array: F1 score with shape determined by average.
"""
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
precision = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(1), min=1)
recall = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(0), min=1)
res = 2 * precision * recall / torch.clamp(
precision + recall, min=1e-20) * 100
res = torch.where(torch.isnan(res), torch.full_like(res, 0), res)
if average == 'macro':
res = res.mean().numpy()
elif average == 'none':
res = res.numpy()
else:
raise ValueError(f'Unsupport type of averaging {average}.')
return res
def support(pred, target, average='macro'):
"""Calculate the total number of occurrences of each label according to
the prediction and target.
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction.
average (str): The type of reduction performed on the result.
Options are 'macro' and 'none'. 'macro' gives the sum and 'none'
gives class-wise results. Defaults to 'macro'.
Returns:
np.array: Support with shape determined by average.
"""
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
res = confusion_matrix.sum(1)
if average == 'macro':
res = res.sum().numpy()
elif average == 'none':
res = res.numpy()
else:
raise ValueError(f'Unsupport type of averaging {average}.')
return res

View File

@ -5,7 +5,8 @@ import mmcv
import numpy as np
from torch.utils.data import Dataset
from mmcls.models.losses import accuracy, f1_score, precision, recall
from mmcls.core.evaluation import f1_score, precision, recall, support
from mmcls.models.losses import accuracy
from .pipelines import Compose
@ -122,6 +123,8 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
Default value is `accuracy`.
metric_options (dict): Options for calculating metrics. Allowed
keys are 'topk' and 'average'.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
Returns:
@ -131,7 +134,9 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
metrics = [metric]
else:
metrics = metric
allowed_metrics = ['accuracy', 'precision', 'recall', 'f1_score']
allowed_metrics = [
'accuracy', 'precision', 'recall', 'f1_score', 'support'
]
eval_results = {}
results = np.vstack(results)
gt_labels = self.get_gt_labels()
@ -145,13 +150,28 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
acc = accuracy(results, gt_labels, topk)
eval_result = {f'top-{k}': a.item() for k, a in zip(topk, acc)}
elif metric == 'precision':
precision_value = precision(results, gt_labels)
precision_value = precision(
results,
gt_labels,
average=metric_options.get('average', 'macro'))
eval_result = {'precision': precision_value}
elif metric == 'recall':
recall_value = recall(results, gt_labels)
recall_value = recall(
results,
gt_labels,
average=metric_options.get('average', 'macro'))
eval_result = {'recall': recall_value}
elif metric == 'f1_score':
f1_score_value = f1_score(results, gt_labels)
f1_score_value = f1_score(
results,
gt_labels,
average=metric_options.get('average', 'macro'))
eval_result = {'f1_score': f1_score_value}
elif metric == 'support':
support_value = support(
results,
gt_labels,
average=metric_options.get('average', 'macro'))
eval_result = {'support': support_value}
eval_results.update(eval_result)
return eval_results

View File

@ -2,7 +2,6 @@ from .accuracy import Accuracy, accuracy
from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy)
from .eval_metrics import f1_score, precision, recall
from .focal_loss import FocalLoss, sigmoid_focal_loss
from .label_smooth_loss import LabelSmoothLoss, label_smooth
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
@ -11,5 +10,5 @@ __all__ = [
'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss',
'precision', 'recall', 'f1_score', 'FocalLoss', 'sigmoid_focal_loss'
'FocalLoss', 'sigmoid_focal_loss'
]

View File

@ -1,81 +0,0 @@
import numpy as np
import torch
def calculate_confusion_matrix(pred, target):
if isinstance(pred, np.ndarray) and isinstance(target, np.ndarray):
pred = torch.from_numpy(pred)
target = torch.from_numpy(target)
elif not (isinstance(pred, torch.Tensor)
and isinstance(target, torch.Tensor)):
raise TypeError('pred and target should both be'
'torch.Tensor or np.ndarray')
_, pred_label = pred.topk(1, dim=1)
num_classes = pred.size(1)
pred_label = pred_label.view(-1)
target_label = target.view(-1)
assert len(pred_label) == len(target_label)
confusion_matrix = torch.zeros(num_classes, num_classes)
with torch.no_grad():
for t, p in zip(target_label, pred_label):
confusion_matrix[t.long(), p.long()] += 1
return confusion_matrix
def precision(pred, target):
"""Calculate macro-averaged precision according to the prediction and target
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction.
Returns:
float: The function will return a single float as precision.
"""
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
res = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(0), min=1)
res = res.mean().item() * 100
return res
def recall(pred, target):
"""Calculate macro-averaged recall according to the prediction and target
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction.
Returns:
float: The function will return a single float as recall.
"""
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
res = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(1), min=1)
res = res.mean().item() * 100
return res
def f1_score(pred, target):
"""Calculate macro-averaged F1 score according to the prediction and target
Args:
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction.
Returns:
float: The function will return a single float as F1 score.
"""
confusion_matrix = calculate_confusion_matrix(pred, target)
with torch.no_grad():
precision = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(1), min=1)
recall = confusion_matrix.diag() / torch.clamp(
confusion_matrix.sum(0), min=1)
res = 2 * precision * recall / torch.clamp(
precision + recall, min=1e-20)
res = torch.where(torch.isnan(res), torch.full_like(res, 0), res)
res = res.mean().item() * 100
return res

View File

@ -108,13 +108,44 @@ def test_dataset_evaluation():
fake_results = np.array([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
[0, 0, 1], [0, 0, 1]])
eval_results = dataset.evaluate(
fake_results, metric=['precision', 'recall', 'f1_score'])
fake_results, metric=['precision', 'recall', 'f1_score', 'support'])
assert eval_results['precision'] == pytest.approx(
(1 + 1 + 1 / 3) / 3 * 100.0)
assert eval_results['recall'] == pytest.approx(
(2 / 3 + 1 / 2 + 1) / 3 * 100.0)
assert eval_results['f1_score'] == pytest.approx(
(4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0)
assert eval_results['support'] == 6
# test evaluation results for classes
eval_results = dataset.evaluate(
fake_results,
metric=['precision', 'recall', 'f1_score', 'support'],
metric_options={'average': 'none'})
assert eval_results['precision'].shape == (3, )
assert eval_results['recall'].shape == (3, )
assert eval_results['f1_score'].shape == (3, )
assert eval_results['support'].shape == (3, )
# the average method must be valid
with pytest.raises(ValueError):
eval_results = dataset.evaluate(
fake_results,
metric='precision',
metric_options={'average': 'micro'})
with pytest.raises(ValueError):
eval_results = dataset.evaluate(
fake_results, metric='recall', metric_options={'average': 'micro'})
with pytest.raises(ValueError):
eval_results = dataset.evaluate(
fake_results,
metric='f1_score',
metric_options={'average': 'micro'})
with pytest.raises(ValueError):
eval_results = dataset.evaluate(
fake_results,
metric='support',
metric_options={'average': 'micro'})
# test multi-label evalutation
dataset = MultiLabelDataset(data_prefix='', pipeline=[], test_mode=True)