fix rank_cylib bug about valid query index

fix bugs of rank_cylib

Reviewed by: l1aoxingyu
pull/224/head
Xingyu Liao 2020-08-10 19:28:05 +08:00 committed by GitHub
commit f74cebcd88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 4 deletions

View File

@ -181,6 +181,8 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
float tmp_cmc_sum
long valid_index = 0
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
@ -213,7 +215,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
if (raw_cmc[g_idx] == 1) and (g_idx > max_pos_idx):
max_pos_idx = g_idx
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
all_INP[q_idx] = inp
all_INP[valid_index] = inp
for g_idx in range(num_g_real):
if cmc[g_idx] > 1:
@ -231,7 +233,8 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
for g_idx in range(num_g_real):
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
num_rel += raw_cmc[g_idx]
all_AP[q_idx] = tmp_cmc_sum / num_rel
all_AP[valid_index] = tmp_cmc_sum / num_rel
valid_index+=1
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
@ -242,7 +245,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
avg_cmc[rank_idx] /= num_valid_q
return np.asarray(avg_cmc).astype(np.float32), all_AP, all_INP
return np.asarray(avg_cmc).astype(np.float32), all_AP[:valid_index], all_INP[:valid_index]
# Compute the cumulative sum

View File

@ -76,4 +76,4 @@ g_camids = np.random.randint(0, 5, size=num_g)
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
print("Python:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
print("Cython:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
print("Cython:\nmAP = {} \ncmc = {}\nmINP = {}".format(np.array(mAP), cmc, np.array(mINP)))