pull/380/merge
Xuanmeng Zhang 2024-07-15 15:29:06 +08:00 committed by GitHub
commit 188bf5a88e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 460 additions and 0 deletions

View File

@ -0,0 +1,37 @@
# Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective
[[Paper]](https://arxiv.org/abs/2012.07620v2)
On the Market-1501 dataset, we accelerate the re-ranking processing from **89.2s** to **9.4ms** with one K40m GPU, facilitating the real-time post-processing.
Similarly, we observe that our method achieves comparable or even better retrieval results on the other four image retrieval benchmarks,
i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, with limited time cost.
## Prerequisites
The code was mainly developed and tested with python 3.7, PyTorch 1.4.1, CUDA 10.2, and CentOS release 6.10.
The code has been included in `/extension`. To compile it:
```shell
cd extension
sh make.sh
```
## Demo
The demo script `main.py` provides the gnn re-ranking method using the prepared feature.
```shell
python main.py --data_path PATH_TO_DATA --k1 26 --k2 7
```
## Citation
```bibtex
@article{zhang2020understanding,
title={Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective},
author={Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang},
journal={arXiv preprint arXiv:2012.07620},
year={2020}
}
```

View File

@ -0,0 +1,19 @@
#include <torch/extension.h>
#include <iostream>
#include <set>
at::Tensor build_adjacency_matrix_forward(torch::Tensor initial_rank);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor build_adjacency_matrix(at::Tensor initial_rank) {
CHECK_INPUT(initial_rank);
return build_adjacency_matrix_forward(initial_rank);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &build_adjacency_matrix, "build_adjacency_matrix (CUDA)");
}

View File

@ -0,0 +1,31 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#define CUDA_1D_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)
__global__ void build_adjacency_matrix_kernel(float* initial_rank, float* A, const int total_num, const int topk, const int nthreads, const int all_num) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < all_num; i += stride) {
int ii = i / topk;
A[ii * total_num + int(initial_rank[i])] = float(1.0);
}
}
at::Tensor build_adjacency_matrix_forward(at::Tensor initial_rank) {
const auto total_num = initial_rank.size(0);
const auto topk = initial_rank.size(1);
const auto all_num = total_num * topk;
auto A = torch::zeros({total_num, total_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float));
const int threads = 1024;
const int blocks = (all_num + threads - 1) / threads;
build_adjacency_matrix_kernel<<<blocks, threads>>>(initial_rank.data_ptr<float>(), A.data_ptr<float>(), total_num, topk, threads, all_num);
return A;
}

View File

@ -0,0 +1,37 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective
Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang
Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking
Paper: https://arxiv.org/abs/2012.07620v2
======================================================================
On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""
from setuptools import setup, Extension
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='build_adjacency_matrix',
ext_modules=[
CUDAExtension('build_adjacency_matrix', [
'build_adjacency_matrix.cpp',
'build_adjacency_matrix_kernel.cu',
]),
],
cmdclass={
'build_ext':BuildExtension
})

View File

@ -0,0 +1,4 @@
cd adjacency_matrix
python setup.py install
cd ../propagation
python setup.py install

View File

@ -0,0 +1,21 @@
#include <torch/extension.h>
#include <iostream>
#include <set>
at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor gnn_propagate(at::Tensor A ,at::Tensor initial_rank, at::Tensor S) {
CHECK_INPUT(A);
CHECK_INPUT(initial_rank);
CHECK_INPUT(S);
return gnn_propagate_forward(A, initial_rank, S);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &gnn_propagate, "gnn propagate (CUDA)");
}

View File

@ -0,0 +1,36 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <iostream>
__global__ void gnn_propagate_forward_kernel(float* initial_rank, float* A, float* A_qe, float* S, const int sample_num, const int topk, const int total_num) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < total_num; i += stride) {
int fea = i % sample_num;
int sample_index = i / sample_num;
float sum = 0.0;
for (int j = 0; j < topk ; j++) {
int topk_fea_index = int(initial_rank[sample_index*topk+j]) * sample_num + fea;
sum += A[ topk_fea_index] * S[sample_index*topk+j];
}
A_qe[i] = sum;
}
}
at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S) {
const auto sample_num = A.size(0);
const auto topk = initial_rank.size(1);
const auto total_num = sample_num * sample_num ;
auto A_qe = torch::zeros({sample_num, sample_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float));
const int threads = 1024;
const int blocks = (total_num + threads - 1) / threads;
gnn_propagate_forward_kernel<<<blocks, threads>>>(initial_rank.data_ptr<float>(), A.data_ptr<float>(), A_qe.data_ptr<float>(), S.data_ptr<float>(), sample_num, topk, total_num);
return A_qe;
}

View File

@ -0,0 +1,37 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective
Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang
Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking
Paper: https://arxiv.org/abs/2012.07620v2
======================================================================
On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""
from setuptools import setup, Extension
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='gnn_propagate',
ext_modules=[
CUDAExtension('gnn_propagate', [
'gnn_propagate.cpp',
'gnn_propagate_kernel.cu',
]),
],
cmdclass={
'build_ext':BuildExtension
})

View File

