faiss/contrib/ivf_tools.py

483 lines
16 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
from multiprocessing.pool import ThreadPool
import threading
import numpy as np
import faiss
from faiss.contrib.inspect_tools import get_invlist
def add_preassigned(index_ivf, x, a, ids=None):
"""
Add elements to an IVF index, where the assignment is already computed
"""
n, d = x.shape
assert a.shape == (n, )
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
assert d == index_ivf.d
if ids is not None:
assert ids.shape == (n, )
ids = faiss.swig_ptr(ids)
index_ivf.add_core(
n, faiss.swig_ptr(x), ids, faiss.swig_ptr(a)
)
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
"""
Perform a search in the IVF index, with predefined lists to search into.
Supports indexes with pretransforms (as opposed to the
IndexIVF.search_preassigned, that cannot be applied with pretransform).
"""
n, d = xq.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
dis_type = "int32"
else:
dis_type = "float32"
assert d == index_ivf.d
assert list_nos.shape == (n, index_ivf.nprobe)
# the coarse distances are used in IVFPQ with L2 distance and
# by_residual=True otherwise we provide dummy coarse_dis
if coarse_dis is None:
coarse_dis = np.zeros((n, index_ivf.nprobe), dtype=dis_type)
else:
assert coarse_dis.shape == (n, index_ivf.nprobe)
return index_ivf.search_preassigned(xq, k, list_nos, coarse_dis)
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
"""
Perform a range search in the IVF index, with predefined lists to
search into
"""
n, d = x.shape
if isinstance(index_ivf, faiss.IndexBinaryIVF):
d *= 8
dis_type = "int32"
else:
dis_type = "float32"
# the coarse distances are used in IVFPQ with L2 distance and
# by_residual=True otherwise we provide dummy coarse_dis
if coarse_dis is None:
coarse_dis = np.empty((n, index_ivf.nprobe), dtype=dis_type)
else:
assert coarse_dis.shape == (n, index_ivf.nprobe)
assert d == index_ivf.d
assert list_nos.shape == (n, index_ivf.nprobe)
res = faiss.RangeSearchResult(n)
sp = faiss.swig_ptr
index_ivf.range_search_preassigned(
n, sp(x), radius,
sp(list_nos), sp(coarse_dis),
res
)
# get pointers and copy them
lims = faiss.rev_swig_ptr(res.lims, n + 1).copy()
num_results = int(lims[-1])
dist = faiss.rev_swig_ptr(res.distances, num_results).copy()
indices = faiss.rev_swig_ptr(res.labels, num_results).copy()
return lims, dist, indices
class BigBatchSearcher:
"""
Object that manages all the data related to the computation
except the actual within-bucket matching and the organization of the
computation (parallel or not)
"""
def __init__(
self,
index, xq, k,
verbose=0,
use_float16=False):
# verbosity
self.verbose = verbose
self.tictoc = []
self.xq = xq
self.index = index
self.use_float16 = use_float16
keep_max = index.metric_type == faiss.METRIC_INNER_PRODUCT
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
self.t_accu = [0] * 5
self.t_display = self.t0 = time.time()
def start_t_accu(self):
self.t_accu_t0 = time.time()
def stop_t_accu(self, n):
self.t_accu[n] += time.time() - self.t_accu_t0
def tic(self, name):
self.tictoc = (name, time.time())
if self.verbose > 0:
print(name, end="\r", flush=True)
def toc(self):
name, t0 = self.tictoc
dt = time.time() - t0
if self.verbose > 0:
print(f"{name}: {dt:.3f} s")
return dt
def report(self, l):
if self.verbose == 1 or (
l > 1000 and time.time() < self.t_display + 1.0):
return
print(
f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} "
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
f"wait {self.t_accu[4]:.3f}",
end="\r", flush=True
)
self.t_display = time.time()
def coarse_quantization(self):
self.tic("coarse quantization")
bs = 65536
nq = len(self.xq)
q_assign = np.empty((nq, self.index.nprobe), dtype='int32')
for i0 in range(0, nq, bs):
i1 = min(nq, i0 + bs)
q_dis_i, q_assign_i = self.index.quantizer.search(
self.xq[i0:i1], self.index.nprobe)
# q_dis[i0:i1] = q_dis_i
q_assign[i0:i1] = q_assign_i
self.toc()
self.q_assign = q_assign
def reorder_assign(self):
self.tic("bucket sort")
q_assign = self.q_assign
q_assign += 1 # move -1 -> 0
self.bucket_lims = faiss.matrix_bucket_sort_inplace(
self.q_assign, nbucket=self.index.nlist + 1, nt=16)
self.query_ids = self.q_assign.ravel()
if self.verbose > 0:
print(' number of -1s:', self.bucket_lims[1])
self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s
del self.q_assign # inplace so let's forget about the old version...
self.toc()
def prepare_bucket(self, l):
""" prepare the queries and database items for bucket l"""
t0 = time.time()
index = self.index
# prepare queries
i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1]
q_subset = self.query_ids[i0:i1]
xq_l = self.xq[q_subset]
if self.by_residual:
xq_l = xq_l - index.quantizer.reconstruct(l)
t1 = time.time()
# prepare database side
list_ids, xb_l = get_invlist(index.invlists, l)
if self.decode_func is None:
xb_l = xb_l.ravel()
else:
xb_l = self.decode_func(xb_l)
if self.use_float16:
xb_l = xb_l.astype('float16')
xq_l = xq_l.astype('float16')
t2 = time.time()
self.t_accu[0] += t1 - t0
self.t_accu[1] += t2 - t1
return q_subset, xq_l, list_ids, xb_l
def add_results_to_heap(self, q_subset, D, list_ids, I):
"""add the bucket results to the heap structure"""
if D is None:
return
t0 = time.time()
if I is None:
I = list_ids
else:
I = list_ids[I]
self.rh.add_result_subset(q_subset, D, I)
self.t_accu[3] += time.time() - t0
class BlockComputer:
""" computation within one bucket """
def __init__(
self,
index,
method="knn_function",
pairwise_distances=faiss.pairwise_distances,
knn=faiss.knn):
self.index = index
if index.__class__ == faiss.IndexIVFFlat:
index_help = faiss.IndexFlat(index.d, index.metric_type)
decode_func = lambda x: x.view("float32")
by_residual = False
elif index.__class__ == faiss.IndexIVFPQ:
index_help = faiss.IndexPQ(
index.d, index.pq.M, index.pq.nbits, index.metric_type)
index_help.pq = index.pq
decode_func = index_help.pq.decode
index_help.is_trained = True
by_residual = index.by_residual
elif index.__class__ == faiss.IndexIVFScalarQuantizer:
index_help = faiss.IndexScalarQuantizer(
index.d, index.sq.qtype, index.metric_type)
index_help.sq = index.sq
decode_func = index_help.sq.decode
index_help.is_trained = True
by_residual = index.by_residual
else:
raise RuntimeError(f"index type {index.__class__} not supported")
self.index_help = index_help
self.decode_func = None if method == "index" else decode_func
self.by_residual = by_residual
self.method = method
self.pairwise_distances = pairwise_distances
self.knn = knn
def block_search(self, xq_l, xb_l, list_ids, k, **extra_args):
metric_type = self.index.metric_type
if xq_l.size == 0 or xb_l.size == 0:
D = I = None
elif self.method == "index":
faiss.copy_array_to_vector(xb_l, self.index_help.codes)
self.index_help.ntotal = len(list_ids)
D, I = self.index_help.search(xq_l, k)
elif self.method == "pairwise_distances":
# TODO implement blockwise to avoid mem blowup
D = self.pairwise_distances(xq_l, xb_l, metric=metric_type)
I = None
elif self.method == "knn_function":
D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args)
return D, I
def big_batch_search(
index, xq, k,
method="knn_function",
pairwise_distances=faiss.pairwise_distances,
knn=faiss.knn,
verbose=0,
threaded=0,
use_float16=False,
prefetch_threads=8,
computation_threads=0,
q_assign=None):
"""
Search queries xq in the IVF index, with a search function that collects
batches of query vectors per inverted list. This can be faster than the
regular search indexes.
Supports IVFFlat, IVFPQ and IVFScalarQuantizer.
Supports three computation methods:
method = "index":
build a flat index and populate it separately for each index
method = "pairwise_distances":
decompress codes and compute all pairwise distances for the queries
and index and add result to heap
method = "knn_function":
decompress codes and compute knn results for the queries
threaded=0: sequential execution
threaded=1: prefetch next bucket while computing the current one
threaded>1: prefetch this many buckets at a time.
compute_threads>1: the knn function will get an additional thread_no that
tells which worker should handle this.
In threaded mode, the computation is tiled with the bucket perparation and
the writeback of results (useful to maximize GPU utilization).
use_float16: convert all matrices to float16 (faster for GPU gemm)
q_assign: override coarse assignment
"""
nprobe = index.nprobe
assert method in ("index", "pairwise_distances", "knn_function")
mem_queries = xq.nbytes
mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize
mem_res = len(xq) * k * (
np.dtype('int64').itemsize
+ np.dtype('float32').itemsize
)
mem_tot = mem_queries + mem_assign + mem_res
if verbose > 0:
print(
f"memory: queries {mem_queries} assign {mem_assign} "
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
)
bbs = BigBatchSearcher(
index, xq, k,
verbose=verbose,
use_float16=use_float16
)
comp = BlockComputer(
index,
method=method,
pairwise_distances=pairwise_distances,
knn=knn
)
bbs.decode_func = comp.decode_func
bbs.by_residual = comp.by_residual
if q_assign is None:
bbs.coarse_quantization()
else:
bbs.q_assign = q_assign
bbs.reorder_assign()
if threaded == 0:
# simple sequential version
for l in range(index.nlist):
bbs.report(l)
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l)
t0i = time.time()
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
bbs.t_accu[2] += time.time() - t0i
bbs.add_results_to_heap(q_subset, D, list_ids, I)
elif threaded == 1:
# parallel version with granularity 1
def add_results_and_prefetch(to_add, l):
""" perform the addition for the previous bucket and
prefetch the next (if applicable) """
if to_add is not None:
bbs.add_results_to_heap(*to_add)
if l < index.nlist:
return bbs.prepare_bucket(l)
prefetched_bucket = bbs.prepare_bucket(0)
to_add = None
pool = ThreadPool(1)
for l in range(index.nlist):
bbs.report(l)
prefetched_bucket_a = pool.apply_async(
add_results_and_prefetch, (to_add, l + 1))
q_subset, xq_l, list_ids, xb_l = prefetched_bucket
bbs.start_t_accu()
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
bbs.stop_t_accu(2)
to_add = q_subset, D, list_ids, I
bbs.start_t_accu()
prefetched_bucket = prefetched_bucket_a.get()
bbs.stop_t_accu(4)
bbs.add_results_to_heap(*to_add)
pool.close()
else:
# run by batches with parallel prefetch and parallel comp
list_step = threaded
if prefetch_threads == 0:
prefetch_map = map
else:
prefetch_pool = ThreadPool(prefetch_threads)
prefetch_map = prefetch_pool.map
if computation_threads > 0:
comp_pool = ThreadPool(computation_threads)
def add_results_and_prefetch_batch(to_add, l):
def add_results(to_add):
for ta in to_add: # this one cannot be run in parallel...
if ta is not None:
bbs.add_results_to_heap(*ta)
if prefetch_threads == 0:
add_results(to_add)
else:
add_a = prefetch_pool.apply_async(add_results, (to_add, ))
next_lists = range(l, min(l + list_step, index.nlist))
res = list(prefetch_map(bbs.prepare_bucket, next_lists))
if prefetch_threads > 0:
add_a.get()
return res
# used only when computation_threads > 1
thread_id_to_seq_lock = threading.Lock()
thread_id_to_seq = {}
def do_comp(bucket):
(q_subset, xq_l, list_ids, xb_l) = bucket
try:
tid = thread_id_to_seq[threading.get_ident()]
except KeyError:
with thread_id_to_seq_lock:
tid = len(thread_id_to_seq)
thread_id_to_seq[threading.get_ident()] = tid
D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid)
return q_subset, D, list_ids, I
prefetched_buckets = add_results_and_prefetch_batch([], 0)
to_add = []
pool = ThreadPool(1)
prefetched_buckets_a = None
# loop over inverted lists
for l in range(0, index.nlist, list_step):
bbs.report(l)
buckets = prefetched_buckets
prefetched_buckets_a = pool.apply_async(
add_results_and_prefetch_batch, (to_add, l + list_step))
bbs.start_t_accu()
to_add = []
if computation_threads == 0:
for q_subset, xq_l, list_ids, xb_l in buckets:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
to_add.append((q_subset, D, list_ids, I))
else:
to_add = list(comp_pool.map(do_comp, buckets))
bbs.stop_t_accu(2)
bbs.start_t_accu()
prefetched_buckets = prefetched_buckets_a.get()
bbs.stop_t_accu(4)
# flush add
for ta in to_add:
bbs.add_results_to_heap(*ta)
pool.close()
if prefetch_threads != 0:
prefetch_pool.close()
if computation_threads != 0:
comp_pool.close()
bbs.tic("finalize heap")
bbs.rh.finalize()
bbs.toc()
return bbs.rh.D, bbs.rh.I