312 lines
12 KiB
Python
312 lines
12 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Callable, List, Optional, Union
|
|
|
|
import mmengine.dist as dist
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.runner import Runner
|
|
from torch.utils.data import DataLoader
|
|
|
|
from mmcls.registry import MODELS
|
|
from mmcls.structures import ClsDataSample
|
|
from mmcls.utils import track_on_main_process
|
|
from .base import BaseRetriever
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ImageToImageRetriever(BaseRetriever):
|
|
"""Image To Image Retriever for supervised retrieval task.
|
|
|
|
Args:
|
|
image_encoder (Union[dict, List[dict]]): Encoder for extracting
|
|
features.
|
|
prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be
|
|
retrieved. The following four types are supported.
|
|
|
|
- DataLoader: The original dataloader serves as the prototype.
|
|
- dict: The configuration to construct Dataloader.
|
|
- str: The path of the saved vector.
|
|
- torch.Tensor: The saved tensor whose dimension should be dim.
|
|
|
|
head (dict, optional): The head module to calculate loss from
|
|
processed features. See :mod:`mmcls.models.heads`. Notice
|
|
that if the head is not set, `loss` method cannot be used.
|
|
Defaults to None.
|
|
similarity_fn (Union[str, Callable]): The way that the similarity
|
|
is calculated. If `similarity` is callable, it is used directly
|
|
as the measure function. If it is a string, the appropriate
|
|
method will be used. The larger the calculated value, the
|
|
greater the similarity. Defaults to "cosine_similarity".
|
|
train_cfg (dict, optional): The training setting. The acceptable
|
|
fields are:
|
|
|
|
- augments (List[dict]): The batch augmentation methods to use.
|
|
More details can be found in :mod:`mmcls.model.utils.augment`.
|
|
|
|
Defaults to None.
|
|
data_preprocessor (dict, optional): The config for preprocessing input
|
|
data. If None or no specified type, it will use
|
|
"ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for
|
|
more details. Defaults to None.
|
|
topk (int): Return the topk of the retrieval result. `-1` means
|
|
return all. Defaults to -1.
|
|
init_cfg (dict, optional): the config to control the initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
image_encoder: Union[dict, List[dict]],
|
|
prototype: Union[DataLoader, dict, str, torch.Tensor],
|
|
head: Optional[dict] = None,
|
|
similarity_fn: Union[str, Callable] = 'cosine_similarity',
|
|
train_cfg: Optional[dict] = None,
|
|
data_preprocessor: Optional[dict] = None,
|
|
topk: int = -1,
|
|
init_cfg: Optional[dict] = None):
|
|
|
|
if data_preprocessor is None:
|
|
data_preprocessor = {}
|
|
# The build process is in MMEngine, so we need to add scope here.
|
|
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
|
|
|
|
if train_cfg is not None and 'augments' in train_cfg:
|
|
# Set batch augmentations by `train_cfg`
|
|
data_preprocessor['batch_augments'] = train_cfg
|
|
|
|
super(ImageToImageRetriever, self).__init__(
|
|
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
|
|
|
if not isinstance(image_encoder, nn.Module):
|
|
image_encoder = MODELS.build(image_encoder)
|
|
if head is not None and not isinstance(head, nn.Module):
|
|
head = MODELS.build(head)
|
|
|
|
self.image_encoder = image_encoder
|
|
self.head = head
|
|
|
|
self.similarity = similarity_fn
|
|
|
|
assert isinstance(prototype, (str, torch.Tensor, dict, DataLoader)), (
|
|
'The `prototype` in `ImageToImageRetriever` must be a path, '
|
|
'a torch.Tensor, a dataloader or a dataloader dict format config.')
|
|
self.prototype = prototype
|
|
self.prototype_inited = False
|
|
self.topk = topk
|
|
|
|
@property
|
|
def similarity_fn(self):
|
|
"""Returns a function that calculates the similarity."""
|
|
# If self.similarity_way is callable, return it directly
|
|
if isinstance(self.similarity, Callable):
|
|
return self.similarity
|
|
|
|
if self.similarity == 'cosine_similarity':
|
|
# a is a tensor with shape (N, C)
|
|
# b is a tensor with shape (M, C)
|
|
# "cosine_similarity" will get the matrix of similarity
|
|
# with shape (N, M).
|
|
# The higher the score is, the more similar is
|
|
return lambda a, b: torch.cosine_similarity(
|
|
a.unsqueeze(1), b.unsqueeze(0), dim=-1)
|
|
else:
|
|
raise RuntimeError(f'Invalid function "{self.similarity_fn}".')
|
|
|
|
def forward(self,
|
|
inputs: torch.Tensor,
|
|
data_samples: Optional[List[ClsDataSample]] = None,
|
|
mode: str = 'tensor'):
|
|
"""The unified entry for a forward process in both training and test.
|
|
|
|
The method should accept three modes: "tensor", "predict" and "loss":
|
|
|
|
- "tensor": Forward the whole network and return tensor without any
|
|
post-processing, same as a common nn.Module.
|
|
- "predict": Forward and return the predictions, which are fully
|
|
processed to a list of :obj:`ClsDataSample`.
|
|
- "loss": Forward and return a dict of losses according to the given
|
|
inputs and data samples.
|
|
|
|
Note that this method doesn't handle neither back propagation nor
|
|
optimizer updating, which are done in the :meth:`train_step`.
|
|
|
|
Args:
|
|
inputs (torch.Tensor, tuple): The input tensor with shape
|
|
(N, C, ...) in general.
|
|
data_samples (List[ClsDataSample], optional): The annotation
|
|
data of every samples. It's required if ``mode="loss"``.
|
|
Defaults to None.
|
|
mode (str): Return what kind of value. Defaults to 'tensor'.
|
|
|
|
Returns:
|
|
The return type depends on ``mode``.
|
|
|
|
- If ``mode="tensor"``, return a tensor.
|
|
- If ``mode="predict"``, return a list of
|
|
:obj:`mmcls.structures.ClsDataSample`.
|
|
- If ``mode="loss"``, return a dict of tensor.
|
|
"""
|
|
if mode == 'tensor':
|
|
return self.extract_feat(inputs)
|
|
elif mode == 'loss':
|
|
return self.loss(inputs, data_samples)
|
|
elif mode == 'predict':
|
|
return self.predict(inputs, data_samples)
|
|
else:
|
|
raise RuntimeError(f'Invalid mode "{mode}".')
|
|
|
|
def extract_feat(self, inputs):
|
|
"""Extract features from the input tensor with shape (N, C, ...).
|
|
|
|
Args:
|
|
inputs (Tensor): A batch of inputs. The shape of it should be
|
|
``(num_samples, num_channels, *img_shape)``.
|
|
Returns:
|
|
Tensor: The output of encoder.
|
|
"""
|
|
|
|
feat = self.image_encoder(inputs)
|
|
return feat
|
|
|
|
def loss(self, inputs: torch.Tensor,
|
|
data_samples: List[ClsDataSample]) -> dict:
|
|
"""Calculate losses from a batch of inputs and data samples.
|
|
|
|
Args:
|
|
inputs (torch.Tensor): The input tensor with shape
|
|
(N, C, ...) in general.
|
|
data_samples (List[ClsDataSample]): The annotation data of
|
|
every samples.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: a dictionary of loss components
|
|
"""
|
|
feats = self.extract_feat(inputs)
|
|
return self.head.loss(feats, data_samples)
|
|
|
|
def matching(self, inputs: torch.Tensor):
|
|
"""Compare the prototype and calculate the similarity.
|
|
|
|
Args:
|
|
inputs (torch.Tensor): The input tensor with shape (N, C).
|
|
Returns:
|
|
dict: a dictionary of score and prediction label based on fn.
|
|
"""
|
|
sim = self.similarity_fn(inputs, self.prototype_vecs)
|
|
sorted_sim, indices = torch.sort(sim, descending=True, dim=-1)
|
|
predictions = dict(
|
|
score=sim, pred_label=indices, pred_score=sorted_sim)
|
|
return predictions
|
|
|
|
def predict(self,
|
|
inputs: tuple,
|
|
data_samples: Optional[List[ClsDataSample]] = None,
|
|
**kwargs) -> List[ClsDataSample]:
|
|
"""Predict results from the extracted features.
|
|
|
|
Args:
|
|
inputs (tuple): The features extracted from the backbone.
|
|
data_samples (List[ClsDataSample], optional): The annotation
|
|
data of every samples. Defaults to None.
|
|
**kwargs: Other keyword arguments accepted by the ``predict``
|
|
method of :attr:`head`.
|
|
Returns:
|
|
List[ClsDataSample]: the raw data_samples with
|
|
the predicted results
|
|
"""
|
|
if not self.prototype_inited:
|
|
self.prepare_prototype()
|
|
|
|
feats = self.extract_feat(inputs)
|
|
if isinstance(feats, tuple):
|
|
feats = feats[-1]
|
|
|
|
# Matching of similarity
|
|
result = self.matching(feats)
|
|
return self._get_predictions(result, data_samples)
|
|
|
|
def _get_predictions(self, result, data_samples):
|
|
"""Post-process the output of retriever."""
|
|
pred_scores = result['score']
|
|
pred_labels = result['pred_label']
|
|
if self.topk != -1:
|
|
topk = min(self.topk, pred_scores.size()[-1])
|
|
pred_labels = pred_labels[:, :topk]
|
|
|
|
if data_samples is not None:
|
|
for data_sample, score, label in zip(data_samples, pred_scores,
|
|
pred_labels):
|
|
data_sample.set_pred_score(score).set_pred_label(label)
|
|
else:
|
|
data_samples = []
|
|
for score, label in zip(pred_scores, pred_labels):
|
|
data_samples.append(ClsDataSample().set_pred_score(
|
|
score).set_pred_label(label))
|
|
return data_samples
|
|
|
|
def _get_prototype_vecs_from_dataloader(self, data_loader):
|
|
"""get prototype_vecs from dataloader."""
|
|
self.eval()
|
|
num = len(data_loader.dataset)
|
|
|
|
prototype_vecs = None
|
|
for data_batch in track_on_main_process(data_loader):
|
|
data = self.data_preprocessor(data_batch, False)
|
|
feat = self(**data)
|
|
if isinstance(feat, tuple):
|
|
feat = feat[-1]
|
|
|
|
if prototype_vecs is None:
|
|
dim = feat.shape[-1]
|
|
prototype_vecs = torch.zeros(num, dim)
|
|
for i, data_sample in enumerate(data_batch['data_samples']):
|
|
sample_idx = data_sample.get('sample_idx')
|
|
prototype_vecs[sample_idx] = feat[i]
|
|
|
|
assert prototype_vecs is not None
|
|
dist.all_reduce(prototype_vecs)
|
|
return prototype_vecs
|
|
|
|
def _get_prototype_vecs_from_path(self, proto_path):
|
|
"""get prototype_vecs from prototype path."""
|
|
data = [None]
|
|
if dist.is_main_process():
|
|
data[0] = torch.load(proto_path)
|
|
dist.broadcast_object_list(data, src=0)
|
|
prototype_vecs = data[0]
|
|
assert prototype_vecs is not None
|
|
return prototype_vecs
|
|
|
|
@torch.no_grad()
|
|
def prepare_prototype(self):
|
|
"""Used in meta testing. This function will be called before the meta
|
|
testing. Obtain the vector based on the prototype.
|
|
|
|
- torch.Tensor: The prototype vector is the prototype
|
|
- str: The path of the extracted feature path, parse data structure,
|
|
and generate the prototype feature vector set
|
|
- Dataloader or config: Extract and save the feature vectors according
|
|
to the dataloader
|
|
"""
|
|
device = next(self.image_encoder.parameters()).device
|
|
if isinstance(self.prototype, torch.Tensor):
|
|
prototype_vecs = self.prototype
|
|
elif isinstance(self.prototype, str):
|
|
prototype_vecs = self._get_prototype_vecs_from_path(self.prototype)
|
|
elif isinstance(self.prototype, (dict, DataLoader)):
|
|
loader = Runner.build_dataloader(self.prototype)
|
|
prototype_vecs = self._get_prototype_vecs_from_dataloader(loader)
|
|
|
|
self.register_buffer(
|
|
'prototype_vecs', prototype_vecs.to(device), persistent=False)
|
|
self.prototype_inited = True
|
|
|
|
def dump_prototype(self, path):
|
|
"""Save the features extracted from the prototype to specific path.
|
|
|
|
Args:
|
|
path (str): Path to save feature.
|
|
"""
|
|
if not self.prototype_inited:
|
|
self.prepare_prototype()
|
|
torch.save(self.prototype_vecs, path)
|