This commit is contained in:
KaiyangZhou 2018-03-11 21:47:45 +00:00
parent b9150adc45
commit ef99dd489a

View File

@ -2,12 +2,15 @@ from __future__ import absolute_import
import numpy as np
import copy
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids):
num_q = distmat.shape[0]
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
# compute cmc curve for each query, maximum rank is fixed to _MAX_RANK
# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0.
@ -30,7 +33,7 @@ def evaluate(distmat, q_pids, g_pids, q_camids, g_camids):
cmc = cmc.cumsum()
cmc[cmc > 1] = 1
all_cmc.append(cmc)
all_cmc.append(cmc[:max_rank])
num_valid_q += 1.
num_rel = cmc.sum()