PaddleClas/deploy/vector_search/interface.py

240 lines
8.2 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import paddle
import numpy.ctypeslib as ctl
import numpy as np
import os
import json
import platform
from ctypes import *
from numpy.ctypeslib import ndpointer
__dir__ = os.path.dirname(os.path.abspath(__file__))
if platform.system() == "Windows":
lib_filename = "index.dll"
else:
lib_filename = "index.so"
so_path = os.path.join(__dir__, lib_filename)
lib = ctypes.cdll.LoadLibrary(so_path)
class IndexContext(Structure):
_fields_ = [("graph", c_void_p), ("data", c_void_p)]
# for mobius IP index
build_mobius_index = lib.build_mobius_index
build_mobius_index.restype = None
build_mobius_index.argtypes = [
ctl.ndpointer(
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
ctypes.c_int, ctypes.c_double, ctypes.c_char_p
]
search_mobius_index = lib.search_mobius_index
search_mobius_index.restype = None
search_mobius_index.argtypes = [
ctl.ndpointer(
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
ctypes.c_int, POINTER(IndexContext), ctl.ndpointer(
np.uint64, flags='aligned, c_contiguous'), ctl.ndpointer(
np.float64, flags='aligned, c_contiguous')
]
load_mobius_index_prefix = lib.load_mobius_index_prefix
load_mobius_index_prefix.restype = None
load_mobius_index_prefix.argtypes = [
ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p
]
save_mobius_index_prefix = lib.save_mobius_index_prefix
save_mobius_index_prefix.restype = None
save_mobius_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]
# for L2 index
build_l2_index = lib.build_l2_index
build_l2_index.restype = None
build_l2_index.argtypes = [
ctl.ndpointer(
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
ctypes.c_int, ctypes.c_char_p
]
search_l2_index = lib.search_l2_index
search_l2_index.restype = None
search_l2_index.argtypes = [
ctl.ndpointer(
np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
ctypes.c_int, POINTER(IndexContext), ctl.ndpointer(
np.uint64, flags='aligned, c_contiguous'), ctl.ndpointer(
np.float64, flags='aligned, c_contiguous')
]
load_l2_index_prefix = lib.load_l2_index_prefix
load_l2_index_prefix.restype = None
load_l2_index_prefix.argtypes = [
ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p
]
save_l2_index_prefix = lib.save_l2_index_prefix
save_l2_index_prefix.restype = None
save_l2_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]
release_context = lib.release_context
release_context.restype = None
release_context.argtypes = [POINTER(IndexContext)]
class Graph_Index(object):
"""
graph index
"""
def __init__(self, dist_type="IP"):
self.dim = 0
self.total_num = 0
self.dist_type = dist_type
self.mobius_pow = 2.0
self.index_context = IndexContext(0, 0)
self.gallery_doc_dict = {}
self.with_attr = False
assert dist_type in ["IP", "L2"], "Only support IP and L2 distance ..."
def build(self,
gallery_vectors,
gallery_docs=[],
pq_size=100,
index_path='graph_index/'):
"""
build index
"""
if paddle.is_tensor(gallery_vectors):
gallery_vectors = gallery_vectors.numpy()
assert gallery_vectors.ndim == 2, "Input vector must be 2D ..."
self.total_num = gallery_vectors.shape[0]
self.dim = gallery_vectors.shape[1]
assert (len(gallery_docs) == self.total_num
if len(gallery_docs) > 0 else True)
print("training index -> num: {}, dim: {}, dist_type: {}".format(
self.total_num, self.dim, self.dist_type))
if not os.path.exists(index_path):
os.makedirs(index_path)
if self.dist_type == "IP":
build_mobius_index(
gallery_vectors, self.total_num, self.dim, pq_size,
self.mobius_pow,
create_string_buffer((index_path + "/index").encode('utf-8')))
load_mobius_index_prefix(
self.total_num, self.dim,
ctypes.byref(self.index_context),
create_string_buffer((index_path + "/index").encode('utf-8')))
else:
build_l2_index(
gallery_vectors, self.total_num, self.dim, pq_size,
create_string_buffer((index_path + "/index").encode('utf-8')))
load_l2_index_prefix(
self.total_num, self.dim,
ctypes.byref(self.index_context),
create_string_buffer((index_path + "/index").encode('utf-8')))
self.gallery_doc_dict = {}
if len(gallery_docs) > 0:
self.with_attr = True
for i in range(gallery_vectors.shape[0]):
self.gallery_doc_dict[str(i)] = gallery_docs[i]
self.gallery_doc_dict["total_num"] = self.total_num
self.gallery_doc_dict["dim"] = self.dim
self.gallery_doc_dict["dist_type"] = self.dist_type
self.gallery_doc_dict["with_attr"] = self.with_attr
with open(index_path + "/info.json", "w") as f:
json.dump(self.gallery_doc_dict, f)
print("finished creating index ...")
def search(self, query, return_k=10, search_budget=100):
"""
search
"""
ret_id = np.zeros(return_k, dtype=np.uint64)
ret_score = np.zeros(return_k, dtype=np.float64)
if paddle.is_tensor(query):
query = query.numpy()
if self.dist_type == "IP":
search_mobius_index(query, self.dim, search_budget, return_k,
ctypes.byref(self.index_context), ret_id,
ret_score)
else:
search_l2_index(query, self.dim, search_budget, return_k,
ctypes.byref(self.index_context), ret_id,
ret_score)
ret_id = ret_id.tolist()
ret_doc = []
if self.with_attr:
for i in range(return_k):
ret_doc.append(self.gallery_doc_dict[str(ret_id[i])])
return ret_score, ret_doc
else:
return ret_score, ret_id
def dump(self, index_path):
if not os.path.exists(index_path):
os.makedirs(index_path)
if self.dist_type == "IP":
save_mobius_index_prefix(
ctypes.byref(self.index_context),
create_string_buffer((index_path + "/index").encode('utf-8')))
else:
save_l2_index_prefix(
ctypes.byref(self.index_context),
create_string_buffer((index_path + "/index").encode('utf-8')))
with open(index_path + "/info.json", "w") as f:
json.dump(self.gallery_doc_dict, f)
def load(self, index_path):
self.gallery_doc_dict = {}
with open(index_path + "/info.json", "r") as f:
self.gallery_doc_dict = json.load(f)
self.total_num = self.gallery_doc_dict["total_num"]
self.dim = self.gallery_doc_dict["dim"]
self.dist_type = self.gallery_doc_dict["dist_type"]
self.with_attr = self.gallery_doc_dict["with_attr"]
if self.dist_type == "IP":
load_mobius_index_prefix(
self.total_num, self.dim,
ctypes.byref(self.index_context),
create_string_buffer((index_path + "/index").encode('utf-8')))
else:
load_l2_index_prefix(
self.total_num, self.dim,
ctypes.byref(self.index_context),
create_string_buffer((index_path + "/index").encode('utf-8')))