mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
support faiss retrieval and cython roc evaluation
This commit is contained in:
parent
f74cebcd88
commit
ae7c9288cf
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
## Prepare pretrained model
|
## Prepare pretrained model
|
||||||
|
|
||||||
If you use origin ResNet, you do not need to do anything. But if you want to use ResNet-ibn or ResNeSt, you need to download pretrain model in [here](https://github.com/XingangPan/IBN-Net).
|
If you use backbones supported by fastreid, you do not need to do anything. It will automatically download the pre-train models.
|
||||||
And then you need to put it in `~/.cache/torch/checkpoints` or anywhere you like.
|
But if your network is not connected, you can download pre-train models manually and put it in `~/.cache/torch/checkpoints`.
|
||||||
|
|
||||||
Then you should set the pretrain model path in `configs/Base-bagtricks.yml`.
|
If you want to use other pre-train models, such as MoCo pre-train, you can download by yourself and set the pre-train model path in `configs/Base-bagtricks.yml`.
|
||||||
|
|
||||||
## Compile with cython to accelerate evalution
|
## Compile with cython to accelerate evalution
|
||||||
|
|
||||||
@ -29,13 +29,13 @@ The configs are made for 1-GPU training.
|
|||||||
If you want to train model with 4 GPUs, you can run:
|
If you want to train model with 4 GPUs, you can run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --num-gpus 4
|
python tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --num-gpus 4
|
||||||
```
|
```
|
||||||
|
|
||||||
To evaluate a model's performance, use
|
To evaluate a model's performance, use
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \
|
python tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \
|
||||||
MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0"
|
MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -12,3 +12,4 @@
|
|||||||
- sklearn
|
- sklearn
|
||||||
- termcolor
|
- termcolor
|
||||||
- tabulate
|
- tabulate
|
||||||
|
- [faiss](https://github.com/facebookresearch/faiss) `pip install faiss-cpu`
|
||||||
|
@ -238,6 +238,7 @@ _C.TEST.EVAL_PERIOD = 20
|
|||||||
# Number of images per batch in one process.
|
# Number of images per batch in one process.
|
||||||
_C.TEST.IMS_PER_BATCH = 64
|
_C.TEST.IMS_PER_BATCH = 64
|
||||||
_C.TEST.METRIC = "cosine"
|
_C.TEST.METRIC = "cosine"
|
||||||
|
_C.TEST.ROC_ENABLED = False
|
||||||
|
|
||||||
# Average query expansion
|
# Average query expansion
|
||||||
_C.TEST.AQE = CN()
|
_C.TEST.AQE = CN()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset
|
from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset
|
||||||
from .rank import evaluate_rank
|
from .rank import evaluate_rank
|
||||||
|
from .roc import evaluate_roc
|
||||||
from .reid_evaluation import ReidEvaluator
|
from .reid_evaluation import ReidEvaluator
|
||||||
from .testing import print_csv_format, verify_results
|
from .testing import print_csv_format, verify_results
|
||||||
|
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank.py
|
# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank.py
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .rank_cylib.rank_cy import evaluate_cy
|
from .rank_cylib.rank_cy import evaluate_cy
|
||||||
|
|
||||||
@ -11,18 +13,27 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
IS_CYTHON_AVAI = False
|
IS_CYTHON_AVAI = False
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'Cython evaluation (very fast so highly recommended) is '
|
'Cython rank evaluation (very fast so highly recommended) is '
|
||||||
'unavailable, now use python evaluation.'
|
'unavailable, now use python evaluation.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
def eval_cuhk03(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat):
|
||||||
"""Evaluation with cuhk03 metric
|
"""Evaluation with cuhk03 metric
|
||||||
Key: one image for each gallery identity is randomly sampled for each query identity.
|
Key: one image for each gallery identity is randomly sampled for each query identity.
|
||||||
Random sampling is performed num_repeats times.
|
Random sampling is performed num_repeats times.
|
||||||
"""
|
"""
|
||||||
num_repeats = 10
|
num_repeats = 10
|
||||||
|
|
||||||
num_q, num_g = distmat.shape
|
num_q, num_g = distmat.shape
|
||||||
|
dim = q_feats.shape[1]
|
||||||
|
|
||||||
|
index = faiss.IndexFlatL2(dim)
|
||||||
|
index.add(g_feats)
|
||||||
|
if use_distmat:
|
||||||
|
indices = np.argsort(distmat, axis=1)
|
||||||
|
else:
|
||||||
|
_, indices = index.search(q_feats, k=num_g)
|
||||||
|
|
||||||
if num_g < max_rank:
|
if num_g < max_rank:
|
||||||
max_rank = num_g
|
max_rank = num_g
|
||||||
@ -31,7 +42,6 @@ def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||||||
format(num_g)
|
format(num_g)
|
||||||
)
|
)
|
||||||
|
|
||||||
indices = np.argsort(distmat, axis=1)
|
|
||||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||||
|
|
||||||
# compute cmc curve for each query
|
# compute cmc curve for each query
|
||||||
@ -93,17 +103,24 @@ def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||||||
return all_cmc, mAP
|
return all_cmc, mAP
|
||||||
|
|
||||||
|
|
||||||
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
def eval_market1501(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat):
|
||||||
"""Evaluation with market1501 metric
|
"""Evaluation with market1501 metric
|
||||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||||
"""
|
"""
|
||||||
num_q, num_g = distmat.shape
|
num_q, num_g = distmat.shape
|
||||||
|
dim = q_feats.shape[1]
|
||||||
|
|
||||||
|
index = faiss.IndexFlatL2(dim)
|
||||||
|
index.add(g_feats)
|
||||||
|
|
||||||
if num_g < max_rank:
|
if num_g < max_rank:
|
||||||
max_rank = num_g
|
max_rank = num_g
|
||||||
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
||||||
|
|
||||||
indices = np.argsort(distmat, axis=1)
|
if use_distmat:
|
||||||
|
indices = np.argsort(distmat, axis=1)
|
||||||
|
else:
|
||||||
|
_, indices = index.search(q_feats, k=num_g)
|
||||||
|
|
||||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||||
|
|
||||||
@ -159,31 +176,36 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_py(
|
def evaluate_py(
|
||||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03
|
distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03, use_distmat
|
||||||
):
|
):
|
||||||
if use_metric_cuhk03:
|
if use_metric_cuhk03:
|
||||||
return eval_cuhk03(
|
return eval_cuhk03(
|
||||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
|
distmat, q_feats, g_feats, g_pids, q_camids, g_camids, max_rank, use_distmat
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return eval_market1501(
|
return eval_market1501(
|
||||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
|
distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_rank(
|
def evaluate_rank(
|
||||||
distmat,
|
distmat,
|
||||||
|
q_feats,
|
||||||
|
g_feats,
|
||||||
q_pids,
|
q_pids,
|
||||||
g_pids,
|
g_pids,
|
||||||
q_camids,
|
q_camids,
|
||||||
g_camids,
|
g_camids,
|
||||||
max_rank=50,
|
max_rank=50,
|
||||||
use_metric_cuhk03=False,
|
use_metric_cuhk03=False,
|
||||||
|
use_distmat=False,
|
||||||
use_cython=True
|
use_cython=True
|
||||||
):
|
):
|
||||||
"""Evaluates CMC rank.
|
"""Evaluates CMC rank.
|
||||||
Args:
|
Args:
|
||||||
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
|
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
|
||||||
|
q_feats (numpy.ndarray): 2-D array containing query features.
|
||||||
|
g_feats (numpy.ndarray): 2-D array containing gallery features.
|
||||||
q_pids (numpy.ndarray): 1-D array containing person identities
|
q_pids (numpy.ndarray): 1-D array containing person identities
|
||||||
of each query instance.
|
of each query instance.
|
||||||
g_pids (numpy.ndarray): 1-D array containing person identities
|
g_pids (numpy.ndarray): 1-D array containing person identities
|
||||||
@ -201,11 +223,11 @@ def evaluate_rank(
|
|||||||
"""
|
"""
|
||||||
if use_cython and IS_CYTHON_AVAI:
|
if use_cython and IS_CYTHON_AVAI:
|
||||||
return evaluate_cy(
|
return evaluate_cy(
|
||||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
|
distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
|
||||||
use_metric_cuhk03
|
use_metric_cuhk03, use_distmat
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return evaluate_py(
|
return evaluate_py(
|
||||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
|
distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
|
||||||
use_metric_cuhk03
|
use_metric_cuhk03, use_distmat
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
all:
|
all:
|
||||||
python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
rm -rf build
|
rm -rf build
|
||||||
|
python test_cython.py
|
||||||
clean:
|
clean:
|
||||||
rm -rf build
|
rm -rf build
|
||||||
rm -f rank_cy.c *.so
|
rm -f rank_cy.c *.so
|
||||||
|
@ -5,6 +5,7 @@ import cython
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import faiss
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -17,22 +18,35 @@ Credit to https://github.com/luzai
|
|||||||
|
|
||||||
|
|
||||||
# Main interface
|
# Main interface
|
||||||
cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False):
|
cpdef evaluate_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False,
|
||||||
|
use_distmat=False):
|
||||||
distmat = np.asarray(distmat, dtype=np.float32)
|
distmat = np.asarray(distmat, dtype=np.float32)
|
||||||
|
q_feats = np.asarray(q_feats, dtype=np.float32)
|
||||||
|
g_feats = np.asarray(g_feats, dtype=np.float32)
|
||||||
q_pids = np.asarray(q_pids, dtype=np.int64)
|
q_pids = np.asarray(q_pids, dtype=np.int64)
|
||||||
g_pids = np.asarray(g_pids, dtype=np.int64)
|
g_pids = np.asarray(g_pids, dtype=np.int64)
|
||||||
q_camids = np.asarray(q_camids, dtype=np.int64)
|
q_camids = np.asarray(q_camids, dtype=np.int64)
|
||||||
g_camids = np.asarray(g_camids, dtype=np.int64)
|
g_camids = np.asarray(g_camids, dtype=np.int64)
|
||||||
if use_metric_cuhk03:
|
if use_metric_cuhk03:
|
||||||
return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
return eval_cuhk03_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
|
||||||
return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
return eval_market1501_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat)
|
||||||
|
|
||||||
|
|
||||||
cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
cpdef eval_cuhk03_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
|
||||||
long[:]q_camids, long[:]g_camids, long max_rank):
|
long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat):
|
||||||
|
|
||||||
cdef long num_q = distmat.shape[0]
|
cdef long num_q = q_feats.shape[0]
|
||||||
cdef long num_g = distmat.shape[1]
|
cdef long num_g = g_feats.shape[0]
|
||||||
|
cdef long dim = q_feats.shape[1]
|
||||||
|
|
||||||
|
cdef long[:,:] indices
|
||||||
|
cdef index = faiss.IndexFlatL2(dim)
|
||||||
|
index.add(np.asarray(g_feats))
|
||||||
|
|
||||||
|
if use_distmat:
|
||||||
|
indices = np.argsort(distmat, axis=1)
|
||||||
|
else:
|
||||||
|
indices = index.search(np.asarray(q_feats), k=num_g)[1]
|
||||||
|
|
||||||
if num_g < max_rank:
|
if num_g < max_rank:
|
||||||
max_rank = num_g
|
max_rank = num_g
|
||||||
@ -40,7 +54,6 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||||||
|
|
||||||
cdef:
|
cdef:
|
||||||
long num_repeats = 10
|
long num_repeats = 10
|
||||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
|
||||||
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||||
|
|
||||||
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
||||||
@ -147,24 +160,34 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||||||
return np.asarray(avg_cmc).astype(np.float32), mAP
|
return np.asarray(avg_cmc).astype(np.float32), mAP
|
||||||
|
|
||||||
|
|
||||||
cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
cpdef eval_market1501_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
|
||||||
long[:]q_camids, long[:]g_camids, long max_rank):
|
long[:]q_camids, long[:]g_camids, long max_rank, bint use_distmat):
|
||||||
|
|
||||||
cdef long num_q = distmat.shape[0]
|
cdef long num_q = q_feats.shape[0]
|
||||||
cdef long num_g = distmat.shape[1]
|
cdef long num_g = g_feats.shape[0]
|
||||||
|
cdef long dim = q_feats.shape[1]
|
||||||
|
|
||||||
|
cdef long[:,:] indices
|
||||||
|
cdef index = faiss.IndexFlatL2(dim)
|
||||||
|
index.add(np.asarray(g_feats))
|
||||||
|
|
||||||
|
if use_distmat:
|
||||||
|
indices = np.argsort(distmat, axis=1)
|
||||||
|
else:
|
||||||
|
indices = index.search(np.asarray(q_feats), k=num_g)[1]
|
||||||
|
|
||||||
if num_g < max_rank:
|
if num_g < max_rank:
|
||||||
max_rank = num_g
|
max_rank = num_g
|
||||||
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
||||||
|
|
||||||
cdef:
|
cdef:
|
||||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
|
||||||
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||||
|
|
||||||
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
||||||
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
|
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
|
||||||
float[:] all_INP = np.zeros(num_q, dtype=np.float32)
|
float[:] all_INP = np.zeros(num_q, dtype=np.float32)
|
||||||
float num_valid_q = 0. # number of valid query
|
float num_valid_q = 0. # number of valid query
|
||||||
|
long valid_index = 0
|
||||||
|
|
||||||
long q_idx, q_pid, q_camid, g_idx
|
long q_idx, q_pid, q_camid, g_idx
|
||||||
long[:] order = np.zeros(num_g, dtype=np.int64)
|
long[:] order = np.zeros(num_g, dtype=np.int64)
|
||||||
@ -181,7 +204,6 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||||||
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
|
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
|
||||||
float tmp_cmc_sum
|
float tmp_cmc_sum
|
||||||
|
|
||||||
long valid_index = 0
|
|
||||||
|
|
||||||
for q_idx in range(num_q):
|
for q_idx in range(num_q):
|
||||||
# get query pid and camid
|
# get query pid and camid
|
||||||
@ -234,7 +256,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||||||
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
|
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
|
||||||
num_rel += raw_cmc[g_idx]
|
num_rel += raw_cmc[g_idx]
|
||||||
all_AP[valid_index] = tmp_cmc_sum / num_rel
|
all_AP[valid_index] = tmp_cmc_sum / num_rel
|
||||||
valid_index+=1
|
valid_index += 1
|
||||||
|
|
||||||
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
||||||
|
|
||||||
@ -245,7 +267,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||||||
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
|
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
|
||||||
avg_cmc[rank_idx] /= num_valid_q
|
avg_cmc[rank_idx] /= num_valid_q
|
||||||
|
|
||||||
return np.asarray(avg_cmc).astype(np.float32), all_AP[:valid_index], all_INP[:valid_index]
|
return np.asarray(avg_cmc).astype(np.float32), np.asarray(all_AP[:valid_index]), np.asarray(all_INP[:valid_index])
|
||||||
|
|
||||||
|
|
||||||
# Compute the cumulative sum
|
# Compute the cumulative sum
|
||||||
|
96
fastreid/evaluation/rank_cylib/roc_cy.pyx
Normal file
96
fastreid/evaluation/rank_cylib/roc_cy.pyx
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
|
||||||
|
# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank_cylib/rank_cy.pyx
|
||||||
|
|
||||||
|
import cython
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
cimport numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Compiler directives:
|
||||||
|
https://github.com/cython/cython/wiki/enhancements-compilerdirectives
|
||||||
|
Cython tutorial:
|
||||||
|
https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html
|
||||||
|
Credit to https://github.com/luzai
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Main interface
|
||||||
|
cpdef evaluate_roc_cy(float[:,:] distmat, float[:,:] q_feats, float[:,:] g_feats, long[:] q_pids, long[:]g_pids,
|
||||||
|
long[:]q_camids, long[:]g_camids):
|
||||||
|
|
||||||
|
distmat = np.asarray(distmat, dtype=np.float32)
|
||||||
|
q_feats = np.asarray(q_feats, dtype=np.float32)
|
||||||
|
g_feats = np.asarray(g_feats, dtype=np.float32)
|
||||||
|
q_pids = np.asarray(q_pids, dtype=np.int64)
|
||||||
|
g_pids = np.asarray(g_pids, dtype=np.int64)
|
||||||
|
q_camids = np.asarray(q_camids, dtype=np.int64)
|
||||||
|
g_camids = np.asarray(g_camids, dtype=np.int64)
|
||||||
|
|
||||||
|
cdef long num_q = distmat.shape[0]
|
||||||
|
cdef long num_g = distmat.shape[1]
|
||||||
|
cdef long dim = q_feats.shape[1]
|
||||||
|
|
||||||
|
cdef long[:,:] indices
|
||||||
|
cdef index = faiss.IndexFlatL2(dim)
|
||||||
|
index.add(np.asarray(g_feats))
|
||||||
|
|
||||||
|
indices = index.search(np.asarray(q_feats), k=num_g)[1]
|
||||||
|
|
||||||
|
cdef:
|
||||||
|
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||||
|
|
||||||
|
float[:] pos = np.zeros(num_q*num_g, dtype=np.float32)
|
||||||
|
float[:] neg = np.zeros(num_q*num_g, dtype=np.float32)
|
||||||
|
|
||||||
|
long valid_pos = 0
|
||||||
|
long valid_neg = 0
|
||||||
|
long ind
|
||||||
|
|
||||||
|
long q_idx, q_pid, q_camid, g_idx
|
||||||
|
long[:] order = np.zeros(num_g, dtype=np.int64)
|
||||||
|
|
||||||
|
float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
|
||||||
|
long[:] sort_idx = np.zeros(num_g, dtype=np.int64)
|
||||||
|
|
||||||
|
long idx
|
||||||
|
|
||||||
|
for q_idx in range(num_q):
|
||||||
|
# get query pid and camid
|
||||||
|
q_pid = q_pids[q_idx]
|
||||||
|
q_camid = q_camids[q_idx]
|
||||||
|
|
||||||
|
for g_idx in range(num_g):
|
||||||
|
order[g_idx] = indices[q_idx, g_idx]
|
||||||
|
num_g_real = 0
|
||||||
|
|
||||||
|
# remove gallery samples that have the same pid and camid with query
|
||||||
|
for g_idx in range(num_g):
|
||||||
|
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
|
||||||
|
raw_cmc[num_g_real] = matches[q_idx][g_idx]
|
||||||
|
sort_idx[num_g_real] = order[g_idx]
|
||||||
|
num_g_real += 1
|
||||||
|
|
||||||
|
q_dist = distmat[q_idx]
|
||||||
|
|
||||||
|
for valid_idx in range(num_g_real):
|
||||||
|
if raw_cmc[valid_idx] == 1:
|
||||||
|
pos[valid_pos] = q_dist[sort_idx[valid_idx]]
|
||||||
|
valid_pos += 1
|
||||||
|
elif raw_cmc[valid_idx] == 0:
|
||||||
|
neg[valid_neg] = q_dist[sort_idx[valid_idx]]
|
||||||
|
valid_neg += 1
|
||||||
|
|
||||||
|
cdef float[:] scores = np.hstack((pos[:valid_pos], neg[:valid_neg]))
|
||||||
|
cdef float[:] labels = np.hstack((np.zeros(valid_pos, dtype=np.float32),
|
||||||
|
np.ones(valid_neg, dtype=np.float32)))
|
||||||
|
return np.asarray(scores), np.asarray(labels)
|
||||||
|
|
||||||
|
|
||||||
|
# Compute the cumulative sum
|
||||||
|
cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n):
|
||||||
|
cdef long i
|
||||||
|
dst[0] = src[0]
|
||||||
|
for i in range(1, n):
|
||||||
|
dst[i] = src[i] + dst[i - 1]
|
@ -18,6 +18,11 @@ ext_modules = [
|
|||||||
'rank_cy',
|
'rank_cy',
|
||||||
['rank_cy.pyx'],
|
['rank_cy.pyx'],
|
||||||
include_dirs=[numpy_include()],
|
include_dirs=[numpy_include()],
|
||||||
|
),
|
||||||
|
Extension(
|
||||||
|
'roc_cy',
|
||||||
|
['roc_cy.pyx'],
|
||||||
|
include_dirs=[numpy_include()],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
import timeit
|
import timeit
|
||||||
|
import numpy as np
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
|
|
||||||
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
|
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
|
||||||
|
|
||||||
|
from fastreid.evaluation import evaluate_rank
|
||||||
|
from fastreid.evaluation import evaluate_roc
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Test the speed of cython-based evaluation code. The speed improvements
|
Test the speed of cython-based evaluation code. The speed improvements
|
||||||
can be much bigger when using the real reid data, which contains a larger
|
can be much bigger when using the real reid data, which contains a larger
|
||||||
@ -22,58 +24,97 @@ import sys
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
|
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
|
||||||
from fastreid import evaluation
|
from fastreid.evaluation import evaluate_rank
|
||||||
|
from fastreid.evaluation import evaluate_roc
|
||||||
num_q = 30
|
num_q = 30
|
||||||
num_g = 300
|
num_g = 300
|
||||||
|
dim = 512
|
||||||
max_rank = 5
|
max_rank = 5
|
||||||
distmat = np.random.rand(num_q, num_g) * 20
|
q_feats = np.random.rand(num_q, dim).astype(np.float32) * 20
|
||||||
|
q_feats = q_feats / np.linalg.norm(q_feats, ord=2, axis=1, keepdims=True)
|
||||||
|
g_feats = np.random.rand(num_g, dim).astype(np.float32) * 20
|
||||||
|
g_feats = g_feats / np.linalg.norm(g_feats, ord=2, axis=1, keepdims=True)
|
||||||
|
distmat = 1 - np.dot(q_feats, g_feats.transpose())
|
||||||
q_pids = np.random.randint(0, num_q, size=num_q)
|
q_pids = np.random.randint(0, num_q, size=num_q)
|
||||||
g_pids = np.random.randint(0, num_g, size=num_g)
|
g_pids = np.random.randint(0, num_g, size=num_g)
|
||||||
q_camids = np.random.randint(0, 5, size=num_q)
|
q_camids = np.random.randint(0, 5, size=num_q)
|
||||||
g_camids = np.random.randint(0, 5, size=num_g)
|
g_camids = np.random.randint(0, 5, size=num_g)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# print('=> Using market1501\'s metric')
|
print('=> Using CMC metric')
|
||||||
# pytime = timeit.timeit(
|
pytime = timeit.timeit(
|
||||||
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)',
|
'evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat=True, use_cython=False)',
|
||||||
# setup=setup,
|
setup=setup,
|
||||||
# number=20
|
number=20
|
||||||
# )
|
)
|
||||||
# cytime = timeit.timeit(
|
cytime = timeit.timeit(
|
||||||
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)',
|
'evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank, use_distmat=True, use_cython=True)',
|
||||||
# setup=setup,
|
setup=setup,
|
||||||
# number=20
|
number=20
|
||||||
# )
|
)
|
||||||
# print('Python time: {} s'.format(pytime))
|
print('Python time: {} s'.format(pytime))
|
||||||
# print('Cython time: {} s'.format(cytime))
|
print('Cython time: {} s'.format(cytime))
|
||||||
# print('Cython is {} times faster than python\n'.format(pytime / cytime))
|
print('CMC Cython is {} times faster than python\n'.format(pytime / cytime))
|
||||||
#
|
|
||||||
# print('=> Using cuhk03\'s metric')
|
print('=> Using ROC metric')
|
||||||
# pytime = timeit.timeit(
|
pytime = timeit.timeit(
|
||||||
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)',
|
'evaluate_roc(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, use_cython=False)',
|
||||||
# setup=setup,
|
setup=setup,
|
||||||
# number=20
|
number=20
|
||||||
# )
|
)
|
||||||
# cytime = timeit.timeit(
|
cytime = timeit.timeit(
|
||||||
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)',
|
'evaluate_roc(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, use_cython=True)',
|
||||||
# setup=setup,
|
setup=setup,
|
||||||
# number=20
|
number=20
|
||||||
# )
|
)
|
||||||
# print('Python time: {} s'.format(pytime))
|
print('Python time: {} s'.format(pytime))
|
||||||
# print('Cython time: {} s'.format(cytime))
|
print('Cython time: {} s'.format(cytime))
|
||||||
# print('Cython is {} times faster than python\n'.format(pytime / cytime))
|
print('ROC Cython is {} times faster than python\n'.format(pytime / cytime))
|
||||||
|
|
||||||
from fastreid.evaluation import evaluate_rank
|
|
||||||
print("=> Check precision")
|
print("=> Check precision")
|
||||||
num_q = 30
|
num_q = 30
|
||||||
num_g = 300
|
num_g = 300
|
||||||
|
dim = 512
|
||||||
max_rank = 5
|
max_rank = 5
|
||||||
distmat = np.random.rand(num_q, num_g) * 20
|
q_feats = np.random.rand(num_q, dim).astype(np.float32) * 20
|
||||||
|
q_feats = q_feats / np.linalg.norm(q_feats, ord=2, axis=1, keepdims=True)
|
||||||
|
g_feats = np.random.rand(num_g, dim).astype(np.float32) * 20
|
||||||
|
g_feats = g_feats / np.linalg.norm(g_feats, ord=2, axis=1, keepdims=True)
|
||||||
|
distmat = 1 - np.dot(q_feats, g_feats.transpose())
|
||||||
q_pids = np.random.randint(0, num_q, size=num_q)
|
q_pids = np.random.randint(0, num_q, size=num_q)
|
||||||
g_pids = np.random.randint(0, num_g, size=num_g)
|
g_pids = np.random.randint(0, num_g, size=num_g)
|
||||||
q_camids = np.random.randint(0, 5, size=num_q)
|
q_camids = np.random.randint(0, 5, size=num_q)
|
||||||
g_camids = np.random.randint(0, 5, size=num_g)
|
g_camids = np.random.randint(0, 5, size=num_g)
|
||||||
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
|
cmc_py_d, mAP_py_d, mINP_py_d = evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
|
||||||
print("Python:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
|
use_distmat=True, use_cython=False)
|
||||||
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
|
cmc_py, mAP_py, mINP_py = evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
|
||||||
print("Cython:\nmAP = {} \ncmc = {}\nmINP = {}".format(np.array(mAP), cmc, np.array(mINP)))
|
use_distmat=False, use_cython=False)
|
||||||
|
np.testing.assert_allclose(cmc_py_d, cmc_py, rtol=1e-3, atol=1e-6)
|
||||||
|
np.testing.assert_allclose(mAP_py_d, mAP_py, rtol=1e-3, atol=1e-6)
|
||||||
|
np.testing.assert_allclose(mINP_py_d, mINP_py, rtol=1e-3, atol=1e-6)
|
||||||
|
print('Results between distmat and features are the same in python!')
|
||||||
|
|
||||||
|
cmc_cy_d, mAP_cy_d, mINP_cy_d = evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
|
||||||
|
use_distmat=True, use_cython=True)
|
||||||
|
cmc_cy, mAP_cy, mINP_cy = evaluate_rank(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids, max_rank,
|
||||||
|
use_distmat=False, use_cython=True)
|
||||||
|
np.testing.assert_allclose(cmc_cy_d, cmc_cy, rtol=1e-3, atol=1e-6)
|
||||||
|
np.testing.assert_allclose(mAP_cy_d, mAP_cy, rtol=1e-3, atol=1e-6)
|
||||||
|
np.testing.assert_allclose(mINP_cy_d, mINP_cy, rtol=1e-3, atol=1e-6)
|
||||||
|
print('Results between distmat and features are the same in cython!')
|
||||||
|
|
||||||
|
np.testing.assert_allclose(cmc_py, cmc_cy, rtol=1e-3, atol=1e-6)
|
||||||
|
np.testing.assert_allclose(mAP_py, mAP_cy, rtol=1e-3, atol=1e-6)
|
||||||
|
np.testing.assert_allclose(mINP_py, mINP_cy, rtol=1e-3, atol=1e-6)
|
||||||
|
print('Rank results between python and cython are the same!')
|
||||||
|
|
||||||
|
scores_cy, labels_cy = evaluate_roc(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids)
|
||||||
|
scores_py, labels_py = evaluate_roc(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids,
|
||||||
|
use_cython=False)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(scores_cy, scores_py, rtol=1e-3, atol=1e-6)
|
||||||
|
np.testing.assert_allclose(labels_cy, labels_py, rtol=1e-3, atol=1e-6)
|
||||||
|
print('ROC results between python and cython are the same!\n')
|
||||||
|
|
||||||
|
print("=> Check exact values")
|
||||||
|
print("mAP = {} \ncmc = {}\nmINP = {}\nScores = {}".format(np.array(mAP_cy), cmc_cy, np.array(mINP_cy), scores_cy))
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -45,8 +46,6 @@ class ReidEvaluator(DatasetEvaluator):
|
|||||||
def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
|
def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
|
||||||
assert metric in ["cosine", "euclidean"], "must choose from [cosine, euclidean], but got {}".format(metric)
|
assert metric in ["cosine", "euclidean"], "must choose from [cosine, euclidean], but got {}".format(metric)
|
||||||
if metric == "cosine":
|
if metric == "cosine":
|
||||||
query_feat = F.normalize(query_feat, dim=1)
|
|
||||||
gallery_feat = F.normalize(gallery_feat, dim=1)
|
|
||||||
dist = 1 - torch.mm(query_feat, gallery_feat.t())
|
dist = 1 - torch.mm(query_feat, gallery_feat.t())
|
||||||
else:
|
else:
|
||||||
m, n = query_feat.size(0), gallery_feat.size(0)
|
m, n = query_feat.size(0), gallery_feat.size(0)
|
||||||
@ -96,6 +95,10 @@ class ReidEvaluator(DatasetEvaluator):
|
|||||||
alpha = self.cfg.TEST.AQE.ALPHA
|
alpha = self.cfg.TEST.AQE.ALPHA
|
||||||
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
|
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
|
||||||
|
|
||||||
|
if self.cfg.TEST.METRIC == "cosine":
|
||||||
|
query_features = F.normalize(query_features, dim=1)
|
||||||
|
gallery_features = F.normalize(gallery_features, dim=1)
|
||||||
|
|
||||||
dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, gallery_features)
|
dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, gallery_features)
|
||||||
|
|
||||||
if self.cfg.TEST.RERANK.ENABLED:
|
if self.cfg.TEST.RERANK.ENABLED:
|
||||||
@ -105,9 +108,18 @@ class ReidEvaluator(DatasetEvaluator):
|
|||||||
lambda_value = self.cfg.TEST.RERANK.LAMBDA
|
lambda_value = self.cfg.TEST.RERANK.LAMBDA
|
||||||
q_q_dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, query_features)
|
q_q_dist = self.cal_dist(self.cfg.TEST.METRIC, query_features, query_features)
|
||||||
g_g_dist = self.cal_dist(self.cfg.TEST.METRIC, gallery_features, gallery_features)
|
g_g_dist = self.cal_dist(self.cfg.TEST.METRIC, gallery_features, gallery_features)
|
||||||
dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value)
|
re_dist = re_ranking(dist, q_q_dist, g_g_dist, k1, k2, lambda_value)
|
||||||
|
query_features = query_features.numpy()
|
||||||
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
gallery_features = gallery_features.numpy()
|
||||||
|
cmc, all_AP, all_INP = evaluate_rank(re_dist, query_features, gallery_features,
|
||||||
|
query_pids, gallery_pids, query_camids,
|
||||||
|
gallery_camids, use_distmat=True)
|
||||||
|
else:
|
||||||
|
query_features = query_features.numpy()
|
||||||
|
gallery_features = gallery_features.numpy()
|
||||||
|
cmc, all_AP, all_INP = evaluate_rank(dist, query_features, gallery_features,
|
||||||
|
query_pids, gallery_pids, query_camids, gallery_camids,
|
||||||
|
use_distmat=False)
|
||||||
mAP = np.mean(all_AP)
|
mAP = np.mean(all_AP)
|
||||||
mINP = np.mean(all_INP)
|
mINP = np.mean(all_INP)
|
||||||
for r in [1, 5, 10]:
|
for r in [1, 5, 10]:
|
||||||
@ -115,9 +127,13 @@ class ReidEvaluator(DatasetEvaluator):
|
|||||||
self._results['mAP'] = mAP
|
self._results['mAP'] = mAP
|
||||||
self._results['mINP'] = mINP
|
self._results['mINP'] = mINP
|
||||||
|
|
||||||
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
if self.cfg.TEST.ROC_ENABLED:
|
||||||
fprs = [1e-4, 1e-3, 1e-2]
|
scores, labels = evaluate_roc(dist, query_features, gallery_features,
|
||||||
for i in range(len(fprs)):
|
query_pids, gallery_pids, query_camids, gallery_camids)
|
||||||
self._results["TPR@FPR={:.0e}".format(fprs[i])] = tprs[i]
|
fprs, tprs, thres = metrics.roc_curve(labels, scores)
|
||||||
|
|
||||||
|
for fpr in [1e-4, 1e-3, 1e-2]:
|
||||||
|
ind = np.argmin(np.abs(fprs - fpr))
|
||||||
|
self._results["TPR@FPR={:.0e}".format(fpr)] = tprs[ind]
|
||||||
|
|
||||||
return copy.deepcopy(self._results)
|
return copy.deepcopy(self._results)
|
||||||
|
@ -4,11 +4,24 @@
|
|||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn import metrics
|
|
||||||
|
try:
|
||||||
|
from .rank_cylib.roc_cy import evaluate_roc_cy
|
||||||
|
|
||||||
|
IS_CYTHON_AVAI = True
|
||||||
|
except ImportError:
|
||||||
|
IS_CYTHON_AVAI = False
|
||||||
|
warnings.warn(
|
||||||
|
'Cython roc evaluation (very fast so highly recommended) is '
|
||||||
|
'unavailable, now use python evaluation.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
|
def evaluate_roc_py(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids):
|
||||||
r"""Evaluation with ROC curve.
|
r"""Evaluation with ROC curve.
|
||||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||||
|
|
||||||
@ -16,8 +29,12 @@ def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
|
|||||||
distmat (np.ndarray): cosine distance matrix
|
distmat (np.ndarray): cosine distance matrix
|
||||||
"""
|
"""
|
||||||
num_q, num_g = distmat.shape
|
num_q, num_g = distmat.shape
|
||||||
|
dim = q_feats.shape[1]
|
||||||
|
|
||||||
indices = np.argsort(distmat, axis=1)
|
index = faiss.IndexFlatL2(dim)
|
||||||
|
index.add(g_feats)
|
||||||
|
|
||||||
|
_, indices = index.search(q_feats, k=num_g)
|
||||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||||
|
|
||||||
pos = []
|
pos = []
|
||||||
@ -31,22 +48,49 @@ def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
|
|||||||
order = indices[q_idx]
|
order = indices[q_idx]
|
||||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||||
keep = np.invert(remove)
|
keep = np.invert(remove)
|
||||||
cmc = matches[q_idx][keep]
|
raw_cmc = matches[q_idx][keep]
|
||||||
|
|
||||||
sort_idx = order[keep]
|
sort_idx = order[keep]
|
||||||
|
|
||||||
q_dist = distmat[q_idx]
|
q_dist = distmat[q_idx]
|
||||||
ind_pos = np.where(cmc == 1)[0]
|
ind_pos = np.where(raw_cmc == 1)[0]
|
||||||
pos.extend(q_dist[sort_idx[ind_pos]])
|
pos.extend(q_dist[sort_idx[ind_pos]])
|
||||||
|
|
||||||
ind_neg = np.where(cmc == 0)[0]
|
ind_neg = np.where(raw_cmc == 0)[0]
|
||||||
neg.extend(q_dist[sort_idx[ind_neg]])
|
neg.extend(q_dist[sort_idx[ind_neg]])
|
||||||
|
|
||||||
scores = np.hstack((pos, neg))
|
scores = np.hstack((pos, neg))
|
||||||
|
|
||||||
labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg))))
|
labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg))))
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
|
return scores, labels
|
||||||
tprs = []
|
|
||||||
for i in [1e-4, 1e-3, 1e-2]:
|
|
||||||
ind = np.argmin(np.abs(fpr-i))
|
def evaluate_roc(
|
||||||
tprs.append(tpr[ind])
|
distmat,
|
||||||
return tprs
|
q_feats,
|
||||||
|
g_feats,
|
||||||
|
q_pids,
|
||||||
|
g_pids,
|
||||||
|
q_camids,
|
||||||
|
g_camids,
|
||||||
|
use_cython=True
|
||||||
|
):
|
||||||
|
"""Evaluates CMC rank.
|
||||||
|
Args:
|
||||||
|
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
|
||||||
|
q_pids (numpy.ndarray): 1-D array containing person identities
|
||||||
|
of each query instance.
|
||||||
|
g_pids (numpy.ndarray): 1-D array containing person identities
|
||||||
|
of each gallery instance.
|
||||||
|
q_camids (numpy.ndarray): 1-D array containing camera views under
|
||||||
|
which each query instance is captured.
|
||||||
|
g_camids (numpy.ndarray): 1-D array containing camera views under
|
||||||
|
which each gallery instance is captured.
|
||||||
|
use_cython (bool, optional): use cython code for evaluation. Default is True.
|
||||||
|
This is highly recommended as the cython code can speed up the cmc computation
|
||||||
|
by more than 10x. This requires Cython to be installed.
|
||||||
|
"""
|
||||||
|
if use_cython and IS_CYTHON_AVAI:
|
||||||
|
return evaluate_roc_cy(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids)
|
||||||
|
else:
|
||||||
|
return evaluate_roc_py(distmat, q_feats, g_feats, q_pids, g_pids, q_camids, g_camids)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user