[Feature] Add retrieval mAP metric. (#1552)

* rebase

* fefine

* fix lint

* update readme

* rebase

* fix lint

* update docstring

* update docstring

* rebase

* rename corespanding names

* rebase
pull/1637/head
Ezra-Yu 2023-05-26 10:40:08 +08:00 committed by GitHub
parent 9bb692e440
commit 1f07c92ed1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 339 additions and 13 deletions

View File

@ -55,7 +55,10 @@ gallery_dataloader = dict(
sampler=dict(type='DefaultSampler', shuffle=False), sampler=dict(type='DefaultSampler', shuffle=False),
) )
val_dataloader = query_dataloader val_dataloader = query_dataloader
val_evaluator = dict(type='RetrievalRecall', topk=1) val_evaluator = [
dict(type='RetrievalRecall', topk=1),
dict(type='RetrievalAveragePrecision', topk=10),
]
test_dataloader = val_dataloader test_dataloader = val_dataloader
test_evaluator = val_evaluator test_evaluator = val_evaluator

View File

@ -21,7 +21,7 @@ Recently, a popular line of research in face recognition is adopting margins in
```python ```python
from mmpretrain import ImageRetrievalInferencer from mmpretrain import ImageRetrievalInferencer
inferencer = ImageRetrievalInferencer('resnet50-arcface_8xb32_inshop', prototype='demo/') inferencer = ImageRetrievalInferencer('resnet50-arcface_inshop', prototype='demo/')
predict = inferencer('demo/dog.jpg', topk=2)[0] predict = inferencer('demo/dog.jpg', topk=2)[0]
print(predict[0]) print(predict[0])
print(predict[1]) print(predict[1])
@ -33,7 +33,7 @@ print(predict[1])
import torch import torch
from mmpretrain import get_model from mmpretrain import get_model
model = get_model('resnet50-arcface_8xb32_inshop', pretrained=True) model = get_model('resnet50-arcface_inshop', pretrained=True)
inputs = torch.rand(1, 3, 224, 224) inputs = torch.rand(1, 3, 224, 224)
out = model(inputs) out = model(inputs)
print(type(out)) print(type(out))
@ -64,9 +64,9 @@ python tools/test.py configs/arcface/resnet50-arcface_8xb32_inshop.py https://do
### Image Retrieval on InShop ### Image Retrieval on InShop
| Model | Pretrain | Params (M) | Flops (G) | Recall@1 | Config | Download | | Model | Pretrain | Params(M) | Flops(G) | Recall@1 | mAP@10 | Config | Download |
| :------------------------------ | :----------: | :--------: | :-------: | :------: | :----------------------------------------: | :--------------------------------------------------------------------------------------: | | :-----------------------: | :------------------------------------------------: | :-------: | :------: | :------: | :----: | :------------------------------------------: | :------------------------------------------------: |
| `resnet50-arcface_8xb32_inshop` | From scratch | 31.69 | 16.57 | 90.18 | [config](resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) | | `resnet50-arcface_inshop` | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 31.69 | 16.48 | 90.18 | 69.30 | [config](./resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) |
## Citation ## Citation

View File

@ -13,7 +13,7 @@ Collections:
URL: https://github.com/open-mmlab/mmpretrain/blob/v1.0.0rc3/mmcls/models/heads/margin_head.py URL: https://github.com/open-mmlab/mmpretrain/blob/v1.0.0rc3/mmcls/models/heads/margin_head.py
Models: Models:
- Name: resnet50-arcface_8xb32_inshop - Name: resnet50-arcface_inshop
Metadata: Metadata:
FLOPs: 16571226112 FLOPs: 16571226112
Parameters: 31693888 Parameters: 31693888
@ -22,6 +22,7 @@ Models:
- Dataset: InShop - Dataset: InShop
Metrics: Metrics:
Recall@1: 90.18 Recall@1: 90.18
mAP@10: 69.30
Task: Image Retrieval Task: Image Retrieval
Weights: https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth Weights: https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth
Config: configs/arcface/resnet50-arcface_8xb32_inshop.py Config: configs/arcface/resnet50-arcface_8xb32_inshop.py

