feat: update roc curve and TPR@FPR metric

support plot multiple ROC curves with different model
This commit is contained in:
liaoxingyu 2020-05-20 14:29:33 +08:00
parent e344eae1cc
commit 2ac55a7601
6 changed files with 100 additions and 32 deletions

View File

@ -0,0 +1,23 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
import matplotlib.pyplot as plt
import sys
sys.path.append('.')
from fastreid.utils.visualizer import Visualizer
if __name__ == "__main__":
baseline_res = Visualizer.load_roc_info("logs/duke_vis/roc_info.pickle")
mgn_res = Visualizer.load_roc_info("logs/mgn_duke_vis/roc_info.pickle")
fig = Visualizer.plot_roc_curve(baseline_res['fpr'], baseline_res['tpr'], name='baseline')
Visualizer.plot_roc_curve(mgn_res['fpr'], mgn_res['tpr'], name='mgn', fig=fig)
plt.savefig('roc.jpg')
fig = Visualizer.plot_distribution(baseline_res['pos'], baseline_res['neg'], name='baseline')
Visualizer.plot_distribution(mgn_res['pos'], mgn_res['neg'], name='mgn', fig=fig)
plt.savefig('dist.jpg')

View File

@ -1,3 +1,3 @@
python demo/visualize_result.py --config-file ''configs/DukeMTMC/sbs_R50.yml'' \
--parallel --vis-label --dataset-name 'DukeMTMC' --output 'logs/duke_vis' \
--opts MODEL.WEIGHTS "logs/dukemtmc/sbs_R50_60epoch/model_final.pth"
python demo/visualize_result.py --config-file 'logs/dukemtmc/mgn_R50-ibn/config.yaml' \
--parallel --vis-label --dataset-name 'DukeMTMC' --output 'logs/mgn_duke_vis' \
--opts MODEL.WEIGHTS "logs/dukemtmc/mgn_R50-ibn/model_final.pth"

View File

@ -121,17 +121,18 @@ if __name__ == '__main__':
g_camids = np.asarray(camids[num_query:])
# compute cosine distance
distmat = torch.mm(q_feat, g_feat.t())
distmat = 1 - torch.mm(q_feat, g_feat.t())
distmat = distmat.numpy()
logger.info("Computing APs for all query images ...")
cmc, all_ap, all_inp = evaluate_rank(1 - distmat, q_pids, g_pids, q_camids, g_camids)
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
visualizer = Visualizer(test_loader.loader.dataset)
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
logger.info("Saving ROC curve ...")
visualizer.vis_roc_curve(args.output)
fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output)
visualizer.save_roc_info(args.output, fpr, tpr, pos, neg)
logger.info("Saving rank list result ...")
query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis,

View File

@ -97,6 +97,8 @@ class ReidEvaluator(DatasetEvaluator):
self._results['mAP'] = mAP
self._results['mINP'] = mINP
auc = evaluate_roc(1 - dist, query_pids, gallery_pids, query_camids, gallery_camids)
self._results["AUC"] = auc
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
fprs = [1e-4, 1e-3, 1e-2]
for i in range(len(fprs)):
self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
return copy.deepcopy(self._results)

View File

