update rank_cylib

pull/365/head
liaoxingyu2 2020-11-06 11:00:13 +08:00
parent 3bd2fad9a5
commit 64bf78afee
1 changed files with 13 additions and 34 deletions

View File

@ -18,35 +18,22 @@ Credit to https://github.com/luzai
# Main interface
cpdef evaluate_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False,
use_distmat=False):
cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False):
distmat = np.asarray(distmat, dtype=np.float32)
q_feats = np.asarray(q_feats, dtype=np.float32)
g_feats = np.asarray(g_feats, dtype=np.float32)
q_pids = np.asarray(q_pids, dtype=np.int64)
g_pids = np.asarray(g_pids, dtype=np.int64)
q_camids = np.asarray(q_camids, dtype=np.int64)
g_camids = np.asarray(g_camids, dtype=np.int64)
if use_metric_cuhk03:
return eval_cuhk03_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
return eval_market1501_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat):
cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank):
cdef long num_q = distmat.shape[0]
cdef long num_g = distmat.shape[1]
cdef long num_q = q_feats.shape[0]
cdef long num_g = g_feats.shape[0]
cdef long dim = q_feats.shape[1]
cdef long[:,:] indices
cdef index = faiss.IndexFlatL2(dim)
index.add(np.asarray(g_feats))
if use_distmat:
indices = np.argsort(distmat, axis=1)
else:
indices = index.search(np.asarray(q_feats), k=num_g)[1]
if num_g < max_rank:
max_rank = num_g
@ -54,6 +41,7 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats,
cdef:
long num_repeats = 10
long[:,:] indices = np.argsort(distmat, axis=1)
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
@ -160,27 +148,18 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats,
return np.asarray(avg_cmc).astype(np.float32), mAP
cpdef eval_market1501_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat):
cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank):
cdef long num_q = q_feats.shape[0]
cdef long num_g = g_feats.shape[0]
cdef long dim = q_feats.shape[1]
cdef long[:,:] indices
cdef index = faiss.IndexFlatL2(dim)
index.add(np.asarray(g_feats))
if use_distmat:
indices = np.argsort(distmat, axis=1)
else:
indices = index.search(np.asarray(q_feats), k=num_g)[1]
cdef long num_q = distmat.shape[0]
cdef long num_g = distmat.shape[1]
if num_g < max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
cdef:
long[:,:] indices = np.argsort(distmat, axis=1)
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)