fast-reid/fastreid/evaluation/reid_evaluation.py

140 lines
5.3 KiB
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import logging
2020-02-10 07:38:56 +08:00
from collections import OrderedDict
from sklearn import metrics
2020-02-10 07:38:56 +08:00
2020-02-27 12:16:57 +08:00
import numpy as np
import torch
import torch.nn.functional as F
2020-02-10 07:38:56 +08:00
from .evaluator import DatasetEvaluator
from .query_expansion import aqe
2020-02-10 07:38:56 +08:00
from .rank import evaluate_rank
2020-05-13 11:47:52 +08:00
from .rerank import re_ranking
2020-07-06 16:57:43 +08:00
from .roc import evaluate_roc
2020-07-30 20:15:28 +08:00
from fastreid.utils import comm
2020-05-13 11:47:52 +08:00
logger = logging.getLogger(__name__)
2020-02-10 07:38:56 +08:00
class ReidEvaluator(DatasetEvaluator):
2020-02-27 12:16:57 +08:00
def __init__(self, cfg, num_query, output_dir=None):
2020-05-13 11:47:52 +08:00
self.cfg = cfg
2020-02-10 07:38:56 +08:00
self._num_query = num_query
2020-02-27 12:16:57 +08:00
self._output_dir = output_dir
2020-02-10 22:13:04 +08:00
2020-02-10 07:38:56 +08:00
self.features = []
self.pids = []
self.camids = []
def reset(self):
self.features = []
self.pids = []
self.camids = []
2020-07-06 16:57:43 +08:00
def process(self, inputs, outputs):
2020-07-30 20:15:28 +08:00
self.pids.extend(inputs["targets"])
self.camids.extend(inputs["camid"])
2020-07-06 16:57:43 +08:00
self.features.append(outputs.cpu())
2020-02-10 07:38:56 +08:00
2020-05-13 11:47:52 +08:00
@staticmethod
def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
assert metric in ["cosine", "euclidean"], "must choose from [cosine, euclidean], but got {}".format(metric)
if metric == "cosine":
dist = 1 - torch.mm(query_feat, gallery_feat.t())
else:
m, n = query_feat.size(0), gallery_feat.size(0)
xx = torch.pow(query_feat, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(gallery_feat, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
2020-07-31 16:32:10 +08:00
dist.addmm_(query_feat, gallery_feat.t(), beta=1, alpha=-2)
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist.cpu().numpy()
2020-05-13 11:47:52 +08:00
2020-02-10 07:38:56 +08:00
def evaluate(self):
2020-07-30 20:15:28 +08:00
if comm.get_world_size() > 1:
comm.synchronize()
features = comm.gather(self.features)
features = sum(features, [])
2020-02-10 07:38:56 +08:00
2020-07-30 20:15:28 +08:00
pids = comm.gather(self.pids)
pids = sum(pids, [])
camids = comm.gather(self.camids)
camids = sum(camids, [])
if not comm.is_main_process():
return {}
else:
features = self.features
pids = self.pids
camids = self.camids
features = torch.cat(features, dim=0)
2020-02-10 07:38:56 +08:00
# query feature, person ids and camera ids
query_features = features[:self._num_query]
2020-07-30 20:15:28 +08:00
query_pids = np.asarray(pids[:self._num_query])
query_camids = np.asarray(camids[:self._num_query])
2020-02-10 07:38:56 +08:00
# gallery features, person ids and camera ids
gallery_features = features[self._num_query:]
2020-07-30 20:15:28 +08:00
gallery_pids = np.asarray(pids[self._num_query:])
gallery_camids = np.asarray(camids[self._num_query:])
2020-02-10 07:38:56 +08:00
self._results = OrderedDict()
if self.cfg.TEST.AQE.ENABLED:
logger.info("Test with AQE setting")
qe_time = self.cfg.TEST.AQE.QE_TIME
qe_k = self.cfg.TEST.AQE.QE_K
alpha = self.cfg.TEST.AQE.ALPHA
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
if self.cfg.TEST.METRIC == "cosine":
query_features = F.normalize(query_features, dim=1)
gallery_features = F.normalize(gallery_features, dim=1)
dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, gallery_features)
2020-05-13 11:47:52 +08:00
if self.cfg.TEST.RERANK.ENABLED:
logger.info("Test with rerank setting")
k1 = self.cfg.TEST.RERANK.K1
2020-05-18 17:05:20 +08:00
k2 = self.cfg.TEST.RERANK.K2
2020-05-13 11:47:52 +08:00
lambda_value = self.cfg.TEST.RERANK.LAMBDA
q_q_dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, query_features)
g_g_dist = self.cal_dist(self.cfg.TEST.METRIC, gallery_features, gallery_features)
re_dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value)
query_features = query_features.numpy()
gallery_features = gallery_features.numpy()
cmc, all_AP, all_INP = evaluate_rank(re_dist, query_features, gallery_features,
query_pids, gallery_pids, query_camids,
gallery_camids, use_distmat=True)
else:
query_features = query_features.numpy()
gallery_features = gallery_features.numpy()
cmc, all_AP, all_INP = evaluate_rank(dist, query_features, gallery_features,
query_pids, gallery_pids, query_camids, gallery_camids,
use_distmat=False)
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)
2020-07-15 14:56:18 +08:00
for r in [1, 5, 10]:
self._results['Rank-{}'.format(r)] = cmc[r - 1]
2020-02-10 07:38:56 +08:00
self._results['mAP'] = mAP
2020-03-25 10:58:26 +08:00
self._results['mINP'] = mINP
2020-02-10 07:38:56 +08:00
if self.cfg.TEST.ROC_ENABLED:
scores, labels = evaluate_roc(dist, query_features, gallery_features,
query_pids, gallery_pids, query_camids, gallery_camids)
fprs, tprs, thres = metrics.roc_curve(labels, scores)
for fpr in [1e-4, 1e-3, 1e-2]:
ind = np.argmin(np.abs(fprs - fpr))
self._results["TPR@FPR={:.0e}".format(fpr)] = tprs[ind]
2020-07-15 14:56:18 +08:00
2020-02-10 07:38:56 +08:00
return copy.deepcopy(self._results)