View File

@ -44,3 +44,4 @@ Retrieval Metric
:template: classtemplate.rst :template: classtemplate.rst
RetrievalRecall RetrievalRecall
RetrievalAveragePrecision

View File

@ -46,7 +46,7 @@ class ImageRetrievalInferencer(BaseInferencer):
Example: Example:
>>> from mmpretrain import ImageRetrievalInferencer >>> from mmpretrain import ImageRetrievalInferencer
>>> inferencer = ImageRetrievalInferencer( >>> inferencer = ImageRetrievalInferencer(
... 'resnet50-arcface_8xb32_inshop', ... 'resnet50-arcface_inshop',
... prototype='./demo/', ... prototype='./demo/',
... prototype_cache='img_retri.pth') ... prototype_cache='img_retri.pth')
>>> inferencer('demo/cat-dog.png', topk=2)[0][1] >>> inferencer('demo/cat-dog.png', topk=2)[0][1]

View File

@ -4,7 +4,7 @@ from .gqa import GQAAcc
from .multi_label import AveragePrecision, MultiLabelMetric from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric from .multi_task import MultiTasksMetric
from .nocaps import NocapsSave from .nocaps import NocapsSave
from .retrieval import RetrievalRecall from .retrieval import RetrievalAveragePrecision, RetrievalRecall
from .scienceqa import ScienceQAMetric from .scienceqa import ScienceQAMetric
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
from .visual_grounding_eval import VisualGroundingMetric from .visual_grounding_eval import VisualGroundingMetric
@ -15,5 +15,6 @@ __all__ = [
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave' 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
'RetrievalAveragePrecision'
] ]

View File

