mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
feat: update roc curve and TPR@FPR metric
support plot multiple ROC curves with different model
This commit is contained in:
parent
e344eae1cc
commit
2ac55a7601
23
demo/plot_roc_with_pickle.py
Normal file
23
demo/plot_roc_with_pickle.py
Normal file
@ -0,0 +1,23 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
|
||||
sys.path.append('.')
|
||||
from fastreid.utils.visualizer import Visualizer
|
||||
|
||||
if __name__ == "__main__":
|
||||
baseline_res = Visualizer.load_roc_info("logs/duke_vis/roc_info.pickle")
|
||||
mgn_res = Visualizer.load_roc_info("logs/mgn_duke_vis/roc_info.pickle")
|
||||
|
||||
fig = Visualizer.plot_roc_curve(baseline_res['fpr'], baseline_res['tpr'], name='baseline')
|
||||
Visualizer.plot_roc_curve(mgn_res['fpr'], mgn_res['tpr'], name='mgn', fig=fig)
|
||||
plt.savefig('roc.jpg')
|
||||
|
||||
fig = Visualizer.plot_distribution(baseline_res['pos'], baseline_res['neg'], name='baseline')
|
||||
Visualizer.plot_distribution(mgn_res['pos'], mgn_res['neg'], name='mgn', fig=fig)
|
||||
plt.savefig('dist.jpg')
|
@ -1,3 +1,3 @@
|
||||
python demo/visualize_result.py --config-file ''configs/DukeMTMC/sbs_R50.yml'' \
|
||||
--parallel --vis-label --dataset-name 'DukeMTMC' --output 'logs/duke_vis' \
|
||||
--opts MODEL.WEIGHTS "logs/dukemtmc/sbs_R50_60epoch/model_final.pth"
|
||||
python demo/visualize_result.py --config-file 'logs/dukemtmc/mgn_R50-ibn/config.yaml' \
|
||||
--parallel --vis-label --dataset-name 'DukeMTMC' --output 'logs/mgn_duke_vis' \
|
||||
--opts MODEL.WEIGHTS "logs/dukemtmc/mgn_R50-ibn/model_final.pth"
|
||||
|
@ -121,17 +121,18 @@ if __name__ == '__main__':
|
||||
g_camids = np.asarray(camids[num_query:])
|
||||
|
||||
# compute cosine distance
|
||||
distmat = torch.mm(q_feat, g_feat.t())
|
||||
distmat = 1 - torch.mm(q_feat, g_feat.t())
|
||||
distmat = distmat.numpy()
|
||||
|
||||
logger.info("Computing APs for all query images ...")
|
||||
cmc, all_ap, all_inp = evaluate_rank(1 - distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
|
||||
visualizer = Visualizer(test_loader.loader.dataset)
|
||||
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
|
||||
logger.info("Saving ROC curve ...")
|
||||
visualizer.vis_roc_curve(args.output)
|
||||
fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output)
|
||||
visualizer.save_roc_info(args.output, fpr, tpr, pos, neg)
|
||||
|
||||
logger.info("Saving rank list result ...")
|
||||
query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,
|
||||
|
@ -97,6 +97,8 @@ class ReidEvaluator(DatasetEvaluator):
|
||||
self._results['mAP'] = mAP
|
||||
self._results['mINP'] = mINP
|
||||
|
||||
auc = evaluate_roc(1 - dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
self._results["AUC"] = auc
|
||||
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
fprs = [1e-4, 1e-3, 1e-2]
|
||||
for i in range(len(fprs)):
|
||||
self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
|
||||
return copy.deepcopy(self._results)
|
||||
|
@ -13,7 +13,7 @@ def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
|
||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||
|
||||
Args:
|
||||
distmat (np.ndarray): similarity matrix
|
||||
distmat (np.ndarray): cosine distance matrix
|
||||
"""
|
||||
num_q, num_g = distmat.shape
|
||||
|
||||
@ -41,10 +41,12 @@ def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
|
||||
ind_neg = np.where(cmc == 0)[0]
|
||||
neg.extend(q_dist[sort_idx[ind_neg]])
|
||||
|
||||
pos = 1 - np.array(pos)
|
||||
neg = 1 - np.array(neg)
|
||||
scores = np.hstack((pos, neg))
|
||||
|
||||
labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg))))
|
||||
fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
|
||||
return metrics.auc(fpr, tpr)
|
||||
tprs = []
|
||||
for i in [1e-4, 1e-3, 1e-2]:
|
||||
ind = np.argmin(np.abs(fpr-i))
|
||||
tprs.append(tpr[ind])
|
||||
return tprs
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
import os
|
||||
|
||||
import pickle
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tqdm
|
||||
@ -13,6 +14,7 @@ from scipy.stats import norm
|
||||
from sklearn import metrics
|
||||
|
||||
from .file_io import PathManager
|
||||
import random
|
||||
|
||||
|
||||
class Visualizer:
|
||||
@ -21,15 +23,16 @@ class Visualizer:
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def get_model_output(self, all_ap, sim, q_pids, g_pids, q_camids, g_camids):
|
||||
def get_model_output(self, all_ap, dist, q_pids, g_pids, q_camids, g_camids):
|
||||
self.all_ap = all_ap
|
||||
self.sim = sim
|
||||
self.dist = dist
|
||||
self.sim = 1 - dist
|
||||
self.q_pids = q_pids
|
||||
self.g_pids = g_pids
|
||||
self.q_camids = q_camids
|
||||
self.g_camids = g_camids
|
||||
|
||||
self.indices = np.argsort(1 - sim, axis=1)
|
||||
self.indices = np.argsort(dist, axis=1)
|
||||
self.matches = (g_pids[self.indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
self.num_query = len(q_pids)
|
||||
@ -148,54 +151,91 @@ class Visualizer:
|
||||
assert rank_sort in ['ascending', 'descending'], "{} not match [ascending, descending]".format(rank_sort)
|
||||
|
||||
query_indices = np.argsort(self.all_ap)
|
||||
if rank_sort == 'descending': query_indices = query_indices[::-1]
|
||||
if rank_sort == 'descending': query_indices = query_indices[::-1]
|
||||
|
||||
query_indices = query_indices[:num_vis]
|
||||
self.save_rank_result(query_indices, output, max_rank, vis_label, label_sort, actmap)
|
||||
|
||||
def vis_roc_curve(self, output):
|
||||
PathManager.mkdirs(output)
|
||||
pos_sim, neg_sim = [], []
|
||||
pos, neg = [], []
|
||||
for i, q in enumerate(self.q_pids):
|
||||
cmc, sort_idx = self.get_matched_result(i) # remove same id in same camera
|
||||
ind_pos = np.where(cmc == 1)[0]
|
||||
q_dist = self.sim[i]
|
||||
pos_sim.extend(q_dist[sort_idx[ind_pos]])
|
||||
q_dist = self.dist[i]
|
||||
pos.extend(q_dist[sort_idx[ind_pos]])
|
||||
|
||||
ind_neg = np.where(cmc == 0)[0]
|
||||
neg_sim.extend(q_dist[sort_idx[ind_neg]])
|
||||
neg.extend(q_dist[sort_idx[ind_neg]])
|
||||
|
||||
pos = 1 - np.array(pos_sim)
|
||||
neg = 1 - np.array(neg_sim)
|
||||
scores = np.hstack((pos, neg))
|
||||
|
||||
labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg))))
|
||||
|
||||
fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
|
||||
plt.figure()
|
||||
plt.semilogx(fpr, tpr, 'b')
|
||||
|
||||
self.plot_roc_curve(fpr, tpr)
|
||||
filepath = os.path.join(output, "roc.jpg")
|
||||
plt.savefig(filepath)
|
||||
self.plot_distribution(pos, neg)
|
||||
filepath = os.path.join(output, "pos_neg_dist.jpg")
|
||||
plt.savefig(filepath)
|
||||
return fpr, tpr, pos, neg
|
||||
|
||||
plt.figure(figsize=(10, 5))
|
||||
n, bins, _ = plt.hist(pos, bins=80, alpha=0.7, density=True, color='red', label='positive')
|
||||
@staticmethod
|
||||
def plot_roc_curve(fpr, tpr, name='model', fig=None):
|
||||
if fig is None:
|
||||
fig = plt.figure()
|
||||
plt.semilogx(np.arange(0, 1, 0.01), np.arange(0, 1, 0.01), 'r', linestyle='--', label='Random guess')
|
||||
plt.semilogx(fpr, tpr, color=(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)),
|
||||
label='ROC curve with {}'.format(name))
|
||||
plt.title('Receiver Operating Characteristic')
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.legend(loc='best')
|
||||
return fig
|
||||
|
||||
@staticmethod
|
||||
def plot_distribution(pos, neg, name='model', fig=None):
|
||||
if fig is None:
|
||||
fig = plt.figure()
|
||||
pos_color = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))
|
||||
n, bins, _ = plt.hist(pos, bins=80, alpha=0.7, density=True,
|
||||
color=pos_color,
|
||||
label='positive with {}'.format(name))
|
||||
mu = np.mean(pos)
|
||||
sigma = np.std(pos)
|
||||
y = norm.pdf(bins, mu, sigma) # fitting curve
|
||||
plt.plot(bins, y, 'r--') # plot y curve
|
||||
plt.plot(bins, y, color=pos_color) # plot y curve
|
||||
|
||||
n, bins, _ = plt.hist(neg, bins=80, alpha=0.5, density=True, color='blue', label='negative')
|
||||
neg_color = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))
|
||||
n, bins, _ = plt.hist(neg, bins=80, alpha=0.5, density=True,
|
||||
color=neg_color,
|
||||
label='negative with {}'.format(name))
|
||||
mu = np.mean(neg)
|
||||
sigma = np.std(neg)
|
||||
y = norm.pdf(bins, mu, sigma) # fitting curve
|
||||
plt.plot(bins, y, 'b--') # plot y curve
|
||||
plt.plot(bins, y, color=neg_color) # plot y curve
|
||||
|
||||
plt.xticks(np.arange(0, 1.5, 0.1))
|
||||
plt.title('positive and negative pairs distribution')
|
||||
plt.legend(loc='best')
|
||||
filepath = os.path.join(output, "pos_neg_dist.jpg")
|
||||
plt.savefig(filepath)
|
||||
return fig
|
||||
|
||||
@staticmethod
|
||||
def save_roc_info(output, fpr, tpr, pos, neg):
|
||||
results = {
|
||||
"fpr": np.asarray(fpr),
|
||||
"tpr": np.asarray(tpr),
|
||||
"pos": np.asarray(pos),
|
||||
"neg": np.asarray(neg),
|
||||
}
|
||||
with open(os.path.join(output, "roc_info.pickle"), "wb") as handle:
|
||||
pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
@staticmethod
|
||||
def load_roc_info(path):
|
||||
with open(path, 'rb') as handle: res = pickle.load(handle)
|
||||
return res
|
||||
# def plot_camera_dist(self):
|
||||
# same_cam, diff_cam = [], []
|
||||
# for i, q in enumerate(self.q_pids):
|
||||
|
Loading…
x
Reference in New Issue
Block a user