fast-reid/engine/interpreter.py
liaoxingyu 710d40aec0 interp
Finish visulization code and interpretation
2019-08-02 19:30:16 +08:00

70 lines
2.3 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import numpy as np
import torch
import matplotlib.pyplot as plt
from fastai.train import *
from fastai.torch_core import *
from fastai.basic_data import *
class ReidInterpretation(Interpretation):
"Interpretation methods for reid models."
def __init__(self, learn, preds, y_true, losses, ds_type=DatasetType.Valid):
super().__init__(learn, preds, y_true, losses, ds_type=ds_type)
def get_distmat(self, test_labels, num_query):
pids = []
camids = []
for p, c in test_labels:
pids.append(p)
camids.append(c)
self.q_pids = np.asarray(pids[:num_query])
self.g_pids = np.asarray(pids[num_query:])
self.q_camids = np.asarray(camids[:num_query])
self.g_camids = np.asarray(camids[num_query:])
qf = self.preds[:num_query]
gf = self.preds[num_query:]
m, n = qf.shape[0], gf.shape[0]
self.num_q=num_query
# Cosine distance
distmat = torch.mm(F.normalize(qf), F.normalize(gf).t())
self.distmat = to_np(distmat)
self.indices = np.argsort(self.distmat, axis=1)[:, ::-1]
self.matches = (self.g_pids[self.indices] == self.q_pids[:, np.newaxis]).astype(np.int32)
def plot_rank_result(self, q_idx, top=5, title="Rank result"):
q_pid = self.q_pids[q_idx]
q_camid = self.q_camids[q_idx]
order = self.indices[q_idx]
remove = (self.g_pids[order] == q_pid) & (self.g_camids[order] == q_camid)
keep = np.invert(remove)
raw_cmc = self.matches[q_idx][keep]
matched_idx = self.indices[q_idx][keep]
fig,axes = plt.subplots(1, top+1, figsize=(12,5))
fig.suptitle('query/sim/true(false)')
query_im,cl=self.learn.data.dl(DatasetType.Test).dataset[q_idx]
query_im.show(ax=axes.flat[0],title='query')
for i in range(top):
if raw_cmc[i] == 1:
label='true'
else:
label='false'
im_idx=self.num_q+matched_idx[i]+1
im,cl = self.learn.data.dl(DatasetType.Test).dataset[im_idx]
im.show(ax=axes.flat[i+1],title=f'{self.distmat[q_idx, im_idx]:.3f} / {label}')
return fig