mirror of https://github.com/JDAI-CV/fast-reid.git
feat: add aqe support in test phase
query expansion will combine the retrived topk nearest neighbors with the original query feature, it will enhance mAP by a large margin.feat:pull/64/head
parent
320010f2ae
commit
5ae3d4fecf
|
@ -221,6 +221,14 @@ _C.TEST = CN()
|
|||
|
||||
_C.TEST.EVAL_PERIOD = 50
|
||||
_C.TEST.IMS_PER_BATCH = 128
|
||||
_C.TEST.METRIC = "cosine"
|
||||
|
||||
# Average query expansion
|
||||
_C.TEST.AQE = CN()
|
||||
_C.TEST.AQE.ENABLED = False
|
||||
_C.TEST.AQE.ALPHA = 3.0
|
||||
_C.TEST.AQE.QE_TIME = 1
|
||||
_C.TEST.AQE.QE_K = 5
|
||||
|
||||
# Re-rank
|
||||
_C.TEST.RERANK = CN()
|
||||
|
@ -229,15 +237,12 @@ _C.TEST.RERANK.K1 = 20
|
|||
_C.TEST.RERANK.K2 = 6
|
||||
_C.TEST.RERANK.LAMBDA = 0.3
|
||||
|
||||
# Average query expansion
|
||||
_C.TEST.AQE = CN()
|
||||
_C.TEST.AQE.ENABLED = True
|
||||
|
||||
# Precise batchnorm
|
||||
_C.TEST.PRECISE_BN = CN()
|
||||
_C.TEST.PRECISE_BN.ENABLED = False
|
||||
_C.TEST.PRECISE_BN.DATASET = 'Market1501'
|
||||
_C.TEST.PRECISE_BN.NUM_ITER = 300
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Misc options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
# based on
|
||||
# https://github.com/PyRetri/PyRetri/blob/master/pyretri/index/re_ranker/re_ranker_impl/query_expansion.py
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def aqe(query_feat: torch.tensor, gallery_feat: torch.tensor,
|
||||
qe_times: int = 1, qe_k: int = 10, alpha: float = 3.0):
|
||||
"""
|
||||
Combining the retrieved topk nearest neighbors with the original query and doing another retrieval.
|
||||
c.f. https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf
|
||||
Args :
|
||||
query_feat (torch.tensor):
|
||||
gallery_feat (torch.tensor):
|
||||
qe_times (int): number of query expansion times.
|
||||
qe_k (int): number of the neighbors to be combined.
|
||||
alpha (float):
|
||||
"""
|
||||
num_query = query_feat.shape[0]
|
||||
all_feat = torch.cat((query_feat, gallery_feat), dim=0)
|
||||
norm_feat = F.normalize(all_feat, p=2, dim=1)
|
||||
|
||||
all_feat = all_feat.numpy()
|
||||
for i in range(qe_times):
|
||||
all_feat_list = []
|
||||
sims = torch.mm(norm_feat, norm_feat.t())
|
||||
sims = sims.data.cpu().numpy()
|
||||
for sim in sims:
|
||||
init_rank = np.argpartition(-sim, range(1, qe_k + 1))
|
||||
weights = sim[init_rank[:qe_k]].reshape((-1, 1))
|
||||
weights = np.power(weights, alpha)
|
||||
all_feat_list.append(np.mean(all_feat[init_rank[:qe_k], :] * weights, axis=0))
|
||||
all_feat = np.stack(all_feat_list, axis=0)
|
||||
norm_feat = F.normalize(torch.from_numpy(all_feat), p=2, dim=1)
|
||||
|
||||
query_feat = torch.from_numpy(all_feat[:num_query])
|
||||
gallery_feat = torch.from_numpy(all_feat[num_query:])
|
||||
return query_feat, gallery_feat
|
|
@ -6,6 +6,7 @@
|
|||
import logging
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -14,6 +15,7 @@ import torch.nn.functional as F
|
|||
from .evaluator import DatasetEvaluator
|
||||
from .rank import evaluate_rank
|
||||
from .rerank import re_ranking
|
||||
from .query_expansion import aqe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -34,16 +36,25 @@ class ReidEvaluator(DatasetEvaluator):
|
|||
self.camids = []
|
||||
|
||||
def process(self, outputs):
|
||||
self.features.append(outputs[0])
|
||||
self.features.append(outputs[0].cpu())
|
||||
self.pids.extend(outputs[1].cpu().numpy())
|
||||
self.camids.extend(outputs[2].cpu().numpy())
|
||||
|
||||
@staticmethod
|
||||
def cal_dist(query_feat: torch.tensor, gallery_feat: torch.tensor):
|
||||
query_feat = F.normalize(query_feat, dim=1)
|
||||
gallery_feat = F.normalize(gallery_feat, dim=1)
|
||||
cos_dist = 1 - torch.mm(query_feat, gallery_feat.t()).cpu().numpy()
|
||||
return cos_dist
|
||||
def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
|
||||
assert metric in ["cosine", "euclidean"], "must choose from [cosine, euclidean], but got {}".format(metric)
|
||||
if metric == "cosine":
|
||||
query_feat = F.normalize(query_feat, dim=1)
|
||||
gallery_feat = F.normalize(gallery_feat, dim=1)
|
||||
dist = 1 - torch.mm(query_feat, gallery_feat.t())
|
||||
else:
|
||||
m, n = query_feat.size(0), gallery_feat.size(0)
|
||||
xx = torch.pow(query_feat, 2).sum(1, keepdim=True).expand(m, n)
|
||||
yy = torch.pow(gallery_feat, 2).sum(1, keepdim=True).expand(n, m).t()
|
||||
dist = xx + yy
|
||||
dist.addmm_(1, -2, query_feat, gallery_feat.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
return dist.cpu().numpy()
|
||||
|
||||
def evaluate(self):
|
||||
features = torch.cat(self.features, dim=0)
|
||||
|
@ -60,20 +71,24 @@ class ReidEvaluator(DatasetEvaluator):
|
|||
|
||||
self._results = OrderedDict()
|
||||
|
||||
dist = self.cal_dist(query_features, gallery_features)
|
||||
if self.cfg.TEST.AQE.ENABLED:
|
||||
logger.info("Test with AQE setting")
|
||||
qe_time = self.cfg.TEST.AQE.QE_TIME
|
||||
qe_k = self.cfg.TEST.AQE.QE_K
|
||||
alpha = self.cfg.TEST.AQE.ALPHA
|
||||
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
|
||||
|
||||
dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, gallery_features)
|
||||
|
||||
if self.cfg.TEST.RERANK.ENABLED:
|
||||
logger.info("Test with rerank setting")
|
||||
k1 = self.cfg.TEST.RERANK.K1
|
||||
k2 = self.cfg.TEST.RERANK.K1
|
||||
lambda_value = self.cfg.TEST.RERANK.LAMBDA
|
||||
q_q_dist = self.cal_dist(query_features, query_features)
|
||||
g_g_dist = self.cal_dist(gallery_features, gallery_features)
|
||||
q_q_dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, query_features)
|
||||
g_g_dist = self.cal_dist(self.cfg.TEST.METRIC, gallery_features, gallery_features)
|
||||
dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value)
|
||||
|
||||
if self.cfg.TEST.AQE.ENABLED:
|
||||
pass
|
||||
|
||||
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
mAP = np.mean(all_AP)
|
||||
mINP = np.mean(all_INP)
|
||||
|
|
Loading…
Reference in New Issue