483 lines
16 KiB
Python
483 lines
16 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import time
|
|
from multiprocessing.pool import ThreadPool
|
|
import threading
|
|
|
|
import numpy as np
|
|
import faiss
|
|
|
|
from faiss.contrib.inspect_tools import get_invlist
|
|
|
|
|
|
def add_preassigned(index_ivf, x, a, ids=None):
|
|
"""
|
|
Add elements to an IVF index, where the assignment is already computed
|
|
"""
|
|
n, d = x.shape
|
|
assert a.shape == (n, )
|
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
|
d *= 8
|
|
assert d == index_ivf.d
|
|
if ids is not None:
|
|
assert ids.shape == (n, )
|
|
ids = faiss.swig_ptr(ids)
|
|
index_ivf.add_core(
|
|
n, faiss.swig_ptr(x), ids, faiss.swig_ptr(a)
|
|
)
|
|
|
|
|
|
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
|
|
"""
|
|
Perform a search in the IVF index, with predefined lists to search into.
|
|
Supports indexes with pretransforms (as opposed to the
|
|
IndexIVF.search_preassigned, that cannot be applied with pretransform).
|
|
"""
|
|
n, d = xq.shape
|
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
|
d *= 8
|
|
dis_type = "int32"
|
|
else:
|
|
dis_type = "float32"
|
|
|
|
assert d == index_ivf.d
|
|
assert list_nos.shape == (n, index_ivf.nprobe)
|
|
|
|
# the coarse distances are used in IVFPQ with L2 distance and
|
|
# by_residual=True otherwise we provide dummy coarse_dis
|
|
if coarse_dis is None:
|
|
coarse_dis = np.zeros((n, index_ivf.nprobe), dtype=dis_type)
|
|
else:
|
|
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
|
|
|
return index_ivf.search_preassigned(xq, k, list_nos, coarse_dis)
|
|
|
|
|
|
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
|
|
"""
|
|
Perform a range search in the IVF index, with predefined lists to
|
|
search into
|
|
"""
|
|
n, d = x.shape
|
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
|
d *= 8
|
|
dis_type = "int32"
|
|
else:
|
|
dis_type = "float32"
|
|
|
|
# the coarse distances are used in IVFPQ with L2 distance and
|
|
# by_residual=True otherwise we provide dummy coarse_dis
|
|
if coarse_dis is None:
|
|
coarse_dis = np.empty((n, index_ivf.nprobe), dtype=dis_type)
|
|
else:
|
|
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
|
|
|
assert d == index_ivf.d
|
|
assert list_nos.shape == (n, index_ivf.nprobe)
|
|
|
|
res = faiss.RangeSearchResult(n)
|
|
sp = faiss.swig_ptr
|
|
|
|
index_ivf.range_search_preassigned(
|
|
n, sp(x), radius,
|
|
sp(list_nos), sp(coarse_dis),
|
|
res
|
|
)
|
|
# get pointers and copy them
|
|
lims = faiss.rev_swig_ptr(res.lims, n + 1).copy()
|
|
num_results = int(lims[-1])
|
|
dist = faiss.rev_swig_ptr(res.distances, num_results).copy()
|
|
indices = faiss.rev_swig_ptr(res.labels, num_results).copy()
|
|
return lims, dist, indices
|
|
|
|
|
|
class BigBatchSearcher:
|
|
"""
|
|
Object that manages all the data related to the computation
|
|
except the actual within-bucket matching and the organization of the
|
|
computation (parallel or not)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
index, xq, k,
|
|
verbose=0,
|
|
use_float16=False):
|
|
|
|
# verbosity
|
|
self.verbose = verbose
|
|
self.tictoc = []
|
|
|
|
self.xq = xq
|
|
self.index = index
|
|
self.use_float16 = use_float16
|
|
keep_max = index.metric_type == faiss.METRIC_INNER_PRODUCT
|
|
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
|
|
self.t_accu = [0] * 5
|
|
self.t_display = self.t0 = time.time()
|
|
|
|
def start_t_accu(self):
|
|
self.t_accu_t0 = time.time()
|
|
|
|
def stop_t_accu(self, n):
|
|
self.t_accu[n] += time.time() - self.t_accu_t0
|
|
|
|
def tic(self, name):
|
|
self.tictoc = (name, time.time())
|
|
if self.verbose > 0:
|
|
print(name, end="\r", flush=True)
|
|
|
|
def toc(self):
|
|
name, t0 = self.tictoc
|
|
dt = time.time() - t0
|
|
if self.verbose > 0:
|
|
print(f"{name}: {dt:.3f} s")
|
|
return dt
|
|
|
|
def report(self, l):
|
|
if self.verbose == 1 or (
|
|
l > 1000 and time.time() < self.t_display + 1.0):
|
|
return
|
|
print(
|
|
f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} "
|
|
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
|
|
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
|
|
f"wait {self.t_accu[4]:.3f}",
|
|
end="\r", flush=True
|
|
)
|
|
self.t_display = time.time()
|
|
|
|
def coarse_quantization(self):
|
|
self.tic("coarse quantization")
|
|
bs = 65536
|
|
nq = len(self.xq)
|
|
q_assign = np.empty((nq, self.index.nprobe), dtype='int32')
|
|
for i0 in range(0, nq, bs):
|
|
i1 = min(nq, i0 + bs)
|
|
q_dis_i, q_assign_i = self.index.quantizer.search(
|
|
self.xq[i0:i1], self.index.nprobe)
|
|
# q_dis[i0:i1] = q_dis_i
|
|
q_assign[i0:i1] = q_assign_i
|
|
self.toc()
|
|
self.q_assign = q_assign
|
|
|
|
def reorder_assign(self):
|
|
self.tic("bucket sort")
|
|
q_assign = self.q_assign
|
|
q_assign += 1 # move -1 -> 0
|
|
self.bucket_lims = faiss.matrix_bucket_sort_inplace(
|
|
self.q_assign, nbucket=self.index.nlist + 1, nt=16)
|
|
self.query_ids = self.q_assign.ravel()
|
|
if self.verbose > 0:
|
|
print(' number of -1s:', self.bucket_lims[1])
|
|
self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s
|
|
del self.q_assign # inplace so let's forget about the old version...
|
|
self.toc()
|
|
|
|
def prepare_bucket(self, l):
|
|
""" prepare the queries and database items for bucket l"""
|
|
t0 = time.time()
|
|
index = self.index
|
|
# prepare queries
|
|
i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1]
|
|
q_subset = self.query_ids[i0:i1]
|
|
xq_l = self.xq[q_subset]
|
|
if self.by_residual:
|
|
xq_l = xq_l - index.quantizer.reconstruct(l)
|
|
t1 = time.time()
|
|
# prepare database side
|
|
list_ids, xb_l = get_invlist(index.invlists, l)
|
|
|
|
if self.decode_func is None:
|
|
xb_l = xb_l.ravel()
|
|
else:
|
|
xb_l = self.decode_func(xb_l)
|
|
|
|
if self.use_float16:
|
|
xb_l = xb_l.astype('float16')
|
|
xq_l = xq_l.astype('float16')
|
|
|
|
t2 = time.time()
|
|
self.t_accu[0] += t1 - t0
|
|
self.t_accu[1] += t2 - t1
|
|
return q_subset, xq_l, list_ids, xb_l
|
|
|
|
def add_results_to_heap(self, q_subset, D, list_ids, I):
|
|
"""add the bucket results to the heap structure"""
|
|
if D is None:
|
|
return
|
|
t0 = time.time()
|
|
if I is None:
|
|
I = list_ids
|
|
else:
|
|
I = list_ids[I]
|
|
self.rh.add_result_subset(q_subset, D, I)
|
|
self.t_accu[3] += time.time() - t0
|
|
|
|
|
|
class BlockComputer:
|
|
""" computation within one bucket """
|
|
|
|
def __init__(
|
|
self,
|
|
index,
|
|
method="knn_function",
|
|
pairwise_distances=faiss.pairwise_distances,
|
|
knn=faiss.knn):
|
|
|
|
self.index = index
|
|
if index.__class__ == faiss.IndexIVFFlat:
|
|
index_help = faiss.IndexFlat(index.d, index.metric_type)
|
|
decode_func = lambda x: x.view("float32")
|
|
by_residual = False
|
|
elif index.__class__ == faiss.IndexIVFPQ:
|
|
index_help = faiss.IndexPQ(
|
|
index.d, index.pq.M, index.pq.nbits, index.metric_type)
|
|
index_help.pq = index.pq
|
|
decode_func = index_help.pq.decode
|
|
index_help.is_trained = True
|
|
by_residual = index.by_residual
|
|
elif index.__class__ == faiss.IndexIVFScalarQuantizer:
|
|
index_help = faiss.IndexScalarQuantizer(
|
|
index.d, index.sq.qtype, index.metric_type)
|
|
index_help.sq = index.sq
|
|
decode_func = index_help.sq.decode
|
|
index_help.is_trained = True
|
|
by_residual = index.by_residual
|
|
else:
|
|
raise RuntimeError(f"index type {index.__class__} not supported")
|
|
self.index_help = index_help
|
|
self.decode_func = None if method == "index" else decode_func
|
|
self.by_residual = by_residual
|
|
self.method = method
|
|
self.pairwise_distances = pairwise_distances
|
|
self.knn = knn
|
|
|
|
def block_search(self, xq_l, xb_l, list_ids, k, **extra_args):
|
|
metric_type = self.index.metric_type
|
|
if xq_l.size == 0 or xb_l.size == 0:
|
|
D = I = None
|
|
elif self.method == "index":
|
|
faiss.copy_array_to_vector(xb_l, self.index_help.codes)
|
|
self.index_help.ntotal = len(list_ids)
|
|
D, I = self.index_help.search(xq_l, k)
|
|
elif self.method == "pairwise_distances":
|
|
# TODO implement blockwise to avoid mem blowup
|
|
D = self.pairwise_distances(xq_l, xb_l, metric=metric_type)
|
|
I = None
|
|
elif self.method == "knn_function":
|
|
D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args)
|
|
|
|
return D, I
|
|
|
|
|
|
def big_batch_search(
|
|
index, xq, k,
|
|
method="knn_function",
|
|
pairwise_distances=faiss.pairwise_distances,
|
|
knn=faiss.knn,
|
|
verbose=0,
|
|
threaded=0,
|
|
use_float16=False,
|
|
prefetch_threads=8,
|
|
computation_threads=0,
|
|
q_assign=None):
|
|
"""
|
|
Search queries xq in the IVF index, with a search function that collects
|
|
batches of query vectors per inverted list. This can be faster than the
|
|
regular search indexes.
|
|
Supports IVFFlat, IVFPQ and IVFScalarQuantizer.
|
|
|
|
Supports three computation methods:
|
|
method = "index":
|
|
build a flat index and populate it separately for each index
|
|
method = "pairwise_distances":
|
|
decompress codes and compute all pairwise distances for the queries
|
|
and index and add result to heap
|
|
method = "knn_function":
|
|
decompress codes and compute knn results for the queries
|
|
|
|
threaded=0: sequential execution
|
|
threaded=1: prefetch next bucket while computing the current one
|
|
threaded>1: prefetch this many buckets at a time.
|
|
|
|
compute_threads>1: the knn function will get an additional thread_no that
|
|
tells which worker should handle this.
|
|
|
|
In threaded mode, the computation is tiled with the bucket perparation and
|
|
the writeback of results (useful to maximize GPU utilization).
|
|
|
|
use_float16: convert all matrices to float16 (faster for GPU gemm)
|
|
|
|
q_assign: override coarse assignment
|
|
"""
|
|
nprobe = index.nprobe
|
|
|
|
assert method in ("index", "pairwise_distances", "knn_function")
|
|
|
|
mem_queries = xq.nbytes
|
|
mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize
|
|
mem_res = len(xq) * k * (
|
|
np.dtype('int64').itemsize
|
|
+ np.dtype('float32').itemsize
|
|
)
|
|
mem_tot = mem_queries + mem_assign + mem_res
|
|
if verbose > 0:
|
|
print(
|
|
f"memory: queries {mem_queries} assign {mem_assign} "
|
|
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
|
|
)
|
|
|
|
bbs = BigBatchSearcher(
|
|
index, xq, k,
|
|
verbose=verbose,
|
|
use_float16=use_float16
|
|
)
|
|
|
|
comp = BlockComputer(
|
|
index,
|
|
method=method,
|
|
pairwise_distances=pairwise_distances,
|
|
knn=knn
|
|
)
|
|
|
|
bbs.decode_func = comp.decode_func
|
|
bbs.by_residual = comp.by_residual
|
|
|
|
if q_assign is None:
|
|
bbs.coarse_quantization()
|
|
else:
|
|
bbs.q_assign = q_assign
|
|
bbs.reorder_assign()
|
|
|
|
if threaded == 0:
|
|
# simple sequential version
|
|
|
|
for l in range(index.nlist):
|
|
bbs.report(l)
|
|
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l)
|
|
t0i = time.time()
|
|
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
|
bbs.t_accu[2] += time.time() - t0i
|
|
bbs.add_results_to_heap(q_subset, D, list_ids, I)
|
|
|
|
elif threaded == 1:
|
|
|
|
# parallel version with granularity 1
|
|
|
|
def add_results_and_prefetch(to_add, l):
|
|
""" perform the addition for the previous bucket and
|
|
prefetch the next (if applicable) """
|
|
if to_add is not None:
|
|
bbs.add_results_to_heap(*to_add)
|
|
if l < index.nlist:
|
|
return bbs.prepare_bucket(l)
|
|
|
|
prefetched_bucket = bbs.prepare_bucket(0)
|
|
to_add = None
|
|
pool = ThreadPool(1)
|
|
|
|
for l in range(index.nlist):
|
|
bbs.report(l)
|
|
prefetched_bucket_a = pool.apply_async(
|
|
add_results_and_prefetch, (to_add, l + 1))
|
|
q_subset, xq_l, list_ids, xb_l = prefetched_bucket
|
|
bbs.start_t_accu()
|
|
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
|
bbs.stop_t_accu(2)
|
|
to_add = q_subset, D, list_ids, I
|
|
bbs.start_t_accu()
|
|
prefetched_bucket = prefetched_bucket_a.get()
|
|
bbs.stop_t_accu(4)
|
|
|
|
bbs.add_results_to_heap(*to_add)
|
|
pool.close()
|
|
else:
|
|
# run by batches with parallel prefetch and parallel comp
|
|
list_step = threaded
|
|
|
|
if prefetch_threads == 0:
|
|
prefetch_map = map
|
|
else:
|
|
prefetch_pool = ThreadPool(prefetch_threads)
|
|
prefetch_map = prefetch_pool.map
|
|
|
|
if computation_threads > 0:
|
|
comp_pool = ThreadPool(computation_threads)
|
|
|
|
def add_results_and_prefetch_batch(to_add, l):
|
|
def add_results(to_add):
|
|
for ta in to_add: # this one cannot be run in parallel...
|
|
if ta is not None:
|
|
bbs.add_results_to_heap(*ta)
|
|
if prefetch_threads == 0:
|
|
add_results(to_add)
|
|
else:
|
|
add_a = prefetch_pool.apply_async(add_results, (to_add, ))
|
|
next_lists = range(l, min(l + list_step, index.nlist))
|
|
res = list(prefetch_map(bbs.prepare_bucket, next_lists))
|
|
if prefetch_threads > 0:
|
|
add_a.get()
|
|
return res
|
|
|
|
# used only when computation_threads > 1
|
|
thread_id_to_seq_lock = threading.Lock()
|
|
thread_id_to_seq = {}
|
|
|
|
def do_comp(bucket):
|
|
(q_subset, xq_l, list_ids, xb_l) = bucket
|
|
try:
|
|
tid = thread_id_to_seq[threading.get_ident()]
|
|
except KeyError:
|
|
with thread_id_to_seq_lock:
|
|
tid = len(thread_id_to_seq)
|
|
thread_id_to_seq[threading.get_ident()] = tid
|
|
D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid)
|
|
return q_subset, D, list_ids, I
|
|
|
|
prefetched_buckets = add_results_and_prefetch_batch([], 0)
|
|
to_add = []
|
|
pool = ThreadPool(1)
|
|
prefetched_buckets_a = None
|
|
|
|
# loop over inverted lists
|
|
for l in range(0, index.nlist, list_step):
|
|
bbs.report(l)
|
|
buckets = prefetched_buckets
|
|
prefetched_buckets_a = pool.apply_async(
|
|
add_results_and_prefetch_batch, (to_add, l + list_step))
|
|
|
|
bbs.start_t_accu()
|
|
|
|
to_add = []
|
|
if computation_threads == 0:
|
|
for q_subset, xq_l, list_ids, xb_l in buckets:
|
|
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
|
to_add.append((q_subset, D, list_ids, I))
|
|
else:
|
|
to_add = list(comp_pool.map(do_comp, buckets))
|
|
|
|
bbs.stop_t_accu(2)
|
|
|
|
bbs.start_t_accu()
|
|
prefetched_buckets = prefetched_buckets_a.get()
|
|
bbs.stop_t_accu(4)
|
|
|
|
# flush add
|
|
for ta in to_add:
|
|
bbs.add_results_to_heap(*ta)
|
|
pool.close()
|
|
if prefetch_threads != 0:
|
|
prefetch_pool.close()
|
|
if computation_threads != 0:
|
|
comp_pool.close()
|
|
|
|
bbs.tic("finalize heap")
|
|
bbs.rh.finalize()
|
|
bbs.toc()
|
|
|
|
return bbs.rh.D, bbs.rh.I
|