PyRetri/pyretri/index/metric/metric_impl/knn.py

54 lines
1.7 KiB
Python

# -*- coding: utf-8 -*-
import torch
from ..metric_base import MetricBase
from ...registry import METRICS
from typing import Dict
@METRICS.register
class KNN(MetricBase):
"""
Similarity measure based on the euclidean distance.
Hyper-Params:
top_k (int): top_k nearest neighbors will be output in sorted order. If it is 0, all neighbors will be output.
"""
default_hyper_params = {
"top_k": 0,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(KNN, self).__init__(hps)
def _cal_dis(self, query_fea: torch.tensor, gallery_fea: torch.tensor) -> torch.tensor:
"""
Calculate the distance between query set features and gallery set features.
Args:
query_fea (torch.tensor): query set features.
gallery_fea (torch.tensor): gallery set features.
Returns:
dis (torch.tensor): the distance between query set features and gallery set features.
"""
query_fea = query_fea.transpose(1, 0)
inner_dot = gallery_fea.mm(query_fea)
dis = (gallery_fea ** 2).sum(dim=1, keepdim=True) + (query_fea ** 2).sum(dim=0, keepdim=True)
dis = dis - 2 * inner_dot
dis = dis.transpose(1, 0)
return dis
def __call__(self, query_fea: torch.tensor, gallery_fea: torch.tensor) -> (torch.tensor, torch.tensor):
dis = self._cal_dis(query_fea, gallery_fea)
sorted_index = torch.argsort(dis, dim=1)
if self._hyper_params["top_k"] != 0:
sorted_index = sorted_index[:, :self._hyper_params["top_k"]]
return dis, sorted_index