149 lines
4.4 KiB
Python
149 lines
4.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import unittest
|
|
import faiss
|
|
import numpy as np
|
|
|
|
class TestIVFlib(unittest.TestCase):
|
|
|
|
def test_methods_exported(self):
|
|
methods = ['check_compatible_for_merge', 'extract_index_ivf',
|
|
'merge_into', 'search_centroid',
|
|
'search_and_return_centroids', 'get_invlist_range',
|
|
'set_invlist_range', 'search_with_parameters']
|
|
|
|
for method in methods:
|
|
assert callable(getattr(faiss, method, None))
|
|
|
|
|
|
def search_single_scan(index, xq, k, bs=128):
|
|
"""performs a search so that the inverted lists are accessed
|
|
sequentially by blocks of size bs"""
|
|
|
|
# handle pretransform
|
|
if isinstance(index, faiss.IndexPreTransform):
|
|
xq = index.apply_py(xq)
|
|
index = faiss.downcast_index(index.index)
|
|
|
|
# coarse assignment
|
|
coarse_dis, assign = index.quantizer.search(xq, index.nprobe)
|
|
nlist = index.nlist
|
|
assign_buckets = assign // bs
|
|
nq = len(xq)
|
|
|
|
rh = faiss.ResultHeap(nq, k)
|
|
index.parallel_mode |= index.PARALLEL_MODE_NO_HEAP_INIT
|
|
|
|
for l0 in range(0, nlist, bs):
|
|
bucket_no = l0 // bs
|
|
skip_rows, skip_cols = np.where(assign_buckets != bucket_no)
|
|
sub_assign = assign.copy()
|
|
sub_assign[skip_rows, skip_cols] = -1
|
|
|
|
index.search_preassigned(
|
|
nq, faiss.swig_ptr(xq), k,
|
|
faiss.swig_ptr(sub_assign), faiss.swig_ptr(coarse_dis),
|
|
faiss.swig_ptr(rh.D), faiss.swig_ptr(rh.I),
|
|
False, None
|
|
)
|
|
|
|
rh.finalize()
|
|
|
|
return rh.D, rh.I
|
|
|
|
|
|
class TestSequentialScan(unittest.TestCase):
|
|
|
|
def test_sequential_scan(self):
|
|
d = 20
|
|
index = faiss.index_factory(d, 'IVF100,SQ8')
|
|
|
|
rs = np.random.RandomState(123)
|
|
xt = rs.rand(5000, d).astype('float32')
|
|
xb = rs.rand(10000, d).astype('float32')
|
|
index.train(xt)
|
|
index.add(xb)
|
|
k = 15
|
|
xq = rs.rand(200, d).astype('float32')
|
|
|
|
ref_D, ref_I = index.search(xq, k)
|
|
D, I = search_single_scan(index, xq, k, bs=10)
|
|
|
|
assert np.all(D == ref_D)
|
|
assert np.all(I == ref_I)
|
|
|
|
|
|
class TestSearchWithParameters(unittest.TestCase):
|
|
|
|
def test_search_with_parameters(self):
|
|
d = 20
|
|
index = faiss.index_factory(d, 'IVF100,SQ8')
|
|
|
|
rs = np.random.RandomState(123)
|
|
xt = rs.rand(5000, d).astype('float32')
|
|
xb = rs.rand(10000, d).astype('float32')
|
|
index.train(xt)
|
|
index.nprobe = 3
|
|
index.add(xb)
|
|
k = 15
|
|
xq = rs.rand(200, d).astype('float32')
|
|
|
|
stats = faiss.cvar.indexIVF_stats
|
|
stats.reset()
|
|
Dref, Iref = index.search(xq, k)
|
|
ref_ndis = stats.ndis
|
|
|
|
# make sure the nprobe used is the one from params not the one
|
|
# set in the index
|
|
index.nprobe = 1
|
|
params = faiss.IVFSearchParameters()
|
|
params.nprobe = 3
|
|
|
|
Dnew, Inew, stats2 = faiss.search_with_parameters(
|
|
index, xq, k, params, output_stats=True)
|
|
|
|
np.testing.assert_array_equal(Inew, Iref)
|
|
np.testing.assert_array_equal(Dnew, Dref)
|
|
|
|
self.assertEqual(stats2["ndis"], ref_ndis)
|
|
|
|
def test_range_search_with_parameters(self):
|
|
d = 20
|
|
index = faiss.index_factory(d, 'IVF100,SQ8')
|
|
|
|
rs = np.random.RandomState(123)
|
|
xt = rs.rand(5000, d).astype('float32')
|
|
xb = rs.rand(10000, d).astype('float32')
|
|
index.train(xt)
|
|
index.nprobe = 3
|
|
index.add(xb)
|
|
xq = rs.rand(200, d).astype('float32')
|
|
|
|
Dpre, _ = index.search(xq, 15)
|
|
radius = float(np.median(Dpre[:, -1]))
|
|
print("Radius=", radius)
|
|
stats = faiss.cvar.indexIVF_stats
|
|
stats.reset()
|
|
Lref, Dref, Iref = index.range_search(xq, radius)
|
|
ref_ndis = stats.ndis
|
|
|
|
# make sure the nprobe used is the one from params not the one
|
|
# set in the index
|
|
index.nprobe = 1
|
|
params = faiss.IVFSearchParameters()
|
|
params.nprobe = 3
|
|
|
|
Lnew, Dnew, Inew, stats2 = faiss.range_search_with_parameters(
|
|
index, xq, radius, params, output_stats=True)
|
|
|
|
np.testing.assert_array_equal(Lnew, Lref)
|
|
np.testing.assert_array_equal(Inew, Iref)
|
|
np.testing.assert_array_equal(Dnew, Dref)
|
|
|
|
self.assertEqual(stats2["ndis"], ref_ndis)
|