149 lines
4.8 KiB
Python
149 lines
4.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import numpy as np
|
|
import faiss
|
|
|
|
from faiss.contrib.inspect_tools import get_invlist_sizes
|
|
|
|
|
|
def add_preassigned(index_ivf, x, a, ids=None):
|
|
"""
|
|
Add elements to an IVF index, where the assignment is already computed
|
|
"""
|
|
n, d = x.shape
|
|
assert a.shape == (n, )
|
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
|
d *= 8
|
|
assert d == index_ivf.d
|
|
if ids is not None:
|
|
assert ids.shape == (n, )
|
|
ids = faiss.swig_ptr(ids)
|
|
index_ivf.add_core(
|
|
n, faiss.swig_ptr(x), ids, faiss.swig_ptr(a)
|
|
)
|
|
|
|
|
|
def search_preassigned(index_ivf, xq, k, list_nos, coarse_dis=None):
|
|
"""
|
|
Perform a search in the IVF index, with predefined lists to search into.
|
|
Supports indexes with pretransforms (as opposed to the
|
|
IndexIVF.search_preassigned, that cannot be applied with pretransform).
|
|
"""
|
|
if isinstance(index_ivf, faiss.IndexPreTransform):
|
|
assert index_ivf.chain.size() == 1, "chain must have only one component"
|
|
transform = faiss.downcast_VectorTransform(index_ivf.chain.at(0))
|
|
xq = transform.apply(xq)
|
|
index_ivf = faiss.downcast_index(index_ivf.index)
|
|
n, d = xq.shape
|
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
|
d *= 8
|
|
dis_type = "int32"
|
|
else:
|
|
dis_type = "float32"
|
|
|
|
assert d == index_ivf.d
|
|
assert list_nos.shape == (n, index_ivf.nprobe)
|
|
|
|
# the coarse distances are used in IVFPQ with L2 distance and
|
|
# by_residual=True otherwise we provide dummy coarse_dis
|
|
if coarse_dis is None:
|
|
coarse_dis = np.zeros((n, index_ivf.nprobe), dtype=dis_type)
|
|
else:
|
|
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
|
|
|
return index_ivf.search_preassigned(xq, k, list_nos, coarse_dis)
|
|
|
|
|
|
def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
|
|
"""
|
|
Perform a range search in the IVF index, with predefined lists to
|
|
search into
|
|
"""
|
|
n, d = x.shape
|
|
if isinstance(index_ivf, faiss.IndexBinaryIVF):
|
|
d *= 8
|
|
dis_type = "int32"
|
|
else:
|
|
dis_type = "float32"
|
|
|
|
# the coarse distances are used in IVFPQ with L2 distance and
|
|
# by_residual=True otherwise we provide dummy coarse_dis
|
|
if coarse_dis is None:
|
|
coarse_dis = np.empty((n, index_ivf.nprobe), dtype=dis_type)
|
|
else:
|
|
assert coarse_dis.shape == (n, index_ivf.nprobe)
|
|
|
|
assert d == index_ivf.d
|
|
assert list_nos.shape == (n, index_ivf.nprobe)
|
|
|
|
res = faiss.RangeSearchResult(n)
|
|
sp = faiss.swig_ptr
|
|
|
|
index_ivf.range_search_preassigned_c(
|
|
n, sp(x), radius,
|
|
sp(list_nos), sp(coarse_dis),
|
|
res
|
|
)
|
|
# get pointers and copy them
|
|
lims = faiss.rev_swig_ptr(res.lims, n + 1).copy()
|
|
num_results = int(lims[-1])
|
|
dist = faiss.rev_swig_ptr(res.distances, num_results).copy()
|
|
indices = faiss.rev_swig_ptr(res.labels, num_results).copy()
|
|
return lims, dist, indices
|
|
|
|
|
|
def replace_ivf_quantizer(index_ivf, new_quantizer):
|
|
""" replace the IVF quantizer with a flat quantizer and return the
|
|
old quantizer"""
|
|
if new_quantizer.ntotal == 0:
|
|
centroids = index_ivf.quantizer.reconstruct_n()
|
|
new_quantizer.train(centroids)
|
|
new_quantizer.add(centroids)
|
|
else:
|
|
assert new_quantizer.ntotal == index_ivf.nlist
|
|
|
|
# cleanly dealloc old quantizer
|
|
old_own = index_ivf.own_fields
|
|
index_ivf.own_fields = False
|
|
old_quantizer = faiss.downcast_index(index_ivf.quantizer)
|
|
old_quantizer.this.own(old_own)
|
|
index_ivf.quantizer = new_quantizer
|
|
|
|
if hasattr(index_ivf, "referenced_objects"):
|
|
index_ivf.referenced_objects.append(new_quantizer)
|
|
else:
|
|
index_ivf.referenced_objects = [new_quantizer]
|
|
return old_quantizer
|
|
|
|
|
|
def permute_invlists(index_ivf, perm):
|
|
""" Apply some permutation to the inverted lists, and modify the quantizer
|
|
entries accordingly.
|
|
Perm is an array of size nlist, where old_index = perm[new_index]
|
|
"""
|
|
nlist, = perm.shape
|
|
assert index_ivf.nlist == nlist
|
|
quantizer = faiss.downcast_index(index_ivf.quantizer)
|
|
assert quantizer.ntotal == index_ivf.nlist
|
|
perm = np.ascontiguousarray(perm, dtype='int64')
|
|
|
|
# just make sure it's a permutation...
|
|
bc = np.bincount(perm, minlength=nlist)
|
|
assert np.all(bc == np.ones(nlist, dtype=int))
|
|
|
|
# handle quantizer
|
|
quantizer.permute_entries(perm)
|
|
|
|
# handle inverted lists
|
|
invlists = faiss.downcast_InvertedLists(index_ivf.invlists)
|
|
invlists.permute_invlists(faiss.swig_ptr(perm))
|
|
|
|
|
|
def sort_invlists_by_size(index_ivf):
|
|
invlist_sizes = get_invlist_sizes(index_ivf.invlists)
|
|
perm = np.argsort(invlist_sizes)
|
|
permute_invlists(index_ivf, perm)
|