deep-person-reid/torchreid/metrics/rank.py

208 lines
6.8 KiB
Python
Raw Normal View History

2019-12-01 10:35:44 +08:00
from __future__ import division, print_function, absolute_import
2018-03-12 05:17:48 +08:00
import numpy as np
2018-11-06 03:24:20 +08:00
import warnings
2019-12-01 10:35:44 +08:00
from collections import defaultdict
2018-06-04 17:27:07 +08:00
try:
2019-03-20 01:26:08 +08:00
from torchreid.metrics.rank_cylib.rank_cy import evaluate_cy
2018-11-11 05:09:13 +08:00
IS_CYTHON_AVAI = True
2018-06-04 17:27:07 +08:00
except ImportError:
2018-11-11 05:09:13 +08:00
IS_CYTHON_AVAI = False
2019-03-20 01:26:08 +08:00
warnings.warn(
2019-03-21 20:53:21 +08:00
'Cython evaluation (very fast so highly recommended) is '
'unavailable, now use python evaluation.'
2019-03-20 01:26:08 +08:00
)
2018-03-12 05:17:48 +08:00
2018-07-02 17:17:14 +08:00
2018-11-11 05:25:40 +08:00
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
2018-04-23 19:55:47 +08:00
"""Evaluation with cuhk03 metric
Key: one image for each gallery identity is randomly sampled for each query identity.
2018-11-11 05:25:40 +08:00
Random sampling is performed num_repeats times.
2018-04-23 19:55:47 +08:00
"""
2018-11-11 05:25:40 +08:00
num_repeats = 10
2018-04-23 19:55:47 +08:00
num_q, num_g = distmat.shape
2019-12-01 10:35:44 +08:00
2018-04-23 19:55:47 +08:00
if num_g < max_rank:
max_rank = num_g
2019-12-01 10:35:44 +08:00
print(
'Note: number of gallery samples is quite small, got {}'.
format(num_g)
)
2018-04-23 19:55:47 +08:00
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0. # number of valid query
2019-12-01 10:35:44 +08:00
2018-04-23 19:55:47 +08:00
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
# compute cmc curve
2019-12-01 10:35:44 +08:00
raw_cmc = matches[q_idx][
keep] # binary vector, positions with value 1 are correct matches
2018-11-11 05:09:13 +08:00
if not np.any(raw_cmc):
2018-04-23 19:55:47 +08:00
# this condition is true when query identity does not appear in gallery
continue
kept_g_pids = g_pids[order][keep]
g_pids_dict = defaultdict(list)
for idx, pid in enumerate(kept_g_pids):
g_pids_dict[pid].append(idx)
cmc = 0.
2018-11-11 05:25:40 +08:00
for repeat_idx in range(num_repeats):
2018-11-11 05:09:13 +08:00
mask = np.zeros(len(raw_cmc), dtype=np.bool)
2018-04-23 19:55:47 +08:00
for _, idxs in g_pids_dict.items():
# randomly sample one image for each gallery person
rnd_idx = np.random.choice(idxs)
mask[rnd_idx] = True
2018-11-11 05:09:13 +08:00
masked_raw_cmc = raw_cmc[mask]
_cmc = masked_raw_cmc.cumsum()
2018-04-23 19:55:47 +08:00
_cmc[_cmc > 1] = 1
cmc += _cmc[:max_rank].astype(np.float32)
2019-12-01 10:35:44 +08:00
2018-11-11 05:25:40 +08:00
cmc /= num_repeats
2018-04-23 19:55:47 +08:00
all_cmc.append(cmc)
# compute AP
num_rel = raw_cmc.sum()
tmp_cmc = raw_cmc.cumsum()
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
AP = tmp_cmc.sum() / num_rel
2018-04-23 19:55:47 +08:00
all_AP.append(AP)
num_valid_q += 1.
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
2018-04-23 19:55:47 +08:00
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
return all_cmc, mAP
2018-04-23 04:18:01 +08:00
2018-07-02 17:17:14 +08:00
2018-04-23 19:55:47 +08:00
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
2018-04-23 17:27:51 +08:00
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
2018-03-12 05:47:45 +08:00
num_q, num_g = distmat.shape
2019-12-01 10:35:44 +08:00
2018-03-12 05:47:45 +08:00
if num_g < max_rank:
max_rank = num_g
2019-12-01 10:35:44 +08:00
print(
'Note: number of gallery samples is quite small, got {}'.
format(num_g)
)
2018-03-12 05:17:48 +08:00
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
2018-03-12 05:47:45 +08:00
# compute cmc curve for each query
2018-03-12 05:17:48 +08:00
all_cmc = []
all_AP = []
2018-04-23 04:18:01 +08:00
num_valid_q = 0. # number of valid query
2019-12-01 10:35:44 +08:00
2018-03-12 05:17:48 +08:00
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
2018-03-12 05:35:49 +08:00
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
2018-03-12 05:17:48 +08:00
keep = np.invert(remove)
# compute cmc curve
2019-12-01 10:35:44 +08:00
raw_cmc = matches[q_idx][
keep] # binary vector, positions with value 1 are correct matches
2018-11-11 05:09:13 +08:00
if not np.any(raw_cmc):
2018-03-12 05:17:48 +08:00
# this condition is true when query identity does not appear in gallery
continue
2018-11-11 05:09:13 +08:00
cmc = raw_cmc.cumsum()
2018-03-12 05:17:48 +08:00
cmc[cmc > 1] = 1
2018-03-12 05:47:45 +08:00
all_cmc.append(cmc[:max_rank])
2018-03-12 05:17:48 +08:00
num_valid_q += 1.
2018-03-12 06:08:09 +08:00
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
2018-11-11 05:09:13 +08:00
num_rel = raw_cmc.sum()
tmp_cmc = raw_cmc.cumsum()
2018-03-12 05:22:14 +08:00
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
2018-11-11 05:09:13 +08:00
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
AP = tmp_cmc.sum() / num_rel
2018-03-12 05:17:48 +08:00
all_AP.append(AP)
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
2018-03-12 05:17:48 +08:00
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
return all_cmc, mAP
2018-07-02 17:17:14 +08:00
2019-12-01 10:35:44 +08:00
def evaluate_py(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03
):
2018-11-11 05:25:40 +08:00
if use_metric_cuhk03:
2019-12-01 10:35:44 +08:00
return eval_cuhk03(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
)
2018-11-11 05:25:40 +08:00
else:
2019-12-01 10:35:44 +08:00
return eval_market1501(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
)
def evaluate_rank(
distmat,
q_pids,
g_pids,
q_camids,
g_camids,
max_rank=50,
use_metric_cuhk03=False,
use_cython=True
):
"""Evaluates CMC rank.
Args:
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
q_pids (numpy.ndarray): 1-D array containing person identities
of each query instance.
g_pids (numpy.ndarray): 1-D array containing person identities
of each gallery instance.
q_camids (numpy.ndarray): 1-D array containing camera views under
which each query instance is captured.
g_camids (numpy.ndarray): 1-D array containing camera views under
which each gallery instance is captured.
max_rank (int, optional): maximum CMC rank to be computed. Default is 50.
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
Default is False. This should be enabled when using cuhk03 classic split.
use_cython (bool, optional): use cython code for evaluation. Default is True.
This is highly recommended as the cython code can speed up the cmc computation
by more than 10x. This requires Cython to be installed.
"""
2018-11-11 05:09:13 +08:00
if use_cython and IS_CYTHON_AVAI:
2019-12-01 10:35:44 +08:00
return evaluate_cy(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
use_metric_cuhk03
)
2018-04-23 04:18:01 +08:00
else:
2019-12-01 10:35:44 +08:00
return evaluate_py(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
use_metric_cuhk03
)