feat: add aqe support in test phase

query expansion will combine the retrived topk nearest neighbors with the original query feature,
it will enhance mAP by a large margin.feat:
pull/64/head
liaoxingyu 2020-05-13 16:27:22 +08:00
parent 320010f2ae
commit 5ae3d4fecf
3 changed files with 82 additions and 16 deletions

View File

@ -221,6 +221,14 @@ _C.TEST = CN()
_C.TEST.EVAL_PERIOD = 50
_C.TEST.IMS_PER_BATCH = 128
_C.TEST.METRIC = "cosine"
# Average query expansion
_C.TEST.AQE = CN()
_C.TEST.AQE.ENABLED = False
_C.TEST.AQE.ALPHA = 3.0
_C.TEST.AQE.QE_TIME = 1
_C.TEST.AQE.QE_K = 5
# Re-rank
_C.TEST.RERANK = CN()
@ -229,15 +237,12 @@ _C.TEST.RERANK.K1 = 20
_C.TEST.RERANK.K2 = 6
_C.TEST.RERANK.LAMBDA = 0.3
# Average query expansion
_C.TEST.AQE = CN()
_C.TEST.AQE.ENABLED = True
# Precise batchnorm
_C.TEST.PRECISE_BN = CN()
_C.TEST.PRECISE_BN.ENABLED = False
_C.TEST.PRECISE_BN.DATASET = 'Market1501'
_C.TEST.PRECISE_BN.NUM_ITER = 300
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #

View File

@ -0,0 +1,46 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
# based on
# https://github.com/PyRetri/PyRetri/blob/master/pyretri/index/re_ranker/re_ranker_impl/query_expansion.py
import numpy as np
import torch
import torch.nn.functional as F
def aqe(query_feat: torch.tensor, gallery_feat: torch.tensor,
qe_times: int = 1, qe_k: int = 10, alpha: float = 3.0):
"""
Combining the retrieved topk nearest neighbors with the original query and doing another retrieval.
c.f. https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf
Args :
query_feat (torch.tensor):
gallery_feat (torch.tensor):
qe_times (int): number of query expansion times.
qe_k (int): number of the neighbors to be combined.
alpha (float):
"""
num_query = query_feat.shape[0]
all_feat = torch.cat((query_feat, gallery_feat), dim=0)
norm_feat = F.normalize(all_feat, p=2, dim=1)
all_feat = all_feat.numpy()
for i in range(qe_times):
all_feat_list = []
sims = torch.mm(norm_feat, norm_feat.t())
sims = sims.data.cpu().numpy()
for sim in sims:
init_rank = np.argpartition(-sim, range(1, qe_k + 1))
weights = sim[init_rank[:qe_k]].reshape((-1, 1))
weights = np.power(weights, alpha)
all_feat_list.append(np.mean(all_feat[init_rank[:qe_k], :] * weights, axis=0))
all_feat = np.stack(all_feat_list, axis=0)
norm_feat = F.normalize(torch.from_numpy(all_feat), p=2, dim=1)
query_feat = torch.from_numpy(all_feat[:num_query])
gallery_feat = torch.from_numpy(all_feat[num_query:])
return query_feat, gallery_feat

View File

@ -6,6 +6,7 @@
import logging
import copy
from collections import OrderedDict
from functools import partial
import numpy as np
import torch
@ -14,6 +15,7 @@ import torch.nn.functional as F
from .evaluator import DatasetEvaluator
from .rank import evaluate_rank
from .rerank import re_ranking
from .query_expansion import aqe
logger = logging.getLogger(__name__)
@ -34,16 +36,25 @@ class ReidEvaluator(DatasetEvaluator):
self.camids = []
def process(self, outputs):
self.features.append(outputs[0])
self.features.append(outputs[0].cpu())
self.pids.extend(outputs[1].cpu().numpy())
self.camids.extend(outputs[2].cpu().numpy())
@staticmethod
def cal_dist(query_feat: torch.tensor, gallery_feat: torch.tensor):
query_feat = F.normalize(query_feat, dim=1)
gallery_feat = F.normalize(gallery_feat, dim=1)
cos_dist = 1 - torch.mm(query_feat, gallery_feat.t()).cpu().numpy()
return cos_dist
def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
assert metric in ["cosine", "euclidean"], "must choose from [cosine, euclidean], but got {}".format(metric)
if metric == "cosine":
query_feat = F.normalize(query_feat, dim=1)
gallery_feat = F.normalize(gallery_feat, dim=1)
dist = 1 - torch.mm(query_feat, gallery_feat.t())
else:
m, n = query_feat.size(0), gallery_feat.size(0)
xx = torch.pow(query_feat, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(gallery_feat, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
dist.addmm_(1, -2, query_feat, gallery_feat.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist.cpu().numpy()
def evaluate(self):
features = torch.cat(self.features, dim=0)
@ -60,20 +71,24 @@ class ReidEvaluator(DatasetEvaluator):
self._results = OrderedDict()
dist = self.cal_dist(query_features, gallery_features)
if self.cfg.TEST.AQE.ENABLED:
logger.info("Test with AQE setting")
qe_time = self.cfg.TEST.AQE.QE_TIME
qe_k = self.cfg.TEST.AQE.QE_K
alpha = self.cfg.TEST.AQE.ALPHA
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, gallery_features)
if self.cfg.TEST.RERANK.ENABLED:
logger.info("Test with rerank setting")
k1 = self.cfg.TEST.RERANK.K1
k2 = self.cfg.TEST.RERANK.K1
lambda_value = self.cfg.TEST.RERANK.LAMBDA
q_q_dist = self.cal_dist(query_features, query_features)
g_g_dist = self.cal_dist(gallery_features, gallery_features)
q_q_dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, query_features)
g_g_dist = self.cal_dist(self.cfg.TEST.METRIC, gallery_features, gallery_features)
dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value)
if self.cfg.TEST.AQE.ENABLED:
pass
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)