mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2567 Intuitively, it should be easier to handle big-batch searches because all distance computations for a set of queries can be done locally within each inverted list. This benchmark implements this in pure python (but should be close to optimal in terms of speed), on CPU for IndexIVFFlat, IndexIVFPQ and IndexIVFScalarQuantizer. GPU is also supported. The results are not systematically better, see https://docs.google.com/document/d/1d3YuV8uN7hut6aOATCOMx8Ut-QEl_oRnJdPgDBRF1QA/edit?usp=sharing Reviewed By: algoriddle Differential Revision: D41098338 fbshipit-source-id: 479e471b0d541f242d420f581775d57b708a61b8
488 lines
16 KiB
Python
488 lines
16 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import time
|
|
from multiprocessing.pool import ThreadPool
|
|
import threading
|
|
|
|
import numpy as np
|
|
import faiss
|
|
|
|
from faiss.contrib.inspect_tools import get_invlist
|
|
|
|
|
|
def add_preassigned(index_ivf, x, a, ids=None):
|
|
"""
|
|
Add elements to an IVF index, where the assignment is already computed
|
|
"""
|
|
n, d = x.shape
|
|
assert a.shape == (n, )
|
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
|
d *= 8
|
|
assert d == index_ivf.d
|
|
if ids is not None:
|
|
assert ids.shape == (n, )
|
|
ids = faiss.swig_ptr(ids)
|
|
index_ivf.add_core(
|
|
n, faiss.swig_ptr(x), ids, faiss.swig_ptr(a)
|
|
)
|
|
|
|
|
|
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
|
|
"""
|
|
Perform a search in the IVF index, with predefined lists to search into
|
|
"""
|
|
n, d = xq.shape
|
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
|
d *= 8
|
|
dis_type = "int32"
|
|
else:
|
|
dis_type = "float32"
|
|
|
|
assert d == index_ivf.d
|
|
assert list_nos.shape == (n, index_ivf.nprobe)
|
|
|
|
# the coarse distances are used in IVFPQ with L2 distance and
|
|
# by_residual=True otherwise we provide dummy coarse_dis
|
|
if coarse_dis is None:
|
|
coarse_dis = np.zeros((n, index_ivf.nprobe), dtype=dis_type)
|
|
else:
|
|
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
|
|
|
D = np.empty((n, k), dtype=dis_type)
|
|
I = np.empty((n, k), dtype='int64')
|
|
|
|
sp = faiss.swig_ptr
|
|
index_ivf.search_preassigned(
|
|
n, sp(xq), k,
|
|
sp(list_nos), sp(coarse_dis), sp(D), sp(I), False)
|
|
return D, I
|
|
|
|
|
|
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
|
|
"""
|
|
Perform a range search in the IVF index, with predefined lists to
|
|
search into
|
|
"""
|
|
n, d = x.shape
|
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
|
d *= 8
|
|
dis_type = "int32"
|
|
else:
|
|
dis_type = "float32"
|
|
|
|
# the coarse distances are used in IVFPQ with L2 distance and
|
|
# by_residual=True otherwise we provide dummy coarse_dis
|
|
if coarse_dis is None:
|
|
coarse_dis = np.empty((n, index_ivf.nprobe), dtype=dis_type)
|
|
else:
|
|
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
|
|
|
assert d == index_ivf.d
|
|
assert list_nos.shape == (n, index_ivf.nprobe)
|
|
|
|
res = faiss.RangeSearchResult(n)
|
|
sp = faiss.swig_ptr
|
|
|
|
index_ivf.range_search_preassigned(
|
|
n, sp(x), radius,
|
|
sp(list_nos), sp(coarse_dis),
|
|
res
|
|
)
|
|
# get pointers and copy them
|
|
lims = faiss.rev_swig_ptr(res.lims, n + 1).copy()
|
|
num_results = int(lims[-1])
|
|
dist = faiss.rev_swig_ptr(res.distances, num_results).copy()
|
|
indices = faiss.rev_swig_ptr(res.labels, num_results).copy()
|
|
return lims, dist, indices
|
|
|
|
|
|
class BigBatchSearcher:
|
|
"""
|
|
Object that manages all the data related to the computation
|
|
except the actual within-bucket matching and the organization of the
|
|
computation (parallel or not)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
index, xq, k,
|
|
verbose=0,
|
|
use_float16=False):
|
|
|
|
# verbosity
|
|
self.verbose = verbose
|
|
self.tictoc = []
|
|
|
|
self.xq = xq
|
|
self.index = index
|
|
self.use_float16 = use_float16
|
|
keep_max = index.metric_type == faiss.METRIC_INNER_PRODUCT
|
|
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
|
|
self.t_accu = [0] * 5
|
|
self.t_display = self.t0 = time.time()
|
|
|
|
def start_t_accu(self):
|
|
self.t_accu_t0 = time.time()
|
|
|
|
def stop_t_accu(self, n):
|
|
self.t_accu[n] += time.time() - self.t_accu_t0
|
|
|
|
def tic(self, name):
|
|
self.tictoc = (name, time.time())
|
|
if self.verbose > 0:
|
|
print(name, end="\r", flush=True)
|
|
|
|
def toc(self):
|
|
name, t0 = self.tictoc
|
|
dt = time.time() - t0
|
|
if self.verbose > 0:
|
|
print(f"{name}: {dt:.3f} s")
|
|
return dt
|
|
|
|
def report(self, l):
|
|
if self.verbose == 1 or (
|
|
l > 1000 and time.time() < self.t_display + 1.0):
|
|
return
|
|
print(
|
|
f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} "
|
|
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
|
|
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
|
|
f"wait {self.t_accu[4]:.3f}",
|
|
end="\r", flush=True
|
|
)
|
|
self.t_display = time.time()
|
|
|
|
def coarse_quantization(self):
|
|
self.tic("coarse quantization")
|
|
bs = 65536
|
|
nq = len(self.xq)
|
|
q_assign = np.empty((nq, self.index.nprobe), dtype='int32')
|
|
for i0 in range(0, nq, bs):
|
|
i1 = min(nq, i0 + bs)
|
|
q_dis_i, q_assign_i = self.index.quantizer.search(
|
|
self.xq[i0:i1], self.index.nprobe)
|
|
# q_dis[i0:i1] = q_dis_i
|
|
q_assign[i0:i1] = q_assign_i
|
|
self.toc()
|
|
self.q_assign = q_assign
|
|
|
|
def reorder_assign(self):
|
|
self.tic("bucket sort")
|
|
q_assign = self.q_assign
|
|
q_assign += 1 # move -1 -> 0
|
|
self.bucket_lims = faiss.matrix_bucket_sort_inplace(
|
|
self.q_assign, nbucket=self.index.nlist + 1, nt=16)
|
|
self.query_ids = self.q_assign.ravel()
|
|
if self.verbose > 0:
|
|
print(' number of -1s:', self.bucket_lims[1])
|
|
self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s
|
|
del self.q_assign # inplace so let's forget about the old version...
|
|
self.toc()
|
|
|
|
def prepare_bucket(self, l):
|
|
""" prepare the queries and database items for bucket l"""
|
|
t0 = time.time()
|
|
index = self.index
|
|
# prepare queries
|
|
i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1]
|
|
q_subset = self.query_ids[i0:i1]
|
|
xq_l = self.xq[q_subset]
|
|
if self.by_residual:
|
|
xq_l = xq_l - index.quantizer.reconstruct(l)
|
|
t1 = time.time()
|
|
# prepare database side
|
|
list_ids, xb_l = get_invlist(index.invlists, l)
|
|
|
|
if self.decode_func is None:
|
|
xb_l = xb_l.ravel()
|
|
else:
|
|
xb_l = self.decode_func(xb_l)
|
|
|
|
if self.use_float16:
|
|
xb_l = xb_l.astype('float16')
|
|
xq_l = xq_l.astype('float16')
|
|
|
|
t2 = time.time()
|
|
self.t_accu[0] += t1 - t0
|
|
self.t_accu[1] += t2 - t1
|
|
return q_subset, xq_l, list_ids, xb_l
|
|
|
|
def add_results_to_heap(self, q_subset, D, list_ids, I):
|
|
"""add the bucket results to the heap structure"""
|
|
if D is None:
|
|
return
|
|
t0 = time.time()
|
|
if I is None:
|
|
I = list_ids
|
|
else:
|
|
I = list_ids[I]
|
|
self.rh.add_result_subset(q_subset, D, I)
|
|
self.t_accu[3] += time.time() - t0
|
|
|
|
|
|
class BlockComputer:
|
|
""" computation within one bucket """
|
|
|
|
def __init__(
|
|
self,
|
|
index,
|
|
method="knn_function",
|
|
pairwise_distances=faiss.pairwise_distances,
|
|
knn=faiss.knn):
|
|
|
|
self.index = index
|
|
if index.__class__ == faiss.IndexIVFFlat:
|
|
index_help = faiss.IndexFlat(index.d, index.metric_type)
|
|
decode_func = lambda x: x.view("float32")
|
|
by_residual = False
|
|
elif index.__class__ == faiss.IndexIVFPQ:
|
|
index_help = faiss.IndexPQ(
|
|
index.d, index.pq.M, index.pq.nbits, index.metric_type)
|
|
index_help.pq = index.pq
|
|
decode_func = index_help.pq.decode
|
|
index_help.is_trained = True
|
|
by_residual = index.by_residual
|
|
elif index.__class__ == faiss.IndexIVFScalarQuantizer:
|
|
index_help = faiss.IndexScalarQuantizer(
|
|
index.d, index.sq.qtype, index.metric_type)
|
|
index_help.sq = index.sq
|
|
decode_func = index_help.sq.decode
|
|
index_help.is_trained = True
|
|
by_residual = index.by_residual
|
|
else:
|
|
raise RuntimeError(f"index type {index.__class__} not supported")
|
|
self.index_help = index_help
|
|
self.decode_func = None if method == "index" else decode_func
|
|
self.by_residual = by_residual
|
|
self.method = method
|
|
self.pairwise_distances = pairwise_distances
|
|
self.knn = knn
|
|
|
|
def block_search(self, xq_l, xb_l, list_ids, k, **extra_args):
|
|
metric_type = self.index.metric_type
|
|
if xq_l.size == 0 or xb_l.size == 0:
|
|
D = I = None
|
|
elif self.method == "index":
|
|
faiss.copy_array_to_vector(xb_l, self.index_help.codes)
|
|
self.index_help.ntotal = len(list_ids)
|
|
D, I = self.index_help.search(xq_l, k)
|
|
elif self.method == "pairwise_distances":
|
|
# TODO implement blockwise to avoid mem blowup
|
|
D = self.pairwise_distances(xq_l, xb_l, metric=metric_type)
|
|
I = None
|
|
elif self.method == "knn_function":
|
|
D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args)
|
|
|
|
return D, I
|
|
|
|
|
|
def big_batch_search(
|
|
index, xq, k,
|
|
method="knn_function",
|
|
pairwise_distances=faiss.pairwise_distances,
|
|
knn=faiss.knn,
|
|
verbose=0,
|
|
threaded=0,
|
|
use_float16=False,
|
|
prefetch_threads=8,
|
|
computation_threads=0,
|
|
q_assign=None):
|
|
"""
|
|
Search queries xq in the IVF index, with a search function that collects
|
|
batches of query vectors per inverted list. This can be faster than the
|
|
regular search indexes.
|
|
Supports IVFFlat, IVFPQ and IVFScalarQuantizer.
|
|
|
|
Supports three computation methods:
|
|
method = "index":
|
|
build a flat index and populate it separately for each index
|
|
method = "pairwise_distances":
|
|
decompress codes and compute all pairwise distances for the queries
|
|
and index and add result to heap
|
|
method = "knn_function":
|
|
decompress codes and compute knn results for the queries
|
|
|
|
threaded=0: sequential execution
|
|
threaded=1: prefetch next bucket while computing the current one
|
|
threaded>1: prefetch this many buckets at a time.
|
|
|
|
compute_threads>1: the knn function will get an additional thread_no that
|
|
tells which worker should handle this.
|
|
|
|
In threaded mode, the computation is tiled with the bucket perparation and
|
|
the writeback of results (useful to maximize GPU utilization).
|
|
|
|
use_float16: convert all matrices to float16 (faster for GPU gemm)
|
|
|
|
q_assign: override coarse assignment
|
|
"""
|
|
nprobe = index.nprobe
|
|
|
|
assert method in ("index", "pairwise_distances", "knn_function")
|
|
|
|
mem_queries = xq.nbytes
|
|
mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize
|
|
mem_res = len(xq) * k * (
|
|
np.dtype('int64').itemsize
|
|
+ np.dtype('float32').itemsize
|
|
)
|
|
mem_tot = mem_queries + mem_assign + mem_res
|
|
if verbose > 0:
|
|
print(
|
|
f"memory: queries {mem_queries} assign {mem_assign} "
|
|
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
|
|
)
|
|
|
|
bbs = BigBatchSearcher(
|
|
index, xq, k,
|
|
verbose=verbose,
|
|
use_float16=use_float16
|
|
)
|
|
|
|
comp = BlockComputer(
|
|
index,
|
|
method=method,
|
|
pairwise_distances=pairwise_distances,
|
|
knn=knn
|
|
)
|
|
|
|
bbs.decode_func = comp.decode_func
|
|
bbs.by_residual = comp.by_residual
|
|
|
|
if q_assign is None:
|
|
bbs.coarse_quantization()
|
|
else:
|
|
bbs.q_assign = q_assign
|
|
bbs.reorder_assign()
|
|
|
|
if threaded == 0:
|
|
# simple sequential version
|
|
|
|
for l in range(index.nlist):
|
|
bbs.report(l)
|
|
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l)
|
|
t0i = time.time()
|
|
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
|
bbs.t_accu[2] += time.time() - t0i
|
|
bbs.add_results_to_heap(q_subset, D, list_ids, I)
|
|
|
|
elif threaded == 1:
|
|
|
|
# parallel version with granularity 1
|
|
|
|
def add_results_and_prefetch(to_add, l):
|
|
""" perform the addition for the previous bucket and
|
|
prefetch the next (if applicable) """
|
|
if to_add is not None:
|
|
bbs.add_results_to_heap(*to_add)
|
|
if l < index.nlist:
|
|
return bbs.prepare_bucket(l)
|
|
|
|
prefetched_bucket = bbs.prepare_bucket(0)
|
|
to_add = None
|
|
pool = ThreadPool(1)
|
|
|
|
for l in range(index.nlist):
|
|
bbs.report(l)
|
|
prefetched_bucket_a = pool.apply_async(
|
|
add_results_and_prefetch, (to_add, l + 1))
|
|
q_subset, xq_l, list_ids, xb_l = prefetched_bucket
|
|
bbs.start_t_accu()
|
|
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
|
bbs.stop_t_accu(2)
|
|
to_add = q_subset, D, list_ids, I
|
|
bbs.start_t_accu()
|
|
prefetched_bucket = prefetched_bucket_a.get()
|
|
bbs.stop_t_accu(4)
|
|
|
|
bbs.add_results_to_heap(*to_add)
|
|
pool.close()
|
|
else:
|
|
# run by batches with parallel prefetch and parallel comp
|
|
list_step = threaded
|
|
|
|
if prefetch_threads == 0:
|
|
prefetch_map = map
|
|
else:
|
|
prefetch_pool = ThreadPool(prefetch_threads)
|
|
prefetch_map = prefetch_pool.map
|
|
|
|
if computation_threads > 0:
|
|
comp_pool = ThreadPool(computation_threads)
|
|
|
|
def add_results_and_prefetch_batch(to_add, l):
|
|
def add_results(to_add):
|
|
for ta in to_add: # this one cannot be run in parallel...
|
|
if ta is not None:
|
|
bbs.add_results_to_heap(*ta)
|
|
if prefetch_threads == 0:
|
|
add_results(to_add)
|
|
else:
|
|
add_a = prefetch_pool.apply_async(add_results, (to_add, ))
|
|
next_lists = range(l, min(l + list_step, index.nlist))
|
|
res = list(prefetch_map(bbs.prepare_bucket, next_lists))
|
|
if prefetch_threads > 0:
|
|
add_a.get()
|
|
return res
|
|
|
|
# used only when computation_threads > 1
|
|
thread_id_to_seq_lock = threading.Lock()
|
|
thread_id_to_seq = {}
|
|
|
|
def do_comp(bucket):
|
|
(q_subset, xq_l, list_ids, xb_l) = bucket
|
|
try:
|
|
tid = thread_id_to_seq[threading.get_ident()]
|
|
except KeyError:
|
|
with thread_id_to_seq_lock:
|
|
tid = len(thread_id_to_seq)
|
|
thread_id_to_seq[threading.get_ident()] = tid
|
|
D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid)
|
|
return q_subset, D, list_ids, I
|
|
|
|
prefetched_buckets = add_results_and_prefetch_batch([], 0)
|
|
to_add = []
|
|
pool = ThreadPool(1)
|
|
prefetched_buckets_a = None
|
|
|
|
# loop over inverted lists
|
|
for l in range(0, index.nlist, list_step):
|
|
bbs.report(l)
|
|
buckets = prefetched_buckets
|
|
prefetched_buckets_a = pool.apply_async(
|
|
add_results_and_prefetch_batch, (to_add, l + list_step))
|
|
|
|
bbs.start_t_accu()
|
|
|
|
to_add = []
|
|
if computation_threads == 0:
|
|
for q_subset, xq_l, list_ids, xb_l in buckets:
|
|
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
|
to_add.append((q_subset, D, list_ids, I))
|
|
else:
|
|
to_add = list(comp_pool.map(do_comp, buckets))
|
|
|
|
bbs.stop_t_accu(2)
|
|
|
|
bbs.start_t_accu()
|
|
prefetched_buckets = prefetched_buckets_a.get()
|
|
bbs.stop_t_accu(4)
|
|
|
|
# flush add
|
|
for ta in to_add:
|
|
bbs.add_results_to_heap(*ta)
|
|
pool.close()
|
|
if prefetch_threads != 0:
|
|
prefetch_pool.close()
|
|
if computation_threads != 0:
|
|
comp_pool.close()
|
|
|
|
bbs.tic("finalize heap")
|
|
bbs.rh.finalize()
|
|
bbs.toc()
|
|
|
|
return bbs.rh.D, bbs.rh.I
|