diff --git a/eval_metrics_multiprocessing.py b/eval_metrics_multiprocessing.py new file mode 100755 index 0000000..cf4beb6 --- /dev/null +++ b/eval_metrics_multiprocessing.py @@ -0,0 +1,161 @@ +from __future__ import print_function, absolute_import + +import numpy as np +import copy +from collections import defaultdict +import sys +import multiprocessing + +def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100): + """Evaluation with cuhk03 metric + Key: one image for each gallery identity is randomly sampled for each query identity. + Random sampling is performed N times (default: N=100). + """ + num_q, num_g = distmat.shape + if num_g < max_rank: + max_rank = num_g + print("Note: number of gallery samples is quite small, got {}".format(num_g)) + indices = np.argsort(distmat, axis=1) + matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) + + # compute cmc curve for each query + all_cmc = [] + all_AP = [] + num_valid_q = 0. # number of valid query + for q_idx in range(num_q): + # get query pid and camid + q_pid = q_pids[q_idx] + q_camid = q_camids[q_idx] + + # remove gallery samples that have the same pid and camid with query + order = indices[q_idx] + remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) + keep = np.invert(remove) + + # compute cmc curve + orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches + if not np.any(orig_cmc): + # this condition is true when query identity does not appear in gallery + continue + + kept_g_pids = g_pids[order][keep] + g_pids_dict = defaultdict(list) + for idx, pid in enumerate(kept_g_pids): + g_pids_dict[pid].append(idx) + + cmc, AP = 0., 0. + for repeat_idx in range(N): + mask = np.zeros(len(orig_cmc), dtype=np.bool) + for _, idxs in g_pids_dict.items(): + # randomly sample one image for each gallery person + rnd_idx = np.random.choice(idxs) + mask[rnd_idx] = True + masked_orig_cmc = orig_cmc[mask] + _cmc = masked_orig_cmc.cumsum() + _cmc[_cmc > 1] = 1 + cmc += _cmc[:max_rank].astype(np.float32) + # compute AP + num_rel = masked_orig_cmc.sum() + tmp_cmc = masked_orig_cmc.cumsum() + tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] + tmp_cmc = np.asarray(tmp_cmc) * masked_orig_cmc + AP += tmp_cmc.sum() / num_rel + cmc /= N + AP /= N + all_cmc.append(cmc) + all_AP.append(AP) + num_valid_q += 1. + + assert num_valid_q > 0, "Error: all query identities do not appear in gallery" + + all_cmc = np.asarray(all_cmc).astype(np.float32) + all_cmc = all_cmc.sum(0) / num_valid_q + mAP = np.mean(all_AP) + + return all_cmc, mAP + +def compute_cmc_ap(q_idx, q_pids, g_pids, q_camids, g_camids, indices, matches, max_rank, result_dict): + """ + Applicable to eval_market1501() + """ + + # get query pid and camid + q_pid = q_pids[q_idx] + q_camid = q_camids[q_idx] + + # remove gallery samples that have the same pid and camid with query + order = indices[q_idx] + remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) + keep = np.invert(remove) + + # compute cmc curve + orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches + if not np.any(orig_cmc): + # this condition is true when query identity does not appear in gallery + result_dict[str(q_idx) + '_valid'] = False + return + + result_dict[str(q_idx) + '_valid'] = True + + cmc = orig_cmc.cumsum() + cmc[cmc > 1] = 1 + + result_dict[str(q_idx) + '_cmc'] = cmc[:max_rank] + + # compute average precision + # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision + num_rel = orig_cmc.sum() + tmp_cmc = orig_cmc.cumsum() + tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] + tmp_cmc = np.asarray(tmp_cmc) * orig_cmc + AP = tmp_cmc.sum() / num_rel + result_dict[str(q_idx) + '_ap'] = AP + +def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): + """Evaluation with market1501 metric + Key: for each query identity, its gallery images from the same camera view are discarded. + Multiprocessing is supported. + """ + num_q, num_g = distmat.shape + if num_g < max_rank: + max_rank = num_g + print("Note: number of gallery samples is quite small, got {}".format(num_g)) + indices = np.argsort(distmat, axis=1) + matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) + + # do multiprocessing + manager = multiprocessing.Manager() + result_dict = manager.dict() + + jobs = [] + for q_idx in range(num_q): + p = multiprocessing.Process( + target=compute_cmc_ap, + args=(q_idx, q_pids, g_pids, q_camids, g_camids, indices, matches, max_rank, result_dict), + ) + jobs.append(p) + p.start() + for proc in jobs: + proc.join() + + num_valid_q = 0 + all_cmc, all_AP = [], [] + for q_idx in range(num_q): + if result_dict[str(q_idx) + '_valid']: + num_valid_q += 1 + all_cmc.append(result_dict[str(q_idx) + '_cmc']) + all_AP.append(result_dict[str(q_idx) + '_ap']) + + assert num_valid_q > 0, "Error: all query identities do not appear in gallery" + + all_cmc = np.asarray(all_cmc).astype(np.float32) + all_cmc = all_cmc.sum(0) / num_valid_q + mAP = np.mean(all_AP) + + return all_cmc, mAP + +def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False): + if use_metric_cuhk03: + return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) + else: + return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) \ No newline at end of file