From 8e990b5654f4478b258965335ca5663193d33453 Mon Sep 17 00:00:00 2001 From: LXXXXR <73265258+LXXXXR@users.noreply.github.com> Date: Tue, 19 Jan 2021 16:42:16 +0800 Subject: [PATCH] [Feature] Support support and class-wise evaluation results (#143) * support support, support class-wise evaluation results and move eval_metrics.py * Fix docstring * change average to be non-optional * revise according to comments * add more unittest --- mmcls/core/evaluation/__init__.py | 7 +- mmcls/core/evaluation/eval_metrics.py | 138 ++++++++++++++++++++++++++ mmcls/datasets/base_dataset.py | 30 +++++- mmcls/models/losses/__init__.py | 3 +- mmcls/models/losses/eval_metrics.py | 81 --------------- tests/test_dataset.py | 33 +++++- 6 files changed, 201 insertions(+), 91 deletions(-) create mode 100644 mmcls/core/evaluation/eval_metrics.py delete mode 100644 mmcls/models/losses/eval_metrics.py diff --git a/mmcls/core/evaluation/__init__.py b/mmcls/core/evaluation/__init__.py index 0906ff2b..269a9026 100644 --- a/mmcls/core/evaluation/__init__.py +++ b/mmcls/core/evaluation/__init__.py @@ -1,8 +1,11 @@ from .eval_hooks import DistEvalHook, EvalHook +from .eval_metrics import (calculate_confusion_matrix, f1_score, precision, + recall, support) from .mean_ap import average_precision, mAP from .multilabel_eval_metrics import average_performance __all__ = [ - 'DistEvalHook', 'EvalHook', 'average_precision', 'mAP', - 'average_performance' + 'DistEvalHook', 'EvalHook', 'precision', 'recall', 'f1_score', 'support', + 'average_precision', 'mAP', 'average_performance', + 'calculate_confusion_matrix' ] diff --git a/mmcls/core/evaluation/eval_metrics.py b/mmcls/core/evaluation/eval_metrics.py new file mode 100644 index 00000000..1b2e9f32 --- /dev/null +++ b/mmcls/core/evaluation/eval_metrics.py @@ -0,0 +1,138 @@ +import numpy as np +import torch + + +def calculate_confusion_matrix(pred, target): + """Calculate confusion matrix according to the prediction and target. + + Args: + pred (torch.Tensor | np.array): The model prediction. + target (torch.Tensor | np.array): The target of each prediction. + + Returns: + torch.Tensor: Confusion matrix with shape (C, C), where C is the number + of classes. + """ + if isinstance(pred, np.ndarray) and isinstance(target, np.ndarray): + pred = torch.from_numpy(pred) + target = torch.from_numpy(target) + elif not (isinstance(pred, torch.Tensor) + and isinstance(target, torch.Tensor)): + raise TypeError('pred and target should both be' + 'torch.Tensor or np.ndarray') + _, pred_label = pred.topk(1, dim=1) + num_classes = pred.size(1) + pred_label = pred_label.view(-1) + target_label = target.view(-1) + assert len(pred_label) == len(target_label) + confusion_matrix = torch.zeros(num_classes, num_classes) + with torch.no_grad(): + for t, p in zip(target_label, pred_label): + confusion_matrix[t.long(), p.long()] += 1 + return confusion_matrix + + +def precision(pred, target, average='macro'): + """Calculate precision according to the prediction and target. + + Args: + pred (torch.Tensor | np.array): The model prediction. + target (torch.Tensor | np.array): The target of each prediction. + average (str): The type of averaging performed on the result. + Options are 'macro' and 'none'. Defaults to 'macro'. + + Returns: + np.array: Precision value with shape determined by average. + """ + confusion_matrix = calculate_confusion_matrix(pred, target) + with torch.no_grad(): + res = confusion_matrix.diag() / torch.clamp( + confusion_matrix.sum(0), min=1) * 100 + if average == 'macro': + res = res.mean().numpy() + elif average == 'none': + res = res.numpy() + else: + raise ValueError(f'Unsupport type of averaging {average}.') + return res + + +def recall(pred, target, average='macro'): + """Calculate recall according to the prediction and target. + + Args: + pred (torch.Tensor | np.array): The model prediction. + target (torch.Tensor | np.array): The target of each prediction. + average (str): The type of averaging performed on the result. + Options are 'macro' and 'none'. Defaults to 'macro'. + + Returns: + np.array: Recall value with shape determined by average. + """ + confusion_matrix = calculate_confusion_matrix(pred, target) + with torch.no_grad(): + res = confusion_matrix.diag() / torch.clamp( + confusion_matrix.sum(1), min=1) * 100 + if average == 'macro': + res = res.mean().numpy() + elif average == 'none': + res = res.numpy() + else: + raise ValueError(f'Unsupport type of averaging {average}.') + return res + + +def f1_score(pred, target, average='macro'): + """Calculate F1 score according to the prediction and target. + + Args: + pred (torch.Tensor | np.array): The model prediction. + target (torch.Tensor | np.array): The target of each prediction. + average (str): The type of averaging performed on the result. + Options are 'macro' and 'none'. Defaults to 'macro'. + + Returns: + np.array: F1 score with shape determined by average. + """ + confusion_matrix = calculate_confusion_matrix(pred, target) + with torch.no_grad(): + precision = confusion_matrix.diag() / torch.clamp( + confusion_matrix.sum(1), min=1) + recall = confusion_matrix.diag() / torch.clamp( + confusion_matrix.sum(0), min=1) + res = 2 * precision * recall / torch.clamp( + precision + recall, min=1e-20) * 100 + res = torch.where(torch.isnan(res), torch.full_like(res, 0), res) + if average == 'macro': + res = res.mean().numpy() + elif average == 'none': + res = res.numpy() + else: + raise ValueError(f'Unsupport type of averaging {average}.') + return res + + +def support(pred, target, average='macro'): + """Calculate the total number of occurrences of each label according to + the prediction and target. + + Args: + pred (torch.Tensor | np.array): The model prediction. + target (torch.Tensor | np.array): The target of each prediction. + average (str): The type of reduction performed on the result. + Options are 'macro' and 'none'. 'macro' gives the sum and 'none' + gives class-wise results. Defaults to 'macro'. + + Returns: + np.array: Support with shape determined by average. + """ + confusion_matrix = calculate_confusion_matrix(pred, target) + with torch.no_grad(): + res = confusion_matrix.sum(1) + if average == 'macro': + res = res.sum().numpy() + elif average == 'none': + res = res.numpy() + else: + raise ValueError(f'Unsupport type of averaging {average}.') + return res diff --git a/mmcls/datasets/base_dataset.py b/mmcls/datasets/base_dataset.py index 03574184..3455f84a 100644 --- a/mmcls/datasets/base_dataset.py +++ b/mmcls/datasets/base_dataset.py @@ -5,7 +5,8 @@ import mmcv import numpy as np from torch.utils.data import Dataset -from mmcls.models.losses import accuracy, f1_score, precision, recall +from mmcls.core.evaluation import f1_score, precision, recall, support +from mmcls.models.losses import accuracy from .pipelines import Compose @@ -122,6 +123,8 @@ class BaseDataset(Dataset, metaclass=ABCMeta): results (list): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. Default value is `accuracy`. + metric_options (dict): Options for calculating metrics. Allowed + keys are 'topk' and 'average'. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. Returns: @@ -131,7 +134,9 @@ class BaseDataset(Dataset, metaclass=ABCMeta): metrics = [metric] else: metrics = metric - allowed_metrics = ['accuracy', 'precision', 'recall', 'f1_score'] + allowed_metrics = [ + 'accuracy', 'precision', 'recall', 'f1_score', 'support' + ] eval_results = {} results = np.vstack(results) gt_labels = self.get_gt_labels() @@ -145,13 +150,28 @@ class BaseDataset(Dataset, metaclass=ABCMeta): acc = accuracy(results, gt_labels, topk) eval_result = {f'top-{k}': a.item() for k, a in zip(topk, acc)} elif metric == 'precision': - precision_value = precision(results, gt_labels) + precision_value = precision( + results, + gt_labels, + average=metric_options.get('average', 'macro')) eval_result = {'precision': precision_value} elif metric == 'recall': - recall_value = recall(results, gt_labels) + recall_value = recall( + results, + gt_labels, + average=metric_options.get('average', 'macro')) eval_result = {'recall': recall_value} elif metric == 'f1_score': - f1_score_value = f1_score(results, gt_labels) + f1_score_value = f1_score( + results, + gt_labels, + average=metric_options.get('average', 'macro')) eval_result = {'f1_score': f1_score_value} + elif metric == 'support': + support_value = support( + results, + gt_labels, + average=metric_options.get('average', 'macro')) + eval_result = {'support': support_value} eval_results.update(eval_result) return eval_results diff --git a/mmcls/models/losses/__init__.py b/mmcls/models/losses/__init__.py index 774b4917..a16cd49a 100644 --- a/mmcls/models/losses/__init__.py +++ b/mmcls/models/losses/__init__.py @@ -2,7 +2,6 @@ from .accuracy import Accuracy, accuracy from .asymmetric_loss import AsymmetricLoss, asymmetric_loss from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, cross_entropy) -from .eval_metrics import f1_score, precision, recall from .focal_loss import FocalLoss, sigmoid_focal_loss from .label_smooth_loss import LabelSmoothLoss, label_smooth from .utils import reduce_loss, weight_reduce_loss, weighted_loss @@ -11,5 +10,5 @@ __all__ = [ 'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss', 'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 'weight_reduce_loss', 'label_smooth', 'LabelSmoothLoss', 'weighted_loss', - 'precision', 'recall', 'f1_score', 'FocalLoss', 'sigmoid_focal_loss' + 'FocalLoss', 'sigmoid_focal_loss' ] diff --git a/mmcls/models/losses/eval_metrics.py b/mmcls/models/losses/eval_metrics.py deleted file mode 100644 index ef04fa26..00000000 --- a/mmcls/models/losses/eval_metrics.py +++ /dev/null @@ -1,81 +0,0 @@ -import numpy as np -import torch - - -def calculate_confusion_matrix(pred, target): - if isinstance(pred, np.ndarray) and isinstance(target, np.ndarray): - pred = torch.from_numpy(pred) - target = torch.from_numpy(target) - elif not (isinstance(pred, torch.Tensor) - and isinstance(target, torch.Tensor)): - raise TypeError('pred and target should both be' - 'torch.Tensor or np.ndarray') - _, pred_label = pred.topk(1, dim=1) - num_classes = pred.size(1) - pred_label = pred_label.view(-1) - target_label = target.view(-1) - assert len(pred_label) == len(target_label) - confusion_matrix = torch.zeros(num_classes, num_classes) - with torch.no_grad(): - for t, p in zip(target_label, pred_label): - confusion_matrix[t.long(), p.long()] += 1 - return confusion_matrix - - -def precision(pred, target): - """Calculate macro-averaged precision according to the prediction and target - - Args: - pred (torch.Tensor | np.array): The model prediction. - target (torch.Tensor | np.array): The target of each prediction. - - Returns: - float: The function will return a single float as precision. - """ - confusion_matrix = calculate_confusion_matrix(pred, target) - with torch.no_grad(): - res = confusion_matrix.diag() / torch.clamp( - confusion_matrix.sum(0), min=1) - res = res.mean().item() * 100 - return res - - -def recall(pred, target): - """Calculate macro-averaged recall according to the prediction and target - - Args: - pred (torch.Tensor | np.array): The model prediction. - target (torch.Tensor | np.array): The target of each prediction. - - Returns: - float: The function will return a single float as recall. - """ - confusion_matrix = calculate_confusion_matrix(pred, target) - with torch.no_grad(): - res = confusion_matrix.diag() / torch.clamp( - confusion_matrix.sum(1), min=1) - res = res.mean().item() * 100 - return res - - -def f1_score(pred, target): - """Calculate macro-averaged F1 score according to the prediction and target - - Args: - pred (torch.Tensor | np.array): The model prediction. - target (torch.Tensor | np.array): The target of each prediction. - - Returns: - float: The function will return a single float as F1 score. - """ - confusion_matrix = calculate_confusion_matrix(pred, target) - with torch.no_grad(): - precision = confusion_matrix.diag() / torch.clamp( - confusion_matrix.sum(1), min=1) - recall = confusion_matrix.diag() / torch.clamp( - confusion_matrix.sum(0), min=1) - res = 2 * precision * recall / torch.clamp( - precision + recall, min=1e-20) - res = torch.where(torch.isnan(res), torch.full_like(res, 0), res) - res = res.mean().item() * 100 - return res diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 6243cda0..821e10f6 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -108,13 +108,44 @@ def test_dataset_evaluation(): fake_results = np.array([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 0, 1]]) eval_results = dataset.evaluate( - fake_results, metric=['precision', 'recall', 'f1_score']) + fake_results, metric=['precision', 'recall', 'f1_score', 'support']) assert eval_results['precision'] == pytest.approx( (1 + 1 + 1 / 3) / 3 * 100.0) assert eval_results['recall'] == pytest.approx( (2 / 3 + 1 / 2 + 1) / 3 * 100.0) assert eval_results['f1_score'] == pytest.approx( (4 / 5 + 2 / 3 + 1 / 2) / 3 * 100.0) + assert eval_results['support'] == 6 + + # test evaluation results for classes + eval_results = dataset.evaluate( + fake_results, + metric=['precision', 'recall', 'f1_score', 'support'], + metric_options={'average': 'none'}) + assert eval_results['precision'].shape == (3, ) + assert eval_results['recall'].shape == (3, ) + assert eval_results['f1_score'].shape == (3, ) + assert eval_results['support'].shape == (3, ) + + # the average method must be valid + with pytest.raises(ValueError): + eval_results = dataset.evaluate( + fake_results, + metric='precision', + metric_options={'average': 'micro'}) + with pytest.raises(ValueError): + eval_results = dataset.evaluate( + fake_results, metric='recall', metric_options={'average': 'micro'}) + with pytest.raises(ValueError): + eval_results = dataset.evaluate( + fake_results, + metric='f1_score', + metric_options={'average': 'micro'}) + with pytest.raises(ValueError): + eval_results = dataset.evaluate( + fake_results, + metric='support', + metric_options={'average': 'micro'}) # test multi-label evalutation dataset = MultiLabelDataset(data_prefix='', pipeline=[], test_mode=True)