fix so make in windows (#849)
* fix so make in windows * add index.exe for win * fix doc * fix yaml * fix exe to dllpull/851/head
parent
fd4a548897
commit
e2b4ca58c6
|
@ -1,6 +1,15 @@
|
|||
CXX=/usr/bin/g++-5
|
||||
CXX=g++
|
||||
|
||||
ifeq ($(OS),Windows_NT)
|
||||
postfix=dll
|
||||
else
|
||||
postfix=so
|
||||
endif
|
||||
|
||||
all : index
|
||||
|
||||
index.so : src/config.h src/graph.h src/data.h interface.cc
|
||||
$(CXX) -shared -fPIC interface.cc -o index.so -std=c++11 -Ofast -march=native -g -flto -funroll-loops -DOMP -fopenmp
|
||||
index : src/config.h src/graph.h src/data.h interface.cc
|
||||
${CXX} -shared -fPIC interface.cc -o index.${postfix} -std=c++11 -Ofast -march=native -g -flto -funroll-loops -DOMP -fopenmp
|
||||
|
||||
clean :
|
||||
rm index.${postfix}
|
|
@ -1,8 +1,7 @@
|
|||
# 向量检索
|
||||
|
||||
|
||||
|
||||
## 简介
|
||||
## 1. 简介
|
||||
|
||||
一些垂域识别任务(如车辆、商品等)需要识别的类别数较大,往往采用基于检索的方式,通过查询向量与底库向量进行快速的最近邻搜索,获得匹配的预测类别。向量检索模块提供基础的近似最近邻搜索算法,基于百度自研的Möbius算法,一种基于图的近似最近邻搜索算法,用于最大内积搜索 (MIPS)。 该模块提供python接口,支持numpy和 tensor类型向量,支持L2和Inner Product距离计算。
|
||||
|
||||
|
@ -10,17 +9,49 @@ Mobius 算法细节详见论文 ([Möbius Transformation for Fast Inner Produc
|
|||
|
||||
|
||||
|
||||
## 安装
|
||||
## 2. 安装
|
||||
|
||||
若index.so不可用,在项目目录下运行以下命令生成新的index.so文件
|
||||
### 2.1 直接使用提供的库文件
|
||||
|
||||
make index.so
|
||||
该文件夹下有已经编译好的`index.so`(gcc8.2.0下编译,用于Linux)以及`index.dll`(gcc10.3.0下编译,用于Windows),可以跳过2.2与2.3节,直接使用。
|
||||
|
||||
编译环境: g++ 5.4.0 , 9.3.0. 其他版本也可能工作。 请确保您的 C++ 编译器支持 C++11 标准。
|
||||
如果因为gcc版本过低或者环境不兼容的问题,导致库文件无法使用,则需要在不同的平台下手动编译库文件。
|
||||
|
||||
**注意:**
|
||||
请确保您的 C++ 编译器支持 C++11 标准。
|
||||
|
||||
|
||||
### 2.2 Linux上编译生成库文件
|
||||
|
||||
## 快速使用
|
||||
运行下面的命令,安装gcc与g++。
|
||||
|
||||
```shell
|
||||
sudo apt-get update
|
||||
sudo apt-get upgrade -y
|
||||
sudo apt-get install build-essential gcc g++
|
||||
```
|
||||
|
||||
可以通过命令`gcc -v`查看gcc版本。
|
||||
|
||||
进入该文件夹,直接运行`make`即可,如果希望重新生成`index.so`文件,可以首先使用`make clean`清除已经生成的缓存,再使用`make`生成更新之后的库文件。
|
||||
|
||||
|
||||
### 2.3 Windows上编译生成库文件
|
||||
|
||||
Windows上首先需要安装gcc编译工具,推荐使用[TDM-GCC](https://jmeubank.github.io/tdm-gcc/articles/2020-03/9.2.0-release),进入官网之后,可以选择合适的版本进行下载。推荐下载[tdm64-gcc-10.3.0-2.exe](https://github.com/jmeubank/tdm-gcc/releases/download/v10.3.0-tdm64-2/tdm64-gcc-10.3.0-2.exe)。
|
||||
|
||||
下载完成之后,按照默认的安装步骤进行安装即可。这里有3点需要注意:
|
||||
1. 向量检索模块依赖于openmp,因此在安装到`choose components`步骤的时候,需要勾选上`openmp`的安装选项,否则之后编译的时候会报错`libgomp.spec: No such file or directory`,[参考链接](https://github.com/dmlc/xgboost/issues/1027)
|
||||
2. 安装过程中会提示是否需要添加到系统的环境变量中,这里建议勾选上,否则之后使用的时候还需要手动添加系统环境变量。
|
||||
3. Linux上的编译命令为`make`,Windows上为`mingw32-make`,这里需要区分一下。
|
||||
|
||||
|
||||
安装完成后,可以打开一个命令行终端,通过命令`gcc -v`查看gcc版本。
|
||||
|
||||
在该文件夹下,运行命令`mingw32-make`,即可生成`index.dll`库文件。如果希望重新生成`index.dll`文件,可以首先使用`mingw32-make clean`清除已经生成的缓存,再使用`mingw32-make`生成更新之后的库文件。
|
||||
|
||||
|
||||
## 3. 快速使用
|
||||
|
||||
import numpy as np
|
||||
from interface import Graph_Index
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -18,48 +18,77 @@ import numpy.ctypeslib as ctl
|
|||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import platform
|
||||
|
||||
from ctypes import *
|
||||
from numpy.ctypeslib import ndpointer
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
so_path = os.path.join(__dir__, "index.so")
|
||||
if platform.system() == "Windows":
|
||||
lib_filename = "index.dll"
|
||||
else:
|
||||
lib_filename = "index.so"
|
||||
so_path = os.path.join(__dir__, lib_filename)
|
||||
lib = ctypes.cdll.LoadLibrary(so_path)
|
||||
|
||||
|
||||
class IndexContext(Structure):
|
||||
_fields_=[("graph",c_void_p),
|
||||
("data",c_void_p)]
|
||||
_fields_ = [("graph", c_void_p), ("data", c_void_p)]
|
||||
|
||||
|
||||
# for mobius IP index
|
||||
build_mobius_index = lib.build_mobius_index
|
||||
build_mobius_index.restype = None
|
||||
build_mobius_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_double, ctypes.c_char_p]
|
||||
build_mobius_index.argtypes = [
|
||||
ctl.ndpointer(
|
||||
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
|
||||
ctypes.c_int, ctypes.c_double, ctypes.c_char_p
|
||||
]
|
||||
|
||||
search_mobius_index = lib.search_mobius_index
|
||||
search_mobius_index.restype = None
|
||||
search_mobius_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,ctypes.c_int,POINTER(IndexContext),ctl.ndpointer(np.uint64, flags='aligned, c_contiguous'),ctl.ndpointer(np.float64, flags='aligned, c_contiguous')]
|
||||
search_mobius_index.argtypes = [
|
||||
ctl.ndpointer(
|
||||
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
|
||||
ctypes.c_int, POINTER(IndexContext), ctl.ndpointer(
|
||||
np.uint64, flags='aligned, c_contiguous'), ctl.ndpointer(
|
||||
np.float64, flags='aligned, c_contiguous')
|
||||
]
|
||||
|
||||
load_mobius_index_prefix = lib.load_mobius_index_prefix
|
||||
load_mobius_index_prefix.restype = None
|
||||
load_mobius_index_prefix.argtypes = [ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p]
|
||||
load_mobius_index_prefix.argtypes = [
|
||||
ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p
|
||||
]
|
||||
|
||||
save_mobius_index_prefix = lib.save_mobius_index_prefix
|
||||
save_mobius_index_prefix.restype = None
|
||||
save_mobius_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]
|
||||
|
||||
|
||||
# for L2 index
|
||||
build_l2_index = lib.build_l2_index
|
||||
build_l2_index.restype = None
|
||||
build_l2_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p]
|
||||
build_l2_index.argtypes = [
|
||||
ctl.ndpointer(
|
||||
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
|
||||
ctypes.c_int, ctypes.c_char_p
|
||||
]
|
||||
|
||||
search_l2_index = lib.search_l2_index
|
||||
search_l2_index.restype = None
|
||||
search_l2_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,ctypes.c_int,POINTER(IndexContext),ctl.ndpointer(np.uint64, flags='aligned, c_contiguous'),ctl.ndpointer(np.float64, flags='aligned, c_contiguous')]
|
||||
search_l2_index.argtypes = [
|
||||
ctl.ndpointer(
|
||||
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
|
||||
ctypes.c_int, POINTER(IndexContext), ctl.ndpointer(
|
||||
np.uint64, flags='aligned, c_contiguous'), ctl.ndpointer(
|
||||
np.float64, flags='aligned, c_contiguous')
|
||||
]
|
||||
|
||||
load_l2_index_prefix = lib.load_l2_index_prefix
|
||||
load_l2_index_prefix.restype = None
|
||||
load_l2_index_prefix.argtypes = [ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p]
|
||||
load_l2_index_prefix.argtypes = [
|
||||
ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p
|
||||
]
|
||||
|
||||
save_l2_index_prefix = lib.save_l2_index_prefix
|
||||
save_l2_index_prefix.restype = None
|
||||
|
@ -70,11 +99,11 @@ release_context.restype = None
|
|||
release_context.argtypes = [POINTER(IndexContext)]
|
||||
|
||||
|
||||
|
||||
class Graph_Index(object):
|
||||
"""
|
||||
graph index
|
||||
"""
|
||||
|
||||
def __init__(self, dist_type="IP"):
|
||||
self.dim = 0
|
||||
self.total_num = 0
|
||||
|
@ -85,7 +114,11 @@ class Graph_Index(object):
|
|||
self.with_attr = False
|
||||
assert dist_type in ["IP", "L2"], "Only support IP and L2 distance ..."
|
||||
|
||||
def build(self, gallery_vectors, gallery_docs=[], pq_size=100, index_path='graph_index/'):
|
||||
def build(self,
|
||||
gallery_vectors,
|
||||
gallery_docs=[],
|
||||
pq_size=100,
|
||||
index_path='graph_index/'):
|
||||
"""
|
||||
build index
|
||||
"""
|
||||
|
@ -96,19 +129,32 @@ class Graph_Index(object):
|
|||
self.total_num = gallery_vectors.shape[0]
|
||||
self.dim = gallery_vectors.shape[1]
|
||||
|
||||
assert (len(gallery_docs) == self.total_num if len(gallery_docs)>0 else True)
|
||||
assert (len(gallery_docs) == self.total_num
|
||||
if len(gallery_docs) > 0 else True)
|
||||
|
||||
print("training index -> num: {}, dim: {}, dist_type: {}".format(self.total_num, self.dim, self.dist_type))
|
||||
print("training index -> num: {}, dim: {}, dist_type: {}".format(
|
||||
self.total_num, self.dim, self.dist_type))
|
||||
|
||||
if not os.path.exists(index_path):
|
||||
os.makedirs(index_path)
|
||||
|
||||
if self.dist_type == "IP":
|
||||
build_mobius_index(gallery_vectors,self.total_num,self.dim, pq_size, self.mobius_pow, create_string_buffer((index_path+"/index").encode('utf-8')))
|
||||
load_mobius_index_prefix(self.total_num, self.dim, ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
|
||||
build_mobius_index(
|
||||
gallery_vectors, self.total_num, self.dim, pq_size,
|
||||
self.mobius_pow,
|
||||
create_string_buffer((index_path + "/index").encode('utf-8')))
|
||||
load_mobius_index_prefix(
|
||||
self.total_num, self.dim,
|
||||
ctypes.byref(self.index_context),
|
||||
create_string_buffer((index_path + "/index").encode('utf-8')))
|
||||
else:
|
||||
build_l2_index(gallery_vectors,self.total_num,self.dim, pq_size, create_string_buffer((index_path+"/index").encode('utf-8')))
|
||||
load_l2_index_prefix(self.total_num, self.dim, ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
|
||||
build_l2_index(
|
||||
gallery_vectors, self.total_num, self.dim, pq_size,
|
||||
create_string_buffer((index_path + "/index").encode('utf-8')))
|
||||
load_l2_index_prefix(
|
||||
self.total_num, self.dim,
|
||||
ctypes.byref(self.index_context),
|
||||
create_string_buffer((index_path + "/index").encode('utf-8')))
|
||||
|
||||
self.gallery_doc_dict = {}
|
||||
if len(gallery_docs) > 0:
|
||||
|
@ -136,9 +182,13 @@ class Graph_Index(object):
|
|||
if paddle.is_tensor(query):
|
||||
query = query.numpy()
|
||||
if self.dist_type == "IP":
|
||||
search_mobius_index(query,self.dim,search_budget,return_k,ctypes.byref(self.index_context),ret_id,ret_score)
|
||||
search_mobius_index(query, self.dim, search_budget, return_k,
|
||||
ctypes.byref(self.index_context), ret_id,
|
||||
ret_score)
|
||||
else:
|
||||
search_l2_index(query,self.dim,search_budget,return_k,ctypes.byref(self.index_context),ret_id,ret_score)
|
||||
search_l2_index(query, self.dim, search_budget, return_k,
|
||||
ctypes.byref(self.index_context), ret_id,
|
||||
ret_score)
|
||||
|
||||
ret_id = ret_id.tolist()
|
||||
ret_doc = []
|
||||
|
@ -155,9 +205,13 @@ class Graph_Index(object):
|
|||
os.makedirs(index_path)
|
||||
|
||||
if self.dist_type == "IP":
|
||||
save_mobius_index_prefix(ctypes.byref(self.index_context),create_string_buffer((index_path+"/index").encode('utf-8')))
|
||||
save_mobius_index_prefix(
|
||||
ctypes.byref(self.index_context),
|
||||
create_string_buffer((index_path + "/index").encode('utf-8')))
|
||||
else:
|
||||
save_l2_index_prefix(ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
|
||||
save_l2_index_prefix(
|
||||
ctypes.byref(self.index_context),
|
||||
create_string_buffer((index_path + "/index").encode('utf-8')))
|
||||
|
||||
with open(index_path + "/info.json", "w") as f:
|
||||
json.dump(self.gallery_doc_dict, f)
|
||||
|
@ -174,9 +228,12 @@ class Graph_Index(object):
|
|||
self.with_attr = self.gallery_doc_dict["with_attr"]
|
||||
|
||||
if self.dist_type == "IP":
|
||||
load_mobius_index_prefix(self.total_num,self.dim,ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
|
||||
load_mobius_index_prefix(
|
||||
self.total_num, self.dim,
|
||||
ctypes.byref(self.index_context),
|
||||
create_string_buffer((index_path + "/index").encode('utf-8')))
|
||||
else:
|
||||
load_l2_index_prefix(self.total_num,self.dim,ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
|
||||
|
||||
|
||||
|
||||
load_l2_index_prefix(
|
||||
self.total_num, self.dim,
|
||||
ctypes.byref(self.index_context),
|
||||
create_string_buffer((index_path + "/index").encode('utf-8')))
|
||||
|
|
|
@ -96,7 +96,7 @@ DataLoader:
|
|||
dataset:
|
||||
name: LogoDataset
|
||||
image_root: "dataset/LogoDet-3K-crop/val/"
|
||||
cls_label_path: "LogoDet-3K-crop/LogoDet-3K+query.txt"
|
||||
cls_label_path: "dataset/LogoDet-3K-crop/LogoDet-3K+query.txt"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
|
|
Loading…
Reference in New Issue