# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import matplotlib.pyplot as plt import numpy as np import torch.nn.functional as F from fastai.vision import * from fastai.callbacks import * class ReidInterpretation(): "Interpretation methods for reid models." def __init__(self, learn, test_labels, num_q): self.learn, self.test_labels,self.num_q = learn,test_labels,num_q self.get_distmat() def get_distmat(self): pids = [] camids = [] for p,c in self.test_labels: pids.append(p) camids.append(c) self.q_pids = np.asarray(pids[:self.num_q]) self.g_pids = np.asarray(pids[self.num_q:]) self.q_camids = np.asarray(camids[:self.num_q]) self.g_camids = np.asarray(camids[self.num_q:]) feats, _ = self.learn.get_preds(DatasetType.Test, activ=Lambda(lambda x:x)) feats = F.normalize(feats) qf = feats[:self.num_q] gf = feats[self.num_q:] m, n = qf.shape[0], gf.shape[0] # Cosine distance distmat = torch.mm(qf, gf.t()) self.distmat = to_np(distmat) self.indices = np.argsort(-self.distmat, axis=1) self.matches = (self.g_pids[self.indices] == self.q_pids[:, np.newaxis]).astype(np.int32) def get_matched_result(self, q_index): q_pid = self.q_pids[q_index] q_camid = self.q_camids[q_index] order = self.indices[q_index] remove = (self.g_pids[order] == q_pid) & (self.g_camids[order] == q_camid) keep = np.invert(remove) cmc = self.matches[q_index][keep] sort_idx = order[keep] return cmc, sort_idx def plot_rank_result(self, q_idx, top=5, actmap=False, title="Rank result"): m = self.learn.model.eval() cmc, sort_idx = self.get_matched_result(q_idx) fig,axes = plt.subplots(1, top+1, figsize=(15, 5)) fig.suptitle('query similarity/true(false)') query_im,cl=self.learn.data.dl(DatasetType.Test).dataset[q_idx] query_im.show(ax=axes.flat[0], title='query') if actmap: xb, _ = self.learn.data.one_item(query_im, detach=False, denorm=False) sz = list(xb.shape[-2:]) with hook_output(m.base) as hook_a: _ = m(xb) acts = hook_a.stored[0].cpu() acts = self.get_actmap(acts) axes.flat[0].imshow(acts, alpha=0.3, extent=(0,*sz[::-1], 0), interpolation='bilinear', cmap='jet') for i in range(top): g_idx = self.num_q + sort_idx[i] im,cl = self.learn.data.dl(DatasetType.Test).dataset[g_idx] if cmc[i] == 1: label='true' axes.flat[i+1].add_patch(plt.Rectangle(xy=(0, 0), width=im.size[1]-1, height=im.size[0]-1, edgecolor=(1, 0, 0), fill=False, linewidth=5)) else: label='false' axes.flat[i+1].add_patch(plt.Rectangle(xy=(0, 0), width=im.size[1]-1, height=im.size[0]-1, edgecolor=(0, 0, 1), fill=False, linewidth=5)) im.show(ax=axes.flat[i+1], title=f'{self.distmat[q_idx, sort_idx[i]]:.3f} / {label}') if actmap: xb, _ = self.learn.data.one_item(im, detach=False, denorm=False) sz = list(xb.shape[-2:]) with hook_output(m.base) as hook_a: _ = m(xb) acts = hook_a.stored[0].cpu() acts = self.get_actmap(acts) axes.flat[i+1].imshow(acts, alpha=0.3, extent=(0,*sz[::-1], 0), interpolation='bilinear', cmap='jet') return fig def get_top_error(self): # Iteration over query ids and store query gallery similarity similarity_score = namedtuple('similarityScore', 'query gallery sim cmc') storeCorrect = [] storeWrong = [] for q_index in range(self.num_q): cmc, sort_idx = self.get_matched_result(q_index) single_item = similarity_score(query=q_index, gallery=[self.num_q + sort_idx[i] for i in range(5)], sim=[self.distmat[q_index, sort_idx[i]] for i in range(5)], cmc=cmc[:5]) if cmc[0] == 1: storeCorrect.append(single_item) else: storeWrong.append(single_item) storeCorrect.sort(key=lambda x: x.sim[0]) storeWrong.sort(key=lambda x: x.sim[0], reverse=True) self.storeCorrect = storeCorrect self.storeWrong = storeWrong def plot_top_error(self, error_range=range(0,5), actmap=False, positive=True): if not hasattr(self, 'storeCorrect'): self.get_top_error() if positive: img_list = self.storeCorrect else: img_list = self.storeWrong # Rank top error results, which means negative sample with largest similarity # and positive sample with smallest similarity for i in error_range: q_idx,g_idxs,sim,cmc = img_list[i] self.plot_rank_result(q_idx, actmap=actmap) def plot_positve_negative_dist(self): pos_sim, neg_sim = [], [] for i, q in enumerate(self.q_pids): cmc, sort_idx = self.get_matched_result(i) # remove same id in same camera for j in range(len(cmc)): if cmc[j] == 1: pos_sim.append(self.distmat[i,sort_idx[j]]) else: neg_sim.append(self.distmat[i,sort_idx[j]]) fig = plt.figure(figsize=(10,5)) plt.hist(pos_sim, bins=80, alpha=0.7, density=True, color='red', label='positive') plt.hist(neg_sim, bins=80, alpha=0.5, density=True, color='blue', label='negative') plt.xticks(np.arange(-0.3, 0.8, 0.1)) plt.title('posivie and negative pair distribution') return pos_sim, neg_sim def plot_same_cam_diff_cam_dist(self): same_cam, diff_cam = [], [] for i, q in enumerate(self.q_pids): q_camid = self.q_camids[i] order = self.indices[i] same = (self.g_pids[order] == q) & (self.g_camids[order] == q_camid) diff = (self.g_pids[order] == q) & (self.g_camids[order] != q_camid) sameCam_idx = order[same] diffCam_idx = order[diff] same_cam.extend(self.distmat[i, sameCam_idx]) diff_cam.extend(self.distmat[i, diffCam_idx]) fig = plt.figure(figsize=(10,5)) plt.hist(same_cam, bins=80, alpha=0.7, density=True, color='red', label='same camera') plt.hist(diff_cam, bins=80, alpha=0.5, density=True, color='blue', label='diff camera') plt.xticks(np.arange(0.1, 1.0, 0.1)) plt.title('posivie and negative pair distribution') return fig def get_actmap(self, features): features = (features ** 2).sum(0) h, w = features.size() features = features.view(1, h*w) features = F.normalize(features, p=2, dim=1) acts = features.view(h, w) acts = (acts - acts.max()) / (acts.max() - acts.min() + 1e-12) return to_np(acts)