faiss/benchs/distributed_ondisk/search_server.py

223 lines
6.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
from faiss.contrib import rpc
import combined_index
import argparse
############################################################
# Server implementation
############################################################
class MyServer(rpc.Server):
""" Assign version that can be exposed via RPC """
def __init__(self, s, index):
rpc.Server.__init__(self, s)
self.index = index
def __getattr__(self, f):
return getattr(self.index, f)
def main():
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('server options')
aa('--port', default=12012, type=int, help='server port')
aa('--when_ready_dir', default=None,
help='store host:port to this file when ready')
aa('--ipv4', default=False, action='store_true', help='force ipv4')
aa('--rank', default=0, type=int,
help='rank used as index in the client table')
args = parser.parse_args()
when_ready = None
if args.when_ready_dir:
when_ready = '%s/%d' % (args.when_ready_dir, args.rank)
print('loading index')
index = combined_index.CombinedIndexDeep1B()
print('starting server')
rpc.run_server(
lambda s: MyServer(s, index),
args.port, report_to_file=when_ready,
v6=not args.ipv4)
if __name__ == '__main__':
main()
############################################################
# Client implementation
############################################################
from multiprocessing.pool import ThreadPool
import faiss
import numpy as np
class ResultHeap:
""" Combine query results from a sliced dataset (for k-nn search) """
def __init__(self, nq, k):
" nq: number of query vectors, k: number of results per query "
self.I = np.zeros((nq, k), dtype='int64')
self.D = np.zeros((nq, k), dtype='float32')
self.nq, self.k = nq, k
heaps = faiss.float_maxheap_array_t()
heaps.k = k
heaps.nh = nq
heaps.val = faiss.swig_ptr(self.D)
heaps.ids = faiss.swig_ptr(self.I)
heaps.heapify()
self.heaps = heaps
def add_batch_result(self, D, I, i0):
assert D.shape == (self.nq, self.k)
assert I.shape == (self.nq, self.k)
I += i0
self.heaps.addn_with_ids(
self.k, faiss.swig_ptr(D),
faiss.swig_ptr(I), self.k)
def finalize(self):
self.heaps.reorder()
def distribute_weights(weights, nbin):
""" assign a set of weights to a smaller set of bins to balance them """
nw = weights.size
o = weights.argsort()
bins = np.zeros(nbin)
assign = np.ones(nw, dtype=int)
for i in o[::-1]:
b = bins.argmin()
assign[i] = b
bins[b] += weights[i]
return bins, assign
class SplitPerListIndex:
"""manages a local index, that does the coarse quantization and a set
of sub_indexes. The sub_indexes search a subset of the inverted
lists. The SplitPerListIndex merges results from the sub-indexes"""
def __init__(self, index, sub_indexes):
self.index = index
self.code_size = faiss.extract_index_ivf(index.index).code_size
self.sub_indexes = sub_indexes
self.ni = len(self.sub_indexes)
# pool of threads. Each thread manages one sub-index.
self.pool = ThreadPool(self.ni)
self.verbose = False
def set_nprobe(self, nprobe):
self.index.set_nprobe(nprobe)
self.pool.map(
lambda i: self.sub_indexes[i].set_nprobe(nprobe),
range(self.ni)
)
def set_omp_num_threads(self, nt):
faiss.omp_set_num_threads(nt)
self.pool.map(
lambda idx: idx.set_omp_num_threads(nt),
self.sub_indexes
)
def set_parallel_mode(self, pm):
self.index.set_parallel_mode(pm)
self.pool.map(
lambda idx: idx.set_parallel_mode(pm),
self.sub_indexes
)
def set_prefetch_nthread(self, nt):
self.index.set_prefetch_nthread(nt)
self.pool.map(
lambda idx: idx.set_prefetch_nthread(nt),
self.sub_indexes
)
def balance_lists(self, list_nos):
big_il = self.index.big_il
weights = np.array([big_il.list_size(int(i))
for i in list_nos.ravel()])
bins, assign = distribute_weights(weights, self.ni)
if self.verbose:
print('bins weight range %d:%d total %d (%.2f MiB)' % (
bins.min(), bins.max(), bins.sum(),
bins.sum() * (self.code_size + 8) / 2 ** 20))
self.nscan = bins.sum()
return assign.reshape(list_nos.shape)
def search(self, x, k):
xqo, list_nos, coarse_dis = self.index.transform_and_assign(x)
assign = self.balance_lists(list_nos)
def do_query(i):
sub_index = self.sub_indexes[i]
list_nos_i = list_nos.copy()
list_nos_i[assign != i] = -1
t0 = time.time()
Di, Ii = sub_index.ivf_search_preassigned(
xqo, list_nos_i, coarse_dis, k)
#print(list_nos_i, Ii)
if self.verbose:
print('client %d: %.3f s' % (i, time.time() - t0))
return Di, Ii
rh = ResultHeap(x.shape[0], k)
for Di, Ii in self.pool.imap(do_query, range(self.ni)):
#print("ADD", Ii, rh.I)
rh.add_batch_result(Di, Ii, 0)
rh.finalize()
return rh.D, rh.I
def range_search(self, x, radius):
xqo, list_nos, coarse_dis = self.index.transform_and_assign(x)
assign = self.balance_lists(list_nos)
nq = len(x)
def do_query(i):
sub_index = self.sub_indexes[i]
list_nos_i = list_nos.copy()
list_nos_i[assign != i] = -1
t0 = time.time()
limi, Di, Ii = sub_index.ivf_range_search_preassigned(
xqo, list_nos_i, coarse_dis, radius)
if self.verbose:
print('slice %d: %.3f s' % (i, time.time() - t0))
return limi, Di, Ii
D = [[] for i in range(nq)]
I = [[] for i in range(nq)]
sizes = np.zeros(nq, dtype=int)
for lims, Di, Ii in self.pool.imap(do_query, range(self.ni)):
for i in range(nq):
l0, l1 = lims[i:i + 2]
D[i].append(Di[l0:l1])
I[i].append(Ii[l0:l1])
sizes[i] += l1 - l0
lims = np.zeros(nq + 1, dtype=int)
lims[1:] = np.cumsum(sizes)
D = np.hstack([j for i in D for j in i])
I = np.hstack([j for i in I for j in i])
return lims, D, I