701 lines
21 KiB
Python
701 lines
21 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import unittest
|
|
import time
|
|
import os
|
|
import tempfile
|
|
|
|
import numpy as np
|
|
import faiss
|
|
|
|
from faiss.contrib import datasets
|
|
|
|
# the tests tend to timeout in stress modes + dev otherwise
|
|
faiss.omp_set_num_threads(4)
|
|
|
|
class TestCompileOptions(unittest.TestCase):
|
|
|
|
def test_compile_options(self):
|
|
options = faiss.get_compile_options()
|
|
options = options.split(' ')
|
|
for option in options:
|
|
assert option in ['AVX2', 'NEON', 'GENERIC', 'OPTIMIZE']
|
|
|
|
|
|
class TestSearch(unittest.TestCase):
|
|
|
|
def test_PQ4_accuracy(self):
|
|
ds = datasets.SyntheticDataset(32, 2000, 5000, 1000)
|
|
|
|
index_gt = faiss.IndexFlatL2(32)
|
|
index_gt.add(ds.get_database())
|
|
Dref, Iref = index_gt.search(ds.get_queries(), 10)
|
|
|
|
index = faiss.index_factory(32, 'PQ16x4fs')
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
Da, Ia = index.search(ds.get_queries(), 10)
|
|
|
|
nq = Iref.shape[0]
|
|
recall_at_1 = (Iref[:, 0] == Ia[:, 0]).sum() / nq
|
|
assert recall_at_1 > 0.6
|
|
# print(f'recall@1 = {recall_at_1:.3f}')
|
|
|
|
|
|
# This is an experiment to see if we can catch performance
|
|
# regressions. It runs 2 codes, one should be faster than the
|
|
# other by a factor ~10 in opt mode. We check for a factor 5.
|
|
# hopefully the jitter in executtion time will not produce
|
|
# too many spurious test failures. Unoptimized timings are
|
|
# not exploitable, hence the flag test on that as well.
|
|
@unittest.skipUnless(
|
|
('AVX2' in faiss.get_compile_options() or
|
|
'NEON' in faiss.get_compile_options()) and
|
|
"OPTIMIZE" in faiss.get_compile_options(),
|
|
"only test while building with avx2 or neon")
|
|
def test_PQ4_speed(self):
|
|
ds = datasets.SyntheticDataset(32, 2000, 5000, 1000)
|
|
xt = ds.get_train()
|
|
xb = ds.get_database()
|
|
xq = ds.get_queries()
|
|
|
|
index = faiss.index_factory(32, 'PQ16x4')
|
|
index.train(xt)
|
|
index.add(xb)
|
|
|
|
t0 = time.time()
|
|
D1, I1 = index.search(xq, 10)
|
|
t1 = time.time()
|
|
pq_t = t1 - t0
|
|
print('PQ16x4 search time:', pq_t)
|
|
|
|
index2 = faiss.index_factory(32, 'PQ16x4fs')
|
|
index2.train(xt)
|
|
index2.add(xb)
|
|
|
|
t0 = time.time()
|
|
D2, I2 = index2.search(xq, 10)
|
|
t1 = time.time()
|
|
pqfs_t = t1 - t0
|
|
print('PQ16x4fs search time:', pqfs_t)
|
|
self.assertLess(pqfs_t * 5, pq_t)
|
|
|
|
|
|
class TestRounding(unittest.TestCase):
|
|
|
|
def do_test_rounding(self, implem=4, metric=faiss.METRIC_L2):
|
|
ds = datasets.SyntheticDataset(32, 2000, 5000, 200)
|
|
|
|
index = faiss.index_factory(32, 'PQ16x4', metric)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
Dref, Iref = index.search(ds.get_queries(), 10)
|
|
nq = Iref.shape[0]
|
|
|
|
index2 = faiss.IndexPQFastScan(index)
|
|
|
|
# simply repro normal search
|
|
index2.implem = 2
|
|
D2, I2 = index2.search(ds.get_queries(), 10)
|
|
np.testing.assert_array_equal(I2, Iref)
|
|
np.testing.assert_array_equal(D2, Dref)
|
|
|
|
# rounded LUT with correction
|
|
index2.implem = implem
|
|
D4, I4 = index2.search(ds.get_queries(), 10)
|
|
# check accuracy of indexes
|
|
recalls = {}
|
|
for rank in 1, 10:
|
|
recalls[rank] = (Iref[:, :1] == I4[:, :rank]).sum() / nq
|
|
|
|
min_r1 = 0.98 if metric == faiss.METRIC_INNER_PRODUCT else 0.99
|
|
self.assertGreaterEqual(recalls[1], min_r1)
|
|
self.assertGreater(recalls[10], 0.995)
|
|
# check accuracy of distances
|
|
# err3 = ((D3 - D2) ** 2).sum()
|
|
err4 = ((D4 - D2) ** 2).sum()
|
|
nf = (D2 ** 2).sum()
|
|
self.assertLess(err4, nf * 1e-4)
|
|
|
|
def test_implem_4(self):
|
|
self.do_test_rounding(4)
|
|
|
|
def test_implem_4_ip(self):
|
|
self.do_test_rounding(4, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_implem_12(self):
|
|
self.do_test_rounding(12)
|
|
|
|
def test_implem_12_ip(self):
|
|
self.do_test_rounding(12, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_implem_14(self):
|
|
self.do_test_rounding(14)
|
|
|
|
def test_implem_14_ip(self):
|
|
self.do_test_rounding(12, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
class TestReconstruct(unittest.TestCase):
|
|
|
|
def test_pqfastscan(self):
|
|
ds = datasets.SyntheticDataset(20, 1000, 1000, 0)
|
|
|
|
index = faiss.index_factory(20, 'PQ5x4')
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
recons = index.reconstruct_n(0, index.ntotal)
|
|
|
|
index2 = faiss.IndexPQFastScan(index)
|
|
recons2 = index2.reconstruct_n(0, index.ntotal)
|
|
|
|
np.testing.assert_array_equal(recons, recons2)
|
|
|
|
def test_aqfastscan(self):
|
|
ds = datasets.SyntheticDataset(20, 1000, 1000, 0)
|
|
|
|
index = faiss.index_factory(20, 'RQ5x4_Nrq2x4')
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
recons = index.reconstruct_n(0, index.ntotal)
|
|
|
|
index2 = faiss.IndexAdditiveQuantizerFastScan(index)
|
|
recons2 = index2.reconstruct_n(0, index.ntotal)
|
|
|
|
np.testing.assert_array_equal(recons, recons2)
|
|
|
|
|
|
#########################################################
|
|
# Kernel unit test
|
|
#########################################################
|
|
|
|
|
|
def reference_accu(codes, LUT):
|
|
nq, nsp, is_16 = LUT.shape
|
|
nb, nsp_2 = codes.shape
|
|
assert is_16 == 16
|
|
assert nsp_2 == nsp // 2
|
|
accu = np.zeros((nq, nb), 'uint16')
|
|
for i in range(nq):
|
|
for j in range(nb):
|
|
a = np.uint16(0)
|
|
for sp in range(0, nsp, 2):
|
|
c = codes[j, sp // 2]
|
|
a += LUT[i, sp , c & 15].astype('uint16')
|
|
a += LUT[i, sp + 1, c >> 4].astype('uint16')
|
|
accu[i, j] = a
|
|
return accu
|
|
|
|
|
|
# disabled because the function to write to mem is not implemented currently
|
|
class ThisIsNotATestLoop5: # (unittest.TestCase):
|
|
|
|
def do_loop5_kernel(self, nq, bb):
|
|
""" unit test for the accumulation kernel """
|
|
nb = bb * 32 # databse size
|
|
nsp = 24 # number of sub-quantizers
|
|
|
|
rs = np.random.RandomState(123)
|
|
codes = rs.randint(256, size=(nb, nsp // 2)).astype('uint8')
|
|
LUT = rs.randint(256, size=(nq, nsp, 16)).astype('uint8')
|
|
accu_ref = reference_accu(codes, LUT)
|
|
|
|
def to_A(x):
|
|
return faiss.array_to_AlignedTable(x.ravel())
|
|
|
|
sp = faiss.swig_ptr
|
|
|
|
LUT_a = faiss.AlignedTableUint8(LUT.size)
|
|
faiss.pq4_pack_LUT(
|
|
nq, nsp, sp(LUT),
|
|
LUT_a.get()
|
|
)
|
|
|
|
codes_a = faiss.AlignedTableUint8(codes.size)
|
|
faiss.pq4_pack_codes(
|
|
sp(codes),
|
|
nb, nsp, nb, nb, nsp,
|
|
codes_a.get()
|
|
)
|
|
|
|
accu_a = faiss.AlignedTableUint16(nq * nb)
|
|
accu_a.clear()
|
|
faiss.loop5_kernel_accumulate_1_block_to_mem(
|
|
nq, nb, nsp, codes_a.get(), LUT_a.get(), accu_a.get()
|
|
)
|
|
accu = faiss.AlignedTable_to_array(accu_a).reshape(nq, nb)
|
|
np.testing.assert_array_equal(accu_ref, accu)
|
|
|
|
def test_11(self):
|
|
self.do_loop5_kernel(1, 1)
|
|
|
|
def test_21(self):
|
|
self.do_loop5_kernel(2, 1)
|
|
|
|
def test_12(self):
|
|
self.do_loop5_kernel(1, 2)
|
|
|
|
def test_22(self):
|
|
self.do_loop5_kernel(2, 2)
|
|
|
|
|
|
##########################################################
|
|
# Tests for various IndexPQFastScan implementations
|
|
##########################################################
|
|
|
|
def verify_with_draws(testcase, Dref, Iref, Dnew, Inew):
|
|
""" verify a list of results where there are draws in the distances (because
|
|
they are integer). """
|
|
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
|
|
# here we have to be careful because of draws
|
|
for i in range(len(Iref)):
|
|
if np.all(Iref[i] == Inew[i]): # easy case
|
|
continue
|
|
# we can deduce nothing about the latest line
|
|
skip_dis = Dref[i, -1]
|
|
for dis in np.unique(Dref):
|
|
if dis == skip_dis:
|
|
continue
|
|
mask = Dref[i, :] == dis
|
|
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
|
|
|
|
|
|
class TestImplems(unittest.TestCase):
|
|
|
|
def __init__(self, *args):
|
|
unittest.TestCase.__init__(self, *args)
|
|
self.cache = {}
|
|
self.k = 10
|
|
|
|
def get_index(self, d, metric):
|
|
if (d, metric) not in self.cache:
|
|
ds = datasets.SyntheticDataset(d, 1000, 2000, 200)
|
|
target_size = d // 2
|
|
index = faiss.index_factory(d, 'PQ%dx4' % target_size, metric)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
|
|
index2 = faiss.IndexPQFastScan(index)
|
|
# uint8 LUT but no SIMD
|
|
index2.implem = 4
|
|
Dref, Iref = index2.search(ds.get_queries(), 10)
|
|
|
|
self.cache[(d, metric)] = (ds, index, Dref, Iref)
|
|
|
|
return self.cache[(d, metric)]
|
|
|
|
def do_with_params(self, d, params, metric=faiss.METRIC_L2):
|
|
ds, index, Dref, Iref = self.get_index(d, metric)
|
|
|
|
index2 = self.build_fast_scan_index(index, params)
|
|
|
|
Dnew, Inew = index2.search(ds.get_queries(), self.k)
|
|
|
|
Dref = Dref[:, :self.k]
|
|
Iref = Iref[:, :self.k]
|
|
|
|
verify_with_draws(self, Dref, Iref, Dnew, Inew)
|
|
|
|
|
|
def build_fast_scan_index(self, index, params):
|
|
index2 = faiss.IndexPQFastScan(index)
|
|
index2.implem = 5
|
|
return index2
|
|
|
|
|
|
|
|
class TestImplem12(TestImplems):
|
|
|
|
def build_fast_scan_index(self, index, qbs):
|
|
index2 = faiss.IndexPQFastScan(index)
|
|
index2.qbs = qbs
|
|
index2.implem = 12
|
|
return index2
|
|
|
|
def test_qbs7(self):
|
|
self.do_with_params(32, 0x223)
|
|
|
|
def test_qbs7b(self):
|
|
self.do_with_params(32, 0x133)
|
|
|
|
def test_qbs6(self):
|
|
self.do_with_params(32, 0x33)
|
|
|
|
def test_qbs6_ip(self):
|
|
self.do_with_params(32, 0x33, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_qbs6b(self):
|
|
# test codepath where qbs is not known at compile time
|
|
self.do_with_params(32, 0x1113)
|
|
|
|
def test_qbs6_odd_dim(self):
|
|
self.do_with_params(30, 0x33)
|
|
|
|
|
|
class TestImplem13(TestImplems):
|
|
|
|
def build_fast_scan_index(self, index, qbs):
|
|
index2 = faiss.IndexPQFastScan(index)
|
|
index2.qbs = qbs
|
|
index2.implem = 13
|
|
return index2
|
|
|
|
def test_qbs7(self):
|
|
self.do_with_params(32, 0x223)
|
|
|
|
def test_qbs7_k1(self):
|
|
self.k = 1
|
|
self.do_with_params(32, 0x223)
|
|
|
|
|
|
class TestImplem14(TestImplems):
|
|
|
|
def build_fast_scan_index(self, index, params):
|
|
qbs, bbs = params
|
|
index2 = faiss.IndexPQFastScan(index, bbs)
|
|
index2.qbs = qbs
|
|
index2.implem = 14
|
|
return index2
|
|
|
|
def test_1_32(self):
|
|
self.do_with_params(32, (1, 32))
|
|
|
|
def test_1_64(self):
|
|
self.do_with_params(32, (1, 64))
|
|
|
|
def test_2_32(self):
|
|
self.do_with_params(32, (2, 32))
|
|
|
|
def test_2_64(self):
|
|
self.do_with_params(32, (2, 64))
|
|
|
|
def test_qbs_1_32_k1(self):
|
|
self.k = 1
|
|
self.do_with_params(32, (1, 32))
|
|
|
|
def test_qbs_1_64_k1(self):
|
|
self.k = 1
|
|
self.do_with_params(32, (1, 64))
|
|
|
|
def test_1_32_odd_dim(self):
|
|
self.do_with_params(30, (1, 32))
|
|
|
|
def test_1_64_odd_dim(self):
|
|
self.do_with_params(30, (1, 64))
|
|
|
|
|
|
class TestImplem15(TestImplems):
|
|
|
|
def build_fast_scan_index(self, index, params):
|
|
qbs, bbs = params
|
|
index2 = faiss.IndexPQFastScan(index, bbs)
|
|
index2.qbs = qbs
|
|
index2.implem = 15
|
|
return index2
|
|
|
|
def test_1_32(self):
|
|
self.do_with_params(32, (1, 32))
|
|
|
|
def test_2_64(self):
|
|
self.do_with_params(32, (2, 64))
|
|
|
|
class TestAdd(unittest.TestCase):
|
|
|
|
def do_test_add(self, d, bbs):
|
|
|
|
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
|
|
|
|
index = faiss.index_factory(d, f'PQ{d//2}x4np')
|
|
index.train(ds.get_train())
|
|
|
|
xb = ds.get_database()
|
|
index.add(xb[:1235])
|
|
|
|
index2 = faiss.IndexPQFastScan(index, bbs)
|
|
index2.add(xb[1235:])
|
|
new_codes = faiss.AlignedTable_to_array(index2.codes)
|
|
|
|
index.add(xb[1235:])
|
|
index3 = faiss.IndexPQFastScan(index, bbs)
|
|
ref_codes = faiss.AlignedTable_to_array(index3.codes)
|
|
self.assertEqual(index3.ntotal, index2.ntotal)
|
|
|
|
np.testing.assert_array_equal(ref_codes, new_codes)
|
|
|
|
def test_add(self):
|
|
self.do_test_add(32, 32)
|
|
|
|
def test_add_bbs64(self):
|
|
self.do_test_add(32, 64)
|
|
|
|
def test_add_odd_d(self):
|
|
self.do_test_add(30, 64)
|
|
|
|
def test_constructor(self):
|
|
d = 32
|
|
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
|
|
|
|
index = faiss.index_factory(d, f'PQ{d//2}x4np')
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
Dref, Iref = index.search(ds.get_queries(), 10)
|
|
nq = Iref.shape[0]
|
|
|
|
index2 = faiss.IndexPQFastScan(d, d // 2, 4)
|
|
index2.train(ds.get_train())
|
|
index2.add(ds.get_database())
|
|
Dnew, Inew = index2.search(ds.get_queries(), 10)
|
|
|
|
recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq
|
|
|
|
self.assertGreaterEqual(recall_at_1, 0.99)
|
|
|
|
data = faiss.serialize_index(index2)
|
|
index3 = faiss.deserialize_index(data)
|
|
|
|
self.assertEqual(index2.implem, index3.implem)
|
|
|
|
D3, I3 = index3.search(ds.get_queries(), 10)
|
|
np.testing.assert_array_equal(D3, Dnew)
|
|
np.testing.assert_array_equal(I3, Inew)
|
|
|
|
|
|
class TestAQFastScan(unittest.TestCase):
|
|
|
|
def subtest_accuracy(self, aq, st, implem, metric_type='L2'):
|
|
"""
|
|
Compare IndexAdditiveQuantizerFastScan with IndexAQ (qint8)
|
|
"""
|
|
d = 16
|
|
ds = datasets.SyntheticDataset(d, 1000, 1000, 500, metric_type)
|
|
gt = ds.get_groundtruth(k=1)
|
|
|
|
if metric_type == 'L2':
|
|
metric = faiss.METRIC_L2
|
|
postfix1 = '_Nqint8'
|
|
postfix2 = f'_N{st}2x4'
|
|
else:
|
|
metric = faiss.METRIC_INNER_PRODUCT
|
|
postfix1 = postfix2 = ''
|
|
|
|
index = faiss.index_factory(d, f'{aq}3x4{postfix1}', metric)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
Dref, Iref = index.search(ds.get_queries(), 1)
|
|
|
|
indexfs = faiss.index_factory(d, f'{aq}3x4fs_32{postfix2}', metric)
|
|
indexfs.train(ds.get_train())
|
|
indexfs.add(ds.get_database())
|
|
indexfs.implem = implem
|
|
Da, Ia = indexfs.search(ds.get_queries(), 1)
|
|
|
|
nq = Iref.shape[0]
|
|
recall_ref = (Iref == gt).sum() / nq
|
|
recall = (Ia == gt).sum() / nq
|
|
|
|
print(aq, st, implem, metric_type, recall_ref, recall)
|
|
assert abs(recall_ref - recall) < 0.05
|
|
|
|
def xx_test_accuracy(self):
|
|
for metric in 'L2', 'IP':
|
|
for implem in 0, 12, 13, 14, 15:
|
|
self.subtest_accuracy('RQ', 'rq', implem, metric)
|
|
self.subtest_accuracy('LSQ', 'lsq', implem, metric)
|
|
|
|
def subtest_from_idxaq(self, implem, metric):
|
|
if metric == 'L2':
|
|
metric_type = faiss.METRIC_L2
|
|
st = '_Nrq2x4'
|
|
else:
|
|
metric_type = faiss.METRIC_INNER_PRODUCT
|
|
st = ''
|
|
|
|
d = 16
|
|
ds = datasets.SyntheticDataset(d, 1000, 2000, 1000, metric=metric)
|
|
gt = ds.get_groundtruth(k=1)
|
|
index = faiss.index_factory(d, 'RQ8x4' + st, metric_type)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 16
|
|
Dref, Iref = index.search(ds.get_queries(), 1)
|
|
|
|
indexfs = faiss.IndexAdditiveQuantizerFastScan(index)
|
|
indexfs.implem = implem
|
|
D1, I1 = indexfs.search(ds.get_queries(), 1)
|
|
|
|
nq = Iref.shape[0]
|
|
recall_ref = (Iref == gt).sum() / nq
|
|
recall1 = (I1 == gt).sum() / nq
|
|
print(recall_ref, recall1)
|
|
assert abs(recall_ref - recall1) < 0.05
|
|
|
|
def xx_test_from_idxaq(self):
|
|
for implem in 2, 3, 4:
|
|
self.subtest_from_idxaq(implem, 'L2')
|
|
self.subtest_from_idxaq(implem, 'IP')
|
|
|
|
def subtest_factory(self, aq, M, bbs, st):
|
|
"""
|
|
Format: {AQ}{M}x4fs_{bbs}_N{st}
|
|
|
|
AQ (str): `LSQ` or `RQ`
|
|
M (int): number of subquantizers
|
|
bbs (int): build block size
|
|
st (str): search type, `lsq2x4` or `rq2x4`
|
|
"""
|
|
AQ = faiss.AdditiveQuantizer
|
|
d = 16
|
|
|
|
if bbs > 0:
|
|
index = faiss.index_factory(d, f'{aq}{M}x4fs_{bbs}_N{st}2x4')
|
|
else:
|
|
index = faiss.index_factory(d, f'{aq}{M}x4fs_N{st}2x4')
|
|
bbs = 32
|
|
|
|
assert index.bbs == bbs
|
|
aq = faiss.downcast_AdditiveQuantizer(index.aq)
|
|
assert aq.M == M
|
|
|
|
if aq == 'LSQ':
|
|
assert isinstance(aq, faiss.LocalSearchQuantizer)
|
|
if aq == 'RQ':
|
|
assert isinstance(aq, faiss.ResidualQuantizer)
|
|
|
|
if st == 'lsq':
|
|
assert aq.search_type == AQ.ST_norm_lsq2x4
|
|
if st == 'rq':
|
|
assert aq.search_type == AQ.ST_norm_rq2x4
|
|
|
|
def test_factory(self):
|
|
self.subtest_factory('LSQ', 16, 64, 'lsq')
|
|
self.subtest_factory('LSQ', 16, 64, 'rq')
|
|
self.subtest_factory('RQ', 16, 64, 'rq')
|
|
self.subtest_factory('RQ', 16, 64, 'lsq')
|
|
self.subtest_factory('LSQ', 64, 0, 'lsq')
|
|
|
|
def subtest_io(self, factory_str):
|
|
d = 8
|
|
ds = datasets.SyntheticDataset(d, 1000, 500, 100)
|
|
|
|
index = faiss.index_factory(d, factory_str)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
D1, I1 = index.search(ds.get_queries(), 1)
|
|
|
|
fd, fname = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index(index, fname)
|
|
index2 = faiss.read_index(fname)
|
|
D2, I2 = index2.search(ds.get_queries(), 1)
|
|
np.testing.assert_array_equal(I1, I2)
|
|
finally:
|
|
if os.path.exists(fname):
|
|
os.unlink(fname)
|
|
|
|
def test_io(self):
|
|
self.subtest_io('LSQ4x4fs_Nlsq2x4')
|
|
self.subtest_io('LSQ4x4fs_Nrq2x4')
|
|
self.subtest_io('RQ4x4fs_Nrq2x4')
|
|
self.subtest_io('RQ4x4fs_Nlsq2x4')
|
|
|
|
|
|
# programatically generate tests to get finer test granularity.
|
|
|
|
def add_TestAQFastScan_subset_accuracy(aq, st, implem, metric):
|
|
setattr(
|
|
TestAQFastScan,
|
|
f"test_accuracy_{metric}_{aq}_implem{implem}",
|
|
lambda self: self.subtest_accuracy(aq, st, implem, metric)
|
|
)
|
|
|
|
|
|
for metric in 'L2', 'IP':
|
|
for implem in 0, 12, 13, 14, 15:
|
|
add_TestAQFastScan_subset_accuracy('LSQ', 'lsq', implem, metric)
|
|
add_TestAQFastScan_subset_accuracy('RQ', 'rq', implem, metric)
|
|
|
|
|
|
def add_TestAQFastScan_subtest_from_idxaq(implem, metric):
|
|
setattr(
|
|
TestAQFastScan,
|
|
f"test_from_idxaq_{metric}_implem{implem}",
|
|
lambda self: self.subtest_from_idxaq(implem, metric)
|
|
)
|
|
|
|
|
|
for implem in 2, 3, 4:
|
|
add_TestAQFastScan_subtest_from_idxaq(implem, 'L2')
|
|
add_TestAQFastScan_subtest_from_idxaq(implem, 'IP')
|
|
|
|
|
|
class TestPAQFastScan(unittest.TestCase):
|
|
|
|
def subtest_accuracy(self, paq):
|
|
"""
|
|
Compare IndexPAQFastScan with IndexPAQ (qint8)
|
|
"""
|
|
d = 16
|
|
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
|
|
gt = ds.get_groundtruth(k=1)
|
|
|
|
index = faiss.index_factory(d, f'{paq}2x3x4_Nqint8')
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
Dref, Iref = index.search(ds.get_queries(), 1)
|
|
|
|
indexfs = faiss.index_factory(d, f'{paq}2x3x4fs_Nlsq2x4')
|
|
indexfs.train(ds.get_train())
|
|
indexfs.add(ds.get_database())
|
|
Da, Ia = indexfs.search(ds.get_queries(), 1)
|
|
|
|
nq = Iref.shape[0]
|
|
recall_ref = (Iref == gt).sum() / nq
|
|
recall = (Ia == gt).sum() / nq
|
|
|
|
assert abs(recall_ref - recall) < 0.05
|
|
|
|
def test_accuracy_PLSQ(self):
|
|
self.subtest_accuracy("PLSQ")
|
|
|
|
def test_accuracy_PRQ(self):
|
|
self.subtest_accuracy("PRQ")
|
|
|
|
def subtest_factory(self, paq):
|
|
index = faiss.index_factory(16, f'{paq}2x3x4fs_Nlsq2x4')
|
|
q = faiss.downcast_Quantizer(index.aq)
|
|
self.assertEqual(q.nsplits, 2)
|
|
self.assertEqual(q.subquantizer(0).M, 3)
|
|
|
|
def test_factory(self):
|
|
self.subtest_factory('PRQ')
|
|
self.subtest_factory('PLSQ')
|
|
|
|
def subtest_io(self, factory_str):
|
|
d = 8
|
|
ds = datasets.SyntheticDataset(d, 1000, 500, 100)
|
|
|
|
index = faiss.index_factory(d, factory_str)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
D1, I1 = index.search(ds.get_queries(), 1)
|
|
|
|
fd, fname = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index(index, fname)
|
|
index2 = faiss.read_index(fname)
|
|
D2, I2 = index2.search(ds.get_queries(), 1)
|
|
np.testing.assert_array_equal(I1, I2)
|
|
finally:
|
|
if os.path.exists(fname):
|
|
os.unlink(fname)
|
|
|
|
def test_io(self):
|
|
self.subtest_io('PLSQ2x3x4fs_Nlsq2x4')
|
|
self.subtest_io('PRQ2x3x4fs_Nrq2x4')
|