@ -0,0 +1,57 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective
Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang
Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking
Paper: https://arxiv.org/abs/2012.07620v2
======================================================================
On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""
import torch
import numpy as np
import build_adjacency_matrix
import gnn_propagate
from utils import *
def gnn_reranking(X_q, X_g, k1, k2):
query_num, gallery_num = X_q.shape[0], X_g.shape[0]
X_u = torch.cat((X_q, X_g), axis = 0)
original_score = torch.mm(X_u, X_u.t())
del X_u, X_q, X_g
# initial ranking list
S, initial_rank = original_score.topk(k=k1, dim=-1, largest=True, sorted=True)
# stage 1
A = build_adjacency_matrix.forward(initial_rank.float())
S = S * S
# stage 2
if k2 != 1:
for i in range(2):
A = A + A.T
A = gnn_propagate.forward(A, initial_rank[:, :k2].contiguous().float(), S[:, :k2].contiguous().float())
A_norm = torch.norm(A, p=2, dim=1, keepdim=True)
A = A.div(A_norm.expand_as(A))
cosine_similarity = torch.mm(A[:query_num,], A[query_num:, ].t())
del A, S
L = torch.sort(-cosine_similarity, dim = 1)[1]
L = L.data.cpu().numpy()
return L

View File

@ -0,0 +1,62 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective
Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang
Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking
Paper: https://arxiv.org/abs/2012.07620v2
======================================================================
On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""
import os
import torch
import argparse
import numpy as np
from utils import *
from gnn_reranking import *
parser = argparse.ArgumentParser(description='Reranking_is_GNN')
parser.add_argument('--data_path',
type=str,
default='../xm_rerank_gpu_2/features/market_88_test.pkl',
help='path to dataset')
parser.add_argument('--k1',
type=int,
default=26, # Market-1501
# default=60, # Veri-776
help='parameter k1')
parser.add_argument('--k2',
type=int,
default=7, # Market-1501
# default=10, # Veri-776
help='parameter k2')
args = parser.parse_args()
def main():
data = load_pickle(args.data_path)
query_cam = data['query_cam']
query_label = data['query_label']
gallery_cam = data['gallery_cam']
gallery_label = data['gallery_label']
gallery_feature = torch.FloatTensor(data['gallery_f'])
query_feature = torch.FloatTensor(data['query_f'])
query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()
indices = gnn_reranking(query_feature, gallery_feature, args.k1, args.k2)
evaluate_ranking_list(indices, query_label, query_cam, gallery_label, gallery_cam)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,119 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective
Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang
Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking
Paper: https://arxiv.org/abs/2012.07620v2
======================================================================
On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""
import pickle
import numpy as np
import os
import torch
def load_pickle(pickle_path):
with open(pickle_path, 'rb') as f:
data = pickle.load(f)
return data
def save_pickle(pickle_path, data):
with open(pickle_path, 'wb') as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
def pairwise_squared_distance(x):
'''
x : (n_samples, n_points, dims)
return : (n_samples, n_points, n_points)
'''
x2s = (x * x).sum(-1, keepdim=True)
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)
def pairwise_distance(x, y):
m, n = x.size(0), y.size(0)
x = x.view(m, -1)
y = y.view(n, -1)
dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n,m).t()
dist.addmm_(1, -2, x, y.t())
return dist
def cosine_similarity(x, y):
m, n = x.size(0), y.size(0)
x = x.view(m, -1)
y = y.view(n, -1)
y = y.t()
score = torch.mm(x, y)
return score
def evaluate_ranking_list(indices, query_label, query_cam, gallery_label, gallery_cam):
CMC = np.zeros((len(gallery_label)), dtype=np.int)
ap = 0.0
for i in range(len(query_label)):
ap_tmp, CMC_tmp = evaluate(indices[i],query_label[i], query_cam[i], gallery_label, gallery_cam)
if CMC_tmp[0]==-1:
continue
CMC = CMC + CMC_tmp
ap += ap_tmp
CMC = CMC.astype(np.float32)
CMC = CMC/len(query_label) #average CMC
print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))
def evaluate(index, ql,qc,gl,gc):
query_index = np.argwhere(gl==ql)
camera_index = np.argwhere(gc==qc)
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
junk_index1 = np.argwhere(gl==-1)
junk_index2 = np.intersect1d(query_index, camera_index)
junk_index = np.append(junk_index2, junk_index1) #.flatten())
CMC_tmp = compute_mAP(index, good_index, junk_index)
return CMC_tmp
def compute_mAP(index, good_index, junk_index):
ap = 0
cmc = np.zeros((len(index)), dtype=np.int)
if good_index.size==0: # if empty
cmc[0] = -1
return ap,cmc
# remove junk_index
mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
# find good_index index
ngood = len(good_index)
mask = np.in1d(index, good_index)
rows_good = np.argwhere(mask==True)
rows_good = rows_good.flatten()
cmc[rows_good[0]:] = 1
for i in range(ngood):
d_recall = 1.0/ngood
precision = (i+1)*1.0/(rows_good[i]+1)
if rows_good[i]!=0:
old_precision = i*1.0/rows_good[i]
else:
old_precision=1.0
ap = ap + d_recall*(old_precision + precision)/2
return ap, cmc