mirror of https://github.com/JDAI-CV/fast-reid.git
47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
# based on
|
|
# https://github.com/PyRetri/PyRetri/blob/master/pyretri/index/re_ranker/re_ranker_impl/query_expansion.py
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def aqe(query_feat: torch.tensor, gallery_feat: torch.tensor,
|
|
qe_times: int = 1, qe_k: int = 10, alpha: float = 3.0):
|
|
"""
|
|
Combining the retrieved topk nearest neighbors with the original query and doing another retrieval.
|
|
c.f. https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf
|
|
Args :
|
|
query_feat (torch.tensor):
|
|
gallery_feat (torch.tensor):
|
|
qe_times (int): number of query expansion times.
|
|
qe_k (int): number of the neighbors to be combined.
|
|
alpha (float):
|
|
"""
|
|
num_query = query_feat.shape[0]
|
|
all_feat = torch.cat((query_feat, gallery_feat), dim=0)
|
|
norm_feat = F.normalize(all_feat, p=2, dim=1)
|
|
|
|
all_feat = all_feat.numpy()
|
|
for i in range(qe_times):
|
|
all_feat_list = []
|
|
sims = torch.mm(norm_feat, norm_feat.t())
|
|
sims = sims.data.cpu().numpy()
|
|
for sim in sims:
|
|
init_rank = np.argpartition(-sim, range(1, qe_k + 1))
|
|
weights = sim[init_rank[:qe_k]].reshape((-1, 1))
|
|
weights = np.power(weights, alpha)
|
|
all_feat_list.append(np.mean(all_feat[init_rank[:qe_k], :] * weights, axis=0))
|
|
all_feat = np.stack(all_feat_list, axis=0)
|
|
norm_feat = F.normalize(torch.from_numpy(all_feat), p=2, dim=1)
|
|
|
|
query_feat = torch.from_numpy(all_feat[:num_query])
|
|
gallery_feat = torch.from_numpy(all_feat[num_query:])
|
|
return query_feat, gallery_feat
|