@ -13,7 +13,7 @@ def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
Key: for each query identity, its gallery images from the same camera view are discarded.
Args:
distmat (np.ndarray): similarity matrix
distmat (np.ndarray): cosine distance matrix
"""
num_q, num_g = distmat.shape
@ -41,10 +41,12 @@ def evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids):
ind_neg = np.where(cmc == 0)[0]
neg.extend(q_dist[sort_idx[ind_neg]])
pos = 1 - np.array(pos)
neg = 1 - np.array(neg)
scores = np.hstack((pos, neg))
labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg))))
fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
return metrics.auc(fpr, tpr)
tprs = []
for i in [1e-4, 1e-3, 1e-2]:
ind = np.argmin(np.abs(fpr-i))
tprs.append(tpr[ind])
return tprs

View File

@ -6,6 +6,7 @@
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import tqdm
@ -13,6 +14,7 @@ from scipy.stats import norm
from sklearn import metrics
from .file_io import PathManager
import random
class Visualizer:
@ -21,15 +23,16 @@ class Visualizer:
def __init__(self, dataset):
self.dataset = dataset
def get_model_output(self, all_ap, sim, q_pids, g_pids, q_camids, g_camids):
def get_model_output(self, all_ap, dist, q_pids, g_pids, q_camids, g_camids):
self.all_ap = all_ap
self.sim = sim
self.dist = dist
self.sim = 1 - dist
self.q_pids = q_pids
self.g_pids = g_pids
self.q_camids = q_camids
self.g_camids = g_camids
self.indices = np.argsort(1 - sim, axis=1)
self.indices = np.argsort(dist, axis=1)
self.matches = (g_pids[self.indices] == q_pids[:, np.newaxis]).astype(np.int32)
self.num_query = len(q_pids)
@ -148,54 +151,91 @@ class Visualizer:
assert rank_sort in ['ascending', 'descending'], "{} not match [ascending, descending]".format(rank_sort)
query_indices = np.argsort(self.all_ap)
if rank_sort == 'descending': query_indices = query_indices[::-1]
if rank_sort == 'descending': query_indices = query_indices[::-1]
query_indices = query_indices[:num_vis]
self.save_rank_result(query_indices, output, max_rank, vis_label, label_sort, actmap)
def vis_roc_curve(self, output):
PathManager.mkdirs(output)
pos_sim, neg_sim = [], []
pos, neg = [], []
for i, q in enumerate(self.q_pids):
cmc, sort_idx = self.get_matched_result(i) # remove same id in same camera
ind_pos = np.where(cmc == 1)[0]
q_dist = self.sim[i]
pos_sim.extend(q_dist[sort_idx[ind_pos]])
q_dist = self.dist[i]
pos.extend(q_dist[sort_idx[ind_pos]])
ind_neg = np.where(cmc == 0)[0]
neg_sim.extend(q_dist[sort_idx[ind_neg]])
neg.extend(q_dist[sort_idx[ind_neg]])
pos = 1 - np.array(pos_sim)
neg = 1 - np.array(neg_sim)
scores = np.hstack((pos, neg))
labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg))))
fpr, tpr, thresholds = metrics.roc_curve(labels, scores)
plt.figure()
plt.semilogx(fpr, tpr, 'b')
self.plot_roc_curve(fpr, tpr)
filepath = os.path.join(output, "roc.jpg")
plt.savefig(filepath)
self.plot_distribution(pos, neg)
filepath = os.path.join(output, "pos_neg_dist.jpg")
plt.savefig(filepath)
return fpr, tpr, pos, neg
plt.figure(figsize=(10, 5))
n, bins, _ = plt.hist(pos, bins=80, alpha=0.7, density=True, color='red', label='positive')
@staticmethod
def plot_roc_curve(fpr, tpr, name='model', fig=None):
if fig is None:
fig = plt.figure()
plt.semilogx(np.arange(0, 1, 0.01), np.arange(0, 1, 0.01), 'r', linestyle='--', label='Random guess')
plt.semilogx(fpr, tpr, color=(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)),
label='ROC curve with {}'.format(name))
plt.title('Receiver Operating Characteristic')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='best')
return fig
@staticmethod
def plot_distribution(pos, neg, name='model', fig=None):
if fig is None:
fig = plt.figure()
pos_color = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))
n, bins, _ = plt.hist(pos, bins=80, alpha=0.7, density=True,
color=pos_color,
label='positive with {}'.format(name))
mu = np.mean(pos)
sigma = np.std(pos)
y = norm.pdf(bins, mu, sigma) # fitting curve
plt.plot(bins, y, 'r--') # plot y curve
plt.plot(bins, y, color=pos_color) # plot y curve
n, bins, _ = plt.hist(neg, bins=80, alpha=0.5, density=True, color='blue', label='negative')
neg_color = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))
n, bins, _ = plt.hist(neg, bins=80, alpha=0.5, density=True,
color=neg_color,
label='negative with {}'.format(name))
mu = np.mean(neg)
sigma = np.std(neg)
y = norm.pdf(bins, mu, sigma) # fitting curve
plt.plot(bins, y, 'b--') # plot y curve
plt.plot(bins, y, color=neg_color) # plot y curve
plt.xticks(np.arange(0, 1.5, 0.1))
plt.title('positive and negative pairs distribution')
plt.legend(loc='best')
filepath = os.path.join(output, "pos_neg_dist.jpg")
plt.savefig(filepath)
return fig
@staticmethod
def save_roc_info(output, fpr, tpr, pos, neg):
results = {
"fpr": np.asarray(fpr),
"tpr": np.asarray(tpr),
"pos": np.asarray(pos),
"neg": np.asarray(neg),
}
with open(os.path.join(output, "roc_info.pickle"), "wb") as handle:
pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
@staticmethod
def load_roc_info(path):
with open(path, 'rb') as handle: res = pickle.load(handle)
return res
# def plot_camera_dist(self):
# same_cam, diff_cam = [], []
# for i, q in enumerate(self.q_pids):