faiss/tests/test_ivflib.py

149 lines
4.4 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function
import unittest
import faiss
import numpy as np
class TestIVFlib(unittest.TestCase):
def test_methods_exported(self):
methods = ['check_compatible_for_merge', 'extract_index_ivf',
'merge_into', 'search_centroid',
'search_and_return_centroids', 'get_invlist_range',
'set_invlist_range', 'search_with_parameters']
for method in methods:
assert callable(getattr(faiss, method, None))
def search_single_scan(index, xq, k, bs=128):
"""performs a search so that the inverted lists are accessed
sequentially by blocks of size bs"""
# handle pretransform
if isinstance(index, faiss.IndexPreTransform):
xq = index.apply_py(xq)
index = faiss.downcast_index(index.index)
# coarse assignment
coarse_dis, assign = index.quantizer.search(xq, index.nprobe)
nlist = index.nlist
assign_buckets = assign // bs
nq = len(xq)
rh = faiss.ResultHeap(nq, k)
index.parallel_mode |= index.PARALLEL_MODE_NO_HEAP_INIT
for l0 in range(0, nlist, bs):
bucket_no = l0 // bs
skip_rows, skip_cols = np.where(assign_buckets != bucket_no)
sub_assign = assign.copy()
sub_assign[skip_rows, skip_cols] = -1
index.search_preassigned(
nq, faiss.swig_ptr(xq), k,
faiss.swig_ptr(sub_assign), faiss.swig_ptr(coarse_dis),
faiss.swig_ptr(rh.D), faiss.swig_ptr(rh.I),
False, None
)
rh.finalize()
return rh.D, rh.I
class TestSequentialScan(unittest.TestCase):
def test_sequential_scan(self):
d = 20
index = faiss.index_factory(d, 'IVF100,SQ8')
rs = np.random.RandomState(123)
xt = rs.rand(5000, d).astype('float32')
xb = rs.rand(10000, d).astype('float32')
index.train(xt)
index.add(xb)
k = 15
xq = rs.rand(200, d).astype('float32')
ref_D, ref_I = index.search(xq, k)
D, I = search_single_scan(index, xq, k, bs=10)
assert np.all(D == ref_D)
assert np.all(I == ref_I)
class TestSearchWithParameters(unittest.TestCase):
def test_search_with_parameters(self):
d = 20
index = faiss.index_factory(d, 'IVF100,SQ8')
rs = np.random.RandomState(123)
xt = rs.rand(5000, d).astype('float32')
xb = rs.rand(10000, d).astype('float32')
index.train(xt)
index.nprobe = 3
index.add(xb)
k = 15
xq = rs.rand(200, d).astype('float32')
stats = faiss.cvar.indexIVF_stats
stats.reset()
Dref, Iref = index.search(xq, k)
ref_ndis = stats.ndis
# make sure the nprobe used is the one from params not the one
# set in the index
index.nprobe = 1
params = faiss.IVFSearchParameters()
params.nprobe = 3
Dnew, Inew, stats2 = faiss.search_with_parameters(
index, xq, k, params, output_stats=True)
np.testing.assert_array_equal(Inew, Iref)
np.testing.assert_array_equal(Dnew, Dref)
self.assertEqual(stats2["ndis"], ref_ndis)
def test_range_search_with_parameters(self):
d = 20
index = faiss.index_factory(d, 'IVF100,SQ8')
rs = np.random.RandomState(123)
xt = rs.rand(5000, d).astype('float32')
xb = rs.rand(10000, d).astype('float32')
index.train(xt)
index.nprobe = 3
index.add(xb)
xq = rs.rand(200, d).astype('float32')
Dpre, _ = index.search(xq, 15)
radius = float(np.median(Dpre[:, -1]))
print("Radius=", radius)
stats = faiss.cvar.indexIVF_stats
stats.reset()
Lref, Dref, Iref = index.range_search(xq, radius)
ref_ndis = stats.ndis
# make sure the nprobe used is the one from params not the one
# set in the index
index.nprobe = 1
params = faiss.IVFSearchParameters()
params.nprobe = 3
Lnew, Dnew, Inew, stats2 = faiss.range_search_with_parameters(
index, xq, radius, params, output_stats=True)
np.testing.assert_array_equal(Lnew, Lref)
np.testing.assert_array_equal(Inew, Iref)
np.testing.assert_array_equal(Dnew, Dref)
self.assertEqual(stats2["ndis"], ref_ndis)