faiss/tests/test_fast_scan.py

701 lines
21 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import time
import os
import tempfile
import numpy as np
import faiss
from faiss.contrib import datasets
# the tests tend to timeout in stress modes + dev otherwise
faiss.omp_set_num_threads(4)
class TestCompileOptions(unittest.TestCase):
def test_compile_options(self):
options = faiss.get_compile_options()
options = options.split(' ')
for option in options:
assert option in ['AVX2', 'NEON', 'GENERIC', 'OPTIMIZE']
class TestSearch(unittest.TestCase):
def test_PQ4_accuracy(self):
ds = datasets.SyntheticDataset(32, 2000, 5000, 1000)
index_gt = faiss.IndexFlatL2(32)
index_gt.add(ds.get_database())
Dref, Iref = index_gt.search(ds.get_queries(), 10)
index = faiss.index_factory(32, 'PQ16x4fs')
index.train(ds.get_train())
index.add(ds.get_database())
Da, Ia = index.search(ds.get_queries(), 10)
nq = Iref.shape[0]
recall_at_1 = (Iref[:, 0] == Ia[:, 0]).sum() / nq
assert recall_at_1 > 0.6
# print(f'recall@1 = {recall_at_1:.3f}')
# This is an experiment to see if we can catch performance
# regressions. It runs 2 codes, one should be faster than the
# other by a factor ~10 in opt mode. We check for a factor 5.
# hopefully the jitter in executtion time will not produce
# too many spurious test failures. Unoptimized timings are
# not exploitable, hence the flag test on that as well.
@unittest.skipUnless(
('AVX2' in faiss.get_compile_options() or
'NEON' in faiss.get_compile_options()) and
"OPTIMIZE" in faiss.get_compile_options(),
"only test while building with avx2 or neon")
def test_PQ4_speed(self):
ds = datasets.SyntheticDataset(32, 2000, 5000, 1000)
xt = ds.get_train()
xb = ds.get_database()
xq = ds.get_queries()
index = faiss.index_factory(32, 'PQ16x4')
index.train(xt)
index.add(xb)
t0 = time.time()
D1, I1 = index.search(xq, 10)
t1 = time.time()
pq_t = t1 - t0
print('PQ16x4 search time:', pq_t)
index2 = faiss.index_factory(32, 'PQ16x4fs')
index2.train(xt)
index2.add(xb)
t0 = time.time()
D2, I2 = index2.search(xq, 10)
t1 = time.time()
pqfs_t = t1 - t0
print('PQ16x4fs search time:', pqfs_t)
self.assertLess(pqfs_t * 5, pq_t)
class TestRounding(unittest.TestCase):
def do_test_rounding(self, implem=4, metric=faiss.METRIC_L2):
ds = datasets.SyntheticDataset(32, 2000, 5000, 200)
index = faiss.index_factory(32, 'PQ16x4', metric)
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 10)
nq = Iref.shape[0]
index2 = faiss.IndexPQFastScan(index)
# simply repro normal search
index2.implem = 2
D2, I2 = index2.search(ds.get_queries(), 10)
np.testing.assert_array_equal(I2, Iref)
np.testing.assert_array_equal(D2, Dref)
# rounded LUT with correction
index2.implem = implem
D4, I4 = index2.search(ds.get_queries(), 10)
# check accuracy of indexes
recalls = {}
for rank in 1, 10:
recalls[rank] = (Iref[:, :1] == I4[:, :rank]).sum() / nq
min_r1 = 0.98 if metric == faiss.METRIC_INNER_PRODUCT else 0.99
self.assertGreaterEqual(recalls[1], min_r1)
self.assertGreater(recalls[10], 0.995)
# check accuracy of distances
# err3 = ((D3 - D2) ** 2).sum()
err4 = ((D4 - D2) ** 2).sum()
nf = (D2 ** 2).sum()
self.assertLess(err4, nf * 1e-4)
def test_implem_4(self):
self.do_test_rounding(4)
def test_implem_4_ip(self):
self.do_test_rounding(4, faiss.METRIC_INNER_PRODUCT)
def test_implem_12(self):
self.do_test_rounding(12)
def test_implem_12_ip(self):
self.do_test_rounding(12, faiss.METRIC_INNER_PRODUCT)
def test_implem_14(self):
self.do_test_rounding(14)
def test_implem_14_ip(self):
self.do_test_rounding(12, faiss.METRIC_INNER_PRODUCT)
class TestReconstruct(unittest.TestCase):
def test_pqfastscan(self):
ds = datasets.SyntheticDataset(20, 1000, 1000, 0)
index = faiss.index_factory(20, 'PQ5x4')
index.train(ds.get_train())
index.add(ds.get_database())
recons = index.reconstruct_n(0, index.ntotal)
index2 = faiss.IndexPQFastScan(index)
recons2 = index2.reconstruct_n(0, index.ntotal)
np.testing.assert_array_equal(recons, recons2)
def test_aqfastscan(self):
ds = datasets.SyntheticDataset(20, 1000, 1000, 0)
index = faiss.index_factory(20, 'RQ5x4_Nrq2x4')
index.train(ds.get_train())
index.add(ds.get_database())
recons = index.reconstruct_n(0, index.ntotal)
index2 = faiss.IndexAdditiveQuantizerFastScan(index)
recons2 = index2.reconstruct_n(0, index.ntotal)
np.testing.assert_array_equal(recons, recons2)
#########################################################
# Kernel unit test
#########################################################
def reference_accu(codes, LUT):
nq, nsp, is_16 = LUT.shape
nb, nsp_2 = codes.shape
assert is_16 == 16
assert nsp_2 == nsp // 2
accu = np.zeros((nq, nb), 'uint16')
for i in range(nq):
for j in range(nb):
a = np.uint16(0)
for sp in range(0, nsp, 2):
c = codes[j, sp // 2]
a += LUT[i, sp , c & 15].astype('uint16')
a += LUT[i, sp + 1, c >> 4].astype('uint16')
accu[i, j] = a
return accu
# disabled because the function to write to mem is not implemented currently
class ThisIsNotATestLoop5: # (unittest.TestCase):
def do_loop5_kernel(self, nq, bb):
""" unit test for the accumulation kernel """
nb = bb * 32 # databse size
nsp = 24 # number of sub-quantizers
rs = np.random.RandomState(123)
codes = rs.randint(256, size=(nb, nsp // 2)).astype('uint8')
LUT = rs.randint(256, size=(nq, nsp, 16)).astype('uint8')
accu_ref = reference_accu(codes, LUT)
def to_A(x):
return faiss.array_to_AlignedTable(x.ravel())
sp = faiss.swig_ptr
LUT_a = faiss.AlignedTableUint8(LUT.size)
faiss.pq4_pack_LUT(
nq, nsp, sp(LUT),
LUT_a.get()
)
codes_a = faiss.AlignedTableUint8(codes.size)
faiss.pq4_pack_codes(
sp(codes),
nb, nsp, nb, nb, nsp,
codes_a.get()
)
accu_a = faiss.AlignedTableUint16(nq * nb)
accu_a.clear()
faiss.loop5_kernel_accumulate_1_block_to_mem(
nq, nb, nsp, codes_a.get(), LUT_a.get(), accu_a.get()
)
accu = faiss.AlignedTable_to_array(accu_a).reshape(nq, nb)
np.testing.assert_array_equal(accu_ref, accu)
def test_11(self):
self.do_loop5_kernel(1, 1)
def test_21(self):
self.do_loop5_kernel(2, 1)
def test_12(self):
self.do_loop5_kernel(1, 2)
def test_22(self):
self.do_loop5_kernel(2, 2)
##########################################################
# Tests for various IndexPQFastScan implementations
##########################################################
def verify_with_draws(testcase, Dref, Iref, Dnew, Inew):
""" verify a list of results where there are draws in the distances (because
they are integer). """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
# here we have to be careful because of draws
for i in range(len(Iref)):
if np.all(Iref[i] == Inew[i]): # easy case
continue
# we can deduce nothing about the latest line
skip_dis = Dref[i, -1]
for dis in np.unique(Dref):
if dis == skip_dis:
continue
mask = Dref[i, :] == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
class TestImplems(unittest.TestCase):
def __init__(self, *args):
unittest.TestCase.__init__(self, *args)
self.cache = {}
self.k = 10
def get_index(self, d, metric):
if (d, metric) not in self.cache:
ds = datasets.SyntheticDataset(d, 1000, 2000, 200)
target_size = d // 2
index = faiss.index_factory(d, 'PQ%dx4' % target_size, metric)
index.train(ds.get_train())
index.add(ds.get_database())
index2 = faiss.IndexPQFastScan(index)
# uint8 LUT but no SIMD
index2.implem = 4
Dref, Iref = index2.search(ds.get_queries(), 10)
self.cache[(d, metric)] = (ds, index, Dref, Iref)
return self.cache[(d, metric)]
def do_with_params(self, d, params, metric=faiss.METRIC_L2):
ds, index, Dref, Iref = self.get_index(d, metric)
index2 = self.build_fast_scan_index(index, params)
Dnew, Inew = index2.search(ds.get_queries(), self.k)
Dref = Dref[:, :self.k]
Iref = Iref[:, :self.k]
verify_with_draws(self, Dref, Iref, Dnew, Inew)
def build_fast_scan_index(self, index, params):
index2 = faiss.IndexPQFastScan(index)
index2.implem = 5
return index2
class TestImplem12(TestImplems):
def build_fast_scan_index(self, index, qbs):
index2 = faiss.IndexPQFastScan(index)
index2.qbs = qbs
index2.implem = 12
return index2
def test_qbs7(self):
self.do_with_params(32, 0x223)
def test_qbs7b(self):
self.do_with_params(32, 0x133)
def test_qbs6(self):
self.do_with_params(32, 0x33)
def test_qbs6_ip(self):
self.do_with_params(32, 0x33, faiss.METRIC_INNER_PRODUCT)
def test_qbs6b(self):
# test codepath where qbs is not known at compile time
self.do_with_params(32, 0x1113)
def test_qbs6_odd_dim(self):
self.do_with_params(30, 0x33)
class TestImplem13(TestImplems):
def build_fast_scan_index(self, index, qbs):
index2 = faiss.IndexPQFastScan(index)
index2.qbs = qbs
index2.implem = 13
return index2
def test_qbs7(self):
self.do_with_params(32, 0x223)
def test_qbs7_k1(self):
self.k = 1
self.do_with_params(32, 0x223)
class TestImplem14(TestImplems):
def build_fast_scan_index(self, index, params):
qbs, bbs = params
index2 = faiss.IndexPQFastScan(index, bbs)
index2.qbs = qbs
index2.implem = 14
return index2
def test_1_32(self):
self.do_with_params(32, (1, 32))
def test_1_64(self):
self.do_with_params(32, (1, 64))
def test_2_32(self):
self.do_with_params(32, (2, 32))
def test_2_64(self):
self.do_with_params(32, (2, 64))
def test_qbs_1_32_k1(self):
self.k = 1
self.do_with_params(32, (1, 32))
def test_qbs_1_64_k1(self):
self.k = 1
self.do_with_params(32, (1, 64))
def test_1_32_odd_dim(self):
self.do_with_params(30, (1, 32))
def test_1_64_odd_dim(self):
self.do_with_params(30, (1, 64))
class TestImplem15(TestImplems):
def build_fast_scan_index(self, index, params):
qbs, bbs = params
index2 = faiss.IndexPQFastScan(index, bbs)
index2.qbs = qbs
index2.implem = 15
return index2
def test_1_32(self):
self.do_with_params(32, (1, 32))
def test_2_64(self):
self.do_with_params(32, (2, 64))
class TestAdd(unittest.TestCase):
def do_test_add(self, d, bbs):
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
index = faiss.index_factory(d, f'PQ{d//2}x4np')
index.train(ds.get_train())
xb = ds.get_database()
index.add(xb[:1235])
index2 = faiss.IndexPQFastScan(index, bbs)
index2.add(xb[1235:])
new_codes = faiss.AlignedTable_to_array(index2.codes)
index.add(xb[1235:])
index3 = faiss.IndexPQFastScan(index, bbs)
ref_codes = faiss.AlignedTable_to_array(index3.codes)
self.assertEqual(index3.ntotal, index2.ntotal)
np.testing.assert_array_equal(ref_codes, new_codes)
def test_add(self):
self.do_test_add(32, 32)
def test_add_bbs64(self):
self.do_test_add(32, 64)
def test_add_odd_d(self):
self.do_test_add(30, 64)
def test_constructor(self):
d = 32
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
index = faiss.index_factory(d, f'PQ{d//2}x4np')
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 10)
nq = Iref.shape[0]
index2 = faiss.IndexPQFastScan(d, d // 2, 4)
index2.train(ds.get_train())
index2.add(ds.get_database())
Dnew, Inew = index2.search(ds.get_queries(), 10)
recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq
self.assertGreaterEqual(recall_at_1, 0.99)
data = faiss.serialize_index(index2)
index3 = faiss.deserialize_index(data)
self.assertEqual(index2.implem, index3.implem)
D3, I3 = index3.search(ds.get_queries(), 10)
np.testing.assert_array_equal(D3, Dnew)
np.testing.assert_array_equal(I3, Inew)
class TestAQFastScan(unittest.TestCase):
def subtest_accuracy(self, aq, st, implem, metric_type='L2'):
"""
Compare IndexAdditiveQuantizerFastScan with IndexAQ (qint8)
"""
d = 16
ds = datasets.SyntheticDataset(d, 1000, 1000, 500, metric_type)
gt = ds.get_groundtruth(k=1)
if metric_type == 'L2':
metric = faiss.METRIC_L2
postfix1 = '_Nqint8'
postfix2 = f'_N{st}2x4'
else:
metric = faiss.METRIC_INNER_PRODUCT
postfix1 = postfix2 = ''
index = faiss.index_factory(d, f'{aq}3x4{postfix1}', metric)
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 1)
indexfs = faiss.index_factory(d, f'{aq}3x4fs_32{postfix2}', metric)
indexfs.train(ds.get_train())
indexfs.add(ds.get_database())
indexfs.implem = implem
Da, Ia = indexfs.search(ds.get_queries(), 1)
nq = Iref.shape[0]
recall_ref = (Iref == gt).sum() / nq
recall = (Ia == gt).sum() / nq
print(aq, st, implem, metric_type, recall_ref, recall)
assert abs(recall_ref - recall) < 0.05
def xx_test_accuracy(self):
for metric in 'L2', 'IP':
for implem in 0, 12, 13, 14, 15:
self.subtest_accuracy('RQ', 'rq', implem, metric)
self.subtest_accuracy('LSQ', 'lsq', implem, metric)
def subtest_from_idxaq(self, implem, metric):
if metric == 'L2':
metric_type = faiss.METRIC_L2
st = '_Nrq2x4'
else:
metric_type = faiss.METRIC_INNER_PRODUCT
st = ''
d = 16
ds = datasets.SyntheticDataset(d, 1000, 2000, 1000, metric=metric)
gt = ds.get_groundtruth(k=1)
index = faiss.index_factory(d, 'RQ8x4' + st, metric_type)
index.train(ds.get_train())
index.add(ds.get_database())
index.nprobe = 16
Dref, Iref = index.search(ds.get_queries(), 1)
indexfs = faiss.IndexAdditiveQuantizerFastScan(index)
indexfs.implem = implem
D1, I1 = indexfs.search(ds.get_queries(), 1)
nq = Iref.shape[0]
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq
print(recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.05
def xx_test_from_idxaq(self):
for implem in 2, 3, 4:
self.subtest_from_idxaq(implem, 'L2')
self.subtest_from_idxaq(implem, 'IP')
def subtest_factory(self, aq, M, bbs, st):
"""
Format: {AQ}{M}x4fs_{bbs}_N{st}
AQ (str): `LSQ` or `RQ`
M (int): number of subquantizers
bbs (int): build block size
st (str): search type, `lsq2x4` or `rq2x4`
"""
AQ = faiss.AdditiveQuantizer
d = 16
if bbs > 0:
index = faiss.index_factory(d, f'{aq}{M}x4fs_{bbs}_N{st}2x4')
else:
index = faiss.index_factory(d, f'{aq}{M}x4fs_N{st}2x4')
bbs = 32
assert index.bbs == bbs
aq = faiss.downcast_AdditiveQuantizer(index.aq)
assert aq.M == M
if aq == 'LSQ':
assert isinstance(aq, faiss.LocalSearchQuantizer)
if aq == 'RQ':
assert isinstance(aq, faiss.ResidualQuantizer)
if st == 'lsq':
assert aq.search_type == AQ.ST_norm_lsq2x4
if st == 'rq':
assert aq.search_type == AQ.ST_norm_rq2x4
def test_factory(self):
self.subtest_factory('LSQ', 16, 64, 'lsq')
self.subtest_factory('LSQ', 16, 64, 'rq')
self.subtest_factory('RQ', 16, 64, 'rq')
self.subtest_factory('RQ', 16, 64, 'lsq')
self.subtest_factory('LSQ', 64, 0, 'lsq')
def subtest_io(self, factory_str):
d = 8
ds = datasets.SyntheticDataset(d, 1000, 500, 100)
index = faiss.index_factory(d, factory_str)
index.train(ds.get_train())
index.add(ds.get_database())
D1, I1 = index.search(ds.get_queries(), 1)
fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, fname)
index2 = faiss.read_index(fname)
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)
finally:
if os.path.exists(fname):
os.unlink(fname)
def test_io(self):
self.subtest_io('LSQ4x4fs_Nlsq2x4')
self.subtest_io('LSQ4x4fs_Nrq2x4')
self.subtest_io('RQ4x4fs_Nrq2x4')
self.subtest_io('RQ4x4fs_Nlsq2x4')
# programatically generate tests to get finer test granularity.
def add_TestAQFastScan_subset_accuracy(aq, st, implem, metric):
setattr(
TestAQFastScan,
f"test_accuracy_{metric}_{aq}_implem{implem}",
lambda self: self.subtest_accuracy(aq, st, implem, metric)
)
for metric in 'L2', 'IP':
for implem in 0, 12, 13, 14, 15:
add_TestAQFastScan_subset_accuracy('LSQ', 'lsq', implem, metric)
add_TestAQFastScan_subset_accuracy('RQ', 'rq', implem, metric)
def add_TestAQFastScan_subtest_from_idxaq(implem, metric):
setattr(
TestAQFastScan,
f"test_from_idxaq_{metric}_implem{implem}",
lambda self: self.subtest_from_idxaq(implem, metric)
)
for implem in 2, 3, 4:
add_TestAQFastScan_subtest_from_idxaq(implem, 'L2')
add_TestAQFastScan_subtest_from_idxaq(implem, 'IP')
class TestPAQFastScan(unittest.TestCase):
def subtest_accuracy(self, paq):
"""
Compare IndexPAQFastScan with IndexPAQ (qint8)
"""
d = 16
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
gt = ds.get_groundtruth(k=1)
index = faiss.index_factory(d, f'{paq}2x3x4_Nqint8')
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 1)
indexfs = faiss.index_factory(d, f'{paq}2x3x4fs_Nlsq2x4')
indexfs.train(ds.get_train())
indexfs.add(ds.get_database())
Da, Ia = indexfs.search(ds.get_queries(), 1)
nq = Iref.shape[0]
recall_ref = (Iref == gt).sum() / nq
recall = (Ia == gt).sum() / nq
assert abs(recall_ref - recall) < 0.05
def test_accuracy_PLSQ(self):
self.subtest_accuracy("PLSQ")
def test_accuracy_PRQ(self):
self.subtest_accuracy("PRQ")
def subtest_factory(self, paq):
index = faiss.index_factory(16, f'{paq}2x3x4fs_Nlsq2x4')
q = faiss.downcast_Quantizer(index.aq)
self.assertEqual(q.nsplits, 2)
self.assertEqual(q.subquantizer(0).M, 3)
def test_factory(self):
self.subtest_factory('PRQ')
self.subtest_factory('PLSQ')
def subtest_io(self, factory_str):
d = 8
ds = datasets.SyntheticDataset(d, 1000, 500, 100)
index = faiss.index_factory(d, factory_str)
index.train(ds.get_train())
index.add(ds.get_database())
D1, I1 = index.search(ds.get_queries(), 1)
fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, fname)
index2 = faiss.read_index(fname)
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)
finally:
if os.path.exists(fname):
os.unlink(fname)
def test_io(self):
self.subtest_io('PLSQ2x3x4fs_Nlsq2x4')
self.subtest_io('PRQ2x3x4fs_Nrq2x4')