From 1f07c92ed1cdf9e6db951c15311296d1057999e1 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Fri, 26 May 2023 10:40:08 +0800 Subject: [PATCH] [Feature] Add retrieval mAP metric. (#1552) * rebase * fefine * fix lint * update readme * rebase * fix lint * update docstring * update docstring * rebase * rename corespanding names * rebase --- configs/_base_/datasets/inshop_bs32_448.py | 5 +- configs/arcface/README.md | 10 +- configs/arcface/metafile.yml | 3 +- docs/en/api/evaluation.rst | 1 + mmpretrain/apis/image_retrieval.py | 2 +- mmpretrain/evaluation/metrics/__init__.py | 5 +- mmpretrain/evaluation/metrics/retrieval.py | 215 +++++++++++++++++- projects/gradio_demo/launch.py | 2 +- .../test_metrics/test_retrieval.py | 109 ++++++++- 9 files changed, 339 insertions(+), 13 deletions(-) diff --git a/configs/_base_/datasets/inshop_bs32_448.py b/configs/_base_/datasets/inshop_bs32_448.py index 585f301d..f9772fa6 100644 --- a/configs/_base_/datasets/inshop_bs32_448.py +++ b/configs/_base_/datasets/inshop_bs32_448.py @@ -55,7 +55,10 @@ gallery_dataloader = dict( sampler=dict(type='DefaultSampler', shuffle=False), ) val_dataloader = query_dataloader -val_evaluator = dict(type='RetrievalRecall', topk=1) +val_evaluator = [ + dict(type='RetrievalRecall', topk=1), + dict(type='RetrievalAveragePrecision', topk=10), +] test_dataloader = val_dataloader test_evaluator = val_evaluator diff --git a/configs/arcface/README.md b/configs/arcface/README.md index c1384da7..6b2ee6a3 100644 --- a/configs/arcface/README.md +++ b/configs/arcface/README.md @@ -21,7 +21,7 @@ Recently, a popular line of research in face recognition is adopting margins in ```python from mmpretrain import ImageRetrievalInferencer -inferencer = ImageRetrievalInferencer('resnet50-arcface_8xb32_inshop', prototype='demo/') +inferencer = ImageRetrievalInferencer('resnet50-arcface_inshop', prototype='demo/') predict = inferencer('demo/dog.jpg', topk=2)[0] print(predict[0]) print(predict[1]) @@ -33,7 +33,7 @@ print(predict[1]) import torch from mmpretrain import get_model -model = get_model('resnet50-arcface_8xb32_inshop', pretrained=True) +model = get_model('resnet50-arcface_inshop', pretrained=True) inputs = torch.rand(1, 3, 224, 224) out = model(inputs) print(type(out)) @@ -64,9 +64,9 @@ python tools/test.py configs/arcface/resnet50-arcface_8xb32_inshop.py https://do ### Image Retrieval on InShop -| Model | Pretrain | Params (M) | Flops (G) | Recall@1 | Config | Download | -| :------------------------------ | :----------: | :--------: | :-------: | :------: | :----------------------------------------: | :--------------------------------------------------------------------------------------: | -| `resnet50-arcface_8xb32_inshop` | From scratch | 31.69 | 16.57 | 90.18 | [config](resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) | +| Model | Pretrain | Params(M) | Flops(G) | Recall@1 | mAP@10 | Config | Download | +| :-----------------------: | :------------------------------------------------: | :-------: | :------: | :------: | :----: | :------------------------------------------: | :------------------------------------------------: | +| `resnet50-arcface_inshop` | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 31.69 | 16.48 | 90.18 | 69.30 | [config](./resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) | ## Citation diff --git a/configs/arcface/metafile.yml b/configs/arcface/metafile.yml index 20080ddd..050aba5b 100644 --- a/configs/arcface/metafile.yml +++ b/configs/arcface/metafile.yml @@ -13,7 +13,7 @@ Collections: URL: https://github.com/open-mmlab/mmpretrain/blob/v1.0.0rc3/mmcls/models/heads/margin_head.py Models: - - Name: resnet50-arcface_8xb32_inshop + - Name: resnet50-arcface_inshop Metadata: FLOPs: 16571226112 Parameters: 31693888 @@ -22,6 +22,7 @@ Models: - Dataset: InShop Metrics: Recall@1: 90.18 + mAP@10: 69.30 Task: Image Retrieval Weights: https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth Config: configs/arcface/resnet50-arcface_8xb32_inshop.py diff --git a/docs/en/api/evaluation.rst b/docs/en/api/evaluation.rst index 53440714..bddea207 100644 --- a/docs/en/api/evaluation.rst +++ b/docs/en/api/evaluation.rst @@ -44,3 +44,4 @@ Retrieval Metric :template: classtemplate.rst RetrievalRecall + RetrievalAveragePrecision diff --git a/mmpretrain/apis/image_retrieval.py b/mmpretrain/apis/image_retrieval.py index 980d65cc..b88fa658 100644 --- a/mmpretrain/apis/image_retrieval.py +++ b/mmpretrain/apis/image_retrieval.py @@ -46,7 +46,7 @@ class ImageRetrievalInferencer(BaseInferencer): Example: >>> from mmpretrain import ImageRetrievalInferencer >>> inferencer = ImageRetrievalInferencer( - ... 'resnet50-arcface_8xb32_inshop', + ... 'resnet50-arcface_inshop', ... prototype='./demo/', ... prototype_cache='img_retri.pth') >>> inferencer('demo/cat-dog.png', topk=2)[0][1] diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py index a5fa179f..7f5a4f36 100644 --- a/mmpretrain/evaluation/metrics/__init__.py +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -4,7 +4,7 @@ from .gqa import GQAAcc from .multi_label import AveragePrecision, MultiLabelMetric from .multi_task import MultiTasksMetric from .nocaps import NocapsSave -from .retrieval import RetrievalRecall +from .retrieval import RetrievalAveragePrecision, RetrievalRecall from .scienceqa import ScienceQAMetric from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric from .visual_grounding_eval import VisualGroundingMetric @@ -15,5 +15,6 @@ __all__ = [ 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', - 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave' + 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave', + 'RetrievalAveragePrecision' ] diff --git a/mmpretrain/evaluation/metrics/retrieval.py b/mmpretrain/evaluation/metrics/retrieval.py index 3269faeb..9813486b 100644 --- a/mmpretrain/evaluation/metrics/retrieval.py +++ b/mmpretrain/evaluation/metrics/retrieval.py @@ -65,7 +65,8 @@ class RetrievalRecall(BaseMetric): .. code:: python - val/test_evaluator = dict(type='RetrievalRecall', topk=(1, 5)) + val_evaluator = dict(type='RetrievalRecall', topk=(1, 5)) + test_evaluator = val_evaluator """ default_prefix: Optional[str] = 'retrieval' @@ -183,6 +184,218 @@ class RetrievalRecall(BaseMetric): return results +@METRICS.register_module() +class RetrievalAveragePrecision(BaseMetric): + r"""Calculate the average precision for image retrieval. + + Args: + topk (int, optional): Predictions with the k-th highest scores are + considered as positive. + mode (str, optional): The mode to calculate AP, choose from + 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + + Note: + If the ``mode`` set to 'IR', use the stanford AP calculation of + information retrieval as in wikipedia page[1]; if set to 'integrate', + the method implemented integrates over the precision-recall curve + by averaging two adjacent precision points, then multiplying by the + recall step like mAP in Detection task. This is the convention for + the Revisited Oxford/Paris datasets[2]. + + References: + [1] `Wikipedia entry for the Average precision `_ + + [2] `The Oxford Buildings Dataset + `_ + + Examples: + Use in code: + + >>> import torch + >>> import numpy as np + >>> from mmcls.evaluation import RetrievalAveragePrecision + >>> # using index format inputs + >>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3 + >>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]] + >>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True) + 29.246031746031747 + >>> # using tensor format inputs + >>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + >>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2) + >>> RetrievalAveragePrecision.calculate(pred, target, 10) + 62.222222222222214 + + Use in OpenMMLab config files: + + .. code:: python + + val_evaluator = dict(type='RetrievalAveragePrecision', topk=100) + test_evaluator = val_evaluator + """ + + default_prefix: Optional[str] = 'retrieval' + + def __init__(self, + topk: Optional[int] = None, + mode: Optional[str] = 'IR', + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + if topk is None or (isinstance(topk, int) and topk <= 0): + raise ValueError('`topk` must be a ingter larger than 0.') + + mode_options = ['IR', 'integrate'] + assert mode in mode_options, \ + f'Invalid `mode` argument, please specify from {mode_options}.' + + self.topk = topk + self.mode = mode + super().__init__(collect_device=collect_device, prefix=prefix) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]): + """Process one batch of data and predictions. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_score = data_sample.get('pred_score').clone() + + if 'gt_score' in data_sample: + target = data_sample.get('gt_score').clone() + else: + gt_label = data_sample.get('gt_label') + num_classes = pred_score.size()[-1] + target = label_to_onehot(gt_label, num_classes) + + # Because the retrieval output logit vector will be much larger + # compared to the normal classification, to save resources, the + # evaluation results are computed each batch here and then reduce + # all results at the end. + result = RetrievalAveragePrecision.calculate( + pred_score.unsqueeze(0), + target.unsqueeze(0), + self.topk, + mode=self.mode) + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + result_metrics = dict() + result_metrics[f'mAP@{self.topk}'] = np.mean(self.results).item() + + return result_metrics + + @staticmethod + def calculate(pred: Union[np.ndarray, torch.Tensor], + target: Union[np.ndarray, torch.Tensor], + topk: Optional[int] = None, + pred_indices: (bool) = False, + target_indices: (bool) = False, + mode: str = 'IR') -> float: + """Calculate the average precision. + Args: + pred (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + target (torch.Tensor | np.ndarray | Sequence): The prediction + results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with + shape ``(N, M)`` or a sequence of index/onehot + format labels. + topk (int, optional): Predictions with the k-th highest scores + are considered as positive. + pred_indices (bool): Whether the ``pred`` is a sequence of + category index labels. Defaults to False. + target_indices (bool): Whether the ``target`` is a sequence of + category index labels. Defaults to False. + mode (Optional[str]): The mode to calculate AP, choose from + 'IR'(information retrieval) and 'integrate'. Defaults to 'IR'. + + Note: + If the ``mode`` set to 'IR', use the stanford AP calculation of + information retrieval as in wikipedia page; if set to 'integrate', + the method implemented integrates over the precision-recall curve + by averaging two adjacent precision points, then multiplying by the + recall step like mAP in Detection task. This is the convention for + the Revisited Oxford/Paris datasets. + + Returns: + float: the average precision of the query image. + + References: + [1] `Wikipedia entry for Average precision(information_retrieval) + `_ + [2] `The Oxford Buildings Dataset 0 else 1 + cur_precision = (i + 1) / (rank + 1) + prediction = (old_precision + cur_precision) / 2 + ap += prediction + ap = ap / len(target) + + return ap * 100 + + def _format_pred(label, topk=None, is_indices=False): """format various label to List[indices].""" if is_indices: diff --git a/projects/gradio_demo/launch.py b/projects/gradio_demo/launch.py index bd4fa780..191ae094 100644 --- a/projects/gradio_demo/launch.py +++ b/projects/gradio_demo/launch.py @@ -199,7 +199,7 @@ class ImageRetrievalTab: elem_id='image_retri_models', elem_classes='select_model', choices=self.model_list, - value='resnet50-arcface_8xb32_inshop', + value='resnet50-arcface_inshop', ) topk = gr.Slider(minimum=1, maximum=6, value=3, step=1) with gr.Column(): diff --git a/tests/test_evaluation/test_metrics/test_retrieval.py b/tests/test_evaluation/test_metrics/test_retrieval.py index 94c79913..de49754a 100644 --- a/tests/test_evaluation/test_metrics/test_retrieval.py +++ b/tests/test_evaluation/test_metrics/test_retrieval.py @@ -4,7 +4,8 @@ from unittest import TestCase import numpy as np import torch -from mmpretrain.evaluation.metrics import RetrievalRecall +from mmpretrain.evaluation.metrics import (RetrievalAveragePrecision, + RetrievalRecall) from mmpretrain.registry import METRICS from mmpretrain.structures import DataSample @@ -118,3 +119,109 @@ class TestRetrievalRecall(TestCase): with self.assertRaisesRegex(AssertionError, '`pred` should be'): RetrievalRecall.calculate( y_pred, y_true, topk=1, pred_indices=True, target_indices=True) + + +class TestRetrievalAveragePrecision(TestCase): + + def test_evaluate(self): + """Test using the metric in the same way as Evalutor.""" + y_true = torch.tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1], + [0, 1, 0, 0, 1, 0, 1, 0, 0, 0]]) + y_pred = torch.tensor([np.linspace(0.95, 0.05, 10)] * 2) + + pred = [ + DataSample().set_pred_score(i).set_gt_score(j) + for i, j in zip(y_pred, y_true) + ] + + # Test with default macro avergae + metric = METRICS.build(dict(type='RetrievalAveragePrecision', topk=10)) + metric.process([], pred) + res = metric.evaluate(len(pred)) + self.assertIsInstance(res, dict) + self.assertAlmostEqual( + res['retrieval/mAP@10'], 53.25396825396825, places=4) + + # Test with invalid topk + with self.assertRaisesRegex(ValueError, '`topk` must be a'): + METRICS.build(dict(type='RetrievalAveragePrecision', topk=-1)) + + # Test with invalid mode + with self.assertRaisesRegex(AssertionError, 'Invalid `mode` '): + METRICS.build( + dict(type='RetrievalAveragePrecision', topk=5, mode='m')) + + def test_calculate(self): + """Test using the metric from static method.""" + # Test IR mode + # example from https://zhuanlan.zhihu.com/p/35983818 + # or https://www.youtube.com/watch?v=pM6DJ0ZZee0 + + # seq of indices format + y_true = [[0, 2, 5, 8, 9], [1, 4, 6]] + y_pred = [np.arange(10)] * 2 + + # test with average is 'macro' + ap_score = RetrievalAveragePrecision.calculate(y_pred, y_true, 10, + True, True) + expect_ap = 53.25396825396825 + self.assertEqual(ap_score.item(), expect_ap) + + # test with tensor input + y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1], + [0, 1, 0, 0, 1, 0, 1, 0, 0, 0]]) + y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + ap_score = RetrievalAveragePrecision.calculate(y_pred, y_true, 10) + expect_ap = 53.25396825396825 + self.assertEqual(ap_score.item(), expect_ap) + + # test with topk is 5 + y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + ap_score = RetrievalAveragePrecision.calculate(y_pred, y_true, topk=5) + expect_ap = 31.666666666666664 + self.assertEqual(ap_score.item(), expect_ap) + + # Test with invalid mode + with self.assertRaisesRegex(AssertionError, 'Invalid `mode` '): + RetrievalAveragePrecision.calculate( + y_pred, y_true, True, True, mode='m') + + # Test with invalid pred + y_pred = dict() + y_true = [[0, 2, 5, 8, 9], [1, 4, 6]] + with self.assertRaisesRegex(AssertionError, '`pred` must be Seq'): + RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True) + + # Test with invalid target + y_true = dict() + y_pred = [np.arange(10)] * 2 + with self.assertRaisesRegex(AssertionError, '`target` must be Seq'): + RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True) + + # Test with different length `pred` with `target` + y_true = [[0, 2, 5, 8, 9], [1, 4, 6]] + y_pred = [np.arange(10)] * 3 + with self.assertRaisesRegex(AssertionError, 'Length of `pred`'): + RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True) + + # Test with invalid pred + y_true = [[0, 2, 5, 8, 9], dict()] + y_pred = [np.arange(10)] * 2 + with self.assertRaisesRegex(AssertionError, '`target` should be'): + RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True) + + # Test with invalid target + y_true = [[0, 2, 5, 8, 9], [1, 4, 6]] + y_pred = [np.arange(10), dict()] + with self.assertRaisesRegex(AssertionError, '`pred` should be'): + RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True) + + # Test with mode 'integrate' + y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1], + [0, 1, 0, 0, 1, 0, 1, 0, 0, 0]]) + y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2) + + ap_score = RetrievalAveragePrecision.calculate( + y_pred, y_true, topk=5, mode='integrate') + expect_ap = 25.416666666666664 + self.assertEqual(ap_score.item(), expect_ap)