mirror of https://github.com/PyRetri/PyRetri.git
54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import torch
|
|
|
|
from ..metric_base import MetricBase
|
|
from ...registry import METRICS
|
|
|
|
from typing import Dict
|
|
|
|
@METRICS.register
|
|
class KNN(MetricBase):
|
|
"""
|
|
Similarity measure based on the euclidean distance.
|
|
|
|
Hyper-Params:
|
|
top_k (int): top_k nearest neighbors will be output in sorted order. If it is 0, all neighbors will be output.
|
|
"""
|
|
default_hyper_params = {
|
|
"top_k": 0,
|
|
}
|
|
|
|
def __init__(self, hps: Dict or None = None):
|
|
"""
|
|
Args:
|
|
hps (dict): default hyper parameters in a dict (keys, values).
|
|
"""
|
|
super(KNN, self).__init__(hps)
|
|
|
|
def _cal_dis(self, query_fea: torch.tensor, gallery_fea: torch.tensor) -> torch.tensor:
|
|
"""
|
|
Calculate the distance between query set features and gallery set features.
|
|
|
|
Args:
|
|
query_fea (torch.tensor): query set features.
|
|
gallery_fea (torch.tensor): gallery set features.
|
|
|
|
Returns:
|
|
dis (torch.tensor): the distance between query set features and gallery set features.
|
|
"""
|
|
query_fea = query_fea.transpose(1, 0)
|
|
inner_dot = gallery_fea.mm(query_fea)
|
|
dis = (gallery_fea ** 2).sum(dim=1, keepdim=True) + (query_fea ** 2).sum(dim=0, keepdim=True)
|
|
dis = dis - 2 * inner_dot
|
|
dis = dis.transpose(1, 0)
|
|
return dis
|
|
|
|
def __call__(self, query_fea: torch.tensor, gallery_fea: torch.tensor) -> (torch.tensor, torch.tensor):
|
|
|
|
dis = self._cal_dis(query_fea, gallery_fea)
|
|
sorted_index = torch.argsort(dis, dim=1)
|
|
if self._hyper_params["top_k"] != 0:
|
|
sorted_index = sorted_index[:, :self._hyper_params["top_k"]]
|
|
return dis, sorted_index
|