@ -65,7 +65,8 @@ class RetrievalRecall(BaseMetric):
.. code:: python .. code:: python
val/test_evaluator = dict(type='RetrievalRecall', topk=(1, 5)) val_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
test_evaluator = val_evaluator
""" """
default_prefix: Optional[str] = 'retrieval' default_prefix: Optional[str] = 'retrieval'
@ -183,6 +184,218 @@ class RetrievalRecall(BaseMetric):
return results return results
@METRICS.register_module()
class RetrievalAveragePrecision(BaseMetric):
r"""Calculate the average precision for image retrieval.
Args:
topk (int, optional): Predictions with the k-th highest scores are
considered as positive.
mode (str, optional): The mode to calculate AP, choose from
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Note:
If the ``mode`` set to 'IR', use the stanford AP calculation of
information retrieval as in wikipedia page[1]; if set to 'integrate',
the method implemented integrates over the precision-recall curve
by averaging two adjacent precision points, then multiplying by the
recall step like mAP in Detection task. This is the convention for
the Revisited Oxford/Paris datasets[2].
References:
[1] `Wikipedia entry for the Average precision <https://en.wikipedia.
org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_
[2] `The Oxford Buildings Dataset
<https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/>`_
Examples:
Use in code:
>>> import torch
>>> import numpy as np
>>> from mmcls.evaluation import RetrievalAveragePrecision
>>> # using index format inputs
>>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3
>>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]]
>>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True)
29.246031746031747
>>> # using tensor format inputs
>>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
>>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2)
>>> RetrievalAveragePrecision.calculate(pred, target, 10)
62.222222222222214
Use in OpenMMLab config files:
.. code:: python
val_evaluator = dict(type='RetrievalAveragePrecision', topk=100)
test_evaluator = val_evaluator
"""
default_prefix: Optional[str] = 'retrieval'
def __init__(self,
topk: Optional[int] = None,
mode: Optional[str] = 'IR',
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
if topk is None or (isinstance(topk, int) and topk <= 0):
raise ValueError('`topk` must be a ingter larger than 0.')
mode_options = ['IR', 'integrate']
assert mode in mode_options, \
f'Invalid `mode` argument, please specify from {mode_options}.'
self.topk = topk
self.mode = mode
super().__init__(collect_device=collect_device, prefix=prefix)
def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]):
"""Process one batch of data and predictions.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
pred_score = data_sample.get('pred_score').clone()
if 'gt_score' in data_sample:
target = data_sample.get('gt_score').clone()
else:
gt_label = data_sample.get('gt_label')
num_classes = pred_score.size()[-1]
target = label_to_onehot(gt_label, num_classes)
# Because the retrieval output logit vector will be much larger
# compared to the normal classification, to save resources, the
# evaluation results are computed each batch here and then reduce
# all results at the end.
result = RetrievalAveragePrecision.calculate(
pred_score.unsqueeze(0),
target.unsqueeze(0),
self.topk,
mode=self.mode)
self.results.append(result)
def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
result_metrics = dict()
result_metrics[f'mAP@{self.topk}'] = np.mean(self.results).item()
return result_metrics
@staticmethod
def calculate(pred: Union[np.ndarray, torch.Tensor],
target: Union[np.ndarray, torch.Tensor],
topk: Optional[int] = None,
pred_indices: (bool) = False,
target_indices: (bool) = False,
mode: str = 'IR') -> float:
"""Calculate the average precision.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
target (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
topk (int, optional): Predictions with the k-th highest scores
are considered as positive.
pred_indices (bool): Whether the ``pred`` is a sequence of
category index labels. Defaults to False.
target_indices (bool): Whether the ``target`` is a sequence of
category index labels. Defaults to False.
mode (Optional[str]): The mode to calculate AP, choose from
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
Note:
If the ``mode`` set to 'IR', use the stanford AP calculation of
information retrieval as in wikipedia page; if set to 'integrate',
the method implemented integrates over the precision-recall curve
by averaging two adjacent precision points, then multiplying by the
recall step like mAP in Detection task. This is the convention for
the Revisited Oxford/Paris datasets.
Returns:
float: the average precision of the query image.
References:
[1] `Wikipedia entry for Average precision(information_retrieval)
<https://en.wikipedia.org/wiki/Evaluation_measures_
(information_retrieval)#Average_precision>`_
[2] `The Oxford Buildings Dataset <https://www.robots.ox.ac.uk/
~vgg/data/oxbuildings/`_
"""
if topk is None or (isinstance(topk, int) and topk <= 0):
raise ValueError('`topk` must be a ingter larger than 0.')
mode_options = ['IR', 'integrate']
assert mode in mode_options, \
f'Invalid `mode` argument, please specify from {mode_options}.'
pred = _format_pred(pred, topk, pred_indices)
target = _format_target(target, target_indices)
assert len(pred) == len(target), (
f'Length of `pred`({len(pred)}) and `target` ({len(target)}) '
f'must be the same.')
num_samples = len(pred)
aps = np.zeros(num_samples)
for i, (sample_pred, sample_target) in enumerate(zip(pred, target)):
aps[i] = _calculateAp_for_sample(sample_pred, sample_target, mode)
return aps.mean()
def _calculateAp_for_sample(pred, target, mode):
pred = np.array(to_tensor(pred).cpu())
target = np.array(to_tensor(target).cpu())
num_preds = len(pred)
# TODO: use ``torch.isin`` in torch1.10.
positive_ranks = np.arange(num_preds)[np.in1d(pred, target)]
ap = 0
for i, rank in enumerate(positive_ranks):
if mode == 'IR':
precision = (i + 1) / (rank + 1)
ap += precision
elif mode == 'integrate':
# code are modified from https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/compute_ap.cpp # noqa:
old_precision = i / rank if rank > 0 else 1
cur_precision = (i + 1) / (rank + 1)
prediction = (old_precision + cur_precision) / 2
ap += prediction
ap = ap / len(target)
return ap * 100
def _format_pred(label, topk=None, is_indices=False): def _format_pred(label, topk=None, is_indices=False):
"""format various label to List[indices].""" """format various label to List[indices]."""
if is_indices: if is_indices:

View File

@ -199,7 +199,7 @@ class ImageRetrievalTab:
elem_id='image_retri_models', elem_id='image_retri_models',
elem_classes='select_model', elem_classes='select_model',
choices=self.model_list, choices=self.model_list,
value='resnet50-arcface_8xb32_inshop', value='resnet50-arcface_inshop',
) )
topk = gr.Slider(minimum=1, maximum=6, value=3, step=1) topk = gr.Slider(minimum=1, maximum=6, value=3, step=1)
with gr.Column(): with gr.Column():

View File

@ -4,7 +4,8 @@ from unittest import TestCase
import numpy as np import numpy as np
import torch import torch
from mmpretrain.evaluation.metrics import RetrievalRecall from mmpretrain.evaluation.metrics import (RetrievalAveragePrecision,
RetrievalRecall)
from mmpretrain.registry import METRICS from mmpretrain.registry import METRICS
from mmpretrain.structures import DataSample from mmpretrain.structures import DataSample
@ -118,3 +119,109 @@ class TestRetrievalRecall(TestCase):
with self.assertRaisesRegex(AssertionError, '`pred` should be'): with self.assertRaisesRegex(AssertionError, '`pred` should be'):
RetrievalRecall.calculate( RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True) y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
class TestRetrievalAveragePrecision(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
y_true = torch.tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
y_pred = torch.tensor([np.linspace(0.95, 0.05, 10)] * 2)
pred = [
DataSample().set_pred_score(i).set_gt_score(j)
for i, j in zip(y_pred, y_true)
]
# Test with default macro avergae
metric = METRICS.build(dict(type='RetrievalAveragePrecision', topk=10))
metric.process([], pred)
res = metric.evaluate(len(pred))
self.assertIsInstance(res, dict)
self.assertAlmostEqual(
res['retrieval/mAP@10'], 53.25396825396825, places=4)
# Test with invalid topk
with self.assertRaisesRegex(ValueError, '`topk` must be a'):
METRICS.build(dict(type='RetrievalAveragePrecision', topk=-1))
# Test with invalid mode
with self.assertRaisesRegex(AssertionError, 'Invalid `mode` '):
METRICS.build(
dict(type='RetrievalAveragePrecision', topk=5, mode='m'))
def test_calculate(self):
"""Test using the metric from static method."""
# Test IR mode
# example from https://zhuanlan.zhihu.com/p/35983818
# or https://www.youtube.com/watch?v=pM6DJ0ZZee0
# seq of indices format
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
y_pred = [np.arange(10)] * 2
# test with average is 'macro'
ap_score = RetrievalAveragePrecision.calculate(y_pred, y_true, 10,
True, True)
expect_ap = 53.25396825396825
self.assertEqual(ap_score.item(), expect_ap)
# test with tensor input
y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
ap_score = RetrievalAveragePrecision.calculate(y_pred, y_true, 10)
expect_ap = 53.25396825396825
self.assertEqual(ap_score.item(), expect_ap)
# test with topk is 5
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
ap_score = RetrievalAveragePrecision.calculate(y_pred, y_true, topk=5)
expect_ap = 31.666666666666664
self.assertEqual(ap_score.item(), expect_ap)
# Test with invalid mode
with self.assertRaisesRegex(AssertionError, 'Invalid `mode` '):
RetrievalAveragePrecision.calculate(
y_pred, y_true, True, True, mode='m')
# Test with invalid pred
y_pred = dict()
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
with self.assertRaisesRegex(AssertionError, '`pred` must be Seq'):
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
# Test with invalid target
y_true = dict()
y_pred = [np.arange(10)] * 2
with self.assertRaisesRegex(AssertionError, '`target` must be Seq'):
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
# Test with different length `pred` with `target`
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
y_pred = [np.arange(10)] * 3
with self.assertRaisesRegex(AssertionError, 'Length of `pred`'):
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
# Test with invalid pred
y_true = [[0, 2, 5, 8, 9], dict()]
y_pred = [np.arange(10)] * 2
with self.assertRaisesRegex(AssertionError, '`target` should be'):
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
# Test with invalid target
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
y_pred = [np.arange(10), dict()]
with self.assertRaisesRegex(AssertionError, '`pred` should be'):
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
# Test with mode 'integrate'
y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
ap_score = RetrievalAveragePrecision.calculate(
y_pred, y_true, topk=5, mode='integrate')
expect_ap = 25.416666666666664
self.assertEqual(ap_score.item(), expect_ap)