mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
70 lines
2.3 KiB
Python
70 lines
2.3 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import numpy as np
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
from fastai.train import *
|
|
from fastai.torch_core import *
|
|
from fastai.basic_data import *
|
|
|
|
|
|
class ReidInterpretation(Interpretation):
|
|
"Interpretation methods for reid models."
|
|
def __init__(self, learn, preds, y_true, losses, ds_type=DatasetType.Valid):
|
|
super().__init__(learn, preds, y_true, losses, ds_type=ds_type)
|
|
|
|
def get_distmat(self, test_labels, num_query):
|
|
pids = []
|
|
camids = []
|
|
for p, c in test_labels:
|
|
pids.append(p)
|
|
camids.append(c)
|
|
self.q_pids = np.asarray(pids[:num_query])
|
|
self.g_pids = np.asarray(pids[num_query:])
|
|
self.q_camids = np.asarray(camids[:num_query])
|
|
self.g_camids = np.asarray(camids[num_query:])
|
|
|
|
qf = self.preds[:num_query]
|
|
gf = self.preds[num_query:]
|
|
m, n = qf.shape[0], gf.shape[0]
|
|
self.num_q=num_query
|
|
# Cosine distance
|
|
distmat = torch.mm(F.normalize(qf), F.normalize(gf).t())
|
|
self.distmat = to_np(distmat)
|
|
|
|
self.indices = np.argsort(self.distmat, axis=1)[:, ::-1]
|
|
self.matches = (self.g_pids[self.indices] == self.q_pids[:, np.newaxis]).astype(np.int32)
|
|
|
|
def plot_rank_result(self, q_idx, top=5, title="Rank result"):
|
|
q_pid = self.q_pids[q_idx]
|
|
q_camid = self.q_camids[q_idx]
|
|
|
|
order = self.indices[q_idx]
|
|
remove = (self.g_pids[order] == q_pid) & (self.g_camids[order] == q_camid)
|
|
keep = np.invert(remove)
|
|
|
|
raw_cmc = self.matches[q_idx][keep]
|
|
matched_idx = self.indices[q_idx][keep]
|
|
|
|
fig,axes = plt.subplots(1, top+1, figsize=(12,5))
|
|
fig.suptitle('query/sim/true(false)')
|
|
query_im,cl=self.learn.data.dl(DatasetType.Test).dataset[q_idx]
|
|
query_im.show(ax=axes.flat[0],title='query')
|
|
for i in range(top):
|
|
if raw_cmc[i] == 1:
|
|
label='true'
|
|
else:
|
|
label='false'
|
|
im_idx=self.num_q+matched_idx[i]+1
|
|
im,cl = self.learn.data.dl(DatasetType.Test).dataset[im_idx]
|
|
im.show(ax=axes.flat[i+1],title=f'{self.distmat[q_idx, im_idx]:.3f} / {label}')
|
|
return fig
|
|
|
|
|
|
|
|
|