From 64bf78afee65ab3b20f74879edee8e2172017484 Mon Sep 17 00:00:00 2001
From: liaoxingyu2 <liaoxingyu5@jd.com>
Date: Fri, 6 Nov 2020 11:00:13 +0800
Subject: [PATCH] update rank_cylib

---
 fastreid/evaluation/rank_cylib/rank_cy.pyx | 47 ++++++----------------
 1 file changed, 13 insertions(+), 34 deletions(-)

diff --git a/fastreid/evaluation/rank_cylib/rank_cy.pyx b/fastreid/evaluation/rank_cylib/rank_cy.pyx
index 4dedd23..be45e10 100644
--- a/fastreid/evaluation/rank_cylib/rank_cy.pyx
+++ b/fastreid/evaluation/rank_cylib/rank_cy.pyx
@@ -18,35 +18,22 @@ Credit to https://github.com/luzai
 
 
 # Main interface
-cpdef evaluate_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False,
-                  use_distmat=False):
+cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False):
     distmat = np.asarray(distmat, dtype=np.float32)
-    q_feats = np.asarray(q_feats, dtype=np.float32)
-    g_feats = np.asarray(g_feats, dtype=np.float32)
     q_pids = np.asarray(q_pids, dtype=np.int64)
     g_pids = np.asarray(g_pids, dtype=np.int64)
     q_camids = np.asarray(q_camids, dtype=np.int64)
     g_camids = np.asarray(g_camids, dtype=np.int64)
     if use_metric_cuhk03:
-        return eval_cuhk03_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
-    return eval_market1501_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
+        return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
+    return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
 
 
-cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
-                     long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat):
+cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
+                     long[:]q_camids, long[:]g_camids, long max_rank):
+    cdef long num_q = distmat.shape[0]
+    cdef long num_g = distmat.shape[1]
 
-    cdef long num_q = q_feats.shape[0]
-    cdef long num_g = g_feats.shape[0]
-    cdef long dim = q_feats.shape[1]
-
-    cdef long[:,:] indices
-    cdef index = faiss.IndexFlatL2(dim)
-    index.add(np.asarray(g_feats))
-
-    if use_distmat:
-        indices = np.argsort(distmat, axis=1)
-    else:
-        indices = index.search(np.asarray(q_feats), k=num_g)[1]
 
     if num_g < max_rank:
         max_rank = num_g
@@ -54,6 +41,7 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats,
 
     cdef:
         long num_repeats = 10
+        long[:,:] indices = np.argsort(distmat, axis=1)
         long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
 
         float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
@@ -160,27 +148,18 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats,
     return np.asarray(avg_cmc).astype(np.float32), mAP
 
 
-cpdef eval_market1501_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
-                         long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat):
+cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
+                         long[:]q_camids, long[:]g_camids, long max_rank):
 
-    cdef long num_q = q_feats.shape[0]
-    cdef long num_g = g_feats.shape[0]
-    cdef long dim = q_feats.shape[1]
-
-    cdef long[:,:] indices
-    cdef index = faiss.IndexFlatL2(dim)
-    index.add(np.asarray(g_feats))
-
-    if use_distmat:
-        indices = np.argsort(distmat, axis=1)
-    else:
-        indices = index.search(np.asarray(q_feats), k=num_g)[1]
+    cdef long num_q = distmat.shape[0]
+    cdef long num_g = distmat.shape[1]
 
     if num_g < max_rank:
         max_rank = num_g
         print('Note: number of gallery samples is quite small, got {}'.format(num_g))
 
     cdef:
+        long[:,:] indices = np.argsort(distmat, axis=1)
         long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
 
         float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)