From 64bf78afee65ab3b20f74879edee8e2172017484 Mon Sep 17 00:00:00 2001 From: liaoxingyu2 <liaoxingyu5@jd.com> Date: Fri, 6 Nov 2020 11:00:13 +0800 Subject: [PATCH] update rank_cylib --- fastreid/evaluation/rank_cylib/rank_cy.pyx | 47 ++++++---------------- 1 file changed, 13 insertions(+), 34 deletions(-) diff --git a/fastreid/evaluation/rank_cylib/rank_cy.pyx b/fastreid/evaluation/rank_cylib/rank_cy.pyx index 4dedd23..be45e10 100644 --- a/fastreid/evaluation/rank_cylib/rank_cy.pyx +++ b/fastreid/evaluation/rank_cylib/rank_cy.pyx @@ -18,35 +18,22 @@ Credit to https://github.com/luzai # Main interface -cpdef evaluate_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False, - use_distmat=False): +cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False): distmat = np.asarray(distmat, dtype=np.float32) - q_feats = np.asarray(q_feats, dtype=np.float32) - g_feats = np.asarray(g_feats, dtype=np.float32) q_pids = np.asarray(q_pids, dtype=np.int64) g_pids = np.asarray(g_pids, dtype=np.int64) q_camids = np.asarray(q_camids, dtype=np.int64) g_camids = np.asarray(g_camids, dtype=np.int64) if use_metric_cuhk03: - return eval_cuhk03_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat) - return eval_market1501_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat) + return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) + return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) -cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids, - long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat): +cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, + long[:]q_camids, long[:]g_camids, long max_rank): + cdef long num_q = distmat.shape[0] + cdef long num_g = distmat.shape[1] - cdef long num_q = q_feats.shape[0] - cdef long num_g = g_feats.shape[0] - cdef long dim = q_feats.shape[1] - - cdef long[:,:] indices - cdef index = faiss.IndexFlatL2(dim) - index.add(np.asarray(g_feats)) - - if use_distmat: - indices = np.argsort(distmat, axis=1) - else: - indices = index.search(np.asarray(q_feats), k=num_g)[1] if num_g < max_rank: max_rank = num_g @@ -54,6 +41,7 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, cdef: long num_repeats = 10 + long[:,:] indices = np.argsort(distmat, axis=1) long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) @@ -160,27 +148,18 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, return np.asarray(avg_cmc).astype(np.float32), mAP -cpdef eval_market1501_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids, - long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat): +cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, + long[:]q_camids, long[:]g_camids, long max_rank): - cdef long num_q = q_feats.shape[0] - cdef long num_g = g_feats.shape[0] - cdef long dim = q_feats.shape[1] - - cdef long[:,:] indices - cdef index = faiss.IndexFlatL2(dim) - index.add(np.asarray(g_feats)) - - if use_distmat: - indices = np.argsort(distmat, axis=1) - else: - indices = index.search(np.asarray(q_feats), k=num_g)[1] + cdef long num_q = distmat.shape[0] + cdef long num_g = distmat.shape[1] if num_g < max_rank: max_rank = num_g print('Note: number of gallery samples is quite small, got {}'.format(num_g)) cdef: + long[:,:] indices = np.argsort(distmat, axis=1) long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)