mirror of https://github.com/JDAI-CV/fast-reid.git
164 lines
5.4 KiB
Python
164 lines
5.4 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
import torch.nn.functional as F
|
|
from collections import defaultdict
|
|
import argparse
|
|
import json
|
|
import os
|
|
from data import get_check_dataloader
|
|
import sys
|
|
import time
|
|
from data.prefetcher import test_data_prefetcher
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from torch.backends import cudnn
|
|
|
|
from modeling import Baseline
|
|
|
|
cudnn.benchmark = True
|
|
|
|
|
|
class Reid(object):
|
|
|
|
def __init__(self, model_path):
|
|
self.model = Baseline('resnet50',
|
|
num_classes=0,
|
|
last_stride=1,
|
|
with_ibn=False,
|
|
with_se=False,
|
|
gcb=None,
|
|
stage_with_gcb=[False, False, False, False],
|
|
pretrain=False,
|
|
model_path='')
|
|
self.model.load_params_wo_fc(torch.load(model_path))
|
|
# state_dict = torch.load('/export/home/lxy/reid_baseline/logs/2019.8.12/bj/ibn_lighting/models/model_119.pth')
|
|
# self.model.load_params_wo_fc(state_dict['model'])
|
|
self.model.cuda()
|
|
self.model.eval()
|
|
# self.model = torch.jit.load("reid_model.pt")
|
|
# self.model.eval()
|
|
# self.model.cuda()
|
|
|
|
# example = torch.rand(1, 3, 256, 128)
|
|
# example = example.cuda()
|
|
# traced_script_module = torch.jit.trace(self.model, example)
|
|
# traced_script_module.save("reid_model.pt")
|
|
|
|
@torch.no_grad()
|
|
def demo(self, img_path):
|
|
img = cv2.imread(img_path)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img = cv2.resize(img, (128, 384))
|
|
img = img / 255.0
|
|
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
|
|
img = img.transpose((2, 0, 1)).astype(np.float32)
|
|
img = img[np.newaxis, :, :, :]
|
|
data = torch.from_numpy(img).cuda().float()
|
|
output = self.model(data)
|
|
feat = output.cpu().data.numpy()
|
|
|
|
return feat
|
|
|
|
@torch.no_grad()
|
|
def extract_feat(self, dataloader):
|
|
prefetcher = test_data_prefetcher(dataloader)
|
|
feats = []
|
|
labels = []
|
|
batch = prefetcher.next()
|
|
num_count = 0
|
|
while batch[0] is not None:
|
|
img, pid, camid = batch
|
|
feat = self.model(img)
|
|
feats.append(feat.cpu())
|
|
labels.extend(np.asarray(pid))
|
|
|
|
# if num_count > 2:
|
|
# break
|
|
batch = prefetcher.next()
|
|
# num_count += 1
|
|
|
|
feats = torch.cat(feats, dim=0)
|
|
id_feats = defaultdict(list)
|
|
for f, i in zip(feats, labels):
|
|
id_feats[i].append(f)
|
|
all_feats = []
|
|
label_names = []
|
|
for i in id_feats:
|
|
all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
|
|
label_names.append(i)
|
|
|
|
label_names = np.asarray(label_names)
|
|
all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
|
|
all_feats = F.normalize(all_feats, p=2, dim=1)
|
|
np.save('feats.npy', all_feats.cpu())
|
|
np.save('labels.npy', label_names)
|
|
cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
|
|
cos -= np.eye(all_feats.shape[0])
|
|
f = open('check_cross_folder_similarity.txt', 'w')
|
|
for i in range(len(label_names)):
|
|
sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
|
|
sim_name = label_names[sim_indx]
|
|
write_str = label_names[i] + ' '
|
|
# f.write(label_names[i]+'\t')
|
|
for n in sim_name:
|
|
write_str += (n + ' ')
|
|
# f.write(n+'\t')
|
|
f.write(write_str+'\n')
|
|
|
|
|
|
def prepare_gt(self, json_file):
|
|
feat = []
|
|
label = []
|
|
with open(json_file, 'r') as f:
|
|
total = json.load(f)
|
|
for index in total:
|
|
label.append(index)
|
|
feat.append(np.array(total[index]))
|
|
time_label = [int(i[0:10]) for i in label]
|
|
|
|
return np.array(feat), np.array(label), np.array(time_label)
|
|
|
|
def compute_topk(self, k, feat, feats, label):
|
|
|
|
# num_gallery = feats.shape[0]
|
|
# new_feat = np.tile(feat,[num_gallery,1])
|
|
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
|
|
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
|
|
matrix = np.sum(np.multiply(feat, feats), axis=-1)
|
|
dist = matrix / np.multiply(norm_feat, norm_feats)
|
|
# print('feat:',feat.shape)
|
|
# print('feats:',feats.shape)
|
|
# print('label:',label.shape)
|
|
# print('dist:',dist.shape)
|
|
|
|
index = np.argsort(-dist)
|
|
|
|
# print('index:',index.shape)
|
|
result = []
|
|
for i in range(min(feats.shape[0], k)):
|
|
print(dist[index[i]])
|
|
result.append(label[index[i]])
|
|
return result
|
|
|
|
|
|
if __name__ == '__main__':
|
|
check_loader = get_check_dataloader()
|
|
reid = Reid('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth')
|
|
reid.extract_feat(check_loader)
|
|
# imgs = os.listdir(img_path)
|
|
# feats = {}
|
|
# for i in range(len(imgs)):
|
|
# feat = reid.demo(os.path.join(img_path, imgs[i]))
|
|
# feats[imgs[i]] = feat
|
|
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
|
|
# out1 = feats['dog.jpg']
|
|
# out2 = feats['kobe2.jpg']
|
|
# innerProduct = np.dot(out1, out2.T)
|
|
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
|
|
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')
|