2020-02-10 07:38:56 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: liaoxingyu
|
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
import copy
|
2020-05-19 20:45:26 +08:00
|
|
|
import logging
|
2021-10-21 14:41:38 +08:00
|
|
|
import time
|
2021-03-09 20:07:13 +08:00
|
|
|
import itertools
|
2020-02-10 07:38:56 +08:00
|
|
|
from collections import OrderedDict
|
|
|
|
|
2020-02-27 12:16:57 +08:00
|
|
|
import numpy as np
|
2020-04-27 14:51:39 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
2021-01-18 11:36:38 +08:00
|
|
|
from sklearn import metrics
|
2020-02-10 07:38:56 +08:00
|
|
|
|
2021-01-18 11:36:38 +08:00
|
|
|
from fastreid.utils import comm
|
|
|
|
from fastreid.utils.compute_dist import build_dist
|
2020-02-10 07:38:56 +08:00
|
|
|
from .evaluator import DatasetEvaluator
|
2020-05-19 20:45:26 +08:00
|
|
|
from .query_expansion import aqe
|
2021-10-21 14:41:38 +08:00
|
|
|
from .rank_cylib import compile_helper
|
2020-05-13 11:47:52 +08:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
class ReidEvaluator(DatasetEvaluator):
|
2020-02-27 12:16:57 +08:00
|
|
|
def __init__(self, cfg, num_query, output_dir=None):
|
2020-05-13 11:47:52 +08:00
|
|
|
self.cfg = cfg
|
2020-02-10 07:38:56 +08:00
|
|
|
self._num_query = num_query
|
2020-02-27 12:16:57 +08:00
|
|
|
self._output_dir = output_dir
|
2020-02-10 22:13:04 +08:00
|
|
|
|
2021-03-09 20:07:13 +08:00
|
|
|
self._cpu_device = torch.device('cpu')
|
|
|
|
|
|
|
|
self._predictions = []
|
2021-10-21 14:41:38 +08:00
|
|
|
self._compile_dependencies()
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
def reset(self):
|
2021-03-09 20:07:13 +08:00
|
|
|
self._predictions = []
|
2020-02-10 07:38:56 +08:00
|
|
|
|
2020-07-06 16:57:43 +08:00
|
|
|
def process(self, inputs, outputs):
|
2021-03-09 20:07:13 +08:00
|
|
|
prediction = {
|
|
|
|
'feats': outputs.to(self._cpu_device, torch.float32),
|
|
|
|
'pids': inputs['targets'].to(self._cpu_device),
|
|
|
|
'camids': inputs['camids'].to(self._cpu_device)
|
|
|
|
|
|
|
|
}
|
|
|
|
self._predictions.append(prediction)
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
def evaluate(self):
|
2020-07-30 20:15:28 +08:00
|
|
|
if comm.get_world_size() > 1:
|
|
|
|
comm.synchronize()
|
2021-03-09 20:07:13 +08:00
|
|
|
predictions = comm.gather(self._predictions, dst=0)
|
|
|
|
predictions = list(itertools.chain(*predictions))
|
2020-02-10 07:38:56 +08:00
|
|
|
|
2021-03-09 20:07:13 +08:00
|
|
|
if not comm.is_main_process():
|
|
|
|
return {}
|
2020-07-30 20:15:28 +08:00
|
|
|
|
|
|
|
else:
|
2021-03-09 20:07:13 +08:00
|
|
|
predictions = self._predictions
|
|
|
|
|
|
|
|
features = []
|
|
|
|
pids = []
|
|
|
|
camids = []
|
|
|
|
for prediction in predictions:
|
|
|
|
features.append(prediction['feats'])
|
|
|
|
pids.append(prediction['pids'])
|
|
|
|
camids.append(prediction['camids'])
|
2020-07-30 20:15:28 +08:00
|
|
|
|
|
|
|
features = torch.cat(features, dim=0)
|
2021-03-09 20:07:13 +08:00
|
|
|
pids = torch.cat(pids, dim=0).numpy()
|
|
|
|
camids = torch.cat(camids, dim=0).numpy()
|
2020-02-10 07:38:56 +08:00
|
|
|
# query feature, person ids and camera ids
|
|
|
|
query_features = features[:self._num_query]
|
2021-03-09 20:07:13 +08:00
|
|
|
query_pids = pids[:self._num_query]
|
|
|
|
query_camids = camids[:self._num_query]
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
# gallery features, person ids and camera ids
|
|
|
|
gallery_features = features[self._num_query:]
|
2021-03-09 20:07:13 +08:00
|
|
|
gallery_pids = pids[self._num_query:]
|
|
|
|
gallery_camids = camids[self._num_query:]
|
2020-02-10 07:38:56 +08:00
|
|
|
|
|
|
|
self._results = OrderedDict()
|
|
|
|
|
2020-05-13 16:27:22 +08:00
|
|
|
if self.cfg.TEST.AQE.ENABLED:
|
|
|
|
logger.info("Test with AQE setting")
|
|
|
|
qe_time = self.cfg.TEST.AQE.QE_TIME
|
|
|
|
qe_k = self.cfg.TEST.AQE.QE_K
|
|
|
|
alpha = self.cfg.TEST.AQE.ALPHA
|
|
|
|
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
|
|
|
|
|
2020-11-06 10:59:53 +08:00
|
|
|
dist = build_dist(query_features, gallery_features, self.cfg.TEST.METRIC)
|
2020-05-13 11:47:52 +08:00
|
|
|
|
|
|
|
if self.cfg.TEST.RERANK.ENABLED:
|
|
|
|
logger.info("Test with rerank setting")
|
|
|
|
k1 = self.cfg.TEST.RERANK.K1
|
2020-05-18 17:05:20 +08:00
|
|
|
k2 = self.cfg.TEST.RERANK.K2
|
2020-05-13 11:47:52 +08:00
|
|
|
lambda_value = self.cfg.TEST.RERANK.LAMBDA
|
2020-11-06 10:59:53 +08:00
|
|
|
|
|
|
|
if self.cfg.TEST.METRIC == "cosine":
|
|
|
|
query_features = F.normalize(query_features, dim=1)
|
|
|
|
gallery_features = F.normalize(gallery_features, dim=1)
|
|
|
|
|
|
|
|
rerank_dist = build_dist(query_features, gallery_features, metric="jaccard", k1=k1, k2=k2)
|
|
|
|
dist = rerank_dist * (1 - lambda_value) + dist * lambda_value
|
|
|
|
|
2021-10-21 14:41:38 +08:00
|
|
|
from .rank import evaluate_rank
|
2020-11-06 10:59:53 +08:00
|
|
|
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
|
|
|
|
2020-05-10 23:17:10 +08:00
|
|
|
mAP = np.mean(all_AP)
|
|
|
|
mINP = np.mean(all_INP)
|
2020-07-15 14:56:18 +08:00
|
|
|
for r in [1, 5, 10]:
|
2021-01-18 11:36:38 +08:00
|
|
|
self._results['Rank-{}'.format(r)] = cmc[r - 1] * 100
|
|
|
|
self._results['mAP'] = mAP * 100
|
|
|
|
self._results['mINP'] = mINP * 100
|
|
|
|
self._results["metric"] = (mAP + cmc[0]) / 2 * 100
|
2020-02-10 07:38:56 +08:00
|
|
|
|
2021-03-26 20:17:39 +08:00
|
|
|
if self.cfg.TEST.ROC.ENABLED:
|
2021-10-21 14:41:38 +08:00
|
|
|
from .roc import evaluate_roc
|
2020-11-06 10:59:53 +08:00
|
|
|
scores, labels = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
2020-08-12 16:27:57 +08:00
|
|
|
fprs, tprs, thres = metrics.roc_curve(labels, scores)
|
|
|
|
|
|
|
|
for fpr in [1e-4, 1e-3, 1e-2]:
|
|
|
|
ind = np.argmin(np.abs(fprs - fpr))
|
|
|
|
self._results["TPR@FPR={:.0e}".format(fpr)] = tprs[ind]
|
2020-07-15 14:56:18 +08:00
|
|
|
|
2020-02-10 07:38:56 +08:00
|
|
|
return copy.deepcopy(self._results)
|
2021-10-21 14:41:38 +08:00
|
|
|
|
|
|
|
def _compile_dependencies(self):
|
|
|
|
# Since we only evaluate results in rank(0), so we just need to compile
|
|
|
|
# cython evaluation tool on rank(0)
|
|
|
|
if comm.is_main_process():
|
|
|
|
try:
|
|
|
|
from .rank_cylib.rank_cy import evaluate_cy
|
|
|
|
except ImportError:
|
|
|
|
start_time = time.time()
|
|
|
|
logger.info("> compiling reid evaluation cython tool")
|
|
|
|
|
|
|
|
compile_helper()
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
">>> done with reid evaluation cython tool. Compilation time: {:.3f} "
|
|
|
|
"seconds".format(time.time() - start_time))
|
|
|
|
comm.synchronize()
|