support faiss retrieval and cython roc evaluation

pull/228/head
liaoxingyu 2020-08-12 16:27:57 +08:00
parent f74cebcd88
commit ae7c9288cf
12 changed files with 344 additions and 94 deletions

View File

@ -2,10 +2,10 @@
## Prepare pretrained model
If you use origin ResNet, you do not need to do anything. But if you want to use ResNet-ibn or ResNeSt, you need to download pretrain model in [here](https://github.com/XingangPan/IBN-Net).
And then you need to put it in `~/.cache/torch/checkpoints` or anywhere you like.
If you use backbones supported by fastreid, you do not need to do anything. It will automatically download the pre-train models.
But if your network is not connected, you can download pre-train models manually and put it in `~/.cache/torch/checkpoints`.
Then you should set the pretrain model path in `configs/Base-bagtricks.yml`.
If you want to use other pre-train models, such as MoCo pre-train, you can download by yourself and set the pre-train model path in `configs/Base-bagtricks.yml`.
## Compile with cython to accelerate evalution
@ -29,13 +29,13 @@ The configs are made for 1-GPU training.
If you want to train model with 4 GPUs, you can run:
```bash
./tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --num-gpus 4
python tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --num-gpus 4
```
To evaluate a model's performance, use
```bash
./tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \
python tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \
MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0"
```

View File

@ -12,3 +12,4 @@
- sklearn
- termcolor
- tabulate
- [faiss](https://github.com/facebookresearch/faiss) `pip install faiss-cpu`

View File

@ -238,6 +238,7 @@ _C.TEST.EVAL_PERIOD = 20
# Number of images per batch in one process.
_C.TEST.IMS_PER_BATCH = 64
_C.TEST.METRIC = "cosine"
_C.TEST.ROC_ENABLED = False
# Average query expansion
_C.TEST.AQE = CN()

View File

@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset
from .rank import evaluate_rank
from .roc import evaluate_roc
from .reid_evaluation import ReidEvaluator
from .testing import print_csv_format, verify_results

View File

@ -1,9 +1,11 @@
# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank.py
import numpy as np
import warnings
from collections import defaultdict
import faiss
import numpy as np
try:
from .rank_cylib.rank_cy import evaluate_cy
@ -11,18 +13,27 @@ try:
except ImportError:
IS_CYTHON_AVAI = False
warnings.warn(
'Cython evaluation (very fast so highly recommended) is '
'Cython rank evaluation (very fast so highly recommended) is '
'unavailable, now use python evaluation.'
)
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
def eval_cuhk03(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat):
"""Evaluation with cuhk03 metric
Key: one image for each gallery identity is randomly sampled for each query identity.
Random sampling is performed num_repeats times.
"""
num_repeats = 10
num_q, num_g = distmat.shape
dim = q_feats.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(g_feats)
if use_distmat:
indices = np.argsort(distmat, axis=1)
else:
_, indices = index.search(q_feats, k=num_g)
if num_g < max_rank:
max_rank = num_g
@ -31,7 +42,6 @@ def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
format(num_g)
)
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
# compute cmc curve for each query
@ -93,17 +103,24 @@ def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
return all_cmc, mAP
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
def eval_market1501(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat):
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
num_q, num_g = distmat.shape
dim = q_feats.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(g_feats)
if num_g < max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
indices = np.argsort(distmat, axis=1)
if use_distmat:
indices = np.argsort(distmat, axis=1)
else:
_, indices = index.search(q_feats, k=num_g)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
@ -159,31 +176,36 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
def evaluate_py(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03
distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03, use_distmat
):
if use_metric_cuhk03:
return eval_cuhk03(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
distmat, q_feats, g_feats, g_pids, q_camids, g_camids, max_rank, use_distmat
)
else:
return eval_market1501(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat
)
def evaluate_rank(
distmat,
q_feats,
g_feats,
q_pids,
g_pids,
q_camids,
g_camids,
max_rank=50,
use_metric_cuhk03=False,
use_distmat=False,
use_cython=True
):
"""Evaluates CMC rank.
Args:
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
q_feats (numpy.ndarray): 2-D array containing query features.
g_feats (numpy.ndarray): 2-D array containing gallery features.
q_pids (numpy.ndarray): 1-D array containing person identities
of each query instance.
g_pids (numpy.ndarray): 1-D array containing person identities
@ -201,11 +223,11 @@ def evaluate_rank(
"""
if use_cython and IS_CYTHON_AVAI:
return evaluate_cy(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
use_metric_cuhk03
distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
use_metric_cuhk03, use_distmat
)
else:
return evaluate_py(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
use_metric_cuhk03
distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
use_metric_cuhk03, use_distmat
)

View File

@ -1,6 +1,7 @@
all:
python setup.py build_ext --inplace
rm -rf build
python test_cython.py
clean:
rm -rf build
rm -f rank_cy.c *.so

View File

@ -5,6 +5,7 @@ import cython
import numpy as np
cimport numpy as np
from collections import defaultdict
import faiss
"""
@ -17,22 +18,35 @@ Credit to https://github.com/luzai
# Main interface
cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False):
cpdef evaluate_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False,
use_distmat=False):
distmat = np.asarray(distmat, dtype=np.float32)
q_feats = np.asarray(q_feats, dtype=np.float32)
g_feats = np.asarray(g_feats, dtype=np.float32)
q_pids = np.asarray(q_pids, dtype=np.int64)
g_pids = np.asarray(g_pids, dtype=np.int64)
q_camids = np.asarray(q_camids, dtype=np.int64)
g_camids = np.asarray(g_camids, dtype=np.int64)
if use_metric_cuhk03:
return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
return eval_cuhk03_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
return eval_market1501_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank):
cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat):
cdef long num_q = distmat.shape[0]
cdef long num_g = distmat.shape[1]
cdef long num_q = q_feats.shape[0]
cdef long num_g = g_feats.shape[0]
cdef long dim = q_feats.shape[1]
cdef long[:,:] indices
cdef index = faiss.IndexFlatL2(dim)
index.add(np.asarray(g_feats))
if use_distmat:
indices = np.argsort(distmat, axis=1)
else:
indices = index.search(np.asarray(q_feats), k=num_g)[1]
if num_g < max_rank:
max_rank = num_g
@ -40,7 +54,6 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
cdef:
long num_repeats = 10
long[:,:] indices = np.argsort(distmat, axis=1)
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
@ -147,24 +160,34 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
return np.asarray(avg_cmc).astype(np.float32), mAP
cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank):
cpdef eval_market1501_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat):
cdef long num_q = distmat.shape[0]
cdef long num_g = distmat.shape[1]
cdef long num_q = q_feats.shape[0]
cdef long num_g = g_feats.shape[0]
cdef long dim = q_feats.shape[1]
cdef long[:,:] indices
cdef index = faiss.IndexFlatL2(dim)
index.add(np.asarray(g_feats))
if use_distmat:
indices = np.argsort(distmat, axis=1)
else:
indices = index.search(np.asarray(q_feats), k=num_g)[1]
if num_g < max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
cdef:
long[:,:] indices = np.argsort(distmat, axis=1)
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
float[:] all_INP = np.zeros(num_q, dtype=np.float32)
float num_valid_q = 0. # number of valid query
long valid_index = 0
long q_idx, q_pid, q_camid, g_idx
long[:] order = np.zeros(num_g, dtype=np.int64)
@ -181,7 +204,6 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
float tmp_cmc_sum
long valid_index = 0
for q_idx in range(num_q):
# get query pid and camid
@ -234,7 +256,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
num_rel += raw_cmc[g_idx]
all_AP[valid_index] = tmp_cmc_sum / num_rel
valid_index+=1
valid_index += 1
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
@ -245,7 +267,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
avg_cmc[rank_idx] /= num_valid_q
return np.asarray(avg_cmc).astype(np.float32), all_AP[:valid_index], all_INP[:valid_index]
return np.asarray(avg_cmc).astype(np.float32), np.asarray(all_AP[:valid_index]), np.asarray(all_INP[:valid_index])
# Compute the cumulative sum

View File

@ -0,0 +1,96 @@
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank_cylib/rank_cy.pyx
import cython
import faiss
import numpy as np
cimport numpy as np
"""
Compiler directives:
https://github.com/cython/cython/wiki/enhancements-compilerdirectives
Cython tutorial:
https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html
Credit to https://github.com/luzai
"""
# Main interface
cpdef evaluate_roc_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids):
distmat = np.asarray(distmat, dtype=np.float32)
q_feats = np.asarray(q_feats, dtype=np.float32)
g_feats = np.asarray(g_feats, dtype=np.float32)
q_pids = np.asarray(q_pids, dtype=np.int64)
g_pids = np.asarray(g_pids, dtype=np.int64)
q_camids = np.asarray(q_camids, dtype=np.int64)
g_camids = np.asarray(g_camids, dtype=np.int64)
cdef long num_q = distmat.shape[0]
cdef long num_g = distmat.shape[1]
cdef long dim = q_feats.shape[1]
cdef long[:,:] indices
cdef index = faiss.IndexFlatL2(dim)
index.add(np.asarray(g_feats))
indices = index.search(np.asarray(q_feats), k=num_g)[1]
cdef:
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
float[:] pos = np.zeros(num_q*num_g, dtype=np.float32)
float[:] neg = np.zeros(num_q*num_g, dtype=np.float32)
long valid_pos = 0
long valid_neg = 0
long ind
long q_idx, q_pid, q_camid, g_idx
long[:] order = np.zeros(num_g, dtype=np.int64)
float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
long[:] sort_idx = np.zeros(num_g, dtype=np.int64)
long idx
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
for g_idx in range(num_g):
order[g_idx] = indices[q_idx, g_idx]
num_g_real = 0
# remove gallery samples that have the same pid and camid with query
for g_idx in range(num_g):
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
raw_cmc[num_g_real] = matches[q_idx][g_idx]
sort_idx[num_g_real] = order[g_idx]
num_g_real += 1
q_dist = distmat[q_idx]
for valid_idx in range(num_g_real):
if raw_cmc[valid_idx] == 1:
pos[valid_pos] = q_dist[sort_idx[valid_idx]]
valid_pos += 1
elif raw_cmc[valid_idx] == 0:
neg[valid_neg] = q_dist[sort_idx[valid_idx]]
valid_neg += 1
cdef float[:] scores = np.hstack((pos[:valid_pos], neg[:valid_neg]))
cdef float[:] labels = np.hstack((np.zeros(valid_pos, dtype=np.float32),
np.ones(valid_neg, dtype=np.float32)))
return np.asarray(scores), np.asarray(labels)
# Compute the cumulative sum
cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n):
cdef long i
dst[0] = src[0]
for i in range(1, n):
dst[i] = src[i] + dst[i - 1]

View File

@ -18,6 +18,11 @@ ext_modules = [
'rank_cy',
['rank_cy.pyx'],
include_dirs=[numpy_include()],
),
Extension(
'roc_cy',
['roc_cy.pyx'],
include_dirs=[numpy_include()],
)
]

View File

@ -1,11 +1,13 @@
import sys
import numpy as np
import timeit
import numpy as np
import os.path as osp
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
from fastreid.evaluation import evaluate_rank
from fastreid.evaluation import evaluate_roc
"""
Test the speed of cython-based evaluation code. The speed improvements
can be much bigger when using the real reid data, which contains a larger
@ -22,58 +24,97 @@ import sys
import os.path as osp
import numpy as np
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
from fastreid import evaluation
from fastreid.evaluation import evaluate_rank
from fastreid.evaluation import evaluate_roc
num_q = 30
num_g = 300
dim = 512
max_rank = 5
distmat = np.random.rand(num_q, num_g) * 20
q_feats = np.random.rand(num_q, dim).astype(np.float32) * 20
q_feats = q_feats / np.linalg.norm(q_feats, ord=2, axis=1, keepdims=True)
g_feats = np.random.rand(num_g, dim).astype(np.float32) * 20
g_feats = g_feats / np.linalg.norm(g_feats, ord=2, axis=1, keepdims=True)
distmat = 1 - np.dot(q_feats, g_feats.transpose())
q_pids = np.random.randint(0, num_q, size=num_q)
g_pids = np.random.randint(0, num_g, size=num_g)
q_camids = np.random.randint(0, 5, size=num_q)
g_camids = np.random.randint(0, 5, size=num_g)
'''
# print('=> Using market1501\'s metric')
# pytime = timeit.timeit(
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)',
# setup=setup,
# number=20
# )
# cytime = timeit.timeit(
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)',
# setup=setup,
# number=20
# )
# print('Python time: {} s'.format(pytime))
# print('Cython time: {} s'.format(cytime))
# print('Cython is {} times faster than python\n'.format(pytime / cytime))
#
# print('=> Using cuhk03\'s metric')
# pytime = timeit.timeit(
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)',
# setup=setup,
# number=20
# )
# cytime = timeit.timeit(
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)',
# setup=setup,
# number=20
# )
# print('Python time: {} s'.format(pytime))
# print('Cython time: {} s'.format(cytime))
# print('Cython is {} times faster than python\n'.format(pytime / cytime))
print('=> Using CMC metric')
pytime = timeit.timeit(
'evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat=True, use_cython=False)',
setup=setup,
number=20
)
cytime = timeit.timeit(
'evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat=True, use_cython=True)',
setup=setup,
number=20
)
print('Python time: {} s'.format(pytime))
print('Cython time: {} s'.format(cytime))
print('CMC Cython is {} times faster than python\n'.format(pytime / cytime))
print('=> Using ROC metric')
pytime = timeit.timeit(
'evaluate_roc(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, use_cython=False)',
setup=setup,
number=20
)
cytime = timeit.timeit(
'evaluate_roc(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, use_cython=True)',
setup=setup,
number=20
)
print('Python time: {} s'.format(pytime))
print('Cython time: {} s'.format(cytime))
print('ROC Cython is {} times faster than python\n'.format(pytime / cytime))
from fastreid.evaluation import evaluate_rank
print("=> Check precision")
num_q = 30
num_g = 300
dim = 512
max_rank = 5
distmat = np.random.rand(num_q, num_g) * 20
q_feats = np.random.rand(num_q, dim).astype(np.float32) * 20
q_feats = q_feats / np.linalg.norm(q_feats, ord=2, axis=1, keepdims=True)
g_feats = np.random.rand(num_g, dim).astype(np.float32) * 20
g_feats = g_feats / np.linalg.norm(g_feats, ord=2, axis=1, keepdims=True)
distmat = 1 - np.dot(q_feats, g_feats.transpose())
q_pids = np.random.randint(0, num_q, size=num_q)
g_pids = np.random.randint(0, num_g, size=num_g)
q_camids = np.random.randint(0, 5, size=num_q)
g_camids = np.random.randint(0, 5, size=num_g)
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
print("Python:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
print("Cython:\nmAP = {} \ncmc = {}\nmINP = {}".format(np.array(mAP), cmc, np.array(mINP)))
cmc_py_d, mAP_py_d, mINP_py_d = evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
use_distmat=True, use_cython=False)
cmc_py, mAP_py, mINP_py = evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
use_distmat=False, use_cython=False)
np.testing.assert_allclose(cmc_py_d, cmc_py, rtol=1e-3, atol=1e-6)
np.testing.assert_allclose(mAP_py_d, mAP_py, rtol=1e-3, atol=1e-6)
np.testing.assert_allclose(mINP_py_d, mINP_py, rtol=1e-3, atol=1e-6)
print('Results between distmat and features are the same in python!')
cmc_cy_d, mAP_cy_d, mINP_cy_d = evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
use_distmat=True, use_cython=True)
cmc_cy, mAP_cy, mINP_cy = evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
use_distmat=False, use_cython=True)
np.testing.assert_allclose(cmc_cy_d, cmc_cy, rtol=1e-3, atol=1e-6)
np.testing.assert_allclose(mAP_cy_d, mAP_cy, rtol=1e-3, atol=1e-6)
np.testing.assert_allclose(mINP_cy_d, mINP_cy, rtol=1e-3, atol=1e-6)
print('Results between distmat and features are the same in cython!')
np.testing.assert_allclose(cmc_py, cmc_cy, rtol=1e-3, atol=1e-6)
np.testing.assert_allclose(mAP_py, mAP_cy, rtol=1e-3, atol=1e-6)
np.testing.assert_allclose(mINP_py, mINP_cy, rtol=1e-3, atol=1e-6)
print('Rank results between python and cython are the same!')
scores_cy, labels_cy = evaluate_roc(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids)
scores_py, labels_py = evaluate_roc(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids,
use_cython=False)
np.testing.assert_allclose(scores_cy, scores_py, rtol=1e-3, atol=1e-6)
np.testing.assert_allclose(labels_cy, labels_py, rtol=1e-3, atol=1e-6)
print('ROC results between python and cython are the same!\n')
print("=> Check exact values")
print("mAP = {} \ncmc = {}\nmINP = {}\nScores = {}".format(np.array(mAP_cy), cmc_cy, np.array(mINP_cy), scores_cy))

View File

@ -6,6 +6,7 @@
import copy
import logging
from collections import OrderedDict
from sklearn import metrics
import numpy as np
import torch
@ -45,8 +46,6 @@ class ReidEvaluator(DatasetEvaluator):
def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
assert metric in ["cosine", "euclidean"], "must choose from [cosine, euclidean], but got {}".format(metric)
if metric == "cosine":
query_feat = F.normalize(query_feat, dim=1)
gallery_feat = F.normalize(gallery_feat, dim=1)
dist = 1 - torch.mm(query_feat, gallery_feat.t())
else:
m, n = query_feat.size(0), gallery_feat.size(0)
@ -96,6 +95,10 @@ class ReidEvaluator(DatasetEvaluator):
alpha = self.cfg.TEST.AQE.ALPHA
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
if self.cfg.TEST.METRIC == "cosine":
query_features = F.normalize(query_features, dim=1)
gallery_features = F.normalize(gallery_features, dim=1)
dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, gallery_features)
if self.cfg.TEST.RERANK.ENABLED:
@ -105,9 +108,18 @@ class ReidEvaluator(DatasetEvaluator):
lambda_value = self.cfg.TEST.RERANK.LAMBDA
q_q_dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, query_features)
g_g_dist = self.cal_dist(self.cfg.TEST.METRIC, gallery_features, gallery_features)
dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value)
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
re_dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value)
query_features = query_features.numpy()
gallery_features = gallery_features.numpy()
cmc, all_AP, all_INP = evaluate_rank(re_dist, query_features, gallery_features,
query_pids, gallery_pids, query_camids,
gallery_camids, use_distmat=True)
else:
query_features = query_features.numpy()
gallery_features = gallery_features.numpy()
cmc, all_AP, all_INP = evaluate_rank(dist, query_features, gallery_features,
query_pids, gallery_pids, query_camids, gallery_camids,
use_distmat=False)
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)
for r in [1, 5, 10]:
@ -115,9 +127,13 @@ class ReidEvaluator(DatasetEvaluator):
self._results['mAP'] = mAP
self._results['mINP'] = mINP
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
fprs = [1e-4, 1e-3, 1e-2]
for i in range(len(fprs)):
self._results["TPR@FPR={:.0e}".format(fprs[i])] = tprs[i]
if self.cfg.TEST.ROC_ENABLED:
scores, labels = evaluate_roc(dist, query_features, gallery_features,
query_pids, gallery_pids, query_camids, gallery_camids)
fprs, tprs, thres = metrics.roc_curve(labels, scores)
for fpr in [1e-4, 1e-3, 1e-2]:
ind = np.argmin(np.abs(fprs - fpr))
self._results["TPR@FPR={:.0e}".format(fpr)] = tprs[ind]
return copy.deepcopy(self._results)

View File

@ -4,11 +4,24 @@
@contact: sherlockliao01@gmail.com
"""
import warnings
import faiss
import numpy as np
from sklearn import metrics
try:
from .rank_cylib.roc_cy import evaluate_roc_cy
IS_CYTHON_AVAI = True
except ImportError:
IS_CYTHON_AVAI = False
warnings.warn(
'Cython roc evaluation (very fast so highly recommended) is '
'unavailable, now use python evaluation.'
)
def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
def evaluate_roc_py(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids):
r"""Evaluation with ROC curve.
Key: for each query identity, its gallery images from the same camera view are discarded.
@ -16,8 +29,12 @@ def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
distmat (np.ndarray): cosine distance matrix
"""
num_q, num_g = distmat.shape
dim = q_feats.shape[1]
indices = np.argsort(distmat, axis=1)
index = faiss.IndexFlatL2(dim)
index.add(g_feats)
_, indices = index.search(q_feats, k=num_g)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
pos = []
@ -31,22 +48,49 @@ def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
cmc = matches[q_idx][keep]
raw_cmc = matches[q_idx][keep]
sort_idx = order[keep]
q_dist = distmat[q_idx]
ind_pos = np.where(cmc == 1)[0]
ind_pos = np.where(raw_cmc == 1)[0]
pos.extend(q_dist[sort_idx[ind_pos]])
ind_neg = np.where(cmc == 0)[0]
ind_neg = np.where(raw_cmc == 0)[0]
neg.extend(q_dist[sort_idx[ind_neg]])
scores = np.hstack((pos, neg))
labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg))))
fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
tprs = []
for i in [1e-4, 1e-3, 1e-2]:
ind = np.argmin(np.abs(fpr-i))
tprs.append(tpr[ind])
return tprs
return scores, labels
def evaluate_roc(
distmat,
q_feats,
g_feats,
q_pids,
g_pids,
q_camids,
g_camids,
use_cython=True
):
"""Evaluates CMC rank.
Args:
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
q_pids (numpy.ndarray): 1-D array containing person identities
of each query instance.
g_pids (numpy.ndarray): 1-D array containing person identities
of each gallery instance.
q_camids (numpy.ndarray): 1-D array containing camera views under
which each query instance is captured.
g_camids (numpy.ndarray): 1-D array containing camera views under
which each gallery instance is captured.
use_cython (bool, optional): use cython code for evaluation. Default is True.
This is highly recommended as the cython code can speed up the cmc computation
by more than 10x. This requires Cython to be installed.
"""
if use_cython and IS_CYTHON_AVAI:
return evaluate_roc_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids)
else:
return evaluate_roc_py(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids)