[Feature] Support RetrieverRecall metric & Add ArcFace config (#1316)

* rebase

* add ap metric

* fix mlti-gpu bug in retrevel

* rebase

* rebase

* add training cfgs and update readme.md

* fix bugs(cannot load vecs in dist and diff test-val recall\)

* update configs and readme

* fix ut

* fix doc

* rebase

* fix rebase conflicts

* fix rebase error

* fix UT error

* fix docs

* fix typo

---------

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
pull/1287/head^2
takuoko 2023-02-14 13:46:21 +09:00 committed by GitHub
parent 1c1273abca
commit 841256b630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 592 additions and 23 deletions

View File

@ -0,0 +1,61 @@
# dataset settings
dataset_type = 'InShop'
data_preprocessor = dict(
num_classes=3997,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=512),
dict(type='RandomCrop', crop_size=448),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=512),
dict(type='CenterCrop', crop_size=448),
dict(type='PackClsInputs'),
]
train_dataloader = dict(
batch_size=32,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/inshop',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
query_dataloader = dict(
batch_size=32,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/inshop',
split='query',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
gallery_dataloader = dict(
batch_size=32,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/inshop',
split='gallery',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_dataloader = query_dataloader
val_evaluator = dict(type='RetrievalRecall', topk=1)
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,32 @@
# ArcFace
> [ArcFace: Additive Angular Margin Loss for Deep Face Recognition](https://arxiv.org/abs/1801.07698)
<!-- [ALGORITHM] -->
## Abstract
Recently, a popular line of research in face recognition is adopting margins in the well-established softmax loss function to maximize class separability. In this paper, we first introduce an Additive Angular Margin Loss (ArcFace), which not only has a clear geometric interpretation but also significantly enhances the discriminative power. Since ArcFace is susceptible to the massive label noise, we further propose sub-center ArcFace, in which each class contains K sub-centers and training samples only need to be close to any of the K positive sub-centers. Sub-center ArcFace encourages one dominant sub-class that contains the majority of clean faces and non-dominant sub-classes that include hard or noisy faces. Based on this self-propelled isolation, we boost the performance through automatically purifying raw web faces under massive real-world noise. Besides discriminative feature embedding, we also explore the inverse problem, mapping feature vectors to face images. Without training any additional generator or discriminator, the pre-trained ArcFace model can generate identity-preserved face images for both subjects inside and outside the training data only by using the network gradient and Batch Normalization (BN) priors. Extensive experiments demonstrate that ArcFace can enhance the discriminative feature embedding as well as strengthen the generative face synthesis.
<div align=center>
<img src="https://user-images.githubusercontent.com/24734142/212606212-8ffc3cd2-dbc1-4abf-8924-22167f3f6e34.png" width="80%"/>
</div>
## Results and models
### InShop
| Model | Pretrain | Params(M) | Flops(G) | Recall@1 | Config | Download |
| :------------: | :------------------------------------------------: | :-------: | :------: | :---: | :-----: | :-----: | :-----------------------------------------------: | :-------------------------------------------------: |
| Resnet50-ArcFace | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 31.69 | 16.48 | 90.18 | [config](./resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) |
## Citation
```bibtex
@inproceedings{deng2018arcface,
title={ArcFace: Additive Angular Margin Loss for Deep Face Recognition},
author={Deng, Jiankang and Guo, Jia and Niannan, Xue and Zafeiriou, Stefanos},
booktitle={CVPR},
year={2019}
}
```

View File

@ -0,0 +1,27 @@
Collections:
- Name: ArcFace
Metadata:
Training Data: InShop
Architecture:
- Additive Angular Margin Loss
Paper:
URL: https://arxiv.org/abs/1801.07698
Title: 'ArcFace: Additive Angular Margin Loss for Deep Face Recognition'
README: configs/arcface/README.md
Code:
Version: v1.0.0rc3
URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc3/mmcls/models/heads/margin_head.py
Models:
- Name: resnet50-arcface_inshop
Metadata:
FLOPs: 16571226112
Parameters: 31693888
In Collection: ArcFace
Results:
- Dataset: InShop
Metrics:
Recall@1: 90.18
Task: Metric Learning
Weights: https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth
Config: configs/arcface/resnet50-arcface_8xb32_inshop.py

View File

@ -0,0 +1,71 @@
_base_ = [
'../_base_/datasets/inshop_bs32_448.py',
'../_base_/schedules/cub_bs64.py',
'../_base_/default_runtime.py',
]
pretrained = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth' # noqa
model = dict(
type='ImageToImageRetriever',
image_encoder=[
dict(
type='ResNet',
depth=50,
init_cfg=dict(
type='Pretrained', checkpoint=pretrained, prefix='backbone')),
dict(type='GlobalAveragePooling'),
],
head=dict(
type='ArcFaceClsHead',
num_classes=3997,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=None),
prototype={{_base_.gallery_dataloader}})
# runtime settings
default_hooks = dict(
# log every 20 intervals
logger=dict(type='LoggerHook', interval=20),
# save last three checkpoints
checkpoint=dict(
type='CheckpointHook',
save_best='auto',
interval=1,
max_keep_ckpts=3,
rule='greater'))
# optimizer
optim_wrapper = dict(
optimizer=dict(
type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005, nesterov=True))
# learning policy
param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
start_factor=0.01,
by_epoch=True,
begin=0,
end=5,
# update by iter
convert_to_iter_based=True),
# main learning rate scheduler
dict(
type='CosineAnnealingLR',
T_max=45,
by_epoch=True,
begin=5,
end=50,
)
]
train_cfg = dict(by_epoch=True, max_epochs=50, val_interval=1)
auto_scale_lr = dict(enable=True, base_batch_size=256)
custom_hooks = [
dict(type='PrepareProtoBeforeValLoopHook'),
dict(type='SyncBuffersHook')
]

View File

@ -33,3 +33,13 @@ Multi Label Metric
MultiLabelMetric
VOCAveragePrecision
VOCMultiLabelMetric
Retrieval Metric
----------------------
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
RetrievalRecall

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import get_file_backend, list_from_file
from mmcls.datasets.base_dataset import BaseDataset
from mmcls.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
@ -14,11 +14,13 @@ class InShop(BaseDataset):
(In-shop Clothes Retrieval Benchmark -> Img -> img.zip,
Eval/list_eval_partition.txt), and organize them as follows way: ::
In-shop dataset directory: ::
In-shop Clothes Retrieval Benchmark (data_root)/
Eval /
list_eval_partition.txt (ann_file)
Img
img/ (img_prefix)
Img (img_prefix)
img/
README.txt
.....
@ -27,7 +29,7 @@ class InShop(BaseDataset):
split (str): Choose from 'train', 'query' and 'gallery'.
Defaults to 'train'.
data_prefix (str | dict): Prefix for training data.
Defaults to 'Img/img'.
Defaults to 'Img'.
ann_file (str): Annotation file path, path relative to
``data_root``. Defaults to 'Eval/list_eval_partition.txt'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
@ -66,7 +68,7 @@ class InShop(BaseDataset):
def __init__(self,
data_root: str,
split: str = 'train',
data_prefix: str = 'Img/img',
data_prefix: str = 'Img',
ann_file: str = 'Eval/list_eval_partition.txt',
**kwargs):
@ -149,9 +151,6 @@ class InShop(BaseDataset):
"""
data_info = self._process_annotations()
data_list = data_info['data_list']
for data in data_list:
data['img_path'] = self.backend.join_path(self.data_root,
data['img_path'])
return data_list
def extra_repr(self):

View File

@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric
from .retrieval import RetrievalRecall
from .single_label import Accuracy, SingleLabelMetric
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
__all__ = [
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric'
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'RetrievalRecall'
]

View File

@ -0,0 +1,234 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Union
import mmengine
import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmengine.structures import LabelData
from mmengine.utils import is_seq_of
from mmcls.registry import METRICS
from .single_label import to_tensor
@METRICS.register_module()
class RetrievalRecall(BaseMetric):
r"""Recall evaluation metric for image retrieval.
Args:
topk (int | Sequence[int]): If the ground truth label matches one of
the best **k** predictions, the sample will be regard as a positive
prediction. If the parameter is a tuple, all of top-k recall will
be calculated and outputted together. Defaults to 1.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Examples:
Use in the code:
>>> import torch
>>> from mmcls.evaluation import RetrievalRecall
>>> # -------------------- The Basic Usage --------------------
>>> y_pred = [[0], [1], [2], [3]]
>>> y_true = [[0, 1], [2], [1], [0, 3]]
>>> RetrievalRecall.calculate(
>>> y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
[tensor([50.])]
>>> # Calculate the recall@1 and recall@5 for non-indices input.
>>> y_score = torch.rand((1000, 10))
>>> import torch.nn.functional as F
>>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10)
>>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5))
[tensor(9.3000), tensor(48.4000)]
>>>
>>> # ------------------- Use with Evalutor -------------------
>>> from mmcls.structures import ClsDataSample
>>> from mmengine.evaluator import Evaluator
>>> data_samples = [
... ClsDataSample().set_gt_label([0, 1]).set_pred_score(
... torch.rand(10))
... for i in range(1000)
... ]
>>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5)))
>>> evaluator.process(data_samples)
>>> evaluator.evaluate(1000)
{'retrieval/Recall@1': 20.700000762939453,
'retrieval/Recall@5': 78.5999984741211}
Use in OpenMMLab configs:
.. code:: python
val/test_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
"""
default_prefix: Optional[str] = 'retrieval'
def __init__(self,
topk: Union[int, Sequence[int]],
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
topk = (topk, ) if isinstance(topk, int) else topk
for k in topk:
if k <= 0:
raise ValueError('`topk` must be a ingter larger than 0 '
'or seq of ingter larger than 0.')
self.topk = topk
super().__init__(collect_device=collect_device, prefix=prefix)
def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]):
"""Process one batch of data and predictions.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']
pred = pred_label['score'].clone()
if 'score' in gt_label:
target = gt_label['score'].clone()
else:
num_classes = pred_label['score'].size()[-1]
target = LabelData.label_to_onehot(gt_label['label'],
num_classes)
# Because the retrieval output logit vector will be much larger
# compared to the normal classification, to save resources, the
# evaluation results are computed each batch here and then reduce
# all results at the end.
result = RetrievalRecall.calculate(
pred.unsqueeze(0), target.unsqueeze(0), topk=self.topk)
self.results.append(result)
def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
result_metrics = dict()
for i, k in enumerate(self.topk):
recall_at_k = sum([r[i].item() for r in results]) / len(results)
result_metrics[f'Recall@{k}'] = recall_at_k
return result_metrics
@staticmethod
def calculate(pred: Union[np.ndarray, torch.Tensor],
target: Union[np.ndarray, torch.Tensor],
topk: Union[int, Sequence[int]],
pred_indices: (bool) = False,
target_indices: (bool) = False) -> float:
"""Calculate the average recall.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
target (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
topk (int, Sequence[int]): Predictions with the k-th highest
scores are considered as positive.
pred_indices (bool): Whether the ``pred`` is a sequence of
category index labels. Defaults to False.
target_indices (bool): Whether the ``target`` is a sequence of
category index labels. Defaults to False.
Returns:
List[float]: the average recalls.
"""
topk = (topk, ) if isinstance(topk, int) else topk
for k in topk:
if k <= 0:
raise ValueError('`topk` must be a ingter larger than 0 '
'or seq of ingter larger than 0.')
max_keep = max(topk)
pred = _format_pred(pred, max_keep, pred_indices)
target = _format_target(target, target_indices)
assert len(pred) == len(target), (
f'Length of `pred`({len(pred)}) and `target` ({len(target)}) '
f'must be the same.')
num_samples = len(pred)
results = []
for k in topk:
recalls = torch.zeros(num_samples)
for i, (sample_pred,
sample_target) in enumerate(zip(pred, target)):
sample_pred = np.array(to_tensor(sample_pred).cpu())
sample_target = np.array(to_tensor(sample_target).cpu())
recalls[i] = int(np.in1d(sample_pred[:k], sample_target).max())
results.append(recalls.mean() * 100)
return results
def _format_pred(label, topk=None, is_indices=False):
"""format various label to List[indices]."""
if is_indices:
assert isinstance(label, Sequence), \
'`pred` must be Sequence of indices when' \
f' `pred_indices` set to True, but get {type(label)}'
for i, sample_pred in enumerate(label):
assert is_seq_of(sample_pred, int) or isinstance(
sample_pred, (np.ndarray, torch.Tensor)), \
'`pred` should be Sequence of indices when `pred_indices`' \
f'set to True. but pred[{i}] is {sample_pred}'
if topk:
label[i] = sample_pred[:min(topk, len(sample_pred))]
return label
if isinstance(label, np.ndarray):
label = torch.from_numpy(label)
elif not isinstance(label, torch.Tensor):
raise TypeError(f'The pred must be type of torch.tensor, '
f'np.ndarray or Sequence but get {type(label)}.')
topk = topk if topk else label.size()[-1]
_, indices = label.topk(topk)
return indices
def _format_target(label, is_indices=False):
"""format various label to List[indices]."""
if is_indices:
assert isinstance(label, Sequence), \
'`target` must be Sequence of indices when' \
f' `target_indices` set to True, but get {type(label)}'
for i, sample_gt in enumerate(label):
assert is_seq_of(sample_gt, int) or isinstance(
sample_gt, (np.ndarray, torch.Tensor)), \
'`target` should be Sequence of indices when ' \
f'`target_indices` set to True. but target[{i}] is {sample_gt}'
return label
if isinstance(label, np.ndarray):
label = torch.from_numpy(label)
elif isinstance(label, Sequence) and not mmengine.is_str(label):
label = torch.tensor(label)
elif not isinstance(label, torch.Tensor):
raise TypeError(f'The pred must be type of torch.tensor, '
f'np.ndarray or Sequence but get {type(label)}.')
indices = [LabelData.onehot_to_label(sample_gt) for sample_gt in label]
return indices

View File

@ -266,6 +266,16 @@ class ImageToImageRetriever(BaseRetriever):
dist.all_reduce(prototype_vecs)
return prototype_vecs
def _get_prototype_vecs_from_path(self, proto_path):
"""get prototype_vecs from prototype path."""
data = [None]
if dist.is_main_process():
data[0] = torch.load(proto_path)
dist.broadcast_object_list(data, src=0)
prototype_vecs = data[0]
assert prototype_vecs is not None
return prototype_vecs
@torch.no_grad()
def prepare_prototype(self):
"""Used in meta testing. This function will be called before the meta
@ -281,7 +291,7 @@ class ImageToImageRetriever(BaseRetriever):
if isinstance(self.prototype, torch.Tensor):
prototype_vecs = self.prototype
elif isinstance(self.prototype, str):
prototype_vecs = torch.load(self.prototype)
prototype_vecs = self._get_prototype_vecs_from_path(self.prototype)
elif isinstance(self.prototype, (dict, DataLoader)):
loader = Runner.build_dataloader(self.prototype)
prototype_vecs = self._get_prototype_vecs_from_dataloader(loader)

View File

@ -51,3 +51,4 @@ Import:
- configs/convnext_v2/metafile.yml
- configs/levit/metafile.yml
- configs/vig/metafile.yml
- configs/arcface/metafile.yml

View File

@ -957,14 +957,14 @@ class TestInShop(TestBaseDataset):
f.write('\n'.join([
'8',
'image_name item_id evaluation_status',
'02_1_front.jpg id_00000002 train',
'02_2_side.jpg id_00000002 train',
'12_3_back.jpg id_00007982 gallery',
'12_7_additional.jpg id_00007982 gallery',
'13_1_front.jpg id_00007982 query',
'13_2_side.jpg id_00007983 gallery',
'13_3_back.jpg id_00007983 query ',
'13_7_additional.jpg id_00007983 query',
f'{osp.join("img", "02_1_front.jpg")} id_00000002 train',
f'{osp.join("img", "02_2_side.jpg")} id_00000002 train',
f'{osp.join("img", "12_3_back.jpg")} id_00007982 gallery',
f'{osp.join("img", "12_7_addition.jpg")} id_00007982 gallery',
f'{osp.join("img", "13_1_front.jpg")} id_00007982 query',
f'{osp.join("img", "13_2_side.jpg")} id_00007983 gallery',
f'{osp.join("img", "13_3_back.jpg")} id_00007983 query ',
f'{osp.join("img", "13_7_additional.jpg")} id_00007983 query',
]))
def test_initialize(self):
@ -1004,8 +1004,9 @@ class TestInShop(TestBaseDataset):
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 2)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'Img/img', '02_1_front.jpg'))
self.assertEqual(
data_info['img_path'],
os.path.join(self.root, 'Img', 'img', '02_1_front.jpg'))
self.assertEqual(data_info['gt_label'], 0)
# Test with mode=query
@ -1013,8 +1014,9 @@ class TestInShop(TestBaseDataset):
dataset = dataset_class(**cfg)
self.assertEqual(len(dataset), 3)
data_info = dataset[0]
self.assertEqual(data_info['img_path'],
os.path.join(self.root, 'Img/img', '13_1_front.jpg'))
self.assertEqual(
data_info['img_path'],
os.path.join(self.root, 'Img', 'img', '13_1_front.jpg'))
self.assertEqual(data_info['gt_label'], [0, 1])
# Test with mode=gallery
@ -1024,7 +1026,7 @@ class TestInShop(TestBaseDataset):
data_info = dataset[0]
self.assertEqual(
data_info['img_path'],
os.path.join(self.root, self.root, 'Img/img', '12_3_back.jpg'))
os.path.join(self.root, 'Img', 'img', '12_3_back.jpg'))
self.assertEqual(data_info['sample_idx'], 0)
def test_extra_repr(self):

View File

@ -0,0 +1,120 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmcls.evaluation.metrics import RetrievalRecall
from mmcls.registry import METRICS
from mmcls.structures import ClsDataSample
class TestRetrievalRecall(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
pred = [
ClsDataSample().set_pred_score(i).set_gt_label(k).to_dict()
for i, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
torch.tensor([0.4, 0.5, 0.1]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
], [[0], [0, 1], [1], [2], [1, 2], [0, 1]])
]
# Test with score (use score instead of label if score exists)
metric = METRICS.build(dict(type='RetrievalRecall', topk=1))
metric.process(None, pred)
recall = metric.evaluate(6)
self.assertIsInstance(recall, dict)
self.assertAlmostEqual(
recall['retrieval/Recall@1'], 5 / 6 * 100, places=4)
# Test with invalid topk
with self.assertRaisesRegex(RuntimeError, 'selected index k'):
metric = METRICS.build(dict(type='RetrievalRecall', topk=10))
metric.process(None, pred)
metric.evaluate(6)
with self.assertRaisesRegex(ValueError, '`topk` must be a'):
METRICS.build(dict(type='RetrievalRecall', topk=-1))
# Test initialization
metric = METRICS.build(dict(type='RetrievalRecall', topk=5))
self.assertEqual(metric.topk, (5, ))
# Test initialization
metric = METRICS.build(dict(type='RetrievalRecall', topk=(1, 2, 5)))
self.assertEqual(metric.topk, (1, 2, 5))
def test_calculate(self):
"""Test using the metric from static method."""
# seq of indices format
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
y_pred = [np.arange(10)] * 2
# test with average is 'macro'
recall_score = RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
expect_recall = 50.
self.assertEqual(recall_score[0].item(), expect_recall)
# test with tensor input
y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=1)
expect_recall = 50.
self.assertEqual(recall_score[0].item(), expect_recall)
# test with topk is 5
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=2)
expect_recall = 100.
self.assertEqual(recall_score[0].item(), expect_recall)
# test with topk is (1, 5)
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=(1, 5))
expect_recalls = [50., 100.]
self.assertEqual(len(recall_score), len(expect_recalls))
for i in range(len(expect_recalls)):
self.assertEqual(recall_score[i].item(), expect_recalls[i])
# Test with invalid pred
y_pred = dict()
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
with self.assertRaisesRegex(AssertionError, '`pred` must be Seq'):
RetrievalRecall.calculate(y_pred, y_true, True, True)
# Test with invalid target
y_true = dict()
y_pred = [np.arange(10)] * 2
with self.assertRaisesRegex(AssertionError, '`target` must be Seq'):
RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
# Test with different length `pred` with `target`
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
y_pred = [np.arange(10)] * 3
with self.assertRaisesRegex(AssertionError, 'Length of `pred`'):
RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
# Test with invalid pred
y_true = [[0, 2, 5, 8, 9], dict()]
y_pred = [np.arange(10)] * 2
with self.assertRaisesRegex(AssertionError, '`target` should be'):
RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
# Test with invalid target
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
y_pred = [np.arange(10), dict()]
with self.assertRaisesRegex(AssertionError, '`pred` should be'):
RetrievalRecall.calculate(
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)