[Feature] Support RetrieverRecall metric & Add ArcFace config (#1316)
* rebase * add ap metric * fix mlti-gpu bug in retrevel * rebase * rebase * add training cfgs and update readme.md * fix bugs(cannot load vecs in dist and diff test-val recall\) * update configs and readme * fix ut * fix doc * rebase * fix rebase conflicts * fix rebase error * fix UT error * fix docs * fix typo --------- Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>pull/1287/head^2
parent
1c1273abca
commit
841256b630
|
@ -0,0 +1,61 @@
|
|||
# dataset settings
|
||||
dataset_type = 'InShop'
|
||||
data_preprocessor = dict(
|
||||
num_classes=3997,
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=512),
|
||||
dict(type='RandomCrop', crop_size=448),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=512),
|
||||
dict(type='CenterCrop', crop_size=448),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/inshop',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
query_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/inshop',
|
||||
split='query',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
|
||||
gallery_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/inshop',
|
||||
split='gallery',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_dataloader = query_dataloader
|
||||
val_evaluator = dict(type='RetrievalRecall', topk=1)
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,32 @@
|
|||
# ArcFace
|
||||
|
||||
> [ArcFace: Additive Angular Margin Loss for Deep Face Recognition](https://arxiv.org/abs/1801.07698)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Recently, a popular line of research in face recognition is adopting margins in the well-established softmax loss function to maximize class separability. In this paper, we first introduce an Additive Angular Margin Loss (ArcFace), which not only has a clear geometric interpretation but also significantly enhances the discriminative power. Since ArcFace is susceptible to the massive label noise, we further propose sub-center ArcFace, in which each class contains K sub-centers and training samples only need to be close to any of the K positive sub-centers. Sub-center ArcFace encourages one dominant sub-class that contains the majority of clean faces and non-dominant sub-classes that include hard or noisy faces. Based on this self-propelled isolation, we boost the performance through automatically purifying raw web faces under massive real-world noise. Besides discriminative feature embedding, we also explore the inverse problem, mapping feature vectors to face images. Without training any additional generator or discriminator, the pre-trained ArcFace model can generate identity-preserved face images for both subjects inside and outside the training data only by using the network gradient and Batch Normalization (BN) priors. Extensive experiments demonstrate that ArcFace can enhance the discriminative feature embedding as well as strengthen the generative face synthesis.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/24734142/212606212-8ffc3cd2-dbc1-4abf-8924-22167f3f6e34.png" width="80%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### InShop
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Recall@1 | Config | Download |
|
||||
| :------------: | :------------------------------------------------: | :-------: | :------: | :---: | :-----: | :-----: | :-----------------------------------------------: | :-------------------------------------------------: |
|
||||
| Resnet50-ArcFace | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 31.69 | 16.48 | 90.18 | [config](./resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{deng2018arcface,
|
||||
title={ArcFace: Additive Angular Margin Loss for Deep Face Recognition},
|
||||
author={Deng, Jiankang and Guo, Jia and Niannan, Xue and Zafeiriou, Stefanos},
|
||||
booktitle={CVPR},
|
||||
year={2019}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,27 @@
|
|||
Collections:
|
||||
- Name: ArcFace
|
||||
Metadata:
|
||||
Training Data: InShop
|
||||
Architecture:
|
||||
- Additive Angular Margin Loss
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/1801.07698
|
||||
Title: 'ArcFace: Additive Angular Margin Loss for Deep Face Recognition'
|
||||
README: configs/arcface/README.md
|
||||
Code:
|
||||
Version: v1.0.0rc3
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc3/mmcls/models/heads/margin_head.py
|
||||
|
||||
Models:
|
||||
- Name: resnet50-arcface_inshop
|
||||
Metadata:
|
||||
FLOPs: 16571226112
|
||||
Parameters: 31693888
|
||||
In Collection: ArcFace
|
||||
Results:
|
||||
- Dataset: InShop
|
||||
Metrics:
|
||||
Recall@1: 90.18
|
||||
Task: Metric Learning
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth
|
||||
Config: configs/arcface/resnet50-arcface_8xb32_inshop.py
|
|
@ -0,0 +1,71 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/inshop_bs32_448.py',
|
||||
'../_base_/schedules/cub_bs64.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
pretrained = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth' # noqa
|
||||
model = dict(
|
||||
type='ImageToImageRetriever',
|
||||
image_encoder=[
|
||||
dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
init_cfg=dict(
|
||||
type='Pretrained', checkpoint=pretrained, prefix='backbone')),
|
||||
dict(type='GlobalAveragePooling'),
|
||||
],
|
||||
head=dict(
|
||||
type='ArcFaceClsHead',
|
||||
num_classes=3997,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
init_cfg=None),
|
||||
prototype={{_base_.gallery_dataloader}})
|
||||
|
||||
# runtime settings
|
||||
default_hooks = dict(
|
||||
# log every 20 intervals
|
||||
logger=dict(type='LoggerHook', interval=20),
|
||||
# save last three checkpoints
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook',
|
||||
save_best='auto',
|
||||
interval=1,
|
||||
max_keep_ckpts=3,
|
||||
rule='greater'))
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0005, nesterov=True))
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
# warm up learning rate scheduler
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.01,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
# update by iter
|
||||
convert_to_iter_based=True),
|
||||
# main learning rate scheduler
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=45,
|
||||
by_epoch=True,
|
||||
begin=5,
|
||||
end=50,
|
||||
)
|
||||
]
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=50, val_interval=1)
|
||||
|
||||
auto_scale_lr = dict(enable=True, base_batch_size=256)
|
||||
|
||||
custom_hooks = [
|
||||
dict(type='PrepareProtoBeforeValLoopHook'),
|
||||
dict(type='SyncBuffersHook')
|
||||
]
|
|
@ -33,3 +33,13 @@ Multi Label Metric
|
|||
MultiLabelMetric
|
||||
VOCAveragePrecision
|
||||
VOCMultiLabelMetric
|
||||
|
||||
Retrieval Metric
|
||||
----------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
RetrievalRecall
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine import get_file_backend, list_from_file
|
||||
|
||||
from mmcls.datasets.base_dataset import BaseDataset
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -14,11 +14,13 @@ class InShop(BaseDataset):
|
|||
(In-shop Clothes Retrieval Benchmark -> Img -> img.zip,
|
||||
Eval/list_eval_partition.txt), and organize them as follows way: ::
|
||||
|
||||
In-shop dataset directory: ::
|
||||
|
||||
In-shop Clothes Retrieval Benchmark (data_root)/
|
||||
├── Eval /
|
||||
│ └── list_eval_partition.txt (ann_file)
|
||||
├── Img
|
||||
│ └── img/ (img_prefix)
|
||||
├── Img (img_prefix)
|
||||
│ └── img/
|
||||
├── README.txt
|
||||
└── .....
|
||||
|
||||
|
@ -27,7 +29,7 @@ class InShop(BaseDataset):
|
|||
split (str): Choose from 'train', 'query' and 'gallery'.
|
||||
Defaults to 'train'.
|
||||
data_prefix (str | dict): Prefix for training data.
|
||||
Defaults to 'Img/img'.
|
||||
Defaults to 'Img'.
|
||||
ann_file (str): Annotation file path, path relative to
|
||||
``data_root``. Defaults to 'Eval/list_eval_partition.txt'.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
|
@ -66,7 +68,7 @@ class InShop(BaseDataset):
|
|||
def __init__(self,
|
||||
data_root: str,
|
||||
split: str = 'train',
|
||||
data_prefix: str = 'Img/img',
|
||||
data_prefix: str = 'Img',
|
||||
ann_file: str = 'Eval/list_eval_partition.txt',
|
||||
**kwargs):
|
||||
|
||||
|
@ -149,9 +151,6 @@ class InShop(BaseDataset):
|
|||
"""
|
||||
data_info = self._process_annotations()
|
||||
data_list = data_info['data_list']
|
||||
for data in data_list:
|
||||
data['img_path'] = self.backend.join_path(self.data_root,
|
||||
data['img_path'])
|
||||
return data_list
|
||||
|
||||
def extra_repr(self):
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
from .multi_task import MultiTasksMetric
|
||||
from .retrieval import RetrievalRecall
|
||||
from .single_label import Accuracy, SingleLabelMetric
|
||||
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
|
||||
|
||||
__all__ = [
|
||||
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
|
||||
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric'
|
||||
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
|
||||
'RetrievalRecall'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.structures import LabelData
|
||||
from mmengine.utils import is_seq_of
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
from .single_label import to_tensor
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class RetrievalRecall(BaseMetric):
|
||||
r"""Recall evaluation metric for image retrieval.
|
||||
|
||||
Args:
|
||||
topk (int | Sequence[int]): If the ground truth label matches one of
|
||||
the best **k** predictions, the sample will be regard as a positive
|
||||
prediction. If the parameter is a tuple, all of top-k recall will
|
||||
be calculated and outputted together. Defaults to 1.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
|
||||
Examples:
|
||||
Use in the code:
|
||||
|
||||
>>> import torch
|
||||
>>> from mmcls.evaluation import RetrievalRecall
|
||||
>>> # -------------------- The Basic Usage --------------------
|
||||
>>> y_pred = [[0], [1], [2], [3]]
|
||||
>>> y_true = [[0, 1], [2], [1], [0, 3]]
|
||||
>>> RetrievalRecall.calculate(
|
||||
>>> y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
||||
[tensor([50.])]
|
||||
>>> # Calculate the recall@1 and recall@5 for non-indices input.
|
||||
>>> y_score = torch.rand((1000, 10))
|
||||
>>> import torch.nn.functional as F
|
||||
>>> y_true = F.one_hot(torch.arange(0, 1000) % 10, num_classes=10)
|
||||
>>> RetrievalRecall.calculate(y_score, y_true, topk=(1, 5))
|
||||
[tensor(9.3000), tensor(48.4000)]
|
||||
>>>
|
||||
>>> # ------------------- Use with Evalutor -------------------
|
||||
>>> from mmcls.structures import ClsDataSample
|
||||
>>> from mmengine.evaluator import Evaluator
|
||||
>>> data_samples = [
|
||||
... ClsDataSample().set_gt_label([0, 1]).set_pred_score(
|
||||
... torch.rand(10))
|
||||
... for i in range(1000)
|
||||
... ]
|
||||
>>> evaluator = Evaluator(metrics=RetrievalRecall(topk=(1, 5)))
|
||||
>>> evaluator.process(data_samples)
|
||||
>>> evaluator.evaluate(1000)
|
||||
{'retrieval/Recall@1': 20.700000762939453,
|
||||
'retrieval/Recall@5': 78.5999984741211}
|
||||
|
||||
Use in OpenMMLab configs:
|
||||
|
||||
.. code:: python
|
||||
|
||||
val/test_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
|
||||
"""
|
||||
default_prefix: Optional[str] = 'retrieval'
|
||||
|
||||
def __init__(self,
|
||||
topk: Union[int, Sequence[int]],
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
topk = (topk, ) if isinstance(topk, int) else topk
|
||||
|
||||
for k in topk:
|
||||
if k <= 0:
|
||||
raise ValueError('`topk` must be a ingter larger than 0 '
|
||||
'or seq of ingter larger than 0.')
|
||||
|
||||
self.topk = topk
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
data_samples: Sequence[dict]):
|
||||
"""Process one batch of data and predictions.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
|
||||
pred = pred_label['score'].clone()
|
||||
if 'score' in gt_label:
|
||||
target = gt_label['score'].clone()
|
||||
else:
|
||||
num_classes = pred_label['score'].size()[-1]
|
||||
target = LabelData.label_to_onehot(gt_label['label'],
|
||||
num_classes)
|
||||
|
||||
# Because the retrieval output logit vector will be much larger
|
||||
# compared to the normal classification, to save resources, the
|
||||
# evaluation results are computed each batch here and then reduce
|
||||
# all results at the end.
|
||||
result = RetrievalRecall.calculate(
|
||||
pred.unsqueeze(0), target.unsqueeze(0), topk=self.topk)
|
||||
self.results.append(result)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
result_metrics = dict()
|
||||
for i, k in enumerate(self.topk):
|
||||
recall_at_k = sum([r[i].item() for r in results]) / len(results)
|
||||
result_metrics[f'Recall@{k}'] = recall_at_k
|
||||
|
||||
return result_metrics
|
||||
|
||||
@staticmethod
|
||||
def calculate(pred: Union[np.ndarray, torch.Tensor],
|
||||
target: Union[np.ndarray, torch.Tensor],
|
||||
topk: Union[int, Sequence[int]],
|
||||
pred_indices: (bool) = False,
|
||||
target_indices: (bool) = False) -> float:
|
||||
"""Calculate the average recall.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor | np.ndarray | Sequence): The prediction
|
||||
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
|
||||
shape ``(N, M)`` or a sequence of index/onehot
|
||||
format labels.
|
||||
target (torch.Tensor | np.ndarray | Sequence): The prediction
|
||||
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
|
||||
shape ``(N, M)`` or a sequence of index/onehot
|
||||
format labels.
|
||||
topk (int, Sequence[int]): Predictions with the k-th highest
|
||||
scores are considered as positive.
|
||||
pred_indices (bool): Whether the ``pred`` is a sequence of
|
||||
category index labels. Defaults to False.
|
||||
target_indices (bool): Whether the ``target`` is a sequence of
|
||||
category index labels. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[float]: the average recalls.
|
||||
"""
|
||||
topk = (topk, ) if isinstance(topk, int) else topk
|
||||
for k in topk:
|
||||
if k <= 0:
|
||||
raise ValueError('`topk` must be a ingter larger than 0 '
|
||||
'or seq of ingter larger than 0.')
|
||||
|
||||
max_keep = max(topk)
|
||||
pred = _format_pred(pred, max_keep, pred_indices)
|
||||
target = _format_target(target, target_indices)
|
||||
|
||||
assert len(pred) == len(target), (
|
||||
f'Length of `pred`({len(pred)}) and `target` ({len(target)}) '
|
||||
f'must be the same.')
|
||||
|
||||
num_samples = len(pred)
|
||||
results = []
|
||||
for k in topk:
|
||||
recalls = torch.zeros(num_samples)
|
||||
for i, (sample_pred,
|
||||
sample_target) in enumerate(zip(pred, target)):
|
||||
sample_pred = np.array(to_tensor(sample_pred).cpu())
|
||||
sample_target = np.array(to_tensor(sample_target).cpu())
|
||||
recalls[i] = int(np.in1d(sample_pred[:k], sample_target).max())
|
||||
results.append(recalls.mean() * 100)
|
||||
return results
|
||||
|
||||
|
||||
def _format_pred(label, topk=None, is_indices=False):
|
||||
"""format various label to List[indices]."""
|
||||
if is_indices:
|
||||
assert isinstance(label, Sequence), \
|
||||
'`pred` must be Sequence of indices when' \
|
||||
f' `pred_indices` set to True, but get {type(label)}'
|
||||
for i, sample_pred in enumerate(label):
|
||||
assert is_seq_of(sample_pred, int) or isinstance(
|
||||
sample_pred, (np.ndarray, torch.Tensor)), \
|
||||
'`pred` should be Sequence of indices when `pred_indices`' \
|
||||
f'set to True. but pred[{i}] is {sample_pred}'
|
||||
if topk:
|
||||
label[i] = sample_pred[:min(topk, len(sample_pred))]
|
||||
return label
|
||||
if isinstance(label, np.ndarray):
|
||||
label = torch.from_numpy(label)
|
||||
elif not isinstance(label, torch.Tensor):
|
||||
raise TypeError(f'The pred must be type of torch.tensor, '
|
||||
f'np.ndarray or Sequence but get {type(label)}.')
|
||||
topk = topk if topk else label.size()[-1]
|
||||
_, indices = label.topk(topk)
|
||||
return indices
|
||||
|
||||
|
||||
def _format_target(label, is_indices=False):
|
||||
"""format various label to List[indices]."""
|
||||
if is_indices:
|
||||
assert isinstance(label, Sequence), \
|
||||
'`target` must be Sequence of indices when' \
|
||||
f' `target_indices` set to True, but get {type(label)}'
|
||||
for i, sample_gt in enumerate(label):
|
||||
assert is_seq_of(sample_gt, int) or isinstance(
|
||||
sample_gt, (np.ndarray, torch.Tensor)), \
|
||||
'`target` should be Sequence of indices when ' \
|
||||
f'`target_indices` set to True. but target[{i}] is {sample_gt}'
|
||||
return label
|
||||
|
||||
if isinstance(label, np.ndarray):
|
||||
label = torch.from_numpy(label)
|
||||
elif isinstance(label, Sequence) and not mmengine.is_str(label):
|
||||
label = torch.tensor(label)
|
||||
elif not isinstance(label, torch.Tensor):
|
||||
raise TypeError(f'The pred must be type of torch.tensor, '
|
||||
f'np.ndarray or Sequence but get {type(label)}.')
|
||||
|
||||
indices = [LabelData.onehot_to_label(sample_gt) for sample_gt in label]
|
||||
return indices
|
|
@ -266,6 +266,16 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
dist.all_reduce(prototype_vecs)
|
||||
return prototype_vecs
|
||||
|
||||
def _get_prototype_vecs_from_path(self, proto_path):
|
||||
"""get prototype_vecs from prototype path."""
|
||||
data = [None]
|
||||
if dist.is_main_process():
|
||||
data[0] = torch.load(proto_path)
|
||||
dist.broadcast_object_list(data, src=0)
|
||||
prototype_vecs = data[0]
|
||||
assert prototype_vecs is not None
|
||||
return prototype_vecs
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_prototype(self):
|
||||
"""Used in meta testing. This function will be called before the meta
|
||||
|
@ -281,7 +291,7 @@ class ImageToImageRetriever(BaseRetriever):
|
|||
if isinstance(self.prototype, torch.Tensor):
|
||||
prototype_vecs = self.prototype
|
||||
elif isinstance(self.prototype, str):
|
||||
prototype_vecs = torch.load(self.prototype)
|
||||
prototype_vecs = self._get_prototype_vecs_from_path(self.prototype)
|
||||
elif isinstance(self.prototype, (dict, DataLoader)):
|
||||
loader = Runner.build_dataloader(self.prototype)
|
||||
prototype_vecs = self._get_prototype_vecs_from_dataloader(loader)
|
||||
|
|
|
@ -51,3 +51,4 @@ Import:
|
|||
- configs/convnext_v2/metafile.yml
|
||||
- configs/levit/metafile.yml
|
||||
- configs/vig/metafile.yml
|
||||
- configs/arcface/metafile.yml
|
||||
|
|
|
@ -957,14 +957,14 @@ class TestInShop(TestBaseDataset):
|
|||
f.write('\n'.join([
|
||||
'8',
|
||||
'image_name item_id evaluation_status',
|
||||
'02_1_front.jpg id_00000002 train',
|
||||
'02_2_side.jpg id_00000002 train',
|
||||
'12_3_back.jpg id_00007982 gallery',
|
||||
'12_7_additional.jpg id_00007982 gallery',
|
||||
'13_1_front.jpg id_00007982 query',
|
||||
'13_2_side.jpg id_00007983 gallery',
|
||||
'13_3_back.jpg id_00007983 query ',
|
||||
'13_7_additional.jpg id_00007983 query',
|
||||
f'{osp.join("img", "02_1_front.jpg")} id_00000002 train',
|
||||
f'{osp.join("img", "02_2_side.jpg")} id_00000002 train',
|
||||
f'{osp.join("img", "12_3_back.jpg")} id_00007982 gallery',
|
||||
f'{osp.join("img", "12_7_addition.jpg")} id_00007982 gallery',
|
||||
f'{osp.join("img", "13_1_front.jpg")} id_00007982 query',
|
||||
f'{osp.join("img", "13_2_side.jpg")} id_00007983 gallery',
|
||||
f'{osp.join("img", "13_3_back.jpg")} id_00007983 query ',
|
||||
f'{osp.join("img", "13_7_additional.jpg")} id_00007983 query',
|
||||
]))
|
||||
|
||||
def test_initialize(self):
|
||||
|
@ -1004,8 +1004,9 @@ class TestInShop(TestBaseDataset):
|
|||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 2)
|
||||
data_info = dataset[0]
|
||||
self.assertEqual(data_info['img_path'],
|
||||
os.path.join(self.root, 'Img/img', '02_1_front.jpg'))
|
||||
self.assertEqual(
|
||||
data_info['img_path'],
|
||||
os.path.join(self.root, 'Img', 'img', '02_1_front.jpg'))
|
||||
self.assertEqual(data_info['gt_label'], 0)
|
||||
|
||||
# Test with mode=query
|
||||
|
@ -1013,8 +1014,9 @@ class TestInShop(TestBaseDataset):
|
|||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 3)
|
||||
data_info = dataset[0]
|
||||
self.assertEqual(data_info['img_path'],
|
||||
os.path.join(self.root, 'Img/img', '13_1_front.jpg'))
|
||||
self.assertEqual(
|
||||
data_info['img_path'],
|
||||
os.path.join(self.root, 'Img', 'img', '13_1_front.jpg'))
|
||||
self.assertEqual(data_info['gt_label'], [0, 1])
|
||||
|
||||
# Test with mode=gallery
|
||||
|
@ -1024,7 +1026,7 @@ class TestInShop(TestBaseDataset):
|
|||
data_info = dataset[0]
|
||||
self.assertEqual(
|
||||
data_info['img_path'],
|
||||
os.path.join(self.root, self.root, 'Img/img', '12_3_back.jpg'))
|
||||
os.path.join(self.root, 'Img', 'img', '12_3_back.jpg'))
|
||||
self.assertEqual(data_info['sample_idx'], 0)
|
||||
|
||||
def test_extra_repr(self):
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmcls.evaluation.metrics import RetrievalRecall
|
||||
from mmcls.registry import METRICS
|
||||
from mmcls.structures import ClsDataSample
|
||||
|
||||
|
||||
class TestRetrievalRecall(TestCase):
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
pred = [
|
||||
ClsDataSample().set_pred_score(i).set_gt_label(k).to_dict()
|
||||
for i, k in zip([
|
||||
torch.tensor([0.7, 0.0, 0.3]),
|
||||
torch.tensor([0.5, 0.2, 0.3]),
|
||||
torch.tensor([0.4, 0.5, 0.1]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
torch.tensor([0.0, 0.0, 1.0]),
|
||||
], [[0], [0, 1], [1], [2], [1, 2], [0, 1]])
|
||||
]
|
||||
|
||||
# Test with score (use score instead of label if score exists)
|
||||
metric = METRICS.build(dict(type='RetrievalRecall', topk=1))
|
||||
metric.process(None, pred)
|
||||
recall = metric.evaluate(6)
|
||||
self.assertIsInstance(recall, dict)
|
||||
self.assertAlmostEqual(
|
||||
recall['retrieval/Recall@1'], 5 / 6 * 100, places=4)
|
||||
|
||||
# Test with invalid topk
|
||||
with self.assertRaisesRegex(RuntimeError, 'selected index k'):
|
||||
metric = METRICS.build(dict(type='RetrievalRecall', topk=10))
|
||||
metric.process(None, pred)
|
||||
metric.evaluate(6)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, '`topk` must be a'):
|
||||
METRICS.build(dict(type='RetrievalRecall', topk=-1))
|
||||
|
||||
# Test initialization
|
||||
metric = METRICS.build(dict(type='RetrievalRecall', topk=5))
|
||||
self.assertEqual(metric.topk, (5, ))
|
||||
|
||||
# Test initialization
|
||||
metric = METRICS.build(dict(type='RetrievalRecall', topk=(1, 2, 5)))
|
||||
self.assertEqual(metric.topk, (1, 2, 5))
|
||||
|
||||
def test_calculate(self):
|
||||
"""Test using the metric from static method."""
|
||||
|
||||
# seq of indices format
|
||||
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
||||
y_pred = [np.arange(10)] * 2
|
||||
|
||||
# test with average is 'macro'
|
||||
recall_score = RetrievalRecall.calculate(
|
||||
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
||||
expect_recall = 50.
|
||||
self.assertEqual(recall_score[0].item(), expect_recall)
|
||||
|
||||
# test with tensor input
|
||||
y_true = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1],
|
||||
[0, 1, 0, 0, 1, 0, 1, 0, 0, 0]])
|
||||
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
||||
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=1)
|
||||
expect_recall = 50.
|
||||
self.assertEqual(recall_score[0].item(), expect_recall)
|
||||
|
||||
# test with topk is 5
|
||||
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
||||
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=2)
|
||||
expect_recall = 100.
|
||||
self.assertEqual(recall_score[0].item(), expect_recall)
|
||||
|
||||
# test with topk is (1, 5)
|
||||
y_pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
|
||||
recall_score = RetrievalRecall.calculate(y_pred, y_true, topk=(1, 5))
|
||||
expect_recalls = [50., 100.]
|
||||
self.assertEqual(len(recall_score), len(expect_recalls))
|
||||
for i in range(len(expect_recalls)):
|
||||
self.assertEqual(recall_score[i].item(), expect_recalls[i])
|
||||
|
||||
# Test with invalid pred
|
||||
y_pred = dict()
|
||||
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
||||
with self.assertRaisesRegex(AssertionError, '`pred` must be Seq'):
|
||||
RetrievalRecall.calculate(y_pred, y_true, True, True)
|
||||
|
||||
# Test with invalid target
|
||||
y_true = dict()
|
||||
y_pred = [np.arange(10)] * 2
|
||||
with self.assertRaisesRegex(AssertionError, '`target` must be Seq'):
|
||||
RetrievalRecall.calculate(
|
||||
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
||||
|
||||
# Test with different length `pred` with `target`
|
||||
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
||||
y_pred = [np.arange(10)] * 3
|
||||
with self.assertRaisesRegex(AssertionError, 'Length of `pred`'):
|
||||
RetrievalRecall.calculate(
|
||||
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
||||
|
||||
# Test with invalid pred
|
||||
y_true = [[0, 2, 5, 8, 9], dict()]
|
||||
y_pred = [np.arange(10)] * 2
|
||||
with self.assertRaisesRegex(AssertionError, '`target` should be'):
|
||||
RetrievalRecall.calculate(
|
||||
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
||||
|
||||
# Test with invalid target
|
||||
y_true = [[0, 2, 5, 8, 9], [1, 4, 6]]
|
||||
y_pred = [np.arange(10), dict()]
|
||||
with self.assertRaisesRegex(AssertionError, '`pred` should be'):
|
||||
RetrievalRecall.calculate(
|
||||
y_pred, y_true, topk=1, pred_indices=True, target_indices=True)
|
Loading…
Reference in New Issue