faiss/tests/test_fast_scan.py

390 lines
11 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
import faiss
from faiss.contrib import datasets
import platform
class TestSearch(unittest.TestCase):
def test_PQ4_accuracy(self):
ds = datasets.SyntheticDataset(32, 2000, 5000, 1000)
index_gt = faiss.IndexFlatL2(32)
index_gt.add(ds.get_database())
Dref, Iref = index_gt.search(ds.get_queries(), 10)
index = faiss.index_factory(32, 'PQ16x4')
index.train(ds.get_train())
index.add(ds.get_database())
Da, Ia = index.search(ds.get_queries(), 10)
nq = Iref.shape[0]
recall_at_1 = (Iref[:, 0] == Ia[:, 0]).sum() / nq
assert recall_at_1 > 0.6
# print(f'recall@1 = {recall_at_1:.3f}')
class TestRounding(unittest.TestCase):
def do_test_rounding(self, implem=4, metric=faiss.METRIC_L2):
ds = datasets.SyntheticDataset(32, 2000, 5000, 200)
index = faiss.index_factory(32, 'PQ16x4', metric)
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 10)
nq = Iref.shape[0]
index2 = faiss.IndexPQFastScan(index)
# simply repro normal search
index2.implem = 2
D2, I2 = index2.search(ds.get_queries(), 10)
np.testing.assert_array_equal(I2, Iref)
np.testing.assert_array_equal(D2, Dref)
# rounded LUT with correction
index2.implem = implem
D4, I4 = index2.search(ds.get_queries(), 10)
# check accuracy of indexes
recalls = {}
for rank in 1, 10:
recalls[rank] = (Iref[:, :1] == I4[:, :rank]).sum() / nq
min_r1 = 0.98 if metric == faiss.METRIC_INNER_PRODUCT else 0.99
self.assertGreater(recalls[1], min_r1)
self.assertGreater(recalls[10], 0.995)
# check accuracy of distances
# err3 = ((D3 - D2) ** 2).sum()
err4 = ((D4 - D2) ** 2).sum()
nf = (D2 ** 2).sum()
self.assertLess(err4, nf * 1e-4)
def test_implem_4(self):
self.do_test_rounding(4)
def test_implem_4_ip(self):
self.do_test_rounding(4, faiss.METRIC_INNER_PRODUCT)
def test_implem_12(self):
self.do_test_rounding(12)
def test_implem_12_ip(self):
self.do_test_rounding(12, faiss.METRIC_INNER_PRODUCT)
def test_implem_14(self):
self.do_test_rounding(14)
def test_implem_14_ip(self):
self.do_test_rounding(12, faiss.METRIC_INNER_PRODUCT)
#########################################################
# Kernel unit test
#########################################################
def reference_accu(codes, LUT):
nq, nsp, is_16 = LUT.shape
nb, nsp_2 = codes.shape
assert is_16 == 16
assert nsp_2 == nsp // 2
accu = np.zeros((nq, nb), 'uint16')
for i in range(nq):
for j in range(nb):
a = np.uint16(0)
for sp in range(0, nsp, 2):
c = codes[j, sp // 2]
a += LUT[i, sp , c & 15].astype('uint16')
a += LUT[i, sp + 1, c >> 4].astype('uint16')
accu[i, j] = a
return accu
# disabled because the function to write to mem is not implemented currently
class ThisIsNotATestLoop5: # (unittest.TestCase):
def do_loop5_kernel(self, nq, bb):
""" unit test for the accumulation kernel """
nb = bb * 32 # databse size
nsp = 24 # number of sub-quantizers
rs = np.random.RandomState(123)
codes = rs.randint(256, size=(nb, nsp // 2)).astype('uint8')
LUT = rs.randint(256, size=(nq, nsp, 16)).astype('uint8')
accu_ref = reference_accu(codes, LUT)
def to_A(x):
return faiss.array_to_AlignedTable(x.ravel())
sp = faiss.swig_ptr
LUT_a = faiss.AlignedTableUint8(LUT.size)
faiss.pq4_pack_LUT(
nq, nsp, sp(LUT),
LUT_a.get()
)
codes_a = faiss.AlignedTableUint8(codes.size)
faiss.pq4_pack_codes(
sp(codes),
nb, nsp, nb, nb, nsp,
codes_a.get()
)
accu_a = faiss.AlignedTableUint16(nq * nb)
accu_a.clear()
faiss.loop5_kernel_accumulate_1_block_to_mem(
nq, nb, nsp, codes_a.get(), LUT_a.get(), accu_a.get()
)
accu = faiss.AlignedTable_to_array(accu_a).reshape(nq, nb)
np.testing.assert_array_equal(accu_ref, accu)
def test_11(self):
self.do_loop5_kernel(1, 1)
def test_21(self):
self.do_loop5_kernel(2, 1)
def test_12(self):
self.do_loop5_kernel(1, 2)
def test_22(self):
self.do_loop5_kernel(2, 2)
##########################################################
# Tests for various IndexPQFastScan implementations
##########################################################
def verify_with_draws(testcase, Dref, Iref, Dnew, Inew):
""" verify a list of results where there are draws in the distances (because
they are integer). """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
# here we have to be careful because of draws
for i in range(len(Iref)):
if np.all(Iref[i] == Inew[i]): # easy case
continue
# we can deduce nothing about the latest line
skip_dis = Dref[i, -1]
for dis in np.unique(Dref):
if dis == skip_dis:
continue
mask = Dref[i, :] == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
class TestImplems(unittest.TestCase):
def __init__(self, *args):
unittest.TestCase.__init__(self, *args)
self.cache = {}
self.k = 10
def get_index(self, d, metric):
if (d, metric) not in self.cache:
ds = datasets.SyntheticDataset(d, 1000, 2000, 200)
target_size = d // 2
index = faiss.index_factory(d, 'PQ%dx4' % target_size, metric)
index.train(ds.get_train())
index.add(ds.get_database())
index2 = faiss.IndexPQFastScan(index)
# uint8 LUT but no SIMD
index2.implem = 4
Dref, Iref = index2.search(ds.get_queries(), 10)
self.cache[(d, metric)] = (ds, index, Dref, Iref)
return self.cache[(d, metric)]
def do_with_params(self, d, params, metric=faiss.METRIC_L2):
ds, index, Dref, Iref = self.get_index(d, metric)
index2 = self.build_fast_scan_index(index, params)
Dnew, Inew = index2.search(ds.get_queries(), self.k)
Dref = Dref[:, :self.k]
Iref = Iref[:, :self.k]
verify_with_draws(self, Dref, Iref, Dnew, Inew)
def build_fast_scan_index(self, index, params):
index2 = faiss.IndexPQFastScan(index)
index2.implem = 5
return index2
class TestImplem12(TestImplems):
def build_fast_scan_index(self, index, qbs):
index2 = faiss.IndexPQFastScan(index)
index2.qbs = qbs
index2.implem = 12
return index2
def test_qbs7(self):
self.do_with_params(32, 0x223)
def test_qbs7b(self):
self.do_with_params(32, 0x133)
def test_qbs6(self):
self.do_with_params(32, 0x33)
def test_qbs6_ip(self):
self.do_with_params(32, 0x33, faiss.METRIC_INNER_PRODUCT)
def test_qbs6b(self):
# test codepath where qbs is not known at compile time
self.do_with_params(32, 0x1113)
def test_qbs6_odd_dim(self):
self.do_with_params(30, 0x33)
class TestImplem13(TestImplems):
def build_fast_scan_index(self, index, qbs):
index2 = faiss.IndexPQFastScan(index)
index2.qbs = qbs
index2.implem = 13
return index2
def test_qbs7(self):
self.do_with_params(32, 0x223)
def test_qbs7_k1(self):
self.k = 1
self.do_with_params(32, 0x223)
class TestImplem14(TestImplems):
def build_fast_scan_index(self, index, params):
qbs, bbs = params
index2 = faiss.IndexPQFastScan(index, bbs)
index2.qbs = qbs
index2.implem = 14
return index2
def test_1_32(self):
self.do_with_params(32, (1, 32))
def test_1_64(self):
self.do_with_params(32, (1, 64))
def test_2_32(self):
self.do_with_params(32, (2, 32))
def test_2_64(self):
self.do_with_params(32, (2, 64))
def test_qbs_1_32_k1(self):
self.k = 1
self.do_with_params(32, (1, 32))
def test_qbs_1_64_k1(self):
self.k = 1
self.do_with_params(32, (1, 64))
def test_1_32_odd_dim(self):
self.do_with_params(30, (1, 32))
def test_1_64_odd_dim(self):
self.do_with_params(30, (1, 64))
class TestImplem15(TestImplems):
def build_fast_scan_index(self, index, params):
qbs, bbs = params
index2 = faiss.IndexPQFastScan(index, bbs)
index2.qbs = qbs
index2.implem = 15
return index2
def test_1_32(self):
self.do_with_params(32, (1, 32))
def test_2_64(self):
self.do_with_params(32, (2, 64))
class TestAdd(unittest.TestCase):
def do_test_add(self, d, bbs):
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
index = faiss.index_factory(d, f'PQ{d//2}x4np')
index.train(ds.get_train())
xb = ds.get_database()
index.add(xb[:1235])
index2 = faiss.IndexPQFastScan(index, bbs)
index2.add(xb[1235:])
new_codes = faiss.AlignedTable_to_array(index2.codes)
index.add(xb[1235:])
index3 = faiss.IndexPQFastScan(index, bbs)
ref_codes = faiss.AlignedTable_to_array(index3.codes)
self.assertEqual(index3.ntotal, index2.ntotal)
np.testing.assert_array_equal(ref_codes, new_codes)
def test_add(self):
self.do_test_add(32, 32)
def test_add_bbs64(self):
self.do_test_add(32, 64)
def test_add_odd_d(self):
self.do_test_add(30, 64)
def test_constructor(self):
d = 32
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
index = faiss.index_factory(d, f'PQ{d//2}x4np')
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 10)
nq = Iref.shape[0]
index2 = faiss.IndexPQFastScan(d, d // 2, 4)
index2.train(ds.get_train())
index2.add(ds.get_database())
Dnew, Inew = index2.search(ds.get_queries(), 10)
recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq
self.assertGreater(recall_at_1, 0.99)
data = faiss.serialize_index(index2)
index3 = faiss.deserialize_index(data)
self.assertEqual(index2.implem, index3.implem)
D3, I3 = index3.search(ds.get_queries(), 10)
np.testing.assert_array_equal(D3, Dnew)
np.testing.assert_array_equal(I3, Inew)