From a9057e88c4ac429c3655b1fb8fe74cba5d208206 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <1105212286@qq.com> Date: Thu, 23 Jun 2022 07:18:18 +0000 Subject: [PATCH] Add multi label metrics --- configs/_base_/datasets/voc_bs16.py | 45 +- mmcls/metrics/__init__.py | 5 +- mmcls/metrics/multi_label.py | 593 ++++++++++++++++++++++++ mmcls/metrics/single_label.py | 83 ++-- requirements/tests.txt | 1 + tests/test_metrics/test_multi_label.py | 398 ++++++++++++++++ tests/test_metrics/test_single_label.py | 21 +- 7 files changed, 1068 insertions(+), 78 deletions(-) create mode 100644 mmcls/metrics/multi_label.py create mode 100644 tests/test_metrics/test_multi_label.py diff --git a/configs/_base_/datasets/voc_bs16.py b/configs/_base_/datasets/voc_bs16.py index c6e1aade..96fbc909 100644 --- a/configs/_base_/datasets/voc_bs16.py +++ b/configs/_base_/datasets/voc_bs16.py @@ -22,35 +22,13 @@ test_pipeline = [ dict(type='PackClsInputs'), ] -data = dict( - samples_per_gpu=16, - workers_per_gpu=2, - train=dict( - type=dataset_type, - data_prefix='', - ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/trainval.txt', - pipeline=train_pipeline), - val=dict( - type=dataset_type, - data_prefix='data/VOCdevkit/VOC2007/', - ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/test.txt', - pipeline=test_pipeline), - test=dict( - type=dataset_type, - data_prefix='data/VOCdevkit/VOC2007/', - ann_file='data/VOCdevkit/VOC2007/ImageSets/Main/test.txt', - pipeline=test_pipeline)) -evaluation = dict( - interval=1, metric=['mAP', 'CP', 'OP', 'CR', 'OR', 'CF1', 'OF1']) - train_dataloader = dict( batch_size=16, num_workers=5, dataset=dict( type=dataset_type, - data_root='data/VOCdevkit/VOC2007/', - # manually split the `trainval.txt` for standard training. - ann_file='ImageSets/Main/trainval.txt', + data_root='data/VOCdevkit/VOC2007', + image_set_path='ImageSets/Layout/val.txt', pipeline=train_pipeline), sampler=dict(type='DefaultSampler', shuffle=True), persistent_workers=True, @@ -61,27 +39,28 @@ val_dataloader = dict( num_workers=5, dataset=dict( type=dataset_type, - data_root='data/VOCdevkit/VOC2007/', - # manually split the `trainval.txt` for standard validation. - ann_file='ImageSets/Main/test.txt', + data_root='data/VOCdevkit/VOC2007', + image_set_path='ImageSets/Layout/val.txt', pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), persistent_workers=True, ) -val_evaluator = dict( - type='MultiLabelMetric', - items=['mAP', 'CP', 'OP', 'CR', 'OR', 'CF1', 'OF1']) test_dataloader = dict( batch_size=16, num_workers=5, dataset=dict( type=dataset_type, - data_root='data/VOCdevkit/VOC2007/', - ann_file='ImageSets/Main/test.txt', - data_prefix='val', + data_root='data/VOCdevkit/VOC2007', + image_set_path='ImageSets/Layout/val.txt', pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), persistent_workers=True, ) + +# calculate precision_recall_f1 and mAP +val_evaluator = [dict(type='MultiLabelMetric'), dict(type='AveragePrecision')] + +# If you want standard test, please manually configure the test dataset +test_dataloader = val_dataloader test_evaluator = val_evaluator diff --git a/mmcls/metrics/__init__.py b/mmcls/metrics/__init__.py index 99053469..a4d575e4 100644 --- a/mmcls/metrics/__init__.py +++ b/mmcls/metrics/__init__.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .multi_label import AveragePrecision, MultiLabelMetric from .single_label import Accuracy, SingleLabelMetric -__all__ = ['Accuracy', 'SingleLabelMetric'] +__all__ = [ + 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision' +] diff --git a/mmcls/metrics/multi_label.py b/mmcls/metrics/multi_label.py new file mode 100644 index 00000000..c3b96417 --- /dev/null +++ b/mmcls/metrics/multi_label.py @@ -0,0 +1,593 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine import LabelData, MMLogger +from mmengine.evaluator import BaseMetric + +from mmcls.registry import METRICS +from .single_label import _precision_recall_f1_support, to_tensor + + +@METRICS.register_module() +class MultiLabelMetric(BaseMetric): + """A collection of metrics for multi-label multi-class classification task + based on confusion matrix. + + It includes precision, recall, f1-score and support. + + Args: + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + items (Sequence[str]): The detailed metric items to evaluate. Here is + the available options: + + - `"precision"`: The ratio tp / (tp + fp) where tp is the + number of true positives and fp the number of false + positives. + - `"recall"`: The ratio tp / (tp + fn) where tp is the number + of true positives and fn the number of false negatives. + - `"f1-score"`: The f1-score is the harmonic mean of the + precision and recall. + - `"support"`: The total number of positive of each category + in the target. + + Defaults to ('precision', 'recall', 'f1-score'). + average (str | None): The average method. It supports three average + modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `"micro"`: Calculate metrics globally by counting the total + true positives, false negatives and false positives. + - `None`: Return scores of all categories. + + Defaults to "macro". + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Examples: + >>> import torch + >>> from mmcls.metrics import MultiLabelMetric + >>> # ------ The Basic Usage for category indices labels ------- + >>> y_pred = [[0], [1], [0, 1], [3]] + >>> y_true = [[0, 3], [0, 2], [1], [3]] + >>> # Output precision, recall, f1-score and support + >>> MultiLabelMetric.calculate( + ... y_pred, y_true, pred_indices=True, target_indices=True, num_classes=4) + (tensor(50.), tensor(50.), tensor(45.8333), tensor(6)) + >>> # ----------- The Basic Usage for one-hot labels ----------- + >>> y_pred = torch.tensor([[1, 1, 0, 0], + ... [1, 1, 0, 0], + ... [0, 0, 1, 0], + ... [0, 1, 0, 0], + ... [0, 1, 0, 0]]) + >>> y_true = torch.Tensor([[1, 1, 0, 0], + ... [0, 0, 1, 0], + ... [1, 1, 1, 0], + ... [1, 0, 0, 0], + ... [1, 0, 0, 0]]) + >>> MultiLabelMetric.calculate(y_pred, y_true) + (tensor(43.7500), tensor(31.2500), tensor(33.3333), tensor(8)) + >>> # --------- The Basic Usage for one-hot pred scores --------- + >>> y_pred = torch.rand(y_true.size()) + >>> y_pred + tensor([[0.4575, 0.7335, 0.3934, 0.2572], + [0.1318, 0.1004, 0.8248, 0.6448], + [0.8349, 0.6294, 0.7896, 0.2061], + [0.4037, 0.7308, 0.6713, 0.8374], + [0.3779, 0.4836, 0.0313, 0.0067]]) + >>> # Calculate with different threshold. + >>> MultiLabelMetric.calculate(y_pred, y_true, thr=0.1) + (tensor(42.5000), tensor(75.), tensor(53.1746), tensor(8)) + >>> # Calculate with topk. + >>> MultiLabelMetric.calculate(y_pred, y_true, topk=1) + (tensor(62.5000), tensor(31.2500), tensor(39.1667), tensor(8)) + >>> + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmcls.core import ClsDataSample + >>> from mmengine.evaluator import Evaluator + >>> # The `data_batch` won't be used in this case, just use a fake. + >>> data_batch = [ + ... {'inputs': None, 'data_sample': ClsDataSample()} + ... for i in range(1000)] + >>> pred = [ + ... ClsDataSample().set_pred_score(torch.rand((5, ))).set_gt_score(torch.randint(2, size=(5, ))) + ... for i in range(1000)] + >>> evaluator = Evaluator(metrics=MultiLabelMetric(thrs=0.5)) + >>> evaluator.process(data_batch, pred) + >>> evaluator.evaluate(1000) + { + 'multi-label/precision': 50.72898037055408, + 'multi-label/recall': 50.06836461357571, + 'multi-label/f1-score': 50.384466955258475 + } + >>> # Evaluate on each class by using topk strategy + >>> evaluator = Evaluator(metrics=MultiLabelMetric(topk=1, average=None)) + >>> evaluator.process(data_batch, pred) + >>> evaluator.evaluate(1000) + { + 'multi-label/precision_top1_classwise': [48.22, 50.54, 50.99, 44.18, 52.5], + 'multi-label/recall_top1_classwise': [18.92, 19.22, 19.92, 20.0, 20.27], + 'multi-label/f1-score_top1_classwise': [27.18, 27.85, 28.65, 27.54, 29.25] + } + >>> # Evaluate by label data got from head + >>> pred = [ + ... ClsDataSample().set_pred_score(torch.rand((5, ))).set_pred_label( + ... torch.randint(2, size=(5, ))).set_gt_score(torch.randint(2, size=(5, ))) + ... for i in range(1000)] + >>> evaluator = Evaluator(metrics=MultiLabelMetric()) + >>> evaluator.process(data_batch, pred) + >>> evaluator.evaluate(1000) + { + 'multi-label/precision': 20.28921606216292, + 'multi-label/recall': 38.628095855722314, + 'multi-label/f1-score': 26.603530359627918 + } + """ # noqa: E501 + default_prefix: Optional[str] = 'multi-label' + + def __init__(self, + thr: Optional[float] = None, + topk: Optional[int] = None, + items: Sequence[str] = ('precision', 'recall', 'f1-score'), + average: Optional[str] = 'macro', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + + logger = MMLogger.get_current_instance() + if thr is None and topk is None: + thr = 0.5 + logger.warning('Neither thr nor k is given, set thr as 0.5 by ' + 'default.') + elif thr is not None and topk is not None: + logger.warning('Both thr and topk are given, ' + 'use threshold in favor of top-k.') + + self.thr = thr + self.topk = topk + self.average = average + + for item in items: + assert item in ['precision', 'recall', 'f1-score', 'support'], \ + f'The metric {item} is not supported by `SingleLabelMetric`,' \ + ' please choose from "precision", "recall", "f1-score" and ' \ + '"support".' + self.items = tuple(items) + + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + for pred in predictions: + result = dict() + pred_label = pred['pred_label'] + gt_label = pred['gt_label'] + + result['pred_score'] = pred_label['score'].clone() + num_classes = result['pred_score'].size()[-1] + + if 'score' in gt_label: + result['gt_score'] = gt_label['score'].clone() + else: + result['gt_score'] = LabelData.label_to_onehot( + gt_label['label'], num_classes) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + metrics = {} + + target = torch.stack([res['gt_score'] for res in results]) + pred = torch.stack([res['pred_score'] for res in results]) + + metric_res = self.calculate( + pred, + target, + pred_indices=False, + target_indices=False, + average=self.average, + thr=self.thr, + topk=self.topk) + + def pack_results(precision, recall, f1_score, support): + single_metrics = {} + if 'precision' in self.items: + single_metrics['precision'] = precision + if 'recall' in self.items: + single_metrics['recall'] = recall + if 'f1-score' in self.items: + single_metrics['f1-score'] = f1_score + if 'support' in self.items: + single_metrics['support'] = support + return single_metrics + + if self.thr: + suffix = '' if self.thr == 0.5 else f'_thr-{self.thr:.2f}' + for k, v in pack_results(*metric_res).items(): + metrics[k + suffix] = v + else: + for k, v in pack_results(*metric_res).items(): + metrics[k + f'_top{self.topk}'] = v + + result_metrics = dict() + for k, v in metrics.items(): + if self.average is None: + result_metrics[k + '_classwise'] = v.detach().cpu().tolist() + elif self.average == 'macro': + result_metrics[k] = v.item() + else: + result_metrics[k + f'_{self.average}'] = v.item() + return result_metrics + + @staticmethod + def calculate( + pred: Union[torch.Tensor, np.ndarray, Sequence], + target: Union[torch.Tensor, np.ndarray, Sequence], + pred_indices: bool = False, + target_indices: bool = False, + average: Optional[str] = 'macro', + thr: Optional[float] = None, + topk: Optional[int] = None, + num_classes: Optional[int] = None + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Calculate the precision, recall, f1-score. + + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, num_classes)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, num_classes)`` or a sequence of index/onehot + format labels. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. If True, ``num_classes`` must be set. + Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. If True, ``num_classes`` must be set. + Defaults to False. + average (str | None): The average method. It supports three average + modes: + + - `"macro"`: Calculate metrics for each category, and + calculate the mean value over all categories. + - `"micro"`: Calculate metrics globally by counting the + total true positives, false negatives and false + positives. + - `None`: Return scores of all categories. + + Defaults to "macro". + thr (float, optional): Predictions with scores under the thresholds + are considered as negative. Defaults to None. + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. Defaults to None. + num_classes (Optional, int): The number of classes. If the ``pred`` + is indices instead of onehot, this argument is required. + Defaults to None. + + Returns: + Tuple: The tuple contains precision, recall and f1-score. + And the type of each item is: + + - torch.Tensor: A tensor for each metric. The shape is (1, ) if + ``average`` is not None, and (C, ) if ``average`` is None. + + Notes: + If both ``thr`` and ``topk`` are set, use ``thr` to determine + positive predictions. If neither is set, use ``thr=0.5`` as + default. + """ + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specicy from {average_options}.' + + def _format_label(label, is_indices): + """format various label to torch.Tensor.""" + if isinstance(label, np.ndarray): + assert label.ndim == 2, 'The shape `pred` and `target` ' \ + 'array must be (N, num_classes).' + label = torch.from_numpy(label) + elif isinstance(label, torch.Tensor): + assert label.ndim == 2, 'The shape `pred` and `target` ' \ + 'tensor must be (N, num_classes).' + elif isinstance(label, Sequence): + if is_indices: + assert num_classes is not None, 'For index-type labels, ' \ + 'please specify `num_classes`.' + label = torch.stack([ + LabelData.label_to_onehot( + to_tensor(indices), num_classes) + for indices in label + ]) + else: + label = torch.stack( + [to_tensor(onehot) for onehot in label]) + else: + raise TypeError( + 'The `pred` and `target` must be type of torch.tensor or ' + f'np.ndarray or sequence but get {type(label)}.') + return label + + pred = _format_label(pred, pred_indices) + target = _format_label(target, target_indices).long() + + assert pred.shape == target.shape, \ + f"The size of pred ({pred.shape}) doesn't match "\ + f'the target ({target.shape}).' + + if num_classes is not None: + assert pred.size(1) == num_classes, \ + f'The shape of `pred` ({pred.shape}) '\ + f"doesn't match the num_classes ({num_classes})." + num_classes = pred.size(1) + + thr = 0.5 if (thr is None and topk is None) else thr + + if thr is not None: + # a label is predicted positive if larger than thr + pos_inds = (pred >= thr).long() + else: + # top-k labels will be predicted positive for any example + _, topk_indices = pred.topk(topk) + pos_inds = torch.zeros_like(pred).scatter_(1, topk_indices, 1) + pos_inds = pos_inds.long() + + return _precision_recall_f1_support(pos_inds, target, average) + + +def _average_precision(pred: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + r"""Calculate the average precision for a single class. + + AP summarizes a precision-recall curve as the weighted mean of maximum + precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + pred (torch.Tensor): The model prediction with shape + ``(N, num_classes)``. + target (torch.Tensor): The target of predictions with shape + ``(N, num_classes)``. + + Returns: + torch.Tensor: average precision result. + """ + assert pred.shape == target.shape, \ + f"The size of pred ({pred.shape}) doesn't match "\ + f'the target ({target.shape}).' + + # a small value for division by zero errors + eps = torch.finfo(torch.float32).eps + + # sort examples + sorted_pred_inds = torch.argsort(pred, dim=0, descending=True) + sorted_target = target[sorted_pred_inds] + + # get indexes when gt_true is positive + pos_inds = sorted_target == 1 + + # Calculate cumulative tp case numbers + tps = torch.cumsum(pos_inds, 0) + total_pos = tps[-1].item() # the last of tensor may change later + + # Calculate cumulative tp&fp(pred_poss) case numbers + pred_pos_nums = torch.arange(1, len(sorted_target) + 1) + pred_pos_nums[pred_pos_nums < eps] = eps + + tps[torch.logical_not(pos_inds)] = 0 + precision = tps / pred_pos_nums + ap = torch.sum(precision, 0) / max(total_pos, eps) + return ap + + +@METRICS.register_module() +class AveragePrecision(BaseMetric): + """Calculate the average precision with respect of classes. + + Args: + average (str | None): The average method. It supports two modes: + + - `"macro"`: Calculate metrics for each category, and calculate + the mean value over all categories. + - `None`: Return scores of all categories. + + Defaults to "macro". + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + References + ---------- + .. [1] `Wikipedia entry for the Average precision + `_ + + Examples: + >>> import torch + >>> from mmcls.metrics import AveragePrecision + >>> # --------- The Basic Usage for one-hot pred scores --------- + >>> y_pred = torch.Tensor([[0.9, 0.8, 0.3, 0.2], + ... [0.1, 0.2, 0.2, 0.1], + ... [0.7, 0.5, 0.9, 0.3], + ... [0.8, 0.1, 0.1, 0.2]]) + >>> y_true = torch.Tensor([[1, 1, 0, 0], + ... [0, 1, 0, 0], + ... [0, 0, 1, 0], + ... [1, 0, 0, 0]]) + >>> AveragePrecision.calculate(y_pred, y_true) + tensor(70.833) + >>> # ------------------- Use with Evalutor ------------------- + >>> from mmcls.core import ClsDataSample + >>> from mmengine.evaluator import Evaluator + >>> # The `data_batch` won't be used in this case, just use a fake. + >>> data_batch = [ + ... {'inputs': None, 'data_sample': ClsDataSample()} + ... for i in range(4)] + >>> pred = [ + ... ClsDataSample().set_pred_score(i).set_gt_score(j) + ... for i, j in zip(y_pred, y_true) + ... ] + >>> evaluator = Evaluator(metrics=AveragePrecision()) + >>> evaluator.process(data_batch, pred) + >>> evaluator.evaluate(5) + {'multi-label/mAP': 70.83333587646484} + >>> # Evaluate on each class + >>> evaluator = Evaluator(metrics=AveragePrecision(average=None)) + >>> evaluator.process(data_batch, pred) + >>> evaluator.evaluate(5) + {'multi-label/AP_classwise': [100., 83.33, 100., 0.]} + """ + default_prefix: Optional[str] = 'multi-label' + + def __init__(self, + average: Optional[str] = 'macro', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.average = average + + def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + + for pred in predictions: + result = dict() + pred_label = pred['pred_label'] + gt_label = pred['gt_label'] + + result['pred_score'] = pred_label['score'] + num_classes = result['pred_score'].size()[-1] + + if 'score' in gt_label: + result['gt_score'] = gt_label['score'] + else: + result['gt_score'] = LabelData.label_to_onehot( + gt_label['label'], num_classes) + + # Save the result to `self.results`. + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + # NOTICE: don't access `self.results` from the method. `self.results` + # are a list of results from multiple batch, while the input `results` + # are the collected results. + + # concat + target = torch.stack([res['gt_score'] for res in results]) + pred = torch.stack([res['pred_score'] for res in results]) + + ap = self.calculate(pred, target, self.average) + + result_metrics = dict() + + if self.average is None: + result_metrics['AP_classwise'] = ap.detach().cpu().tolist() + else: + result_metrics['mAP'] = ap.item() + + return result_metrics + + @staticmethod + def calculate(pred: Union[torch.Tensor, np.ndarray], + target: Union[torch.Tensor, np.ndarray], + average: Optional[str] = 'macro') -> torch.Tensor: + r"""Calculate the average precision for a single class. + + AP summarizes a precision-recall curve as the weighted mean of maximum + precisions obtained for any r'>r, where r is the recall: + + .. math:: + \text{AP} = \sum_n (R_n - R_{n-1}) P_n + + Note that no approximation is involved since the curve is piecewise + constant. + + Args: + pred (torch.Tensor | np.ndarray): The model predictions with + shape ``(N, num_classes)``. + target (torch.Tensor | np.ndarray): The target of predictions + with shape ``(N, num_classes)``. + average (str | None): The average method. It supports two modes: + + - `"macro"`: Calculate metrics for each category, and + calculate the mean value over all categories. + - `None`: Return scores of all categories. + + Defaults to "macro". + + Returns: + torch.Tensor: the average precision of all classes. + """ + average_options = ['macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specicy from {average_options}.' + + pred = to_tensor(pred) + target = to_tensor(target) + assert pred.ndim == 2 and pred.shape == target.shape, \ + 'Both `pred` and `target` should have shape `(N, num_classes)`.' + + num_classes = pred.shape[1] + ap = pred.new_zeros(num_classes) + for k in range(num_classes): + ap[k] = _average_precision(pred[:, k], target[:, k]) + if average == 'macro': + return ap.mean() * 100.0 + else: + return ap * 100 diff --git a/mmcls/metrics/single_label.py b/mmcls/metrics/single_label.py index 42586583..eff97c3f 100644 --- a/mmcls/metrics/single_label.py +++ b/mmcls/metrics/single_label.py @@ -21,6 +21,37 @@ def to_tensor(value): return value +def _precision_recall_f1_support(pred_positive, gt_positive, average): + """calculate base classification task metrics, such as precision, recall, + f1_score, support.""" + average_options = ['micro', 'macro', None] + assert average in average_options, 'Invalid `average` argument, ' \ + f'please specicy from {average_options}.' + + class_correct = (pred_positive & gt_positive) + if average == 'micro': + tp_sum = class_correct.sum() + pred_sum = pred_positive.sum() + gt_sum = gt_positive.sum() + else: + tp_sum = class_correct.sum(0) + pred_sum = pred_positive.sum(0) + gt_sum = gt_positive.sum(0) + + precision = tp_sum / torch.clamp(pred_sum, min=1.) * 100 + recall = tp_sum / torch.clamp(gt_sum, min=1.) * 100 + f1_score = 2 * precision * recall / torch.clamp( + precision + recall, min=torch.finfo(torch.float32).eps) + if average in ['macro', 'micro']: + precision = precision.mean(0) + recall = recall.mean(0) + f1_score = f1_score.mean(0) + support = gt_sum.sum(0) + else: + support = gt_sum + return precision, recall, f1_score, support + + @METRICS.register_module() class Accuracy(BaseMetric): """Top-k accuracy evaluation metric. @@ -327,9 +358,9 @@ class SingleLabelMetric(BaseMetric): >>> evaluator.process(data_batch, pred) >>> evaluator.evaluate(1000) { - 'single-label/precision': [21.14, 18.69, 17.17, 19.42, 16.14], - 'single-label/recall': [18.5, 18.5, 17.0, 20.0, 18.0], - 'single-label/f1-score': [19.73, 18.59, 17.09, 19.70, 17.02] + 'single-label/precision_classwise': [21.1, 18.7, 17.8, 19.4, 16.1], + 'single-label/recall_classwise': [18.5, 18.5, 17.0, 20.0, 18.0], + 'single-label/f1-score_classwise': [19.7, 18.6, 17.1, 19.7, 17.0] } """ default_prefix: Optional[str] = 'single-label' @@ -438,13 +469,17 @@ class SingleLabelMetric(BaseMetric): num_classes=results[0]['num_classes']) metrics = pack_results(*res) + result_metrics = dict() for k, v in metrics.items(): - if self.average is not None: - metrics[k] = v.item() - else: - metrics[k] = v.cpu().detach().tolist() - return metrics + if self.average is None: + result_metrics[k + '_classwise'] = v.cpu().detach().tolist() + elif self.average == 'micro': + result_metrics[k + f'_{self.average}'] = v.item() + else: + result_metrics[k] = v.item() + + return result_metrics @staticmethod def calculate( @@ -503,38 +538,14 @@ class SingleLabelMetric(BaseMetric): f"The size of pred ({pred.size(0)}) doesn't match "\ f'the target ({target.size(0)}).' - def _do_calculate(pred_positive, gt_positive): - class_correct = (pred_positive & gt_positive) - if average == 'micro': - tp_sum = class_correct.sum() - pred_sum = pred_positive.sum() - gt_sum = gt_positive.sum() - else: - tp_sum = class_correct.sum(0) - pred_sum = pred_positive.sum(0) - gt_sum = gt_positive.sum(0) - - precision = tp_sum / np.maximum(pred_sum, 1.) * 100 - recall = tp_sum / np.maximum(gt_sum, 1.) * 100 - f1_score = 2 * precision * recall / np.maximum( - precision + recall, - torch.finfo(torch.float32).eps) - if average in ['macro', 'micro']: - precision = precision.mean(0, keepdim=True) - recall = recall.mean(0, keepdim=True) - f1_score = f1_score.mean(0, keepdim=True) - support = gt_sum.sum(0, keepdim=True) - else: - support = gt_sum - return precision, recall, f1_score, support - if pred.ndim == 1: assert num_classes is not None, \ 'Please specicy the `num_classes` if the `pred` is labels ' \ 'intead of scores.' gt_positive = F.one_hot(target.flatten(), num_classes) pred_positive = F.one_hot(pred.to(torch.int64), num_classes) - return _do_calculate(pred_positive, gt_positive) + return _precision_recall_f1_support(pred_positive, gt_positive, + average) else: # For pred score, calculate on all thresholds. num_classes = pred.size(1) @@ -549,6 +560,8 @@ class SingleLabelMetric(BaseMetric): pred_positive = F.one_hot(pred_label, num_classes) if thr is not None: pred_positive[pred_score <= thr] = 0 - results.append(_do_calculate(pred_positive, gt_positive)) + results.append( + _precision_recall_f1_support(pred_positive, gt_positive, + average)) return results diff --git a/requirements/tests.txt b/requirements/tests.txt index 29d351b5..0afe4eba 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -4,5 +4,6 @@ interrogate isort==4.3.21 mmdet pytest +sklearn xdoctest >= 0.10.0 yapf diff --git a/tests/test_metrics/test_multi_label.py b/tests/test_metrics/test_multi_label.py new file mode 100644 index 00000000..f87443e7 --- /dev/null +++ b/tests/test_metrics/test_multi_label.py @@ -0,0 +1,398 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import sklearn.metrics +import torch +from mmengine.evaluator import Evaluator + +from mmcls.core import ClsDataSample +from mmcls.metrics import AveragePrecision, MultiLabelMetric +from mmcls.utils import register_all_modules + +register_all_modules() + + +class TestMultiLabel(TestCase): + + def test_calculate(self): + """Test using the metric from static method.""" + + y_true = [[0], [1, 3], [0, 1, 2], [3]] + y_pred = [[0, 3], [0, 2], [1, 2], [2, 3]] + y_true_binary = np.array([ + [1, 0, 0, 0], + [0, 1, 0, 1], + [1, 1, 1, 0], + [0, 0, 0, 1], + ]) + y_pred_binary = np.array([ + [1, 0, 0, 1], + [1, 0, 1, 0], + [0, 1, 1, 0], + [0, 0, 1, 1], + ]) + y_pred_score = np.array([ + [0.8, 0, 0, 0.6], + [0.2, 0, 0.6, 0], + [0, 0.9, 0.6, 0], + [0, 0, 0.2, 0.3], + ]) + + # Test with sequence of category indexes + res = MultiLabelMetric.calculate( + y_pred, + y_true, + pred_indices=True, + target_indices=True, + num_classes=4) + self.assertIsInstance(res, tuple) + precision, recall, f1_score, support = res + expect_precision = sklearn.metrics.precision_score( + y_true_binary, y_pred_binary, average='macro') * 100 + expect_recall = sklearn.metrics.recall_score( + y_true_binary, y_pred_binary, average='macro') * 100 + expect_f1 = sklearn.metrics.f1_score( + y_true_binary, y_pred_binary, average='macro') * 100 + self.assertTensorEqual(precision, expect_precision) + self.assertTensorEqual(recall, expect_recall) + self.assertTensorEqual(f1_score, expect_f1) + self.assertTensorEqual(support, 7) + + # Test with onehot input + res = MultiLabelMetric.calculate(y_pred_binary, + torch.from_numpy(y_true_binary)) + self.assertIsInstance(res, tuple) + precision, recall, f1_score, support = res + # Expected values come from sklearn + self.assertTensorEqual(precision, expect_precision) + self.assertTensorEqual(recall, expect_recall) + self.assertTensorEqual(f1_score, expect_f1) + self.assertTensorEqual(support, 7) + + # Test with topk argument + res = MultiLabelMetric.calculate( + y_pred_score, y_true, target_indices=True, topk=1, num_classes=4) + self.assertIsInstance(res, tuple) + precision, recall, f1_score, support = res + # Expected values come from sklearn + top1_y_pred = np.array([ + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ]) + expect_precision = sklearn.metrics.precision_score( + y_true_binary, top1_y_pred, average='macro') * 100 + expect_recall = sklearn.metrics.recall_score( + y_true_binary, top1_y_pred, average='macro') * 100 + expect_f1 = sklearn.metrics.f1_score( + y_true_binary, top1_y_pred, average='macro') * 100 + self.assertTensorEqual(precision, expect_precision) + self.assertTensorEqual(recall, expect_recall) + self.assertTensorEqual(f1_score, expect_f1) + self.assertTensorEqual(support, 7) + + # Test with thr argument + res = MultiLabelMetric.calculate( + y_pred_score, y_true, target_indices=True, thr=0.25, num_classes=4) + self.assertIsInstance(res, tuple) + precision, recall, f1_score, support = res + # Expected values come from sklearn + thr_y_pred = np.array([ + [1, 0, 0, 1], + [0, 0, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 1], + ]) + expect_precision = sklearn.metrics.precision_score( + y_true_binary, thr_y_pred, average='macro') * 100 + expect_recall = sklearn.metrics.recall_score( + y_true_binary, thr_y_pred, average='macro') * 100 + expect_f1 = sklearn.metrics.f1_score( + y_true_binary, thr_y_pred, average='macro') * 100 + self.assertTensorEqual(precision, expect_precision) + self.assertTensorEqual(recall, expect_recall) + self.assertTensorEqual(f1_score, expect_f1) + self.assertTensorEqual(support, 7) + + # Test with invalid inputs + with self.assertRaisesRegex(TypeError, " is not"): + MultiLabelMetric.calculate(y_pred, 'hi', num_classes=10) + + # Test with invalid input + with self.assertRaisesRegex(AssertionError, + 'Invalid `average` argument,'): + MultiLabelMetric.calculate( + y_pred, y_true, average='m', num_classes=10) + + y_true_binary = np.array([[1, 0, 0, 0], [0, 1, 0, 1]]) + y_pred_binary = np.array([[1, 0, 0, 1], [1, 0, 1, 0], [0, 1, 1, 0]]) + # Test with invalid inputs + with self.assertRaisesRegex(AssertionError, 'The size of pred'): + MultiLabelMetric.calculate(y_pred_binary, y_true_binary) + + # Test with invalid inputs + with self.assertRaisesRegex(TypeError, 'The `pred` and `target` must'): + MultiLabelMetric.calculate(y_pred_binary, 5) + + def test_evaluate(self): + fake_data_batch = [{ + 'inputs': None, + 'data_sample': ClsDataSample() + } for _ in range(4)] + + y_true = [[0], [1, 3], [0, 1, 2], [3]] + y_true_binary = torch.tensor([ + [1, 0, 0, 0], + [0, 1, 0, 1], + [1, 1, 1, 0], + [0, 0, 0, 1], + ]) + y_pred_score = torch.tensor([ + [0.8, 0, 0, 0.6], + [0.2, 0, 0.6, 0], + [0, 0.9, 0.6, 0], + [0, 0, 0.2, 0.3], + ]) + + pred = [ + ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j) + for i, j in zip(y_pred_score, y_true) + ] + + # Test with default argument + evaluator = Evaluator(dict(type='MultiLabelMetric')) + evaluator.process(fake_data_batch, pred) + res = evaluator.evaluate(4) + self.assertIsInstance(res, dict) + thr05_y_pred = np.array([ + [1, 0, 0, 1], + [0, 0, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0], + ]) + expect_precision = sklearn.metrics.precision_score( + y_true_binary, thr05_y_pred, average='macro') * 100 + expect_recall = sklearn.metrics.recall_score( + y_true_binary, thr05_y_pred, average='macro') * 100 + expect_f1 = sklearn.metrics.f1_score( + y_true_binary, thr05_y_pred, average='macro') * 100 + self.assertEqual(res['multi-label/precision'], expect_precision) + self.assertEqual(res['multi-label/recall'], expect_recall) + self.assertEqual(res['multi-label/f1-score'], expect_f1) + + # Test with topk argument + evaluator = Evaluator(dict(type='MultiLabelMetric', topk=1)) + evaluator.process(fake_data_batch, pred) + res = evaluator.evaluate(4) + self.assertIsInstance(res, dict) + top1_y_pred = np.array([ + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ]) + expect_precision = sklearn.metrics.precision_score( + y_true_binary, top1_y_pred, average='macro') * 100 + expect_recall = sklearn.metrics.recall_score( + y_true_binary, top1_y_pred, average='macro') * 100 + expect_f1 = sklearn.metrics.f1_score( + y_true_binary, top1_y_pred, average='macro') * 100 + self.assertEqual(res['multi-label/precision_top1'], expect_precision) + self.assertEqual(res['multi-label/recall_top1'], expect_recall) + self.assertEqual(res['multi-label/f1-score_top1'], expect_f1) + + # Test with both argument + evaluator = Evaluator(dict(type='MultiLabelMetric', thr=0.25, topk=1)) + evaluator.process(fake_data_batch, pred) + res = evaluator.evaluate(4) + self.assertIsInstance(res, dict) + # Expected values come from sklearn + thr_y_pred = np.array([ + [1, 0, 0, 1], + [0, 0, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 1], + ]) + expect_precision = sklearn.metrics.precision_score( + y_true_binary, thr_y_pred, average='macro') * 100 + expect_recall = sklearn.metrics.recall_score( + y_true_binary, thr_y_pred, average='macro') * 100 + expect_f1 = sklearn.metrics.f1_score( + y_true_binary, thr_y_pred, average='macro') * 100 + self.assertEqual(res['multi-label/precision_thr-0.25'], + expect_precision) + self.assertEqual(res['multi-label/recall_thr-0.25'], expect_recall) + self.assertEqual(res['multi-label/f1-score_thr-0.25'], expect_f1) + + # Test with average micro + evaluator = Evaluator(dict(type='MultiLabelMetric', average='micro')) + evaluator.process(fake_data_batch, pred) + res = evaluator.evaluate(4) + self.assertIsInstance(res, dict) + # Expected values come from sklearn + expect_precision = sklearn.metrics.precision_score( + y_true_binary, thr05_y_pred, average='micro') * 100 + expect_recall = sklearn.metrics.recall_score( + y_true_binary, thr05_y_pred, average='micro') * 100 + expect_f1 = sklearn.metrics.f1_score( + y_true_binary, thr05_y_pred, average='micro') * 100 + self.assertAlmostEqual( + res['multi-label/precision_micro'], expect_precision, places=4) + self.assertAlmostEqual( + res['multi-label/recall_micro'], expect_recall, places=4) + self.assertAlmostEqual( + res['multi-label/f1-score_micro'], expect_f1, places=4) + + # Test with average None + evaluator = Evaluator(dict(type='MultiLabelMetric', average=None)) + evaluator.process(fake_data_batch, pred) + res = evaluator.evaluate(4) + self.assertIsInstance(res, dict) + # Expected values come from sklearn + expect_precision = sklearn.metrics.precision_score( + y_true_binary, thr05_y_pred, average=None) * 100 + expect_recall = sklearn.metrics.recall_score( + y_true_binary, thr05_y_pred, average=None) * 100 + expect_f1 = sklearn.metrics.f1_score( + y_true_binary, thr05_y_pred, average=None) * 100 + np.testing.assert_allclose(res['multi-label/precision_classwise'], + expect_precision) + np.testing.assert_allclose(res['multi-label/recall_classwise'], + expect_recall) + np.testing.assert_allclose(res['multi-label/f1-score_classwise'], + expect_f1) + + # Test with gt_score + pred = [ + ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j) + for i, j in zip(y_pred_score, y_true_binary) + ] + + evaluator = Evaluator(dict(type='MultiLabelMetric', items=['support'])) + evaluator.process(fake_data_batch, pred) + res = evaluator.evaluate(4) + self.assertIsInstance(res, dict) + self.assertEqual(res['multi-label/support'], 7) + + def assertTensorEqual(self, + tensor: torch.Tensor, + value: float, + msg=None, + **kwarg): + tensor = tensor.to(torch.float32) + if tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + value = torch.FloatTensor([value]) + try: + torch.testing.assert_allclose(tensor, value, **kwarg) + except AssertionError as e: + self.fail(self._formatMessage(msg, str(e) + str(tensor))) + + +class TestAveragePrecision(TestCase): + + def test_evaluate(self): + """Test using the metric in the same way as Evalutor.""" + y_pred = torch.tensor([ + [0.9, 0.8, 0.3, 0.2], + [0.1, 0.2, 0.2, 0.1], + [0.7, 0.5, 0.9, 0.3], + [0.8, 0.1, 0.1, 0.2], + ]) + y_true = torch.tensor([ + [1, 1, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [1, 0, 0, 0], + ]) + + fake_data_batch = [{ + 'inputs': None, + 'data_sample': ClsDataSample() + } for _ in range(4)] + + pred = [ + ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(j) + for i, j in zip(y_pred, y_true) + ] + + # Test with default macro avergae + evaluator = Evaluator(dict(type='AveragePrecision')) + evaluator.process(fake_data_batch, pred) + res = evaluator.evaluate(5) + self.assertIsInstance(res, dict) + self.assertAlmostEqual(res['multi-label/mAP'], 70.83333, places=4) + + # Test with average mode None + evaluator = Evaluator(dict(type='AveragePrecision', average=None)) + evaluator.process(fake_data_batch, pred) + res = evaluator.evaluate(5) + self.assertIsInstance(res, dict) + aps = res['multi-label/AP_classwise'] + self.assertAlmostEqual(aps[0], 100., places=4) + self.assertAlmostEqual(aps[1], 83.3333, places=4) + self.assertAlmostEqual(aps[2], 100, places=4) + self.assertAlmostEqual(aps[3], 0, places=4) + + # Test with gt_label without score + pred = [ + ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j) + for i, j in zip(y_pred, [[0, 1], [1], [2], [0]]) + ] + evaluator = Evaluator(dict(type='AveragePrecision')) + evaluator.process(fake_data_batch, pred) + res = evaluator.evaluate(5) + self.assertAlmostEqual(res['multi-label/mAP'], 70.83333, places=4) + + def test_calculate(self): + """Test using the metric from static method.""" + + y_true = np.array([ + [1, 0, 0, 0], + [0, 1, 0, 1], + [1, 1, 1, 0], + [0, 0, 0, 1], + ]) + y_pred = np.array([ + [0.9, 0.8, 0.3, 0.2], + [0.1, 0.2, 0.2, 0.1], + [0.7, 0.5, 0.9, 0.3], + [0.8, 0.1, 0.1, 0.2], + ]) + + ap_score = AveragePrecision.calculate(y_pred, y_true) + expect_ap = sklearn.metrics.average_precision_score(y_true, + y_pred) * 100 + self.assertTensorEqual(ap_score, expect_ap) + + # Test with invalid inputs + with self.assertRaisesRegex(AssertionError, + 'Invalid `average` argument,'): + AveragePrecision.calculate(y_pred, y_true, average='m') + + y_true = np.array([[1, 0, 0, 0], [0, 1, 0, 1]]) + y_pred = np.array([[1, 0, 0, 1], [1, 0, 1, 0], [0, 1, 1, 0]]) + # Test with invalid inputs + with self.assertRaisesRegex(AssertionError, + 'Both `pred` and `target`'): + AveragePrecision.calculate(y_pred, y_true) + + # Test with invalid inputs + with self.assertRaisesRegex(TypeError, " is not an"): + AveragePrecision.calculate(y_pred, 5) + + def assertTensorEqual(self, + tensor: torch.Tensor, + value: float, + msg=None, + **kwarg): + tensor = tensor.to(torch.float32) + if tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + value = torch.FloatTensor([value]) + try: + torch.testing.assert_allclose(tensor, value, **kwarg) + except AssertionError as e: + self.fail(self._formatMessage(msg, str(e) + str(tensor))) diff --git a/tests/test_metrics/test_single_label.py b/tests/test_metrics/test_single_label.py index 4e37c68e..087ad96a 100644 --- a/tests/test_metrics/test_single_label.py +++ b/tests/test_metrics/test_single_label.py @@ -183,10 +183,13 @@ class TestSingleLabel(TestCase): metric.process(data_batch, pred) res = metric.evaluate(6) self.assertIsInstance(res, dict) - self.assertAlmostEqual(res['single-label/precision'], 66.666, places=2) - self.assertAlmostEqual(res['single-label/recall'], 66.666, places=2) - self.assertAlmostEqual(res['single-label/f1-score'], 66.666, places=2) - self.assertEqual(res['single-label/support'], 6) + self.assertAlmostEqual( + res['single-label/precision_micro'], 66.666, places=2) + self.assertAlmostEqual( + res['single-label/recall_micro'], 66.666, places=2) + self.assertAlmostEqual( + res['single-label/f1-score_micro'], 66.666, places=2) + self.assertEqual(res['single-label/support_micro'], 6) # Test with average mode None metric = METRICS.build( @@ -197,19 +200,19 @@ class TestSingleLabel(TestCase): metric.process(data_batch, pred) res = metric.evaluate(6) self.assertIsInstance(res, dict) - precision = res['single-label/precision'] + precision = res['single-label/precision_classwise'] self.assertAlmostEqual(precision[0], 100., places=4) self.assertAlmostEqual(precision[1], 100., places=4) self.assertAlmostEqual(precision[2], 1 / 3 * 100, places=4) - recall = res['single-label/recall'] + recall = res['single-label/recall_classwise'] self.assertAlmostEqual(recall[0], 2 / 3 * 100, places=4) self.assertAlmostEqual(recall[1], 50., places=4) self.assertAlmostEqual(recall[2], 100., places=4) - f1_score = res['single-label/f1-score'] + f1_score = res['single-label/f1-score_classwise'] self.assertAlmostEqual(f1_score[0], 80., places=4) self.assertAlmostEqual(f1_score[1], 2 / 3 * 100, places=4) self.assertAlmostEqual(f1_score[2], 50., places=4) - self.assertEqual(res['single-label/support'], [3, 2, 1]) + self.assertEqual(res['single-label/support_classwise'], [3, 2, 1]) # Test with label, the thrs will be ignored pred_no_score = copy.deepcopy(pred) @@ -293,7 +296,7 @@ class TestSingleLabel(TestCase): msg=None, **kwarg): tensor = tensor.to(torch.float32) - value = torch.FloatTensor([value]) + value = torch.tensor(value).float() try: torch.testing.assert_allclose(tensor, value, **kwarg) except AssertionError as e: