mirror of https://github.com/JDAI-CV/fast-reid.git
fix rank_cylib bug about valid query index
fix bugs of rank_cylib Reviewed by: l1aoxingyupull/224/head
commit
f74cebcd88
|
@ -181,6 +181,8 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
float tmp_cmc_sum
|
||||
|
||||
long valid_index = 0
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
|
@ -213,7 +215,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
if (raw_cmc[g_idx] == 1) and (g_idx > max_pos_idx):
|
||||
max_pos_idx = g_idx
|
||||
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
|
||||
all_INP[q_idx] = inp
|
||||
all_INP[valid_index] = inp
|
||||
|
||||
for g_idx in range(num_g_real):
|
||||
if cmc[g_idx] > 1:
|
||||
|
@ -231,7 +233,8 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
for g_idx in range(num_g_real):
|
||||
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
|
||||
num_rel += raw_cmc[g_idx]
|
||||
all_AP[q_idx] = tmp_cmc_sum / num_rel
|
||||
all_AP[valid_index] = tmp_cmc_sum / num_rel
|
||||
valid_index+=1
|
||||
|
||||
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
||||
|
||||
|
@ -242,7 +245,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
|
||||
avg_cmc[rank_idx] /= num_valid_q
|
||||
|
||||
return np.asarray(avg_cmc).astype(np.float32), all_AP, all_INP
|
||||
return np.asarray(avg_cmc).astype(np.float32), all_AP[:valid_index], all_INP[:valid_index]
|
||||
|
||||
|
||||
# Compute the cumulative sum
|
||||
|
|
|
@ -76,4 +76,4 @@ g_camids = np.random.randint(0, 5, size=num_g)
|
|||
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
|
||||
print("Python:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
|
||||
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
|
||||
print("Cython:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
|
||||
print("Cython:\nmAP = {} \ncmc = {}\nmINP = {}".format(np.array(mAP), cmc, np.array(mINP)))
|
Loading…
Reference in New Issue