faiss/benchs/distributed_ondisk/distributed_query_demo.py

71 lines
1.6 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import faiss
import numpy as np
import time
import rpc
import sys
import combined_index
import search_server
hostnames = sys.argv[1:]
print("Load local index")
ci = combined_index.CombinedIndexDeep1B()
print("connect to clients")
clients = []
for host in hostnames:
client = rpc.Client(host, 12012, v6=False)
clients.append(client)
# check if all servers respond
print("sizes seen by servers:", [cl.get_ntotal() for cl in clients])
# aggregate all clients into a one that uses them all for speed
# note that it also requires a local index ci
sindex = search_server.SplitPerListIndex(ci, clients)
sindex.verbose = True
# set reasonable parameters
ci.set_parallel_mode(1)
ci.set_prefetch_nthread(0)
ci.set_omp_num_threads(64)
# initialize params
sindex.set_parallel_mode(1)
sindex.set_prefetch_nthread(0)
sindex.set_omp_num_threads(64)
def ivecs_read(fname):
a = np.fromfile(fname, dtype='int32')
d = a[0]
return a.reshape(-1, d + 1)[:, 1:].copy()
def fvecs_read(fname):
return ivecs_read(fname).view('float32')
deep1bdir = "/datasets01_101/simsearch/041218/deep1b/"
xq = fvecs_read(deep1bdir + "deep1B_queries.fvecs")
gt_fname = deep1bdir + "deep1B_groundtruth.ivecs"
gt = ivecs_read(gt_fname)
for nprobe in 1, 10, 100, 1000:
sindex.set_nprobe(nprobe)
t0 = time.time()
D, I = sindex.search(xq, 100)
t1 = time.time()
print('nprobe=%d 1-recall@1=%.4f t=%.2fs' % (
nprobe, (I[:, 0] == gt[:, 0]).sum() / len(xq),
t1 - t0
))