2020-02-10 07:38:56 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: liaoxingyu
|
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
import copy
|
2020-05-19 20:45:26 +08:00
|
|
|
import logging
|
2020-02-10 07:38:56 +08:00
|
|
|
from collections import OrderedDict
|
|
|
|
|
2020-02-27 12:16:57 +08:00
|
|
|
import numpy as np
|
2020-04-27 14:51:39 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
from .evaluator import DatasetEvaluator
|
2020-05-19 20:45:26 +08:00
|
|
|
from .query_expansion import aqe
|
2020-02-10 07:38:56 +08:00
|
|
|
from .rank import evaluate_rank
|
2020-05-19 20:45:26 +08:00
|
|
|
from .roc import evaluate_roc
|
2020-05-13 11:47:52 +08:00
|
|
|
from .rerank import re_ranking
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
class ReidEvaluator(DatasetEvaluator):
|
2020-02-27 12:16:57 +08:00
|
|
|
def __init__(self, cfg, num_query, output_dir=None):
|
2020-05-13 11:47:52 +08:00
|
|
|
self.cfg = cfg
|
2020-02-10 07:38:56 +08:00
|
|
|
self._num_query = num_query
|
2020-02-27 12:16:57 +08:00
|
|
|
self._output_dir = output_dir
|
2020-02-10 22:13:04 +08:00
|
|
|
|
2020-02-10 07:38:56 +08:00
|
|
|
self.features = []
|
|
|
|
self.pids = []
|
|
|
|
self.camids = []
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.features = []
|
|
|
|
self.pids = []
|
|
|
|
self.camids = []
|
|
|
|
|
|
|
|
def process(self, outputs):
|
2020-05-13 16:27:22 +08:00
|
|
|
self.features.append(outputs[0].cpu())
|
2020-02-18 21:01:23 +08:00
|
|
|
self.pids.extend(outputs[1].cpu().numpy())
|
|
|
|
self.camids.extend(outputs[2].cpu().numpy())
|
2020-02-10 07:38:56 +08:00
|
|
|
|
2020-05-13 11:47:52 +08:00
|
|
|
@staticmethod
|
2020-05-13 16:27:22 +08:00
|
|
|
def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
|
|
|
|
assert metric in ["cosine", "euclidean"], "must choose from [cosine, euclidean], but got {}".format(metric)
|
|
|
|
if metric == "cosine":
|
|
|
|
query_feat = F.normalize(query_feat, dim=1)
|
|
|
|
gallery_feat = F.normalize(gallery_feat, dim=1)
|
|
|
|
dist = 1 - torch.mm(query_feat, gallery_feat.t())
|
|
|
|
else:
|
|
|
|
m, n = query_feat.size(0), gallery_feat.size(0)
|
|
|
|
xx = torch.pow(query_feat, 2).sum(1, keepdim=True).expand(m, n)
|
|
|
|
yy = torch.pow(gallery_feat, 2).sum(1, keepdim=True).expand(n, m).t()
|
|
|
|
dist = xx + yy
|
|
|
|
dist.addmm_(1, -2, query_feat, gallery_feat.t())
|
|
|
|
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
|
|
|
return dist.cpu().numpy()
|
2020-05-13 11:47:52 +08:00
|
|
|
|
2020-02-10 07:38:56 +08:00
|
|
|
def evaluate(self):
|
|
|
|
features = torch.cat(self.features, dim=0)
|
|
|
|
|
|
|
|
# query feature, person ids and camera ids
|
|
|
|
query_features = features[:self._num_query]
|
2020-02-27 12:16:57 +08:00
|
|
|
query_pids = np.asarray(self.pids[:self._num_query])
|
|
|
|
query_camids = np.asarray(self.camids[:self._num_query])
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
# gallery features, person ids and camera ids
|
|
|
|
gallery_features = features[self._num_query:]
|
2020-02-27 12:16:57 +08:00
|
|
|
gallery_pids = np.asarray(self.pids[self._num_query:])
|
|
|
|
gallery_camids = np.asarray(self.camids[self._num_query:])
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
self._results = OrderedDict()
|
|
|
|
|
2020-05-13 16:27:22 +08:00
|
|
|
if self.cfg.TEST.AQE.ENABLED:
|
|
|
|
logger.info("Test with AQE setting")
|
|
|
|
qe_time = self.cfg.TEST.AQE.QE_TIME
|
|
|
|
qe_k = self.cfg.TEST.AQE.QE_K
|
|
|
|
alpha = self.cfg.TEST.AQE.ALPHA
|
|
|
|
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
|
|
|
|
|
|
|
|
dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, gallery_features)
|
2020-05-13 11:47:52 +08:00
|
|
|
|
|
|
|
if self.cfg.TEST.RERANK.ENABLED:
|
|
|
|
logger.info("Test with rerank setting")
|
|
|
|
k1 = self.cfg.TEST.RERANK.K1
|
2020-05-18 17:05:20 +08:00
|
|
|
k2 = self.cfg.TEST.RERANK.K2
|
2020-05-13 11:47:52 +08:00
|
|
|
lambda_value = self.cfg.TEST.RERANK.LAMBDA
|
2020-05-13 16:27:22 +08:00
|
|
|
q_q_dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, query_features)
|
|
|
|
g_g_dist = self.cal_dist(self.cfg.TEST.METRIC, gallery_features, gallery_features)
|
2020-05-13 11:47:52 +08:00
|
|
|
dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value)
|
|
|
|
|
|
|
|
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
2020-05-10 23:17:10 +08:00
|
|
|
mAP = np.mean(all_AP)
|
|
|
|
mINP = np.mean(all_INP)
|
2020-02-10 07:38:56 +08:00
|
|
|
for r in [1, 5, 10]:
|
|
|
|
self._results['Rank-{}'.format(r)] = cmc[r - 1]
|
|
|
|
self._results['mAP'] = mAP
|
2020-03-25 10:58:26 +08:00
|
|
|
self._results['mINP'] = mINP
|
2020-02-10 07:38:56 +08:00
|
|
|
|
2020-05-20 14:29:33 +08:00
|
|
|
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
|
|
|
fprs = [1e-4, 1e-3, 1e-2]
|
|
|
|
for i in range(len(fprs)):
|
|
|
|
self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
|
2020-02-10 07:38:56 +08:00
|
|
|
return copy.deepcopy(self._results)
|