From 5ae3d4fecf947e0cc5def90114441d85f5c58550 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Wed, 13 May 2020 16:27:22 +0800 Subject: [PATCH] feat: add aqe support in test phase query expansion will combine the retrived topk nearest neighbors with the original query feature, it will enhance mAP by a large margin.feat: --- fastreid/config/defaults.py | 13 +++++--- fastreid/evaluation/query_expansion.py | 46 ++++++++++++++++++++++++++ fastreid/evaluation/reid_evaluation.py | 39 +++++++++++++++------- 3 files changed, 82 insertions(+), 16 deletions(-) create mode 100644 fastreid/evaluation/query_expansion.py diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index 90949c7..fe93149 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -221,6 +221,14 @@ _C.TEST = CN() _C.TEST.EVAL_PERIOD = 50 _C.TEST.IMS_PER_BATCH = 128 +_C.TEST.METRIC = "cosine" + +# Average query expansion +_C.TEST.AQE = CN() +_C.TEST.AQE.ENABLED = False +_C.TEST.AQE.ALPHA = 3.0 +_C.TEST.AQE.QE_TIME = 1 +_C.TEST.AQE.QE_K = 5 # Re-rank _C.TEST.RERANK = CN() @@ -229,15 +237,12 @@ _C.TEST.RERANK.K1 = 20 _C.TEST.RERANK.K2 = 6 _C.TEST.RERANK.LAMBDA = 0.3 -# Average query expansion -_C.TEST.AQE = CN() -_C.TEST.AQE.ENABLED = True - # Precise batchnorm _C.TEST.PRECISE_BN = CN() _C.TEST.PRECISE_BN.ENABLED = False _C.TEST.PRECISE_BN.DATASET = 'Market1501' _C.TEST.PRECISE_BN.NUM_ITER = 300 + # ---------------------------------------------------------------------------- # # Misc options # ---------------------------------------------------------------------------- # diff --git a/fastreid/evaluation/query_expansion.py b/fastreid/evaluation/query_expansion.py new file mode 100644 index 0000000..873aed6 --- /dev/null +++ b/fastreid/evaluation/query_expansion.py @@ -0,0 +1,46 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: liaoxingyu5@jd.com +""" + +# based on +# https://github.com/PyRetri/PyRetri/blob/master/pyretri/index/re_ranker/re_ranker_impl/query_expansion.py + +import numpy as np +import torch +import torch.nn.functional as F + + +def aqe(query_feat: torch.tensor, gallery_feat: torch.tensor, + qe_times: int = 1, qe_k: int = 10, alpha: float = 3.0): + """ + Combining the retrieved topk nearest neighbors with the original query and doing another retrieval. + c.f. https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf + Args : + query_feat (torch.tensor): + gallery_feat (torch.tensor): + qe_times (int): number of query expansion times. + qe_k (int): number of the neighbors to be combined. + alpha (float): + """ + num_query = query_feat.shape[0] + all_feat = torch.cat((query_feat, gallery_feat), dim=0) + norm_feat = F.normalize(all_feat, p=2, dim=1) + + all_feat = all_feat.numpy() + for i in range(qe_times): + all_feat_list = [] + sims = torch.mm(norm_feat, norm_feat.t()) + sims = sims.data.cpu().numpy() + for sim in sims: + init_rank = np.argpartition(-sim, range(1, qe_k + 1)) + weights = sim[init_rank[:qe_k]].reshape((-1, 1)) + weights = np.power(weights, alpha) + all_feat_list.append(np.mean(all_feat[init_rank[:qe_k], :] * weights, axis=0)) + all_feat = np.stack(all_feat_list, axis=0) + norm_feat = F.normalize(torch.from_numpy(all_feat), p=2, dim=1) + + query_feat = torch.from_numpy(all_feat[:num_query]) + gallery_feat = torch.from_numpy(all_feat[num_query:]) + return query_feat, gallery_feat diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py index c7e57b3..82d88c5 100644 --- a/fastreid/evaluation/reid_evaluation.py +++ b/fastreid/evaluation/reid_evaluation.py @@ -6,6 +6,7 @@ import logging import copy from collections import OrderedDict +from functools import partial import numpy as np import torch @@ -14,6 +15,7 @@ import torch.nn.functional as F from .evaluator import DatasetEvaluator from .rank import evaluate_rank from .rerank import re_ranking +from .query_expansion import aqe logger = logging.getLogger(__name__) @@ -34,16 +36,25 @@ class ReidEvaluator(DatasetEvaluator): self.camids = [] def process(self, outputs): - self.features.append(outputs[0]) + self.features.append(outputs[0].cpu()) self.pids.extend(outputs[1].cpu().numpy()) self.camids.extend(outputs[2].cpu().numpy()) @staticmethod - def cal_dist(query_feat: torch.tensor, gallery_feat: torch.tensor): - query_feat = F.normalize(query_feat, dim=1) - gallery_feat = F.normalize(gallery_feat, dim=1) - cos_dist = 1 - torch.mm(query_feat, gallery_feat.t()).cpu().numpy() - return cos_dist + def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor): + assert metric in ["cosine", "euclidean"], "must choose from [cosine, euclidean], but got {}".format(metric) + if metric == "cosine": + query_feat = F.normalize(query_feat, dim=1) + gallery_feat = F.normalize(gallery_feat, dim=1) + dist = 1 - torch.mm(query_feat, gallery_feat.t()) + else: + m, n = query_feat.size(0), gallery_feat.size(0) + xx = torch.pow(query_feat, 2).sum(1, keepdim=True).expand(m, n) + yy = torch.pow(gallery_feat, 2).sum(1, keepdim=True).expand(n, m).t() + dist = xx + yy + dist.addmm_(1, -2, query_feat, gallery_feat.t()) + dist = dist.clamp(min=1e-12).sqrt() # for numerical stability + return dist.cpu().numpy() def evaluate(self): features = torch.cat(self.features, dim=0) @@ -60,20 +71,24 @@ class ReidEvaluator(DatasetEvaluator): self._results = OrderedDict() - dist = self.cal_dist(query_features, gallery_features) + if self.cfg.TEST.AQE.ENABLED: + logger.info("Test with AQE setting") + qe_time = self.cfg.TEST.AQE.QE_TIME + qe_k = self.cfg.TEST.AQE.QE_K + alpha = self.cfg.TEST.AQE.ALPHA + query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha) + + dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, gallery_features) if self.cfg.TEST.RERANK.ENABLED: logger.info("Test with rerank setting") k1 = self.cfg.TEST.RERANK.K1 k2 = self.cfg.TEST.RERANK.K1 lambda_value = self.cfg.TEST.RERANK.LAMBDA - q_q_dist = self.cal_dist(query_features, query_features) - g_g_dist = self.cal_dist(gallery_features, gallery_features) + q_q_dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, query_features) + g_g_dist = self.cal_dist(self.cfg.TEST.METRIC, gallery_features, gallery_features) dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value) - if self.cfg.TEST.AQE.ENABLED: - pass - cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids) mAP = np.mean(all_AP) mINP = np.mean(all_INP)