mirror of https://github.com/JDAI-CV/fast-reid.git
167 lines
5.6 KiB
Python
167 lines
5.6 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
import torch.nn.functional as F
|
|
from collections import defaultdict
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from torch.backends import cudnn
|
|
from fastreid.modeling import build_model
|
|
from fastreid.utils.checkpoint import Checkpointer
|
|
from fastreid.config import get_cfg
|
|
|
|
cudnn.benchmark = True
|
|
|
|
|
|
class Reid(object):
|
|
|
|
def __init__(self, config_file):
|
|
cfg = get_cfg()
|
|
cfg.merge_from_file(config_file)
|
|
cfg.defrost()
|
|
cfg.MODEL.WEIGHTS = 'projects/bjzProject/logs/bjz/arcface_adam/model_final.pth'
|
|
model = build_model(cfg)
|
|
Checkpointer(model).resume_or_load(cfg.MODEL.WEIGHTS)
|
|
|
|
model.cuda()
|
|
model.eval()
|
|
self.model = model
|
|
# self.model = torch.jit.load("reid_model.pt")
|
|
# self.model.eval()
|
|
# self.model.cuda()
|
|
|
|
example = torch.rand(1, 3, 256, 128)
|
|
example = example.cuda()
|
|
traced_script_module = torch.jit.trace_module(model, {'inference': example})
|
|
traced_script_module.save("reid_feat_extractor.pt")
|
|
|
|
@classmethod
|
|
def preprocess(cls, img_path):
|
|
img = cv2.imread(img_path)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img = cv2.resize(img, (128, 256))
|
|
img = img / 255.0
|
|
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
|
|
img = img.transpose((2, 0, 1)).astype(np.float32)
|
|
img = img[np.newaxis, :, :, :]
|
|
data = torch.from_numpy(img).cuda().float()
|
|
return data
|
|
|
|
@torch.no_grad()
|
|
def demo(self, img_path):
|
|
data = self.preprocess(img_path)
|
|
output = self.model.inference(data)
|
|
feat = output.cpu().data.numpy()
|
|
return feat
|
|
|
|
# @torch.no_grad()
|
|
# def extract_feat(self, dataloader):
|
|
# prefetcher = test_data_prefetcher(dataloader)
|
|
# feats = []
|
|
# labels = []
|
|
# batch = prefetcher.next()
|
|
# num_count = 0
|
|
# while batch[0] is not None:
|
|
# img, pid, camid = batch
|
|
# feat = self.model(img)
|
|
# feats.append(feat.cpu())
|
|
# labels.extend(np.asarray(pid))
|
|
#
|
|
# # if num_count > 2:
|
|
# # break
|
|
# batch = prefetcher.next()
|
|
# # num_count += 1
|
|
#
|
|
# feats = torch.cat(feats, dim=0)
|
|
# id_feats = defaultdict(list)
|
|
# for f, i in zip(feats, labels):
|
|
# id_feats[i].append(f)
|
|
# all_feats = []
|
|
# label_names = []
|
|
# for i in id_feats:
|
|
# all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
|
|
# label_names.append(i)
|
|
#
|
|
# label_names = np.asarray(label_names)
|
|
# all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
|
|
# all_feats = F.normalize(all_feats, p=2, dim=1)
|
|
# np.save('feats.npy', all_feats.cpu())
|
|
# np.save('labels.npy', label_names)
|
|
# cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
|
|
# cos -= np.eye(all_feats.shape[0])
|
|
# f = open('check_cross_folder_similarity.txt', 'w')
|
|
# for i in range(len(label_names)):
|
|
# sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
|
|
# sim_name = label_names[sim_indx]
|
|
# write_str = label_names[i] + ' '
|
|
# # f.write(label_names[i]+'\t')
|
|
# for n in sim_name:
|
|
# write_str += (n + ' ')
|
|
# # f.write(n+'\t')
|
|
# f.write(write_str+'\n')
|
|
#
|
|
#
|
|
# def prepare_gt(self, json_file):
|
|
# feat = []
|
|
# label = []
|
|
# with open(json_file, 'r') as f:
|
|
# total = json.load(f)
|
|
# for index in total:
|
|
# label.append(index)
|
|
# feat.append(np.array(total[index]))
|
|
# time_label = [int(i[0:10]) for i in label]
|
|
#
|
|
# return np.array(feat), np.array(label), np.array(time_label)
|
|
|
|
def compute_topk(self, k, feat, feats, label):
|
|
|
|
# num_gallery = feats.shape[0]
|
|
# new_feat = np.tile(feat,[num_gallery,1])
|
|
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
|
|
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
|
|
matrix = np.sum(np.multiply(feat, feats), axis=-1)
|
|
dist = matrix / np.multiply(norm_feat, norm_feats)
|
|
# print('feat:',feat.shape)
|
|
# print('feats:',feats.shape)
|
|
# print('label:',label.shape)
|
|
# print('dist:',dist.shape)
|
|
|
|
index = np.argsort(-dist)
|
|
|
|
# print('index:',index.shape)
|
|
result = []
|
|
for i in range(min(feats.shape[0], k)):
|
|
print(dist[index[i]])
|
|
result.append(label[index[i]])
|
|
return result
|
|
|
|
|
|
if __name__ == '__main__':
|
|
reid_sys = Reid(config_file='../../projects/bjzProject/configs/bjz.yml')
|
|
img_path = '/export/home/lxy/beijingStationReID/reid_model/demo_imgs/003740_c5s2_1561733125170.000000.jpg'
|
|
feat = reid_sys.demo(img_path)
|
|
feat_extractor = torch.jit.load('reid_feat_extractor.pt')
|
|
data = reid_sys.preprocess(img_path)
|
|
feat2 = feat_extractor.inference(data)
|
|
from ipdb import set_trace; set_trace()
|
|
# imgs = os.listdir(img_path)
|
|
# feats = {}
|
|
# for i in range(len(imgs)):
|
|
# feat = reid.demo(os.path.join(img_path, imgs[i]))
|
|
# feats[imgs[i]] = feat
|
|
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
|
|
# out1 = feats['dog.jpg']
|
|
# out2 = feats['kobe2.jpg']
|
|
# innerProduct = np.dot(out1, out2.T)
|
|
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
|
|
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')
|