fast-reid/engine/interpreter.py

171 lines
7.1 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from fastai.vision import *
from fastai.callbacks import *
class ReidInterpretation():
"Interpretation methods for reid models."
def __init__(self, learn, test_labels, num_q):
self.learn, self.test_labels,self.num_q = learn,test_labels,num_q
self.get_distmat()
def get_distmat(self):
pids = []
camids = []
for p,c in self.test_labels:
pids.append(p)
camids.append(c)
self.q_pids = np.asarray(pids[:self.num_q])
self.g_pids = np.asarray(pids[self.num_q:])
self.q_camids = np.asarray(camids[:self.num_q])
self.g_camids = np.asarray(camids[self.num_q:])
feats, _ = self.learn.get_preds(DatasetType.Test, activ=Lambda(lambda x:x))
feats = F.normalize(feats)
qf = feats[:self.num_q]
gf = feats[self.num_q:]
m, n = qf.shape[0], gf.shape[0]
# Cosine distance
distmat = torch.mm(qf, gf.t())
self.distmat = to_np(distmat)
self.indices = np.argsort(-self.distmat, axis=1)
self.matches = (self.g_pids[self.indices] == self.q_pids[:, np.newaxis]).astype(np.int32)
def get_matched_result(self, q_index):
q_pid = self.q_pids[q_index]
q_camid = self.q_camids[q_index]
order = self.indices[q_index]
remove = (self.g_pids[order] == q_pid) & (self.g_camids[order] == q_camid)
keep = np.invert(remove)
cmc = self.matches[q_index][keep]
sort_idx = order[keep]
return cmc, sort_idx
def plot_rank_result(self, q_idx, top=5, actmap=False, title="Rank result"):
m = self.learn.model.eval()
cmc, sort_idx = self.get_matched_result(q_idx)
fig,axes = plt.subplots(1, top+1, figsize=(15, 5))
fig.suptitle('query similarity/true(false)')
query_im,cl=self.learn.data.dl(DatasetType.Test).dataset[q_idx]
query_im.show(ax=axes.flat[0], title='query')
if actmap:
xb, _ = self.learn.data.one_item(query_im, detach=False, denorm=False)
sz = list(xb.shape[-2:])
with hook_output(m.base) as hook_a:
_ = m(xb)
acts = hook_a.stored[0].cpu()
acts = self.get_actmap(acts)
axes.flat[0].imshow(acts, alpha=0.3, extent=(0,*sz[::-1], 0), interpolation='bilinear', cmap='jet')
for i in range(top):
g_idx = self.num_q + sort_idx[i]
im,cl = self.learn.data.dl(DatasetType.Test).dataset[g_idx]
if cmc[i] == 1:
label='true'
axes.flat[i+1].add_patch(plt.Rectangle(xy=(0, 0), width=im.size[1]-1, height=im.size[0]-1,
edgecolor=(1, 0, 0), fill=False, linewidth=5))
else:
label='false'
axes.flat[i+1].add_patch(plt.Rectangle(xy=(0, 0), width=im.size[1]-1, height=im.size[0]-1,
edgecolor=(0, 0, 1), fill=False, linewidth=5))
im.show(ax=axes.flat[i+1], title=f'{self.distmat[q_idx, sort_idx[i]]:.3f} / {label}')
if actmap:
xb, _ = self.learn.data.one_item(im, detach=False, denorm=False)
sz = list(xb.shape[-2:])
with hook_output(m.base) as hook_a:
_ = m(xb)
acts = hook_a.stored[0].cpu()
acts = self.get_actmap(acts)
axes.flat[i+1].imshow(acts, alpha=0.3, extent=(0,*sz[::-1], 0), interpolation='bilinear', cmap='jet')
return fig
def get_top_error(self):
# Iteration over query ids and store query gallery similarity
similarity_score = namedtuple('similarityScore', 'query gallery sim cmc')
storeCorrect = []
storeWrong = []
for q_index in range(self.num_q):
cmc, sort_idx = self.get_matched_result(q_index)
single_item = similarity_score(query=q_index, gallery=[self.num_q + sort_idx[i] for i in range(5)],
sim=[self.distmat[q_index, sort_idx[i]] for i in range(5)],
cmc=cmc[:5])
if cmc[0] == 1:
storeCorrect.append(single_item)
else:
storeWrong.append(single_item)
storeCorrect.sort(key=lambda x: x.sim[0])
storeWrong.sort(key=lambda x: x.sim[0], reverse=True)
self.storeCorrect = storeCorrect
self.storeWrong = storeWrong
def plot_top_error(self, error_range=range(0,5), actmap=False, positive=True):
if not hasattr(self, 'storeCorrect'):
self.get_top_error()
if positive:
img_list = self.storeCorrect
else:
img_list = self.storeWrong
# Rank top error results, which means negative sample with largest similarity
# and positive sample with smallest similarity
for i in error_range:
q_idx,g_idxs,sim,cmc = img_list[i]
self.plot_rank_result(q_idx, actmap=actmap)
def plot_positve_negative_dist(self):
pos_sim, neg_sim = [], []
for i, q in enumerate(self.q_pids):
cmc, sort_idx = self.get_matched_result(i) # remove same id in same camera
for j in range(len(cmc)):
if cmc[j] == 1:
pos_sim.append(self.distmat[i,sort_idx[j]])
else:
neg_sim.append(self.distmat[i,sort_idx[j]])
fig = plt.figure(figsize=(10,5))
plt.hist(pos_sim, bins=80, alpha=0.7, density=True, color='red', label='positive')
plt.hist(neg_sim, bins=80, alpha=0.5, density=True, color='blue', label='negative')
plt.xticks(np.arange(-0.3, 0.8, 0.1))
plt.title('posivie and negative pair distribution')
return pos_sim, neg_sim
def plot_same_cam_diff_cam_dist(self):
same_cam, diff_cam = [], []
for i, q in enumerate(self.q_pids):
q_camid = self.q_camids[i]
order = self.indices[i]
same = (self.g_pids[order] == q) & (self.g_camids[order] == q_camid)
diff = (self.g_pids[order] == q) & (self.g_camids[order] != q_camid)
sameCam_idx = order[same]
diffCam_idx = order[diff]
same_cam.extend(self.distmat[i, sameCam_idx])
diff_cam.extend(self.distmat[i, diffCam_idx])
fig = plt.figure(figsize=(10,5))
plt.hist(same_cam, bins=80, alpha=0.7, density=True, color='red', label='same camera')
plt.hist(diff_cam, bins=80, alpha=0.5, density=True, color='blue', label='diff camera')
plt.xticks(np.arange(0.1, 1.0, 0.1))
plt.title('posivie and negative pair distribution')
return fig
def get_actmap(self, features):
features = (features ** 2).sum(0)
h, w = features.size()
features = features.view(1, h*w)
features = F.normalize(features, p=2, dim=1)
acts = features.view(h, w)
acts = (acts - acts.max()) / (acts.max() - acts.min() + 1e-12)
return to_np(acts)