[Feature] Add retrieval mAP metric. (#1552)
* rebase * fefine * fix lint * update readme * rebase * fix lint * update docstring * update docstring * rebase * rename corespanding names * rebasepull/1637/head
parent
9bb692e440
commit
1f07c92ed1
|
@ -55,7 +55,10 @@ gallery_dataloader = dict(
|
|||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_dataloader = query_dataloader
|
||||
val_evaluator = dict(type='RetrievalRecall', topk=1)
|
||||
val_evaluator = [
|
||||
dict(type='RetrievalRecall', topk=1),
|
||||
dict(type='RetrievalAveragePrecision', topk=10),
|
||||
]
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
||||
|
|
|
@ -21,7 +21,7 @@ Recently, a popular line of research in face recognition is adopting margins in
|
|||
```python
|
||||
from mmpretrain import ImageRetrievalInferencer
|
||||
|
||||
inferencer = ImageRetrievalInferencer('resnet50-arcface_8xb32_inshop', prototype='demo/')
|
||||
inferencer = ImageRetrievalInferencer('resnet50-arcface_inshop', prototype='demo/')
|
||||
predict = inferencer('demo/dog.jpg', topk=2)[0]
|
||||
print(predict[0])
|
||||
print(predict[1])
|
||||
|
@ -33,7 +33,7 @@ print(predict[1])
|
|||
import torch
|
||||
from mmpretrain import get_model
|
||||
|
||||
model = get_model('resnet50-arcface_8xb32_inshop', pretrained=True)
|
||||
model = get_model('resnet50-arcface_inshop', pretrained=True)
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
out = model(inputs)
|
||||
print(type(out))
|
||||
|
@ -64,9 +64,9 @@ python tools/test.py configs/arcface/resnet50-arcface_8xb32_inshop.py https://do
|
|||
|
||||
### Image Retrieval on InShop
|
||||
|
||||
| Model | Pretrain | Params (M) | Flops (G) | Recall@1 | Config | Download |
|
||||
| :------------------------------ | :----------: | :--------: | :-------: | :------: | :----------------------------------------: | :--------------------------------------------------------------------------------------: |
|
||||
| `resnet50-arcface_8xb32_inshop` | From scratch | 31.69 | 16.57 | 90.18 | [config](resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) |
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Recall@1 | mAP@10 | Config | Download |
|
||||
| :-----------------------: | :------------------------------------------------: | :-------: | :------: | :------: | :----: | :------------------------------------------: | :------------------------------------------------: |
|
||||
| `resnet50-arcface_inshop` | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 31.69 | 16.48 | 90.18 | 69.30 | [config](./resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ Collections:
|
|||
URL: https://github.com/open-mmlab/mmpretrain/blob/v1.0.0rc3/mmcls/models/heads/margin_head.py
|
||||
|
||||
Models:
|
||||
- Name: resnet50-arcface_8xb32_inshop
|
||||
- Name: resnet50-arcface_inshop
|
||||
Metadata:
|
||||
FLOPs: 16571226112
|
||||
Parameters: 31693888
|
||||
|
@ -22,6 +22,7 @@ Models:
|
|||
- Dataset: InShop
|
||||
Metrics:
|
||||
Recall@1: 90.18
|
||||
mAP@10: 69.30
|
||||
Task: Image Retrieval
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth
|
||||
Config: configs/arcface/resnet50-arcface_8xb32_inshop.py
|
||||
|
|
|
@ -44,3 +44,4 @@ Retrieval Metric
|
|||
:template: classtemplate.rst
|
||||
|
||||
RetrievalRecall
|
||||
RetrievalAveragePrecision
|
||||
|
|
|
@ -46,7 +46,7 @@ class ImageRetrievalInferencer(BaseInferencer):
|
|||
Example:
|
||||
>>> from mmpretrain import ImageRetrievalInferencer
|
||||
>>> inferencer = ImageRetrievalInferencer(
|
||||
... 'resnet50-arcface_8xb32_inshop',
|
||||
... 'resnet50-arcface_inshop',
|
||||
... prototype='./demo/',
|
||||
... prototype_cache='img_retri.pth')
|
||||
>>> inferencer('demo/cat-dog.png', topk=2)[0][1]
|
||||
|
|
|
@ -4,7 +4,7 @@ from .gqa import GQAAcc
|
|||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
from .multi_task import MultiTasksMetric
|
||||
from .nocaps import NocapsSave
|
||||
from .retrieval import RetrievalRecall
|
||||
from .retrieval import RetrievalAveragePrecision, RetrievalRecall
|
||||
from .scienceqa import ScienceQAMetric
|
||||
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
|
||||
from .visual_grounding_eval import VisualGroundingMetric
|
||||
|
@ -15,5 +15,6 @@ __all__ = [
|
|||
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
|
||||
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
|
||||
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
|
||||
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave'
|
||||
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
|
||||
'RetrievalAveragePrecision'
|
||||
]
|
||||
|
|
|
@ -65,7 +65,8 @@ class RetrievalRecall(BaseMetric):
|
|||
|
||||
.. code:: python
|
||||
|
||||
val/test_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
|
||||
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
|
||||
test_evaluator = val_evaluator
|
||||
"""
|
||||
default_prefix: Optional[str] = 'retrieval'
|
||||
|
||||
|
@ -183,6 +184,218 @@ class RetrievalRecall(BaseMetric):
|
|||
return results
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class RetrievalAveragePrecision(BaseMetric):
|
||||
r"""Calculate the average precision for image retrieval.
|
||||
|
||||
Args:
|
||||
topk (int, optional): Predictions with the k-th highest scores are
|
||||
considered as positive.
|
||||
mode (str, optional): The mode to calculate AP, choose from
|
||||
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
|
||||
Note:
|
||||
If the ``mode`` set to 'IR', use the stanford AP calculation of
|
||||
information retrieval as in wikipedia page[1]; if set to 'integrate',
|
||||
the method implemented integrates over the precision-recall curve
|
||||
by averaging two adjacent precision points, then multiplying by the
|
||||
recall step like mAP in Detection task. This is the convention for
|
||||
the Revisited Oxford/Paris datasets[2].
|
||||
|
||||
References:
|
||||
[1] `Wikipedia entry for the Average precision <https://en.wikipedia.
|
||||
org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_
|
||||
|
||||
[2] `The Oxford Buildings Dataset
|
||||
<https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/>`_
|
||||
|
||||
Examples:
|
||||
Use in code:
|
||||
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from mmcls.evaluation import RetrievalAveragePrecision
|
||||
>>> # using index format inputs
|
||||
>>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3
|
||||
>>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]]
|
||||
>>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True)
|
||||
29.246031746031747
|
||||
>>> # using tensor format inputs
|
||||
>>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
||||
>>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2)
|
||||
>>> RetrievalAveragePrecision.calculate(pred, target, 10)
|
||||
62.222222222222214
|
||||
|
||||
Use in OpenMMLab config files:
|
||||
|
||||
.. code:: python
|
||||
|
||||
val_evaluator = dict(type='RetrievalAveragePrecision', topk=100)
|
||||
test_evaluator = val_evaluator
|
||||
"""
|
||||
|
||||
default_prefix: Optional[str] = 'retrieval'
|
||||
|
||||
def __init__(self,
|
||||
topk: Optional[int] = None,
|
||||
mode: Optional[str] = 'IR',
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
if topk is None or (isinstance(topk, int) and topk <= 0):
|
||||
raise ValueError('`topk` must be a ingter larger than 0.')
|
||||
|
||||
mode_options = ['IR', 'integrate']
|
||||
assert mode in mode_options, \
|
||||
f'Invalid `mode` argument, please specify from {mode_options}.'
|
||||
|
||||
self.topk = topk
|
||||
self.mode = mode
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
data_samples: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in data_samples:
|
||||
pred_score = data_sample.get('pred_score').clone()
|
||||
|
||||
if 'gt_score' in data_sample:
|
||||
target = data_sample.get('gt_score').clone()
|
||||
else:
|
||||
gt_label = data_sample.get('gt_label')
|
||||
num_classes = pred_score.size()[-1]
|
||||
target = label_to_onehot(gt_label, num_classes)
|
||||
|
||||
# Because the retrieval output logit vector will be much larger
|
||||
# compared to the normal classification, to save resources, the
|
||||
# evaluation results are computed each batch here and then reduce
|
||||
# all results at the end.
|
||||
result = RetrievalAveragePrecision.calculate(
|
||||
pred_score.unsqueeze(0),
|
||||
target.unsqueeze(0),
|
||||
self.topk,
|
||||
mode=self.mode)
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
result_metrics = dict()
|
||||
result_metrics[f'mAP@{self.topk}'] = np.mean(self.results).item()
|
||||
|
||||
return result_metrics
|
||||
|
||||
@staticmethod
|
||||
def calculate(pred: Union[np.ndarray, torch.Tensor],
|
||||
target: Union[np.ndarray, torch.Tensor],
|
||||
topk: Optional[int] = None,
|
||||
pred_indices: (bool) = False,
|
||||
target_indices: (bool) = False,
|
||||
mode: str = 'IR') -> float:
|
||||
"""Calculate the average precision.
|
||||
Args:
|
||||
pred (torch.Tensor | np.ndarray | Sequence): The prediction
|
||||
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
|
||||
shape ``(N, M)`` or a sequence of index/onehot
|
||||
format labels.
|
||||
target (torch.Tensor | np.ndarray | Sequence): The prediction
|
||||
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
|
||||
shape ``(N, M)`` or a sequence of index/onehot
|
||||
format labels.
|
||||
topk (int, optional): Predictions with the k-th highest scores
|
||||
are considered as positive.
|
||||
pred_indices (bool): Whether the ``pred`` is a sequence of
|
||||
category index labels. Defaults to False.
|
||||
target_indices (bool): Whether the ``target`` is a sequence of
|
||||
category index labels. Defaults to False.
|
||||
mode (Optional[str]): The mode to calculate AP, choose from
|
||||
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
|
||||
|
||||
Note:
|
||||
If the ``mode`` set to 'IR', use the stanford AP calculation of
|
||||
information retrieval as in wikipedia page; if set to 'integrate',
|
||||
the method implemented integrates over the precision-recall curve
|
||||
by averaging two adjacent precision points, then multiplying by the
|
||||
recall step like mAP in Detection task. This is the convention for
|
||||
the Revisited Oxford/Paris datasets.
|
||||
|
||||
Returns:
|
||||
float: the average precision of the query image.
|
||||
|
||||
References:
|
||||
[1] `Wikipedia entry for Average precision(information_retrieval)
|
||||
<https://en.wikipedia.org/wiki/Evaluation_measures_
|
||||
|
||||
(information_retrieval)#Average_precision>`_
|
||||
[2] `The Oxford Buildings Dataset <https://www.robots.ox.ac.uk/
|
||||
~vgg/data/oxbuildings/`_
|
||||
"""
|
||||
if topk is None or (isinstance(topk, int) and topk <= 0):
|
||||
raise ValueError('`topk` must be a ingter larger than 0.')
|
||||
|
||||
mode_options = ['IR', 'integrate']
|
||||
assert mode in mode_options, \
|
||||
f'Invalid `mode` argument, please specify from {mode_options}.'
|
||||
|
||||
pred = _format_pred(pred, topk, pred_indices)
|
||||
target = _format_target(target, target_indices)
|
||||
|
||||
assert len(pred) == len(target), (
|
||||
f'Length of `pred`({len(pred)}) and `target` ({len(target)}) '
|
||||
f'must be the same.')
|
||||
|
||||
num_samples = len(pred)
|
||||
aps = np.zeros(num_samples)
|
||||
for i, (sample_pred, sample_target) in enumerate(zip(pred, target)):
|
||||
aps[i] = _calculateAp_for_sample(sample_pred, sample_target, mode)
|
||||
|
||||
return aps.mean()
|
||||
|
||||
|
||||
def _calculateAp_for_sample(pred, target, mode):
|
||||
pred = np.array(to_tensor(pred).cpu())
|
||||
target = np.array(to_tensor(target).cpu())
|
||||
|
||||
num_preds = len(pred)
|
||||
|
||||
# TODO: use ``torch.isin`` in torch1.10.
|
||||
positive_ranks = np.arange(num_preds)[np.in1d(pred, target)]
|
||||
|
||||
ap = 0
|
||||
for i, rank in enumerate(positive_ranks):
|
||||
if mode == 'IR':
|
||||
precision = (i + 1) / (rank + 1)
|
||||
ap += precision
|
||||
elif mode == 'integrate':
|
||||
# code are modified from https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/compute_ap.cpp # noqa:
|
||||
old_precision = i / rank if rank > 0 else 1
|
||||
cur_precision = (i + 1) / (rank + 1)
|
||||
prediction = (old_precision + cur_precision) / 2
|
||||
ap += prediction
|
||||
ap = ap / len(target)
|
||||
|
||||
return ap * 100
|
||||
|
||||
|
||||
def _format_pred(label, topk=None, is_indices=False):
|
||||
"""format various label to List[indices]."""
|
||||
if is_indices:
|
||||
|
|
|
@ -199,7 +199,7 @@ class ImageRetrievalTab:
|
|||
elem_id='image_retri_models',
|
||||
elem_classes='select_model',
|
||||
choices=self.model_list,
|
||||
value='resnet50-arcface_8xb32_inshop',
|
||||
value='resnet50-arcface_inshop',
|
||||
)
|
||||
topk = gr.Slider(minimum=1, maximum=6, value=3, step=1)
|
||||
with gr.Column():
|
||||
|
|
|
@ -4,7 +4,8 @@ from unittest import TestCase
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmpretrain.evaluation.metrics import RetrievalRecall
|
||||
from mmpretrain.evaluation.metrics import (RetrievalAveragePrecision,
|
||||
RetrievalRecall)
|
||||
from mmpretrain.registry import METRICS
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
@ -118,3 +119,109 @@ class TestRetrievalRecall(TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, '`pred` should be'):
|
||||
RetrievalRecall.calculate(
|
||||
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
||||
|
||||
|
||||
class TestRetrievalAveragePrecision(TestCase):
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
y_true = torch.tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
|
||||
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
|
||||
y_pred = torch.tensor([np.linspace(0.95, 0.05, 10)] * 2)
|
||||
|
||||
pred = [
|
||||
DataSample().set_pred_score(i).set_gt_score(j)
|
||||
for i, j in zip(y_pred, y_true)
|
||||
]
|
||||
|
||||
# Test with default macro avergae
|
||||
metric = METRICS.build(dict(type='RetrievalAveragePrecision', topk=10))
|
||||
metric.process([], pred)
|
||||
res = metric.evaluate(len(pred))
|
||||
self.assertIsInstance(res, dict)
|
||||
self.assertAlmostEqual(
|
||||
res['retrieval/mAP@10'], 53.25396825396825, places=4)
|
||||
|
||||
# Test with invalid topk
|
||||
with self.assertRaisesRegex(ValueError, '`topk` must be a'):
|
||||
METRICS.build(dict(type='RetrievalAveragePrecision', topk=-1))
|
||||
|
||||
# Test with invalid mode
|
||||
with self.assertRaisesRegex(AssertionError, 'Invalid `mode` '):
|
||||
METRICS.build(
|
||||
dict(type='RetrievalAveragePrecision', topk=5, mode='m'))
|
||||
|
||||
def test_calculate(self):
|
||||
"""Test using the metric from static method."""
|
||||
# Test IR mode
|
||||
# example from https://zhuanlan.zhihu.com/p/35983818
|
||||
# or https://www.youtube.com/watch?v=pM6DJ0ZZee0
|
||||
|
||||
# seq of indices format
|
||||
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
||||
y_pred = [np.arange(10)] * 2
|
||||
|
||||
# test with average is 'macro'
|
||||
ap_score = RetrievalAveragePrecision.calculate(y_pred, y_true, 10,
|
||||
True, True)
|
||||
expect_ap = 53.25396825396825
|
||||
self.assertEqual(ap_score.item(), expect_ap)
|
||||
|
||||
# test with tensor input
|
||||
y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
|
||||
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
|
||||
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
||||
ap_score = RetrievalAveragePrecision.calculate(y_pred, y_true, 10)
|
||||
expect_ap = 53.25396825396825
|
||||
self.assertEqual(ap_score.item(), expect_ap)
|
||||
|
||||
# test with topk is 5
|
||||
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
||||
ap_score = RetrievalAveragePrecision.calculate(y_pred, y_true, topk=5)
|
||||
expect_ap = 31.666666666666664
|
||||
self.assertEqual(ap_score.item(), expect_ap)
|
||||
|
||||
# Test with invalid mode
|
||||
with self.assertRaisesRegex(AssertionError, 'Invalid `mode` '):
|
||||
RetrievalAveragePrecision.calculate(
|
||||
y_pred, y_true, True, True, mode='m')
|
||||
|
||||
# Test with invalid pred
|
||||
y_pred = dict()
|
||||
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
||||
with self.assertRaisesRegex(AssertionError, '`pred` must be Seq'):
|
||||
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
|
||||
|
||||
# Test with invalid target
|
||||
y_true = dict()
|
||||
y_pred = [np.arange(10)] * 2
|
||||
with self.assertRaisesRegex(AssertionError, '`target` must be Seq'):
|
||||
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
|
||||
|
||||
# Test with different length `pred` with `target`
|
||||
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
||||
y_pred = [np.arange(10)] * 3
|
||||
with self.assertRaisesRegex(AssertionError, 'Length of `pred`'):
|
||||
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
|
||||
|
||||
# Test with invalid pred
|
||||
y_true = [[0, 2, 5, 8, 9], dict()]
|
||||
y_pred = [np.arange(10)] * 2
|
||||
with self.assertRaisesRegex(AssertionError, '`target` should be'):
|
||||
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
|
||||
|
||||
# Test with invalid target
|
||||
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
||||
y_pred = [np.arange(10), dict()]
|
||||
with self.assertRaisesRegex(AssertionError, '`pred` should be'):
|
||||
RetrievalAveragePrecision.calculate(y_pred, y_true, 10, True, True)
|
||||
|
||||
# Test with mode 'integrate'
|
||||
y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
|
||||
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
|
||||
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
||||
|
||||
ap_score = RetrievalAveragePrecision.calculate(
|
||||
y_pred, y_true, topk=5, mode='integrate')
|
||||
expect_ap = 25.416666666666664
|
||||
self.assertEqual(ap_score.item(), expect_ap)
|
||||
|
|
Loading…
Reference in New Issue