update cython code: add cuhk03-metric
parent
1d0851e8c4
commit
d885f11ce7
|
@ -1,6 +1,7 @@
|
|||
all:
|
||||
python setup.py build_ext --inplace
|
||||
rm -rf build
|
||||
|
||||
clean:
|
||||
rm -rf build
|
||||
rm -f eval.c *.so
|
||||
rm -f eval_metrics_cy.c *.so
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,234 @@
|
|||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
import random
|
||||
|
||||
|
||||
# Main interface
|
||||
cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False):
|
||||
distmat = np.asarray(distmat, dtype=np.float32)
|
||||
q_pids = np.asarray(q_pids, dtype=np.int64)
|
||||
g_pids = np.asarray(g_pids, dtype=np.int64)
|
||||
q_camids = np.asarray(q_camids, dtype=np.int64)
|
||||
g_camids = np.asarray(g_camids, dtype=np.int64)
|
||||
if use_metric_cuhk03:
|
||||
return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
|
||||
|
||||
cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
||||
long[:]q_camids, long[:]g_camids, long max_rank):
|
||||
|
||||
cdef long num_q = distmat.shape[0]
|
||||
cdef long num_g = distmat.shape[1]
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
|
||||
cdef:
|
||||
long num_repeats = 10
|
||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
||||
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||
|
||||
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
||||
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
|
||||
float num_valid_q = 0. # number of valid query
|
||||
|
||||
long q_idx, q_pid, q_camid, g_idx
|
||||
long[:] order = np.zeros(num_g, dtype=np.int64)
|
||||
long keep
|
||||
|
||||
float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
|
||||
float[:] masked_raw_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
float[:] cmc, masked_cmc
|
||||
long num_g_real, num_g_real_masked, rank_i, rnd_idx
|
||||
unsigned long meet_condition
|
||||
float AP
|
||||
long[:] kept_g_pids, mask
|
||||
|
||||
float num_rel
|
||||
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
float tmp_cmc_sum
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
num_g_real = 0
|
||||
meet_condition = 0
|
||||
kept_g_pids = np.zeros(num_g, dtype=np.int64)
|
||||
|
||||
for g_idx in range(num_g):
|
||||
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
|
||||
raw_cmc[num_g_real] = matches[q_idx][g_idx]
|
||||
kept_g_pids[num_g_real] = g_pids[order[g_idx]]
|
||||
num_g_real += 1
|
||||
if matches[q_idx][g_idx] > 1e-31:
|
||||
meet_condition = 1
|
||||
|
||||
if not meet_condition:
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
# cuhk03-specific setting
|
||||
g_pids_dict = defaultdict(list)
|
||||
for g_idx in range(num_g_real):
|
||||
g_pids_dict[kept_g_pids[g_idx]].append(g_idx)
|
||||
|
||||
cmc = np.zeros(max_rank, dtype=np.float32)
|
||||
AP = 0.
|
||||
for _ in range(num_repeats):
|
||||
mask = np.zeros(num_g_real, dtype=np.int64)
|
||||
|
||||
for _, idxs in g_pids_dict.items():
|
||||
# randomly sample one image for each gallery person
|
||||
rnd_idx = np.random.choice(idxs)
|
||||
#rnd_idx = idxs[0] # use deterministic for debugging
|
||||
mask[rnd_idx] = 1
|
||||
|
||||
num_g_real_masked = 0
|
||||
for g_idx in range(num_g_real):
|
||||
if mask[g_idx] == 1:
|
||||
masked_raw_cmc[num_g_real_masked] = raw_cmc[g_idx]
|
||||
num_g_real_masked += 1
|
||||
|
||||
masked_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
function_cumsum(masked_raw_cmc, masked_cmc, num_g_real_masked)
|
||||
for g_idx in range(num_g_real_masked):
|
||||
if masked_cmc[g_idx] > 1:
|
||||
masked_cmc[g_idx] = 1
|
||||
|
||||
for rank_i in range(max_rank):
|
||||
cmc[rank_i] += masked_cmc[rank_i] / num_repeats
|
||||
|
||||
# compute AP
|
||||
function_cumsum(masked_raw_cmc, tmp_cmc, num_g_real_masked)
|
||||
num_rel = 0
|
||||
tmp_cmc_sum = 0
|
||||
for g_idx in range(num_g_real_masked):
|
||||
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * masked_raw_cmc[g_idx]
|
||||
num_rel += masked_raw_cmc[g_idx]
|
||||
AP += tmp_cmc_sum / num_rel
|
||||
|
||||
all_AP[q_idx] = AP / num_repeats
|
||||
all_cmc[q_idx] = cmc
|
||||
num_valid_q += 1.
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
|
||||
# compute averaged cmc
|
||||
cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32)
|
||||
for rank_i in range(max_rank):
|
||||
for q_idx in range(num_q):
|
||||
avg_cmc[rank_i] += all_cmc[q_idx, rank_i]
|
||||
avg_cmc[rank_i] /= num_valid_q
|
||||
|
||||
cdef float mAP = 0
|
||||
for q_idx in range(num_q):
|
||||
mAP += all_AP[q_idx]
|
||||
mAP /= num_valid_q
|
||||
|
||||
return np.asarray(avg_cmc).astype(np.float32), mAP
|
||||
|
||||
|
||||
cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
||||
long[:]q_camids, long[:]g_camids, long max_rank):
|
||||
|
||||
cdef long num_q = distmat.shape[0]
|
||||
cdef long num_g = distmat.shape[1]
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
|
||||
cdef:
|
||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
||||
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||
|
||||
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
||||
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
|
||||
float num_valid_q = 0. # number of valid query
|
||||
|
||||
long q_idx, q_pid, q_camid, g_idx
|
||||
long[:] order = np.zeros(num_g, dtype=np.int64)
|
||||
long keep
|
||||
|
||||
float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
|
||||
float[:] cmc = np.zeros(num_g, dtype=np.float32)
|
||||
long num_g_real
|
||||
unsigned long meet_condition
|
||||
|
||||
float num_rel
|
||||
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
float tmp_cmc_sum
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
num_g_real = 0
|
||||
meet_condition = 0
|
||||
|
||||
for g_idx in range(num_g):
|
||||
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
|
||||
raw_cmc[num_g_real] = matches[q_idx][g_idx]
|
||||
num_g_real += 1
|
||||
if matches[q_idx][g_idx] > 1e-31:
|
||||
meet_condition = 1
|
||||
|
||||
if not meet_condition:
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
# compute cmc
|
||||
function_cumsum(raw_cmc, cmc, num_g_real)
|
||||
for g_idx in range(num_g_real):
|
||||
if cmc[g_idx] > 1:
|
||||
cmc[g_idx] = 1
|
||||
|
||||
all_cmc[q_idx] = cmc[:max_rank]
|
||||
num_valid_q += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
function_cumsum(raw_cmc, tmp_cmc, num_g_real)
|
||||
num_rel = 0
|
||||
tmp_cmc_sum = 0
|
||||
for g_idx in range(num_g_real):
|
||||
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
|
||||
num_rel += raw_cmc[g_idx]
|
||||
all_AP[q_idx] = tmp_cmc_sum / num_rel
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
|
||||
# compute averaged cmc
|
||||
cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32)
|
||||
cdef long rank_i
|
||||
for rank_i in range(max_rank):
|
||||
for q_idx in range(num_q):
|
||||
avg_cmc[rank_i] += all_cmc[q_idx, rank_i]
|
||||
avg_cmc[rank_i] /= num_valid_q
|
||||
|
||||
cdef float mAP = 0
|
||||
for q_idx in range(num_q):
|
||||
mAP += all_AP[q_idx]
|
||||
mAP /= num_valid_q
|
||||
|
||||
return np.asarray(avg_cmc).astype(np.float32), mAP
|
||||
|
||||
|
||||
# Compute the cumulative sum
|
||||
cpdef void function_cumsum(float[:] src, float[:] dst, long n):
|
||||
cdef long i
|
||||
dst[0] = src[0]
|
||||
for i in range(1, n):
|
||||
dst[i] = src[i] + dst[i - 1]
|
|
@ -0,0 +1,15 @@
|
|||
from distutils.core import setup
|
||||
from distutils.extension import Extension
|
||||
from Cython.Build import cythonize
|
||||
|
||||
|
||||
ext_modules = [
|
||||
Extension('eval_metrics_cy',
|
||||
['eval_metrics_cy.pyx']
|
||||
)
|
||||
]
|
||||
|
||||
setup(
|
||||
name='Cython-based reid evaluation code',
|
||||
ext_modules=cythonize(ext_modules)
|
||||
)
|
|
@ -1,133 +0,0 @@
|
|||
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
|
||||
|
||||
cimport cython
|
||||
cimport numpy as np
|
||||
import numpy as np
|
||||
|
||||
cpdef eval_market1501_wrap(distmat,
|
||||
q_pids,
|
||||
g_pids,
|
||||
q_camids,
|
||||
g_camids,
|
||||
max_rank):
|
||||
distmat = np.asarray(distmat,dtype=np.float32)
|
||||
q_pids = np.asarray(q_pids, dtype=np.int64)
|
||||
g_pids = np.asarray(g_pids , dtype=np.int64)
|
||||
q_camids=np.asarray(q_camids,dtype=np.int64)
|
||||
g_camids=np.asarray(g_camids, dtype=np.int64)
|
||||
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
|
||||
cpdef eval_market1501(
|
||||
float[:,:] distmat,
|
||||
long[:] q_pids,
|
||||
long[:] g_pids,
|
||||
long[:] q_camids,
|
||||
long[:] g_camids,
|
||||
long max_rank,
|
||||
):
|
||||
# return 0,0
|
||||
cdef:
|
||||
long num_q = distmat.shape[0], num_g = distmat.shape[1]
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
|
||||
cdef:
|
||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
||||
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||
float[:,:] all_cmc = np.zeros((num_q,max_rank),dtype=np.float32)
|
||||
float[:] all_AP = np.zeros(num_q,dtype=np.float32)
|
||||
|
||||
long q_pid, q_camid
|
||||
long[:] order=np.zeros(num_g,dtype=np.int64), keep =np.zeros(num_g,dtype=np.int64)
|
||||
|
||||
long num_valid_q = 0, q_idx, idx
|
||||
# long[:] orig_cmc=np.zeros(num_g,dtype=np.int64)
|
||||
float[:] orig_cmc=np.zeros(num_g,dtype=np.float32)
|
||||
float[:] cmc=np.zeros(num_g,dtype=np.float32), tmp_cmc=np.zeros(num_g,dtype=np.float32)
|
||||
long num_orig_cmc=0
|
||||
float num_rel=0.
|
||||
float tmp_cmc_sum =0.
|
||||
# num_orig_cmc is the valid size of orig_cmc, cmc and tmp_cmc
|
||||
unsigned int orig_cmc_flag=0
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
for idx in range(num_g):
|
||||
keep[idx] = ( g_pids[order[idx]] !=q_pid) or (g_camids[order[idx]]!=q_camid )
|
||||
# compute cmc curve
|
||||
num_orig_cmc=0
|
||||
orig_cmc_flag=0
|
||||
for idx in range(num_g):
|
||||
if keep[idx]:
|
||||
orig_cmc[num_orig_cmc] = matches[q_idx][idx]
|
||||
num_orig_cmc +=1
|
||||
if matches[q_idx][idx]>1e-31:
|
||||
orig_cmc_flag=1
|
||||
if not orig_cmc_flag:
|
||||
all_AP[q_idx]=-1
|
||||
# print('continue ', q_idx)
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
my_cusum(orig_cmc,cmc,num_orig_cmc)
|
||||
for idx in range(num_orig_cmc):
|
||||
if cmc[idx] >1:
|
||||
cmc[idx] =1
|
||||
all_cmc[q_idx] = cmc[:max_rank]
|
||||
num_valid_q+=1
|
||||
|
||||
# print('ori cmc', np.asarray(orig_cmc).tolist())
|
||||
# print('cmc', np.asarray(cmc).tolist())
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = 0.
|
||||
for idx in range(num_orig_cmc):
|
||||
num_rel += orig_cmc[idx]
|
||||
my_cusum( orig_cmc, tmp_cmc, num_orig_cmc)
|
||||
for idx in range(num_orig_cmc):
|
||||
tmp_cmc[idx] = tmp_cmc[idx] / (idx+1.) * orig_cmc[idx]
|
||||
# print('tmp_cmc', np.asarray(tmp_cmc).tolist())
|
||||
|
||||
tmp_cmc_sum=my_sum(tmp_cmc,num_orig_cmc)
|
||||
all_AP[q_idx] = tmp_cmc_sum / num_rel
|
||||
# print('final',tmp_cmc_sum, num_rel, tmp_cmc_sum / num_rel,'\n')
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
# print_dbg('all ap', all_AP)
|
||||
# print_dbg('all cmc', all_cmc)
|
||||
all_AP_np = np.asarray(all_AP)
|
||||
all_AP_np[np.isclose(all_AP,-1)] = np.nan
|
||||
return np.asarray(all_cmc).astype(np.float32).sum(axis=0) / num_valid_q, \
|
||||
np.nanmean(all_AP_np)
|
||||
|
||||
def print_dbg(msg, val):
|
||||
print(msg, np.asarray(val))
|
||||
|
||||
cpdef void my_cusum(
|
||||
cython.numeric[:] src,
|
||||
cython.numeric[:] dst,
|
||||
long size
|
||||
) nogil:
|
||||
cdef:
|
||||
long idx
|
||||
for idx in range(size):
|
||||
if idx==0:
|
||||
dst[idx] = src[idx]
|
||||
else:
|
||||
dst[idx] = src[idx]+dst[idx-1]
|
||||
|
||||
cpdef cython.numeric my_sum(
|
||||
cython.numeric[:] src,
|
||||
long size
|
||||
) nogil:
|
||||
cdef:
|
||||
long idx
|
||||
cython.numeric ttl=0
|
||||
for idx in range(size):
|
||||
ttl+=src[idx]
|
||||
return ttl
|
|
@ -1,23 +0,0 @@
|
|||
import numpy as np
|
||||
from distutils.core import setup
|
||||
from distutils.extension import Extension
|
||||
from Cython.Distutils import build_ext
|
||||
|
||||
try:
|
||||
numpy_include = np.get_include()
|
||||
except AttributeError:
|
||||
numpy_include = np.get_numpy_include()
|
||||
print(numpy_include)
|
||||
|
||||
ext_modules = [Extension("cython_eval",
|
||||
["eval.pyx"],
|
||||
libraries=["m"],
|
||||
include_dirs=[numpy_include],
|
||||
extra_compile_args=["-ffast-math", "-Wno-cpp", "-Wno-unused-function"]
|
||||
),
|
||||
]
|
||||
|
||||
setup(
|
||||
name='eval_lib',
|
||||
cmdclass={"build_ext": build_ext},
|
||||
ext_modules=ext_modules)
|
|
@ -1,48 +0,0 @@
|
|||
from __future__ import absolute_import, print_function
|
||||
import sys, os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.dirname(os.path.abspath(__file__)) + '/..'
|
||||
)
|
||||
|
||||
try:
|
||||
from eval_lib.cython_eval import eval_market1501_wrap
|
||||
except ImportError:
|
||||
print("Error: eval.pyx not compiled, please do 'make' before running 'python test.py'. exit")
|
||||
sys.exit()
|
||||
|
||||
from eval_metrics import eval_market1501
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
num_q = 100
|
||||
num_g = 1000
|
||||
|
||||
distmat = np.random.rand(num_q, num_g) * 20
|
||||
q_pids = np.random.randint(0, num_q, size=num_q)
|
||||
g_pids = np.random.randint(0, num_g, size=num_g)
|
||||
q_camids = np.random.randint(0, 5, size=num_q)
|
||||
g_camids = np.random.randint(0, 5, size=num_g)
|
||||
|
||||
end = time.time()
|
||||
cmc, mAP = eval_market1501_wrap(distmat,
|
||||
q_pids,
|
||||
g_pids,
|
||||
q_camids,
|
||||
g_camids, 10)
|
||||
elapsed_cython = time.time() - end
|
||||
print("=> Cython evaluation")
|
||||
print("consume time {:.5f} \n mAP is {} \n cmc is {}".format(elapsed_cython, mAP, cmc))
|
||||
|
||||
end = time.time()
|
||||
cmc, mAP = eval_market1501(distmat,
|
||||
q_pids,
|
||||
g_pids,
|
||||
q_camids,
|
||||
g_camids, 10)
|
||||
elapsed_python = time.time() - end
|
||||
print("=> Python evaluation")
|
||||
print("consume time {:.5f} \n mAP is {} \n cmc is {}".format(elapsed_python, mAP, cmc))
|
||||
|
||||
xtimes = elapsed_python / elapsed_cython
|
||||
print("=> Conclusion: cython is {:.2f}x faster than python".format(xtimes))
|
|
@ -9,11 +9,11 @@ import sys
|
|||
import warnings
|
||||
|
||||
try:
|
||||
from torchreid.eval_lib.cython_eval import eval_market1501_wrap
|
||||
CYTHON_EVAL_AVAI = True
|
||||
from torchreid.eval_cylib.eval_metrics_cy import evaluate_cy
|
||||
IS_CYTHON_AVAI = True
|
||||
print("Using Cython evaluation code as the backend")
|
||||
except ImportError:
|
||||
CYTHON_EVAL_AVAI = False
|
||||
IS_CYTHON_AVAI = False
|
||||
warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended")
|
||||
|
||||
|
||||
|
@ -44,8 +44,8 @@ def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
|
|||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(orig_cmc):
|
||||
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
|
@ -56,20 +56,20 @@ def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
|
|||
|
||||
cmc, AP = 0., 0.
|
||||
for repeat_idx in range(N):
|
||||
mask = np.zeros(len(orig_cmc), dtype=np.bool)
|
||||
mask = np.zeros(len(raw_cmc), dtype=np.bool)
|
||||
for _, idxs in g_pids_dict.items():
|
||||
# randomly sample one image for each gallery person
|
||||
rnd_idx = np.random.choice(idxs)
|
||||
mask[rnd_idx] = True
|
||||
masked_orig_cmc = orig_cmc[mask]
|
||||
_cmc = masked_orig_cmc.cumsum()
|
||||
masked_raw_cmc = raw_cmc[mask]
|
||||
_cmc = masked_raw_cmc.cumsum()
|
||||
_cmc[_cmc > 1] = 1
|
||||
cmc += _cmc[:max_rank].astype(np.float32)
|
||||
# compute AP
|
||||
num_rel = masked_orig_cmc.sum()
|
||||
tmp_cmc = masked_orig_cmc.cumsum()
|
||||
num_rel = masked_raw_cmc.sum()
|
||||
tmp_cmc = masked_raw_cmc.cumsum()
|
||||
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * masked_orig_cmc
|
||||
tmp_cmc = np.asarray(tmp_cmc) * masked_raw_cmc
|
||||
AP += tmp_cmc.sum() / num_rel
|
||||
cmc /= N
|
||||
AP /= N
|
||||
|
@ -112,12 +112,12 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(orig_cmc):
|
||||
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = orig_cmc.cumsum()
|
||||
cmc = raw_cmc.cumsum()
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
|
@ -125,10 +125,10 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = orig_cmc.sum()
|
||||
tmp_cmc = orig_cmc.cumsum()
|
||||
num_rel = raw_cmc.sum()
|
||||
tmp_cmc = raw_cmc.cumsum()
|
||||
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
|
||||
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
|
||||
|
@ -142,10 +142,10 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||
|
||||
|
||||
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False, use_cython=True):
|
||||
if use_metric_cuhk03:
|
||||
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
if use_cython and IS_CYTHON_AVAI:
|
||||
return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03)
|
||||
else:
|
||||
if use_cython and CYTHON_EVAL_AVAI:
|
||||
return eval_market1501_wrap(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
if use_metric_cuhk03:
|
||||
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
else:
|
||||
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
Loading…
Reference in New Issue