diff --git a/eval_metrics.py b/eval_metrics.py index 6d2f528..465f5e6 100644 --- a/eval_metrics.py +++ b/eval_metrics.py @@ -2,12 +2,15 @@ from __future__ import absolute_import import numpy as np import copy -def evaluate(distmat, q_pids, g_pids, q_camids, g_camids): - num_q = distmat.shape[0] +def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): + num_q, num_g = distmat.shape + if num_g < max_rank: + max_rank = num_g + print("Note: number of gallery samples is quite small, got {}".format(num_g)) indices = np.argsort(distmat, axis=1) matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) - # compute cmc curve for each query, maximum rank is fixed to _MAX_RANK + # compute cmc curve for each query all_cmc = [] all_AP = [] num_valid_q = 0. @@ -30,7 +33,7 @@ def evaluate(distmat, q_pids, g_pids, q_camids, g_camids): cmc = cmc.cumsum() cmc[cmc > 1] = 1 - all_cmc.append(cmc) + all_cmc.append(cmc[:max_rank]) num_valid_q += 1. num_rel = cmc.sum()