deep-person-reid/eval_metrics.py

67 lines
2.5 KiB
Python
Raw Normal View History

2018-03-12 21:53:08 +08:00
from __future__ import print_function, absolute_import
2018-03-12 05:17:48 +08:00
import numpy as np
import copy
2018-04-23 04:18:01 +08:00
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
2018-04-23 17:21:12 +08:00
"""Evaluation with cuhk03 metric"""
2018-04-23 04:18:01 +08:00
raise NotImplementedError
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
2018-04-23 17:27:51 +08:00
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
2018-03-12 05:47:45 +08:00
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
2018-03-12 05:17:48 +08:00
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
2018-03-12 05:47:45 +08:00
# compute cmc curve for each query
2018-03-12 05:17:48 +08:00
all_cmc = []
all_AP = []
2018-04-23 04:18:01 +08:00
num_valid_q = 0. # number of valid query
2018-03-12 05:17:48 +08:00
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
2018-03-12 05:35:49 +08:00
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
2018-03-12 05:17:48 +08:00
keep = np.invert(remove)
# compute cmc curve
2018-03-12 06:20:10 +08:00
orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
if not np.any(orig_cmc):
2018-03-12 05:17:48 +08:00
# this condition is true when query identity does not appear in gallery
continue
2018-03-12 06:20:10 +08:00
cmc = orig_cmc.cumsum()
2018-03-12 05:17:48 +08:00
cmc[cmc > 1] = 1
2018-03-12 05:47:45 +08:00
all_cmc.append(cmc[:max_rank])
2018-03-12 05:17:48 +08:00
num_valid_q += 1.
2018-03-12 06:08:09 +08:00
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
2018-03-12 06:20:10 +08:00
num_rel = orig_cmc.sum()
tmp_cmc = orig_cmc.cumsum()
2018-03-12 05:22:14 +08:00
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
2018-03-12 06:20:10 +08:00
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
2018-03-12 05:17:48 +08:00
AP = tmp_cmc.sum() / num_rel
all_AP.append(AP)
2018-03-12 06:08:09 +08:00
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
2018-03-12 05:17:48 +08:00
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
return all_cmc, mAP
2018-04-23 17:21:12 +08:00
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False):
if use_metric_cuhk03:
2018-04-23 04:18:01 +08:00
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
else:
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)