223 lines
6.8 KiB
Python
223 lines
6.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import time
|
|
|
|
from faiss.contrib import rpc
|
|
|
|
import combined_index
|
|
import argparse
|
|
|
|
|
|
|
|
############################################################
|
|
# Server implementation
|
|
############################################################
|
|
|
|
|
|
class MyServer(rpc.Server):
|
|
""" Assign version that can be exposed via RPC """
|
|
def __init__(self, s, index):
|
|
rpc.Server.__init__(self, s)
|
|
self.index = index
|
|
|
|
def __getattr__(self, f):
|
|
return getattr(self.index, f)
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
def aa(*args, **kwargs):
|
|
group.add_argument(*args, **kwargs)
|
|
|
|
group = parser.add_argument_group('server options')
|
|
aa('--port', default=12012, type=int, help='server port')
|
|
aa('--when_ready_dir', default=None,
|
|
help='store host:port to this file when ready')
|
|
aa('--ipv4', default=False, action='store_true', help='force ipv4')
|
|
aa('--rank', default=0, type=int,
|
|
help='rank used as index in the client table')
|
|
|
|
args = parser.parse_args()
|
|
|
|
when_ready = None
|
|
if args.when_ready_dir:
|
|
when_ready = '%s/%d' % (args.when_ready_dir, args.rank)
|
|
|
|
print('loading index')
|
|
|
|
index = combined_index.CombinedIndexDeep1B()
|
|
|
|
print('starting server')
|
|
rpc.run_server(
|
|
lambda s: MyServer(s, index),
|
|
args.port, report_to_file=when_ready,
|
|
v6=not args.ipv4)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|
|
|
|
############################################################
|
|
# Client implementation
|
|
############################################################
|
|
|
|
from multiprocessing.pool import ThreadPool
|
|
import faiss
|
|
import numpy as np
|
|
|
|
|
|
|
|
class ResultHeap:
|
|
""" Combine query results from a sliced dataset (for k-nn search) """
|
|
|
|
def __init__(self, nq, k):
|
|
" nq: number of query vectors, k: number of results per query "
|
|
self.I = np.zeros((nq, k), dtype='int64')
|
|
self.D = np.zeros((nq, k), dtype='float32')
|
|
self.nq, self.k = nq, k
|
|
heaps = faiss.float_maxheap_array_t()
|
|
heaps.k = k
|
|
heaps.nh = nq
|
|
heaps.val = faiss.swig_ptr(self.D)
|
|
heaps.ids = faiss.swig_ptr(self.I)
|
|
heaps.heapify()
|
|
self.heaps = heaps
|
|
|
|
def add_batch_result(self, D, I, i0):
|
|
assert D.shape == (self.nq, self.k)
|
|
assert I.shape == (self.nq, self.k)
|
|
I += i0
|
|
self.heaps.addn_with_ids(
|
|
self.k, faiss.swig_ptr(D),
|
|
faiss.swig_ptr(I), self.k)
|
|
|
|
def finalize(self):
|
|
self.heaps.reorder()
|
|
|
|
def distribute_weights(weights, nbin):
|
|
""" assign a set of weights to a smaller set of bins to balance them """
|
|
nw = weights.size
|
|
o = weights.argsort()
|
|
bins = np.zeros(nbin)
|
|
assign = np.ones(nw, dtype=int)
|
|
for i in o[::-1]:
|
|
b = bins.argmin()
|
|
assign[i] = b
|
|
bins[b] += weights[i]
|
|
return bins, assign
|
|
|
|
|
|
|
|
class SplitPerListIndex:
|
|
"""manages a local index, that does the coarse quantization and a set
|
|
of sub_indexes. The sub_indexes search a subset of the inverted
|
|
lists. The SplitPerListIndex merges results from the sub-indexes"""
|
|
|
|
def __init__(self, index, sub_indexes):
|
|
self.index = index
|
|
self.code_size = faiss.extract_index_ivf(index.index).code_size
|
|
self.sub_indexes = sub_indexes
|
|
self.ni = len(self.sub_indexes)
|
|
# pool of threads. Each thread manages one sub-index.
|
|
self.pool = ThreadPool(self.ni)
|
|
self.verbose = False
|
|
|
|
def set_nprobe(self, nprobe):
|
|
self.index.set_nprobe(nprobe)
|
|
self.pool.map(
|
|
lambda i: self.sub_indexes[i].set_nprobe(nprobe),
|
|
range(self.ni)
|
|
)
|
|
|
|
def set_omp_num_threads(self, nt):
|
|
faiss.omp_set_num_threads(nt)
|
|
self.pool.map(
|
|
lambda idx: idx.set_omp_num_threads(nt),
|
|
self.sub_indexes
|
|
)
|
|
|
|
def set_parallel_mode(self, pm):
|
|
self.index.set_parallel_mode(pm)
|
|
self.pool.map(
|
|
lambda idx: idx.set_parallel_mode(pm),
|
|
self.sub_indexes
|
|
)
|
|
|
|
def set_prefetch_nthread(self, nt):
|
|
self.index.set_prefetch_nthread(nt)
|
|
self.pool.map(
|
|
lambda idx: idx.set_prefetch_nthread(nt),
|
|
self.sub_indexes
|
|
)
|
|
|
|
def balance_lists(self, list_nos):
|
|
big_il = self.index.big_il
|
|
weights = np.array([big_il.list_size(int(i))
|
|
for i in list_nos.ravel()])
|
|
bins, assign = distribute_weights(weights, self.ni)
|
|
if self.verbose:
|
|
print('bins weight range %d:%d total %d (%.2f MiB)' % (
|
|
bins.min(), bins.max(), bins.sum(),
|
|
bins.sum() * (self.code_size + 8) / 2 ** 20))
|
|
self.nscan = bins.sum()
|
|
return assign.reshape(list_nos.shape)
|
|
|
|
def search(self, x, k):
|
|
xqo, list_nos, coarse_dis = self.index.transform_and_assign(x)
|
|
assign = self.balance_lists(list_nos)
|
|
|
|
def do_query(i):
|
|
sub_index = self.sub_indexes[i]
|
|
list_nos_i = list_nos.copy()
|
|
list_nos_i[assign != i] = -1
|
|
t0 = time.time()
|
|
Di, Ii = sub_index.ivf_search_preassigned(
|
|
xqo, list_nos_i, coarse_dis, k)
|
|
#print(list_nos_i, Ii)
|
|
if self.verbose:
|
|
print('client %d: %.3f s' % (i, time.time() - t0))
|
|
return Di, Ii
|
|
|
|
rh = ResultHeap(x.shape[0], k)
|
|
|
|
for Di, Ii in self.pool.imap(do_query, range(self.ni)):
|
|
#print("ADD", Ii, rh.I)
|
|
rh.add_batch_result(Di, Ii, 0)
|
|
rh.finalize()
|
|
return rh.D, rh.I
|
|
|
|
def range_search(self, x, radius):
|
|
xqo, list_nos, coarse_dis = self.index.transform_and_assign(x)
|
|
assign = self.balance_lists(list_nos)
|
|
nq = len(x)
|
|
|
|
def do_query(i):
|
|
sub_index = self.sub_indexes[i]
|
|
list_nos_i = list_nos.copy()
|
|
list_nos_i[assign != i] = -1
|
|
t0 = time.time()
|
|
limi, Di, Ii = sub_index.ivf_range_search_preassigned(
|
|
xqo, list_nos_i, coarse_dis, radius)
|
|
if self.verbose:
|
|
print('slice %d: %.3f s' % (i, time.time() - t0))
|
|
return limi, Di, Ii
|
|
|
|
D = [[] for i in range(nq)]
|
|
I = [[] for i in range(nq)]
|
|
|
|
sizes = np.zeros(nq, dtype=int)
|
|
for lims, Di, Ii in self.pool.imap(do_query, range(self.ni)):
|
|
for i in range(nq):
|
|
l0, l1 = lims[i:i + 2]
|
|
D[i].append(Di[l0:l1])
|
|
I[i].append(Ii[l0:l1])
|
|
sizes[i] += l1 - l0
|
|
lims = np.zeros(nq + 1, dtype=int)
|
|
lims[1:] = np.cumsum(sizes)
|
|
D = np.hstack([j for i in D for j in i])
|
|
I = np.hstack([j for i in I for j in i])
|
|
return lims, D, I
|