add multiprocess
parent
deb9ea86e9
commit
e8c3f95c15
|
@ -0,0 +1,161 @@
|
|||
from __future__ import print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
import multiprocessing
|
||||
|
||||
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
|
||||
"""Evaluation with cuhk03 metric
|
||||
Key: one image for each gallery identity is randomly sampled for each query identity.
|
||||
Random sampling is performed N times (default: N=100).
|
||||
"""
|
||||
num_q, num_g = distmat.shape
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(orig_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
kept_g_pids = g_pids[order][keep]
|
||||
g_pids_dict = defaultdict(list)
|
||||
for idx, pid in enumerate(kept_g_pids):
|
||||
g_pids_dict[pid].append(idx)
|
||||
|
||||
cmc, AP = 0., 0.
|
||||
for repeat_idx in range(N):
|
||||
mask = np.zeros(len(orig_cmc), dtype=np.bool)
|
||||
for _, idxs in g_pids_dict.items():
|
||||
# randomly sample one image for each gallery person
|
||||
rnd_idx = np.random.choice(idxs)
|
||||
mask[rnd_idx] = True
|
||||
masked_orig_cmc = orig_cmc[mask]
|
||||
_cmc = masked_orig_cmc.cumsum()
|
||||
_cmc[_cmc > 1] = 1
|
||||
cmc += _cmc[:max_rank].astype(np.float32)
|
||||
# compute AP
|
||||
num_rel = masked_orig_cmc.sum()
|
||||
tmp_cmc = masked_orig_cmc.cumsum()
|
||||
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * masked_orig_cmc
|
||||
AP += tmp_cmc.sum() / num_rel
|
||||
cmc /= N
|
||||
AP /= N
|
||||
all_cmc.append(cmc)
|
||||
all_AP.append(AP)
|
||||
num_valid_q += 1.
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
|
||||
return all_cmc, mAP
|
||||
|
||||
def compute_cmc_ap(q_idx, q_pids, g_pids, q_camids, g_camids, indices, matches, max_rank, result_dict):
|
||||
"""
|
||||
Applicable to eval_market1501()
|
||||
"""
|
||||
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(orig_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
result_dict[str(q_idx) + '_valid'] = False
|
||||
return
|
||||
|
||||
result_dict[str(q_idx) + '_valid'] = True
|
||||
|
||||
cmc = orig_cmc.cumsum()
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
result_dict[str(q_idx) + '_cmc'] = cmc[:max_rank]
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = orig_cmc.sum()
|
||||
tmp_cmc = orig_cmc.cumsum()
|
||||
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
result_dict[str(q_idx) + '_ap'] = AP
|
||||
|
||||
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
"""Evaluation with market1501 metric
|
||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||
Multiprocessing is supported.
|
||||
"""
|
||||
num_q, num_g = distmat.shape
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# do multiprocessing
|
||||
manager = multiprocessing.Manager()
|
||||
result_dict = manager.dict()
|
||||
|
||||
jobs = []
|
||||
for q_idx in range(num_q):
|
||||
p = multiprocessing.Process(
|
||||
target=compute_cmc_ap,
|
||||
args=(q_idx, q_pids, g_pids, q_camids, g_camids, indices, matches, max_rank, result_dict),
|
||||
)
|
||||
jobs.append(p)
|
||||
p.start()
|
||||
for proc in jobs:
|
||||
proc.join()
|
||||
|
||||
num_valid_q = 0
|
||||
all_cmc, all_AP = [], []
|
||||
for q_idx in range(num_q):
|
||||
if result_dict[str(q_idx) + '_valid']:
|
||||
num_valid_q += 1
|
||||
all_cmc.append(result_dict[str(q_idx) + '_cmc'])
|
||||
all_AP.append(result_dict[str(q_idx) + '_ap'])
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
|
||||
return all_cmc, mAP
|
||||
|
||||
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False):
|
||||
if use_metric_cuhk03:
|
||||
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
else:
|
||||
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
Loading…
Reference in New Issue