faiss/contrib/ivf_tools.py
Matthijs Douze fa53e2c941 Implementation of big-batch IVF search (single machine) (#2567)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2567

Intuitively, it should be easier to handle big-batch searches because all distance computations for a set of queries can be done locally within each inverted list.

This benchmark implements this in pure python (but should be close to optimal in terms of speed), on CPU for IndexIVFFlat, IndexIVFPQ and IndexIVFScalarQuantizer. GPU is also supported.

The results are not systematically better, see https://docs.google.com/document/d/1d3YuV8uN7hut6aOATCOMx8Ut-QEl_oRnJdPgDBRF1QA/edit?usp=sharing

Reviewed By: algoriddle

Differential Revision: D41098338

fbshipit-source-id: 479e471b0d541f242d420f581775d57b708a61b8
2022-12-09 08:53:13 -08:00

488 lines
16 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
from multiprocessing.pool import ThreadPool
import threading
import numpy as np
import faiss
from faiss.contrib.inspect_tools import get_invlist
def add_preassigned(index_ivf, x, a, ids=None):
"""
Add elements to an IVF index, where the assignment is already computed
"""
n, d = x.shape
assert a.shape == (n, )
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
assert d == index_ivf.d
if ids is not None:
assert ids.shape == (n, )
ids = faiss.swig_ptr(ids)
index_ivf.add_core(
n, faiss.swig_ptr(x), ids, faiss.swig_ptr(a)
)
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
"""
Perform a search in the IVF index, with predefined lists to search into
"""
n, d = xq.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
dis_type = "int32"
else:
dis_type = "float32"
assert d == index_ivf.d
assert list_nos.shape == (n, index_ivf.nprobe)
# the coarse distances are used in IVFPQ with L2 distance and
# by_residual=True otherwise we provide dummy coarse_dis
if coarse_dis is None:
coarse_dis = np.zeros((n, index_ivf.nprobe), dtype=dis_type)
else:
assert coarse_dis.shape == (n, index_ivf.nprobe)
D = np.empty((n, k), dtype=dis_type)
I = np.empty((n, k), dtype='int64')
sp = faiss.swig_ptr
index_ivf.search_preassigned(
n, sp(xq), k,
sp(list_nos), sp(coarse_dis), sp(D), sp(I), False)
return D, I
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
"""
Perform a range search in the IVF index, with predefined lists to
search into
"""
n, d = x.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
dis_type = "int32"
else:
dis_type = "float32"
# the coarse distances are used in IVFPQ with L2 distance and
# by_residual=True otherwise we provide dummy coarse_dis
if coarse_dis is None:
coarse_dis = np.empty((n, index_ivf.nprobe), dtype=dis_type)
else:
assert coarse_dis.shape == (n, index_ivf.nprobe)
assert d == index_ivf.d
assert list_nos.shape == (n, index_ivf.nprobe)
res = faiss.RangeSearchResult(n)
sp = faiss.swig_ptr
index_ivf.range_search_preassigned(
n, sp(x), radius,
sp(list_nos), sp(coarse_dis),
res
)
# get pointers and copy them
lims = faiss.rev_swig_ptr(res.lims, n + 1).copy()
num_results = int(lims[-1])
dist = faiss.rev_swig_ptr(res.distances, num_results).copy()
indices = faiss.rev_swig_ptr(res.labels, num_results).copy()
return lims, dist, indices
class BigBatchSearcher:
"""
Object that manages all the data related to the computation
except the actual within-bucket matching and the organization of the
computation (parallel or not)
"""
def __init__(
self,
index, xq, k,
verbose=0,
use_float16=False):
# verbosity
self.verbose = verbose
self.tictoc = []
self.xq = xq
self.index = index
self.use_float16 = use_float16
keep_max = index.metric_type == faiss.METRIC_INNER_PRODUCT
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
self.t_accu = [0] * 5
self.t_display = self.t0 = time.time()
def start_t_accu(self):
self.t_accu_t0 = time.time()
def stop_t_accu(self, n):
self.t_accu[n] += time.time() - self.t_accu_t0
def tic(self, name):
self.tictoc = (name, time.time())
if self.verbose > 0:
print(name, end="\r", flush=True)
def toc(self):
name, t0 = self.tictoc
dt = time.time() - t0
if self.verbose > 0:
print(f"{name}: {dt:.3f} s")
return dt
def report(self, l):
if self.verbose == 1 or (
l > 1000 and time.time() < self.t_display + 1.0):
return
print(
f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} "
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
f"wait {self.t_accu[4]:.3f}",
end="\r", flush=True
)
self.t_display = time.time()
def coarse_quantization(self):
self.tic("coarse quantization")
bs = 65536
nq = len(self.xq)
q_assign = np.empty((nq, self.index.nprobe), dtype='int32')
for i0 in range(0, nq, bs):
i1 = min(nq, i0 + bs)
q_dis_i, q_assign_i = self.index.quantizer.search(
self.xq[i0:i1], self.index.nprobe)
# q_dis[i0:i1] = q_dis_i
q_assign[i0:i1] = q_assign_i
self.toc()
self.q_assign = q_assign
def reorder_assign(self):
self.tic("bucket sort")
q_assign = self.q_assign
q_assign += 1 # move -1 -> 0
self.bucket_lims = faiss.matrix_bucket_sort_inplace(
self.q_assign, nbucket=self.index.nlist + 1, nt=16)
self.query_ids = self.q_assign.ravel()
if self.verbose > 0:
print(' number of -1s:', self.bucket_lims[1])
self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s
del self.q_assign # inplace so let's forget about the old version...
self.toc()
def prepare_bucket(self, l):
""" prepare the queries and database items for bucket l"""
t0 = time.time()
index = self.index
# prepare queries
i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1]
q_subset = self.query_ids[i0:i1]
xq_l = self.xq[q_subset]
if self.by_residual:
xq_l = xq_l - index.quantizer.reconstruct(l)
t1 = time.time()
# prepare database side
list_ids, xb_l = get_invlist(index.invlists, l)
if self.decode_func is None:
xb_l = xb_l.ravel()
else:
xb_l = self.decode_func(xb_l)
if self.use_float16:
xb_l = xb_l.astype('float16')
xq_l = xq_l.astype('float16')
t2 = time.time()
self.t_accu[0] += t1 - t0
self.t_accu[1] += t2 - t1
return q_subset, xq_l, list_ids, xb_l
def add_results_to_heap(self, q_subset, D, list_ids, I):
"""add the bucket results to the heap structure"""
if D is None:
return
t0 = time.time()
if I is None:
I = list_ids
else:
I = list_ids[I]
self.rh.add_result_subset(q_subset, D, I)
self.t_accu[3] += time.time() - t0
class BlockComputer:
""" computation within one bucket """
def __init__(
self,
index,
method="knn_function",
pairwise_distances=faiss.pairwise_distances,
knn=faiss.knn):
self.index = index
if index.__class__ == faiss.IndexIVFFlat:
index_help = faiss.IndexFlat(index.d, index.metric_type)
decode_func = lambda x: x.view("float32")
by_residual = False
elif index.__class__ == faiss.IndexIVFPQ:
index_help = faiss.IndexPQ(
index.d, index.pq.M, index.pq.nbits, index.metric_type)
index_help.pq = index.pq
decode_func = index_help.pq.decode
index_help.is_trained = True
by_residual = index.by_residual
elif index.__class__ == faiss.IndexIVFScalarQuantizer:
index_help = faiss.IndexScalarQuantizer(
index.d, index.sq.qtype, index.metric_type)
index_help.sq = index.sq
decode_func = index_help.sq.decode
index_help.is_trained = True
by_residual = index.by_residual
else:
raise RuntimeError(f"index type {index.__class__} not supported")
self.index_help = index_help
self.decode_func = None if method == "index" else decode_func
self.by_residual = by_residual
self.method = method
self.pairwise_distances = pairwise_distances
self.knn = knn
def block_search(self, xq_l, xb_l, list_ids, k, **extra_args):
metric_type = self.index.metric_type
if xq_l.size == 0 or xb_l.size == 0:
D = I = None
elif self.method == "index":
faiss.copy_array_to_vector(xb_l, self.index_help.codes)
self.index_help.ntotal = len(list_ids)
D, I = self.index_help.search(xq_l, k)
elif self.method == "pairwise_distances":
# TODO implement blockwise to avoid mem blowup
D = self.pairwise_distances(xq_l, xb_l, metric=metric_type)
I = None
elif self.method == "knn_function":
D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args)
return D, I
def big_batch_search(
index, xq, k,
method="knn_function",
pairwise_distances=faiss.pairwise_distances,
knn=faiss.knn,
verbose=0,
threaded=0,
use_float16=False,
prefetch_threads=8,
computation_threads=0,
q_assign=None):
"""
Search queries xq in the IVF index, with a search function that collects
batches of query vectors per inverted list. This can be faster than the
regular search indexes.
Supports IVFFlat, IVFPQ and IVFScalarQuantizer.
Supports three computation methods:
method = "index":
build a flat index and populate it separately for each index
method = "pairwise_distances":
decompress codes and compute all pairwise distances for the queries
and index and add result to heap
method = "knn_function":
decompress codes and compute knn results for the queries
threaded=0: sequential execution
threaded=1: prefetch next bucket while computing the current one
threaded>1: prefetch this many buckets at a time.
compute_threads>1: the knn function will get an additional thread_no that
tells which worker should handle this.
In threaded mode, the computation is tiled with the bucket perparation and
the writeback of results (useful to maximize GPU utilization).
use_float16: convert all matrices to float16 (faster for GPU gemm)
q_assign: override coarse assignment
"""
nprobe = index.nprobe
assert method in ("index", "pairwise_distances", "knn_function")
mem_queries = xq.nbytes
mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize
mem_res = len(xq) * k * (
np.dtype('int64').itemsize
+ np.dtype('float32').itemsize
)
mem_tot = mem_queries + mem_assign + mem_res
if verbose > 0:
print(
f"memory: queries {mem_queries} assign {mem_assign} "
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
)
bbs = BigBatchSearcher(
index, xq, k,
verbose=verbose,
use_float16=use_float16
)
comp = BlockComputer(
index,
method=method,
pairwise_distances=pairwise_distances,
knn=knn
)
bbs.decode_func = comp.decode_func
bbs.by_residual = comp.by_residual
if q_assign is None:
bbs.coarse_quantization()
else:
bbs.q_assign = q_assign
bbs.reorder_assign()
if threaded == 0:
# simple sequential version
for l in range(index.nlist):
bbs.report(l)
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l)
t0i = time.time()
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
bbs.t_accu[2] += time.time() - t0i
bbs.add_results_to_heap(q_subset, D, list_ids, I)
elif threaded == 1:
# parallel version with granularity 1
def add_results_and_prefetch(to_add, l):
""" perform the addition for the previous bucket and
prefetch the next (if applicable) """
if to_add is not None:
bbs.add_results_to_heap(*to_add)
if l < index.nlist:
return bbs.prepare_bucket(l)
prefetched_bucket = bbs.prepare_bucket(0)
to_add = None
pool = ThreadPool(1)
for l in range(index.nlist):
bbs.report(l)
prefetched_bucket_a = pool.apply_async(
add_results_and_prefetch, (to_add, l + 1))
q_subset, xq_l, list_ids, xb_l = prefetched_bucket
bbs.start_t_accu()
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
bbs.stop_t_accu(2)
to_add = q_subset, D, list_ids, I
bbs.start_t_accu()
prefetched_bucket = prefetched_bucket_a.get()
bbs.stop_t_accu(4)
bbs.add_results_to_heap(*to_add)
pool.close()
else:
# run by batches with parallel prefetch and parallel comp
list_step = threaded
if prefetch_threads == 0:
prefetch_map = map
else:
prefetch_pool = ThreadPool(prefetch_threads)
prefetch_map = prefetch_pool.map
if computation_threads > 0:
comp_pool = ThreadPool(computation_threads)
def add_results_and_prefetch_batch(to_add, l):
def add_results(to_add):
for ta in to_add: # this one cannot be run in parallel...
if ta is not None:
bbs.add_results_to_heap(*ta)
if prefetch_threads == 0:
add_results(to_add)
else:
add_a = prefetch_pool.apply_async(add_results, (to_add, ))
next_lists = range(l, min(l + list_step, index.nlist))
res = list(prefetch_map(bbs.prepare_bucket, next_lists))
if prefetch_threads > 0:
add_a.get()
return res
# used only when computation_threads > 1
thread_id_to_seq_lock = threading.Lock()
thread_id_to_seq = {}
def do_comp(bucket):
(q_subset, xq_l, list_ids, xb_l) = bucket
try:
tid = thread_id_to_seq[threading.get_ident()]
except KeyError:
with thread_id_to_seq_lock:
tid = len(thread_id_to_seq)
thread_id_to_seq[threading.get_ident()] = tid
D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid)
return q_subset, D, list_ids, I
prefetched_buckets = add_results_and_prefetch_batch([], 0)
to_add = []
pool = ThreadPool(1)
prefetched_buckets_a = None
# loop over inverted lists
for l in range(0, index.nlist, list_step):
bbs.report(l)
buckets = prefetched_buckets
prefetched_buckets_a = pool.apply_async(
add_results_and_prefetch_batch, (to_add, l + list_step))
bbs.start_t_accu()
to_add = []
if computation_threads == 0:
for q_subset, xq_l, list_ids, xb_l in buckets:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
to_add.append((q_subset, D, list_ids, I))
else:
to_add = list(comp_pool.map(do_comp, buckets))
bbs.stop_t_accu(2)
bbs.start_t_accu()
prefetched_buckets = prefetched_buckets_a.get()
bbs.stop_t_accu(4)
# flush add
for ta in to_add:
bbs.add_results_to_heap(*ta)
pool.close()
if prefetch_threads != 0:
prefetch_pool.close()
if computation_threads != 0:
comp_pool.close()
bbs.tic("finalize heap")
bbs.rh.finalize()
bbs.toc()
return bbs.rh.D, bbs.rh.I