273 lines
9.5 KiB
Python
273 lines
9.5 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import ctypes
|
|
import paddle
|
|
import numpy.ctypeslib as ctl
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
import json
|
|
import platform
|
|
|
|
from ctypes import *
|
|
from numpy.ctypeslib import ndpointer
|
|
|
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
winmode = None
|
|
if platform.system() == "Windows":
|
|
lib_filename = "index.dll"
|
|
if sys.version_info.minor >= 8:
|
|
winmode = 0x8
|
|
else:
|
|
lib_filename = "index.so"
|
|
so_path = os.path.join(__dir__, lib_filename)
|
|
try:
|
|
if winmode is not None:
|
|
lib = ctypes.CDLL(so_path, winmode=winmode)
|
|
else:
|
|
lib = ctypes.CDLL(so_path)
|
|
except Exception as ex:
|
|
readme_path = os.path.join(__dir__, "README.md")
|
|
print(
|
|
f"Error happened when load lib {so_path} with msg {ex},\nplease refer to {readme_path} to rebuild your library."
|
|
)
|
|
exit(-1)
|
|
|
|
|
|
class IndexContext(Structure):
|
|
_fields_ = [("graph", c_void_p), ("data", c_void_p)]
|
|
|
|
|
|
# for mobius IP index
|
|
build_mobius_index = lib.build_mobius_index
|
|
build_mobius_index.restype = None
|
|
build_mobius_index.argtypes = [
|
|
ctl.ndpointer(
|
|
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
|
|
ctypes.c_int, ctypes.c_double, ctypes.c_char_p
|
|
]
|
|
|
|
search_mobius_index = lib.search_mobius_index
|
|
search_mobius_index.restype = None
|
|
search_mobius_index.argtypes = [
|
|
ctl.ndpointer(
|
|
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
|
|
ctypes.c_int, POINTER(IndexContext), ctl.ndpointer(
|
|
np.uint64, flags='aligned, c_contiguous'), ctl.ndpointer(
|
|
np.float64, flags='aligned, c_contiguous')
|
|
]
|
|
|
|
load_mobius_index_prefix = lib.load_mobius_index_prefix
|
|
load_mobius_index_prefix.restype = None
|
|
load_mobius_index_prefix.argtypes = [
|
|
ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p
|
|
]
|
|
|
|
save_mobius_index_prefix = lib.save_mobius_index_prefix
|
|
save_mobius_index_prefix.restype = None
|
|
save_mobius_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]
|
|
|
|
# for L2 index
|
|
build_l2_index = lib.build_l2_index
|
|
build_l2_index.restype = None
|
|
build_l2_index.argtypes = [
|
|
ctl.ndpointer(
|
|
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
|
|
ctypes.c_int, ctypes.c_char_p
|
|
]
|
|
|
|
search_l2_index = lib.search_l2_index
|
|
search_l2_index.restype = None
|
|
search_l2_index.argtypes = [
|
|
ctl.ndpointer(
|
|
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
|
|
ctypes.c_int, POINTER(IndexContext), ctl.ndpointer(
|
|
np.uint64, flags='aligned, c_contiguous'), ctl.ndpointer(
|
|
np.float64, flags='aligned, c_contiguous')
|
|
]
|
|
|
|
load_l2_index_prefix = lib.load_l2_index_prefix
|
|
load_l2_index_prefix.restype = None
|
|
load_l2_index_prefix.argtypes = [
|
|
ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p
|
|
]
|
|
|
|
save_l2_index_prefix = lib.save_l2_index_prefix
|
|
save_l2_index_prefix.restype = None
|
|
save_l2_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]
|
|
|
|
release_context = lib.release_context
|
|
release_context.restype = None
|
|
release_context.argtypes = [POINTER(IndexContext)]
|
|
|
|
|
|
class Graph_Index(object):
|
|
"""
|
|
graph index
|
|
"""
|
|
|
|
def __init__(self, dist_type="IP"):
|
|
self.dim = 0
|
|
self.total_num = 0
|
|
self.dist_type = dist_type
|
|
self.mobius_pow = 2.0
|
|
self.index_context = IndexContext(0, 0)
|
|
self.gallery_doc_dict = {}
|
|
self.with_attr = False
|
|
assert dist_type in ["IP", "L2"], "Only support IP and L2 distance ..."
|
|
|
|
def build(self,
|
|
gallery_vectors,
|
|
gallery_docs=[],
|
|
pq_size=100,
|
|
index_path='graph_index/',
|
|
append_index=False):
|
|
"""
|
|
build index
|
|
"""
|
|
if paddle.is_tensor(gallery_vectors):
|
|
gallery_vectors = gallery_vectors.numpy()
|
|
assert gallery_vectors.ndim == 2, "Input vector must be 2D ..."
|
|
|
|
self.total_num = gallery_vectors.shape[0]
|
|
self.dim = gallery_vectors.shape[1]
|
|
|
|
assert (len(gallery_docs) == self.total_num
|
|
if len(gallery_docs) > 0 else True)
|
|
|
|
print("training index -> num: {}, dim: {}, dist_type: {}".format(
|
|
self.total_num, self.dim, self.dist_type))
|
|
|
|
if not os.path.exists(index_path):
|
|
os.makedirs(index_path)
|
|
|
|
if self.dist_type == "IP":
|
|
build_mobius_index(
|
|
gallery_vectors, self.total_num, self.dim, pq_size,
|
|
self.mobius_pow,
|
|
create_string_buffer((index_path + "/index").encode('utf-8')))
|
|
load_mobius_index_prefix(
|
|
self.total_num, self.dim,
|
|
ctypes.byref(self.index_context),
|
|
create_string_buffer((index_path + "/index").encode('utf-8')))
|
|
else:
|
|
build_l2_index(
|
|
gallery_vectors, self.total_num, self.dim, pq_size,
|
|
create_string_buffer((index_path + "/index").encode('utf-8')))
|
|
load_l2_index_prefix(
|
|
self.total_num, self.dim,
|
|
ctypes.byref(self.index_context),
|
|
create_string_buffer((index_path + "/index").encode('utf-8')))
|
|
|
|
self.gallery_doc_dict = {}
|
|
if len(gallery_docs) > 0:
|
|
self.with_attr = True
|
|
for i in range(gallery_vectors.shape[0]):
|
|
self.gallery_doc_dict[str(i)] = gallery_docs[i]
|
|
|
|
self.gallery_doc_dict["total_num"] = self.total_num
|
|
self.gallery_doc_dict["dim"] = self.dim
|
|
self.gallery_doc_dict["dist_type"] = self.dist_type
|
|
self.gallery_doc_dict["with_attr"] = self.with_attr
|
|
|
|
output_path = os.path.join(index_path, "info.json")
|
|
if append_index is True and os.path.exists(output_path):
|
|
with open(output_path, "r") as fin:
|
|
lines = fin.readlines()[0]
|
|
ori_gallery_doc_dict = json.loads(lines)
|
|
assert ori_gallery_doc_dict["dist_type"] == self.gallery_doc_dict[
|
|
"dist_type"]
|
|
assert ori_gallery_doc_dict["dim"] == self.gallery_doc_dict["dim"]
|
|
assert ori_gallery_doc_dict["with_attr"] == self.gallery_doc_dict[
|
|
"with_attr"]
|
|
offset = ori_gallery_doc_dict["total_num"]
|
|
for i in range(0, self.gallery_doc_dict["total_num"]):
|
|
ori_gallery_doc_dict[str(i + offset)] = self.gallery_doc_dict[
|
|
str(i)]
|
|
|
|
ori_gallery_doc_dict["total_num"] += self.gallery_doc_dict[
|
|
"total_num"]
|
|
self.gallery_doc_dict = ori_gallery_doc_dict
|
|
with open(output_path, "w") as f:
|
|
json.dump(self.gallery_doc_dict, f)
|
|
|
|
print("finished creating index ...")
|
|
|
|
def search(self, query, return_k=10, search_budget=100):
|
|
"""
|
|
search
|
|
"""
|
|
ret_id = np.zeros(return_k, dtype=np.uint64)
|
|
ret_score = np.zeros(return_k, dtype=np.float64)
|
|
|
|
if paddle.is_tensor(query):
|
|
query = query.numpy()
|
|
if self.dist_type == "IP":
|
|
search_mobius_index(query, self.dim, search_budget, return_k,
|
|
ctypes.byref(self.index_context), ret_id,
|
|
ret_score)
|
|
else:
|
|
search_l2_index(query, self.dim, search_budget, return_k,
|
|
ctypes.byref(self.index_context), ret_id,
|
|
ret_score)
|
|
|
|
ret_id = ret_id.tolist()
|
|
ret_doc = []
|
|
if self.with_attr:
|
|
for i in range(return_k):
|
|
ret_doc.append(self.gallery_doc_dict[str(ret_id[i])])
|
|
return ret_score, ret_doc
|
|
else:
|
|
return ret_score, ret_id
|
|
|
|
def dump(self, index_path):
|
|
|
|
if not os.path.exists(index_path):
|
|
os.makedirs(index_path)
|
|
|
|
if self.dist_type == "IP":
|
|
save_mobius_index_prefix(
|
|
ctypes.byref(self.index_context),
|
|
create_string_buffer((index_path + "/index").encode('utf-8')))
|
|
else:
|
|
save_l2_index_prefix(
|
|
ctypes.byref(self.index_context),
|
|
create_string_buffer((index_path + "/index").encode('utf-8')))
|
|
|
|
with open(index_path + "/info.json", "w") as f:
|
|
json.dump(self.gallery_doc_dict, f)
|
|
|
|
def load(self, index_path):
|
|
self.gallery_doc_dict = {}
|
|
|
|
with open(index_path + "/info.json", "r") as f:
|
|
self.gallery_doc_dict = json.load(f)
|
|
|
|
self.total_num = self.gallery_doc_dict["total_num"]
|
|
self.dim = self.gallery_doc_dict["dim"]
|
|
self.dist_type = self.gallery_doc_dict["dist_type"]
|
|
self.with_attr = self.gallery_doc_dict["with_attr"]
|
|
|
|
if self.dist_type == "IP":
|
|
load_mobius_index_prefix(
|
|
self.total_num, self.dim,
|
|
ctypes.byref(self.index_context),
|
|
create_string_buffer((index_path + "/index").encode('utf-8')))
|
|
else:
|
|
load_l2_index_prefix(
|
|
self.total_num, self.dim,
|
|
ctypes.byref(self.index_context),
|
|
create_string_buffer((index_path + "/index").encode('utf-8')))
|