1012 lines
32 KiB
Python
1012 lines
32 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import os
|
|
import unittest
|
|
import tempfile
|
|
|
|
import numpy as np
|
|
import faiss
|
|
|
|
from faiss.contrib import datasets
|
|
from faiss.contrib.inspect_tools import get_invlist
|
|
|
|
# the tests tend to timeout in stress modes + dev otherwise
|
|
faiss.omp_set_num_threads(4)
|
|
|
|
class TestLUTQuantization(unittest.TestCase):
|
|
|
|
def compute_dis_float(self, codes, LUT, bias):
|
|
nprobe, nt, M = codes.shape
|
|
dis = np.zeros((nprobe, nt), dtype='float32')
|
|
if bias is not None:
|
|
dis[:] = bias.reshape(-1, 1)
|
|
|
|
if LUT.ndim == 2:
|
|
LUTp = LUT
|
|
|
|
for p in range(nprobe):
|
|
if LUT.ndim == 3:
|
|
LUTp = LUT[p]
|
|
|
|
for i in range(nt):
|
|
dis[p, i] += LUTp[np.arange(M), codes[p, i]].sum()
|
|
|
|
return dis
|
|
|
|
def compute_dis_quant(self, codes, LUT, bias, a, b):
|
|
nprobe, nt, M = codes.shape
|
|
dis = np.zeros((nprobe, nt), dtype='uint16')
|
|
if bias is not None:
|
|
dis[:] = bias.reshape(-1, 1)
|
|
|
|
if LUT.ndim == 2:
|
|
LUTp = LUT
|
|
|
|
for p in range(nprobe):
|
|
if LUT.ndim == 3:
|
|
LUTp = LUT[p]
|
|
|
|
for i in range(nt):
|
|
dis[p, i] += LUTp[np.arange(M), codes[p, i]].astype('uint16').sum()
|
|
|
|
return dis / a + b
|
|
|
|
def do_test(self, LUT, bias, nprobe, alt_3d=False):
|
|
M, ksub = LUT.shape[-2:]
|
|
nt = 200
|
|
|
|
rs = np.random.RandomState(123)
|
|
codes = rs.randint(ksub, size=(nprobe, nt, M)).astype('uint8')
|
|
|
|
dis_ref = self.compute_dis_float(codes, LUT, bias)
|
|
|
|
LUTq = np.zeros(LUT.shape, dtype='uint8')
|
|
biasq = (
|
|
np.zeros(bias.shape, dtype='uint16')
|
|
if (bias is not None) and not alt_3d else None
|
|
)
|
|
atab = np.zeros(1, dtype='float32')
|
|
btab = np.zeros(1, dtype='float32')
|
|
|
|
def sp(x):
|
|
return faiss.swig_ptr(x) if x is not None else None
|
|
|
|
faiss.quantize_LUT_and_bias(
|
|
nprobe, M, ksub, LUT.ndim == 3,
|
|
sp(LUT), sp(bias), sp(LUTq), M, sp(biasq),
|
|
sp(atab), sp(btab)
|
|
)
|
|
a = atab[0]
|
|
b = btab[0]
|
|
dis_new = self.compute_dis_quant(codes, LUTq, biasq, a, b)
|
|
|
|
avg_realtive_error = np.abs(dis_new - dis_ref).sum() / dis_ref.sum()
|
|
self.assertLess(avg_realtive_error, 0.0005)
|
|
|
|
def test_no_residual_ip(self):
|
|
ksub = 16
|
|
M = 20
|
|
nprobe = 10
|
|
rs = np.random.RandomState(1234)
|
|
LUT = rs.rand(M, ksub).astype('float32')
|
|
bias = None
|
|
|
|
self.do_test(LUT, bias, nprobe)
|
|
|
|
def test_by_residual_ip(self):
|
|
ksub = 16
|
|
M = 20
|
|
nprobe = 10
|
|
rs = np.random.RandomState(1234)
|
|
LUT = rs.rand(M, ksub).astype('float32')
|
|
bias = rs.rand(nprobe).astype('float32')
|
|
bias *= 10
|
|
|
|
self.do_test(LUT, bias, nprobe)
|
|
|
|
def test_by_residual_L2(self):
|
|
ksub = 16
|
|
M = 20
|
|
nprobe = 10
|
|
rs = np.random.RandomState(1234)
|
|
LUT = rs.rand(nprobe, M, ksub).astype('float32')
|
|
bias = rs.rand(nprobe).astype('float32')
|
|
bias *= 10
|
|
|
|
self.do_test(LUT, bias, nprobe)
|
|
|
|
def test_by_residual_L2_v2(self):
|
|
ksub = 16
|
|
M = 20
|
|
nprobe = 10
|
|
rs = np.random.RandomState(1234)
|
|
LUT = rs.rand(nprobe, M, ksub).astype('float32')
|
|
bias = rs.rand(nprobe).astype('float32')
|
|
bias *= 10
|
|
|
|
self.do_test(LUT, bias, nprobe, alt_3d=True)
|
|
|
|
|
|
##########################################################
|
|
# Tests for various IndexPQFastScan implementations
|
|
##########################################################
|
|
|
|
def verify_with_draws(testcase, Dref, Iref, Dnew, Inew):
|
|
""" verify a list of results where there are draws in the distances (because
|
|
they are integer). """
|
|
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
|
|
# here we have to be careful because of draws
|
|
for i in range(len(Iref)):
|
|
if np.all(Iref[i] == Inew[i]): # easy case
|
|
continue
|
|
# we can deduce nothing about the latest line
|
|
skip_dis = Dref[i, -1]
|
|
for dis in np.unique(Dref):
|
|
if dis == skip_dis: continue
|
|
mask = Dref[i, :] == dis
|
|
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
|
|
|
|
def three_metrics(Dref, Iref, Dnew, Inew):
|
|
nq = Iref.shape[0]
|
|
recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq
|
|
recall_at_10 = (Iref[:, :1] == Inew[:, :10]).sum() / nq
|
|
ninter = 0
|
|
for i in range(nq):
|
|
ninter += len(np.intersect1d(Inew[i], Iref[i]))
|
|
intersection_at_10 = ninter / nq
|
|
return recall_at_1, recall_at_10, intersection_at_10
|
|
|
|
|
|
##########################################################
|
|
# Tests for various IndexIVFPQFastScan implementations
|
|
##########################################################
|
|
|
|
class TestIVFImplem1(unittest.TestCase):
|
|
""" Verify implem 1 (search from original invlists)
|
|
against IndexIVFPQ """
|
|
|
|
def do_test(self, by_residual, metric_type=faiss.METRIC_L2,
|
|
use_precomputed_table=0):
|
|
ds = datasets.SyntheticDataset(32, 2000, 5000, 1000)
|
|
|
|
index = faiss.index_factory(32, "IVF32,PQ16x4np", metric_type)
|
|
index.use_precomputed_table
|
|
index.use_precomputed_table = use_precomputed_table
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 4
|
|
index.by_residual = by_residual
|
|
Da, Ia = index.search(ds.get_queries(), 10)
|
|
|
|
index2 = faiss.IndexIVFPQFastScan(index)
|
|
index2.implem = 1
|
|
Db, Ib = index2.search(ds.get_queries(), 10)
|
|
# self.assertLess((Ia != Ib).sum(), Ia.size * 0.005)
|
|
np.testing.assert_array_equal(Ia, Ib)
|
|
np.testing.assert_almost_equal(Da, Db, decimal=5)
|
|
|
|
def test_no_residual(self):
|
|
self.do_test(False)
|
|
|
|
def test_by_residual(self):
|
|
self.do_test(True)
|
|
|
|
def test_by_residual_no_precomputed(self):
|
|
self.do_test(True, use_precomputed_table=-1)
|
|
|
|
def test_no_residual_ip(self):
|
|
self.do_test(False, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_by_residual_ip(self):
|
|
self.do_test(True, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
class TestIVFImplem2(unittest.TestCase):
|
|
""" Verify implem 2 (search with original invlists with uint8 LUTs)
|
|
against IndexIVFPQ. Entails some loss in accuracy. """
|
|
|
|
def eval_quant_loss(self, by_residual, metric=faiss.METRIC_L2):
|
|
ds = datasets.SyntheticDataset(32, 2000, 5000, 1000)
|
|
|
|
index = faiss.index_factory(32, "IVF32,PQ16x4np", metric)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 4
|
|
index.by_residual = by_residual
|
|
Da, Ia = index.search(ds.get_queries(), 10)
|
|
|
|
# loss due to int8 quantization of LUTs
|
|
index2 = faiss.IndexIVFPQFastScan(index)
|
|
index2.implem = 2
|
|
Db, Ib = index2.search(ds.get_queries(), 10)
|
|
|
|
m3 = three_metrics(Da, Ia, Db, Ib)
|
|
|
|
ref_results = {
|
|
(True, 1): [0.985, 1.0, 9.872],
|
|
(True, 0): [ 0.987, 1.0, 9.914],
|
|
(False, 1): [0.991, 1.0, 9.907],
|
|
(False, 0): [0.986, 1.0, 9.917],
|
|
}
|
|
|
|
ref = ref_results[(by_residual, metric)]
|
|
|
|
self.assertGreaterEqual(m3[0], ref[0] * 0.995)
|
|
self.assertGreaterEqual(m3[1], ref[1] * 0.995)
|
|
self.assertGreaterEqual(m3[2], ref[2] * 0.995)
|
|
|
|
|
|
def test_qloss_no_residual(self):
|
|
self.eval_quant_loss(False)
|
|
|
|
def test_qloss_by_residual(self):
|
|
self.eval_quant_loss(True)
|
|
|
|
def test_qloss_no_residual_ip(self):
|
|
self.eval_quant_loss(False, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_qloss_by_residual_ip(self):
|
|
self.eval_quant_loss(True, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
class TestEquivPQ(unittest.TestCase):
|
|
|
|
def test_equiv_pq(self):
|
|
ds = datasets.SyntheticDataset(32, 2000, 200, 4)
|
|
xq = ds.get_queries()
|
|
|
|
index = faiss.index_factory(32, "IVF1,PQ16x4np")
|
|
index.by_residual = False
|
|
# force coarse quantizer
|
|
index.quantizer.add(np.zeros((1, 32), dtype='float32'))
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
Dref, Iref = index.search(xq, 4)
|
|
|
|
index_pq = faiss.index_factory(32, "PQ16x4np")
|
|
index_pq.pq = index.pq
|
|
index_pq.is_trained = True
|
|
index_pq.codes = faiss. downcast_InvertedLists(
|
|
index.invlists).codes.at(0)
|
|
index_pq.ntotal = index.ntotal
|
|
Dnew, Inew = index_pq.search(xq, 4)
|
|
|
|
np.testing.assert_array_equal(Iref, Inew)
|
|
np.testing.assert_array_equal(Dref, Dnew)
|
|
|
|
index_pq2 = faiss.IndexPQFastScan(index_pq)
|
|
index_pq2.implem = 12
|
|
Dref, Iref = index_pq2.search(xq, 4)
|
|
|
|
index2 = faiss.IndexIVFPQFastScan(index)
|
|
index2.implem = 12
|
|
Dnew, Inew = index2.search(xq, 4)
|
|
np.testing.assert_array_equal(Iref, Inew)
|
|
np.testing.assert_array_equal(Dref, Dnew)
|
|
|
|
# test encode and decode
|
|
|
|
np.testing.assert_array_equal(
|
|
index_pq.sa_encode(xq),
|
|
index2.sa_encode(xq)
|
|
)
|
|
|
|
np.testing.assert_array_equal(
|
|
index_pq.sa_decode(index_pq.sa_encode(xq)),
|
|
index2.sa_decode(index2.sa_encode(xq))
|
|
)
|
|
|
|
np.testing.assert_array_equal(
|
|
((index_pq.sa_decode(index_pq.sa_encode(xq)) - xq) ** 2).sum(1),
|
|
((index2.sa_decode(index2.sa_encode(xq)) - xq) ** 2).sum(1)
|
|
)
|
|
|
|
def test_equiv_pq_encode_decode(self):
|
|
ds = datasets.SyntheticDataset(32, 1000, 200, 10)
|
|
xq = ds.get_queries()
|
|
|
|
index_ivfpq = faiss.index_factory(ds.d, "IVF10,PQ8x4np")
|
|
index_ivfpq.train(ds.get_train())
|
|
|
|
index_ivfpqfs = faiss.IndexIVFPQFastScan(index_ivfpq)
|
|
|
|
np.testing.assert_array_equal(
|
|
index_ivfpq.sa_encode(xq),
|
|
index_ivfpqfs.sa_encode(xq)
|
|
)
|
|
|
|
np.testing.assert_array_equal(
|
|
index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)),
|
|
index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq))
|
|
)
|
|
|
|
np.testing.assert_array_equal(
|
|
((index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)) - xq) ** 2)
|
|
.sum(1),
|
|
((index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq)) - xq) ** 2)
|
|
.sum(1)
|
|
)
|
|
|
|
|
|
class TestIVFImplem12(unittest.TestCase):
|
|
|
|
IMPLEM = 12
|
|
|
|
def do_test(self, by_residual, metric=faiss.METRIC_L2, d=32, nq=200):
|
|
ds = datasets.SyntheticDataset(d, 2000, 5000, nq)
|
|
|
|
index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric)
|
|
# force coarse quantizer
|
|
# index.quantizer.add(np.zeros((1, 32), dtype='float32'))
|
|
index.by_residual = by_residual
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 4
|
|
|
|
# compare against implem = 2, which includes quantized LUTs
|
|
index2 = faiss.IndexIVFPQFastScan(index)
|
|
index2.implem = 2
|
|
Dref, Iref = index2.search(ds.get_queries(), 4)
|
|
index2 = faiss.IndexIVFPQFastScan(index)
|
|
index2.implem = self.IMPLEM
|
|
Dnew, Inew = index2.search(ds.get_queries(), 4)
|
|
|
|
verify_with_draws(self, Dref, Iref, Dnew, Inew)
|
|
|
|
stats = faiss.cvar.indexIVF_stats
|
|
stats.reset()
|
|
|
|
# also verify with single result
|
|
Dnew, Inew = index2.search(ds.get_queries(), 1)
|
|
for q in range(len(Dref)):
|
|
if Dref[q, 1] == Dref[q, 0]:
|
|
# then we cannot conclude
|
|
continue
|
|
self.assertEqual(Iref[q, 0], Inew[q, 0])
|
|
np.testing.assert_almost_equal(Dref[q, 0], Dnew[q, 0], decimal=5)
|
|
|
|
self.assertGreater(stats.ndis, 0)
|
|
|
|
def test_no_residual(self):
|
|
self.do_test(False)
|
|
|
|
def test_by_residual(self):
|
|
self.do_test(True)
|
|
|
|
def test_no_residual_ip(self):
|
|
self.do_test(False, metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_by_residual_ip(self):
|
|
self.do_test(True, metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_no_residual_odd_dim(self):
|
|
self.do_test(False, d=30)
|
|
|
|
def test_by_residual_odd_dim(self):
|
|
self.do_test(True, d=30)
|
|
|
|
# testin single query
|
|
def test_no_residual_single_query(self):
|
|
self.do_test(False, nq=1)
|
|
|
|
def test_by_residual_single_query(self):
|
|
self.do_test(True, nq=1)
|
|
|
|
def test_no_residual_ip_single_query(self):
|
|
self.do_test(False, metric=faiss.METRIC_INNER_PRODUCT, nq=1)
|
|
|
|
def test_by_residual_ip_single_query(self):
|
|
self.do_test(True, metric=faiss.METRIC_INNER_PRODUCT, nq=1)
|
|
|
|
def test_no_residual_odd_dim_single_query(self):
|
|
self.do_test(False, d=30, nq=1)
|
|
|
|
def test_by_residual_odd_dim_single_query(self):
|
|
self.do_test(True, d=30, nq=1)
|
|
|
|
|
|
class TestIVFImplem10(TestIVFImplem12):
|
|
IMPLEM = 10
|
|
|
|
|
|
class TestIVFImplem11(TestIVFImplem12):
|
|
IMPLEM = 11
|
|
|
|
|
|
class TestIVFImplem13(TestIVFImplem12):
|
|
IMPLEM = 13
|
|
|
|
|
|
class TestIVFImplem14(TestIVFImplem12):
|
|
IMPLEM = 14
|
|
|
|
|
|
class TestIVFImplem15(TestIVFImplem12):
|
|
IMPLEM = 15
|
|
|
|
|
|
class TestAdd(unittest.TestCase):
|
|
|
|
def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
|
|
bbs = 32
|
|
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
|
|
|
|
index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric)
|
|
index.by_residual = by_residual
|
|
index.train(ds.get_train())
|
|
index.nprobe = 4
|
|
|
|
xb = ds.get_database()
|
|
index.add(xb[:1235])
|
|
|
|
index2 = faiss.IndexIVFPQFastScan(index, bbs)
|
|
|
|
index.add(xb[1235:])
|
|
index3 = faiss.IndexIVFPQFastScan(index, bbs)
|
|
Dref, Iref = index3.search(ds.get_queries(), 10)
|
|
|
|
index2.add(xb[1235:])
|
|
Dnew, Inew = index2.search(ds.get_queries(), 10)
|
|
|
|
np.testing.assert_array_equal(Dref, Dnew)
|
|
np.testing.assert_array_equal(Iref, Inew)
|
|
|
|
# direct verification of code content. Not sure the test is correct
|
|
# if codes are shuffled.
|
|
for list_no in range(32):
|
|
ref_ids, ref_codes = get_invlist(index3.invlists, list_no)
|
|
new_ids, new_codes = get_invlist(index2.invlists, list_no)
|
|
self.assertEqual(set(ref_ids), set(new_ids))
|
|
new_code_per_id = {
|
|
new_ids[i]: new_codes[i // bbs, :, i % bbs]
|
|
for i in range(new_ids.size)
|
|
}
|
|
for i, the_id in enumerate(ref_ids):
|
|
ref_code_i = ref_codes[i // bbs, :, i % bbs]
|
|
new_code_i = new_code_per_id[the_id]
|
|
np.testing.assert_array_equal(ref_code_i, new_code_i)
|
|
|
|
def test_add(self):
|
|
self.do_test()
|
|
|
|
def test_odd_d(self):
|
|
self.do_test(d=30)
|
|
|
|
def test_bbs64(self):
|
|
self.do_test(bbs=64)
|
|
|
|
|
|
class TestTraining(unittest.TestCase):
|
|
|
|
def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32):
|
|
bbs = 32
|
|
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
|
|
|
|
index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric)
|
|
index.by_residual = by_residual
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 4
|
|
Dref, Iref = index.search(ds.get_queries(), 10)
|
|
|
|
index2 = faiss.IndexIVFPQFastScan(
|
|
index.quantizer, d, 32, d // 2, 4, metric, bbs)
|
|
index2.by_residual = by_residual
|
|
index2.train(ds.get_train())
|
|
|
|
index2.add(ds.get_database())
|
|
index2.nprobe = 4
|
|
Dnew, Inew = index2.search(ds.get_queries(), 10)
|
|
|
|
m3 = three_metrics(Dref, Iref, Dnew, Inew)
|
|
ref_m3_tab = {
|
|
(True, 1, 32): (0.995, 1.0, 9.91),
|
|
(True, 0, 32): (0.99, 1.0, 9.91),
|
|
(True, 1, 30): (0.989, 1.0, 9.885),
|
|
(False, 1, 32): (0.99, 1.0, 9.875),
|
|
(False, 0, 32): (0.99, 1.0, 9.92),
|
|
(False, 1, 30): (1.0, 1.0, 9.895)
|
|
}
|
|
ref_m3 = ref_m3_tab[(by_residual, metric, d)]
|
|
self.assertGreaterEqual(m3[0], ref_m3[0] * 0.99)
|
|
self.assertGreater(m3[1], ref_m3[1] * 0.99)
|
|
self.assertGreater(m3[2], ref_m3[2] * 0.99)
|
|
|
|
# Test I/O
|
|
data = faiss.serialize_index(index2)
|
|
index3 = faiss.deserialize_index(data)
|
|
D3, I3 = index3.search(ds.get_queries(), 10)
|
|
|
|
np.testing.assert_array_equal(I3, Inew)
|
|
np.testing.assert_array_equal(D3, Dnew)
|
|
|
|
def test_no_residual(self):
|
|
self.do_test(by_residual=False)
|
|
|
|
def test_by_residual(self):
|
|
self.do_test(by_residual=True)
|
|
|
|
def test_no_residual_ip(self):
|
|
self.do_test(by_residual=False, metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_by_residual_ip(self):
|
|
self.do_test(by_residual=True, metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_no_residual_odd_dim(self):
|
|
self.do_test(by_residual=False, d=30)
|
|
|
|
def test_by_residual_odd_dim(self):
|
|
self.do_test(by_residual=True, d=30)
|
|
|
|
|
|
class TestReconstruct(unittest.TestCase):
|
|
""" test reconstruct and sa_encode / sa_decode
|
|
(also for a few additive quantizer variants) """
|
|
|
|
def do_test(self, by_residual=False):
|
|
d = 32
|
|
metric = faiss.METRIC_L2
|
|
|
|
ds = datasets.SyntheticDataset(d, 250, 200, 10)
|
|
|
|
index = faiss.IndexIVFPQFastScan(
|
|
faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric)
|
|
index.by_residual = by_residual
|
|
index.make_direct_map(True)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
|
|
# Test reconstruction
|
|
v123 = index.reconstruct(123) # single id
|
|
v120_10 = index.reconstruct_n(120, 10)
|
|
np.testing.assert_array_equal(v120_10[3], v123)
|
|
v120_10 = index.reconstruct_batch(np.arange(120, 130))
|
|
np.testing.assert_array_equal(v120_10[3], v123)
|
|
|
|
# Test original list reconstruction
|
|
index.orig_invlists = faiss.ArrayInvertedLists(
|
|
index.nlist, index.code_size)
|
|
index.reconstruct_orig_invlists()
|
|
assert index.orig_invlists.compute_ntotal() == index.ntotal
|
|
|
|
# compare with non fast-scan index
|
|
index2 = faiss.IndexIVFPQ(
|
|
index.quantizer, d, 50, d // 2, 4, metric)
|
|
index2.by_residual = by_residual
|
|
index2.pq = index.pq
|
|
index2.is_trained = True
|
|
index2.replace_invlists(index.orig_invlists, False)
|
|
index2.ntotal = index.ntotal
|
|
index2.make_direct_map(True)
|
|
assert np.all(index.reconstruct(123) == index2.reconstruct(123))
|
|
|
|
def test_no_residual(self):
|
|
self.do_test(by_residual=False)
|
|
|
|
def test_by_residual(self):
|
|
self.do_test(by_residual=True)
|
|
|
|
def do_test_generic(self, factory_string,
|
|
by_residual=False, metric=faiss.METRIC_L2):
|
|
d = 32
|
|
ds = datasets.SyntheticDataset(d, 250, 200, 10)
|
|
index = faiss.index_factory(ds.d, factory_string, metric)
|
|
if "IVF" in factory_string:
|
|
index.by_residual = by_residual
|
|
index.make_direct_map(True)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
|
|
# Test reconstruction
|
|
v123 = index.reconstruct(123) # single id
|
|
v120_10 = index.reconstruct_n(120, 10)
|
|
np.testing.assert_array_equal(v120_10[3], v123)
|
|
v120_10 = index.reconstruct_batch(np.arange(120, 130))
|
|
np.testing.assert_array_equal(v120_10[3], v123)
|
|
codes = index.sa_encode(ds.get_database()[120:130])
|
|
np.testing.assert_array_equal(index.sa_decode(codes), v120_10)
|
|
|
|
# make sure pointers are correct after serialization
|
|
index2 = faiss.deserialize_index(faiss.serialize_index(index))
|
|
codes2 = index2.sa_encode(ds.get_database()[120:130])
|
|
np.testing.assert_array_equal(codes, codes2)
|
|
|
|
|
|
def test_ivfpq_residual(self):
|
|
self.do_test_generic("IVF20,PQ16x4fs", by_residual=True)
|
|
|
|
def test_ivfpq_no_residual(self):
|
|
self.do_test_generic("IVF20,PQ16x4fs", by_residual=False)
|
|
|
|
def test_pq(self):
|
|
self.do_test_generic("PQ16x4fs")
|
|
|
|
def test_rq(self):
|
|
self.do_test_generic("RQ4x4fs", metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_ivfprq(self):
|
|
self.do_test_generic("IVF20,PRQ8x2x4fs", by_residual=True, metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_ivfprq_no_residual(self):
|
|
self.do_test_generic("IVF20,PRQ8x2x4fs", by_residual=False, metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_prq(self):
|
|
self.do_test_generic("PRQ8x2x4fs", metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
class TestIsTrained(unittest.TestCase):
|
|
|
|
def test_issue_2019(self):
|
|
index = faiss.index_factory(
|
|
32,
|
|
"PCAR16,IVF200(IVF10,PQ2x4fs,RFlat),PQ4x4fsr"
|
|
)
|
|
des = faiss.rand((1000, 32))
|
|
index.train(des)
|
|
|
|
|
|
class TestIVFAQFastScan(unittest.TestCase):
|
|
|
|
def subtest_accuracy(self, aq, st, by_residual, implem, metric_type='L2'):
|
|
"""
|
|
Compare IndexIVFAdditiveQuantizerFastScan with
|
|
IndexIVFAdditiveQuantizer
|
|
"""
|
|
nlist, d = 16, 8
|
|
ds = datasets.SyntheticDataset(d, 1000, 1000, 500, metric_type)
|
|
gt = ds.get_groundtruth(k=1)
|
|
|
|
if metric_type == 'L2':
|
|
metric = faiss.METRIC_L2
|
|
postfix1 = '_Nqint8'
|
|
postfix2 = f'_N{st}2x4'
|
|
else:
|
|
metric = faiss.METRIC_INNER_PRODUCT
|
|
postfix1 = postfix2 = ''
|
|
|
|
index = faiss.index_factory(d, f'IVF{nlist},{aq}3x4{postfix1}', metric)
|
|
index.by_residual = by_residual
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 16
|
|
Dref, Iref = index.search(ds.get_queries(), 1)
|
|
|
|
indexfs = faiss.index_factory(
|
|
d, f'IVF{nlist},{aq}3x4fs_32{postfix2}', metric)
|
|
indexfs.by_residual = by_residual
|
|
indexfs.train(ds.get_train())
|
|
indexfs.add(ds.get_database())
|
|
indexfs.nprobe = 16
|
|
indexfs.implem = implem
|
|
D1, I1 = indexfs.search(ds.get_queries(), 1)
|
|
|
|
nq = Iref.shape[0]
|
|
recall_ref = (Iref == gt).sum() / nq
|
|
recall1 = (I1 == gt).sum() / nq
|
|
|
|
assert abs(recall_ref - recall1) < 0.051
|
|
|
|
def xx_test_accuracy(self):
|
|
# generated programatically below
|
|
for metric in 'L2', 'IP':
|
|
for byr in True, False:
|
|
for implem in 0, 10, 11, 12, 13, 14, 15:
|
|
self.subtest_accuracy('RQ', 'rq', byr, implem, metric)
|
|
self.subtest_accuracy('LSQ', 'lsq', byr, implem, metric)
|
|
|
|
def subtest_rescale_accuracy(self, aq, st, by_residual, implem):
|
|
"""
|
|
we set norm_scale to 2 and compare it with IndexIVFAQ
|
|
"""
|
|
nlist, d = 16, 8
|
|
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
|
|
gt = ds.get_groundtruth(k=1)
|
|
|
|
metric = faiss.METRIC_L2
|
|
postfix1 = '_Nqint8'
|
|
postfix2 = f'_N{st}2x4'
|
|
|
|
index = faiss.index_factory(
|
|
d, f'IVF{nlist},{aq}3x4{postfix1}', metric)
|
|
index.by_residual = by_residual
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 16
|
|
Dref, Iref = index.search(ds.get_queries(), 1)
|
|
|
|
indexfs = faiss.index_factory(
|
|
d, f'IVF{nlist},{aq}3x4fs_32{postfix2}', metric)
|
|
indexfs.by_residual = by_residual
|
|
indexfs.norm_scale = 2
|
|
indexfs.train(ds.get_train())
|
|
indexfs.add(ds.get_database())
|
|
indexfs.nprobe = 16
|
|
indexfs.implem = implem
|
|
D1, I1 = indexfs.search(ds.get_queries(), 1)
|
|
|
|
nq = Iref.shape[0]
|
|
recall_ref = (Iref == gt).sum() / nq
|
|
recall1 = (I1 == gt).sum() / nq
|
|
|
|
assert abs(recall_ref - recall1) < 0.05
|
|
|
|
def xx_test_rescale_accuracy(self):
|
|
for byr in True, False:
|
|
for implem in 0, 10, 11, 12, 13, 14, 15:
|
|
self.subtest_accuracy('RQ', 'rq', byr, implem, 'L2')
|
|
self.subtest_accuracy('LSQ', 'lsq', byr, implem, 'L2')
|
|
|
|
def subtest_from_ivfaq(self, implem):
|
|
d = 8
|
|
ds = datasets.SyntheticDataset(d, 1000, 2000, 1000, metric='IP')
|
|
gt = ds.get_groundtruth(k=1)
|
|
index = faiss.index_factory(d, 'IVF16,RQ8x4', faiss.METRIC_INNER_PRODUCT)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 16
|
|
Dref, Iref = index.search(ds.get_queries(), 1)
|
|
|
|
indexfs = faiss.IndexIVFAdditiveQuantizerFastScan(index)
|
|
D1, I1 = indexfs.search(ds.get_queries(), 1)
|
|
|
|
nq = Iref.shape[0]
|
|
recall_ref = (Iref == gt).sum() / nq
|
|
recall1 = (I1 == gt).sum() / nq
|
|
assert abs(recall_ref - recall1) < 0.02
|
|
|
|
def test_from_ivfaq(self):
|
|
for implem in 0, 1, 2:
|
|
self.subtest_from_ivfaq(implem)
|
|
|
|
def subtest_factory(self, aq, M, bbs, st, r='r'):
|
|
"""
|
|
Format: IVF{nlist},{AQ}{M}x4fs{r}_{bbs}_N{st}
|
|
|
|
nlist (int): number of inverted lists
|
|
AQ (str): `LSQ` or `RQ`
|
|
M (int): number of sub-quantizers
|
|
bbs (int): build block size
|
|
st (str): search type, `lsq2x4` or `rq2x4`
|
|
r (str): `r` or ``, by_residual or not
|
|
"""
|
|
AQ = faiss.AdditiveQuantizer
|
|
nlist, d = 128, 16
|
|
|
|
if bbs > 0:
|
|
index = faiss.index_factory(
|
|
d, f'IVF{nlist},{aq}{M}x4fs{r}_{bbs}_N{st}2x4')
|
|
else:
|
|
index = faiss.index_factory(
|
|
d, f'IVF{nlist},{aq}{M}x4fs{r}_N{st}2x4')
|
|
bbs = 32
|
|
|
|
assert index.nlist == nlist
|
|
assert index.bbs == bbs
|
|
q = faiss.downcast_Quantizer(index.aq)
|
|
assert q.M == M
|
|
|
|
if aq == 'LSQ':
|
|
assert isinstance(q, faiss.LocalSearchQuantizer)
|
|
if aq == 'RQ':
|
|
assert isinstance(q, faiss.ResidualQuantizer)
|
|
|
|
if st == 'lsq':
|
|
assert q.search_type == AQ.ST_norm_lsq2x4
|
|
if st == 'rq':
|
|
assert q.search_type == AQ.ST_norm_rq2x4
|
|
|
|
assert index.by_residual == (r == 'r')
|
|
|
|
def test_factory(self):
|
|
self.subtest_factory('LSQ', 16, 64, 'lsq')
|
|
self.subtest_factory('LSQ', 16, 64, 'rq')
|
|
self.subtest_factory('RQ', 16, 64, 'rq')
|
|
self.subtest_factory('RQ', 16, 64, 'lsq')
|
|
self.subtest_factory('LSQ', 64, 0, 'lsq')
|
|
|
|
self.subtest_factory('LSQ', 64, 0, 'lsq', r='')
|
|
|
|
def subtest_io(self, factory_str):
|
|
d = 8
|
|
ds = datasets.SyntheticDataset(d, 1000, 2000, 1000)
|
|
|
|
index = faiss.index_factory(d, factory_str)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
D1, I1 = index.search(ds.get_queries(), 1)
|
|
|
|
fd, fname = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index(index, fname)
|
|
index2 = faiss.read_index(fname)
|
|
D2, I2 = index2.search(ds.get_queries(), 1)
|
|
np.testing.assert_array_equal(I1, I2)
|
|
finally:
|
|
if os.path.exists(fname):
|
|
os.unlink(fname)
|
|
|
|
def test_io(self):
|
|
self.subtest_io('IVF16,LSQ4x4fs_Nlsq2x4')
|
|
self.subtest_io('IVF16,LSQ4x4fs_Nrq2x4')
|
|
self.subtest_io('IVF16,RQ4x4fs_Nrq2x4')
|
|
self.subtest_io('IVF16,RQ4x4fs_Nlsq2x4')
|
|
|
|
|
|
# add more tests programatically
|
|
|
|
def add_TestIVFAQFastScan_subtest_accuracy(
|
|
aq, st, by_residual, implem, metric='L2'):
|
|
setattr(
|
|
TestIVFAQFastScan,
|
|
f"test_accuracy_{metric}_{aq}_implem{implem}_residual{by_residual}",
|
|
lambda self:
|
|
self.subtest_accuracy(aq, st, by_residual, implem, metric)
|
|
)
|
|
|
|
|
|
def add_TestIVFAQFastScan_subtest_rescale_accuracy(aq, st, by_residual, implem):
|
|
setattr(
|
|
TestIVFAQFastScan,
|
|
f"test_rescale_accuracy_{aq}_implem{implem}_residual{by_residual}",
|
|
lambda self:
|
|
self.subtest_rescale_accuracy(aq, st, by_residual, implem)
|
|
)
|
|
|
|
for byr in True, False:
|
|
for implem in 0, 10, 11, 12, 13, 14, 15:
|
|
for mt in 'L2', 'IP':
|
|
add_TestIVFAQFastScan_subtest_accuracy('RQ', 'rq', byr, implem, mt)
|
|
add_TestIVFAQFastScan_subtest_accuracy('LSQ', 'lsq', byr, implem, mt)
|
|
|
|
add_TestIVFAQFastScan_subtest_rescale_accuracy('LSQ', 'lsq', byr, implem)
|
|
add_TestIVFAQFastScan_subtest_rescale_accuracy('RQ', 'rq', byr, implem)
|
|
|
|
|
|
class TestIVFPAQFastScan(unittest.TestCase):
|
|
|
|
def subtest_accuracy(self, paq):
|
|
"""
|
|
Compare IndexIVFAdditiveQuantizerFastScan with
|
|
IndexIVFAdditiveQuantizer
|
|
"""
|
|
nlist, d = 16, 8
|
|
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
|
|
gt = ds.get_groundtruth(k=1)
|
|
|
|
index = faiss.index_factory(d, f'IVF{nlist},{paq}2x3x4_Nqint8')
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 4
|
|
Dref, Iref = index.search(ds.get_queries(), 1)
|
|
|
|
indexfs = faiss.index_factory(d, f'IVF{nlist},{paq}2x3x4fsr_Nlsq2x4')
|
|
indexfs.train(ds.get_train())
|
|
indexfs.add(ds.get_database())
|
|
indexfs.nprobe = 4
|
|
D1, I1 = indexfs.search(ds.get_queries(), 1)
|
|
|
|
nq = Iref.shape[0]
|
|
recall_ref = (Iref == gt).sum() / nq
|
|
recall1 = (I1 == gt).sum() / nq
|
|
|
|
assert abs(recall_ref - recall1) < 0.05
|
|
|
|
def test_accuracy_PLSQ(self):
|
|
self.subtest_accuracy("PLSQ")
|
|
|
|
def test_accuracy_PRQ(self):
|
|
self.subtest_accuracy("PRQ")
|
|
|
|
def subtest_factory(self, paq):
|
|
nlist, d = 128, 16
|
|
index = faiss.index_factory(d, f'IVF{nlist},{paq}2x3x4fsr_Nlsq2x4')
|
|
q = faiss.downcast_Quantizer(index.aq)
|
|
|
|
self.assertEqual(index.nlist, nlist)
|
|
self.assertEqual(q.nsplits, 2)
|
|
self.assertEqual(q.subquantizer(0).M, 3)
|
|
self.assertTrue(index.by_residual)
|
|
|
|
def test_factory(self):
|
|
self.subtest_factory('PLSQ')
|
|
self.subtest_factory('PRQ')
|
|
|
|
def subtest_io(self, factory_str):
|
|
d = 8
|
|
ds = datasets.SyntheticDataset(d, 1000, 2000, 1000)
|
|
|
|
index = faiss.index_factory(d, factory_str)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
D1, I1 = index.search(ds.get_queries(), 1)
|
|
|
|
fd, fname = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index(index, fname)
|
|
index2 = faiss.read_index(fname)
|
|
D2, I2 = index2.search(ds.get_queries(), 1)
|
|
np.testing.assert_array_equal(I1, I2)
|
|
finally:
|
|
if os.path.exists(fname):
|
|
os.unlink(fname)
|
|
|
|
def test_io(self):
|
|
self.subtest_io('IVF16,PLSQ2x3x4fsr_Nlsq2x4')
|
|
self.subtest_io('IVF16,PRQ2x3x4fs_Nrq2x4')
|
|
|
|
|
|
class TestSearchParams(unittest.TestCase):
|
|
|
|
def test_search_params(self):
|
|
ds = datasets.SyntheticDataset(32, 500, 100, 10)
|
|
|
|
index = faiss.index_factory(ds.d, "IVF32,PQ16x4fs")
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
|
|
index.nprobe
|
|
index.nprobe = 4
|
|
Dref4, Iref4 = index.search(ds.get_queries(), 10)
|
|
# index.nprobe = 16
|
|
# Dref16, Iref16 = index.search(ds.get_queries(), 10)
|
|
|
|
index.nprobe = 1
|
|
Dnew4, Inew4 = index.search(
|
|
ds.get_queries(), 10, params=faiss.IVFSearchParameters(nprobe=4))
|
|
np.testing.assert_array_equal(Dref4, Dnew4)
|
|
np.testing.assert_array_equal(Iref4, Inew4)
|
|
|
|
|
|
class TestRangeSearchImplem12(unittest.TestCase):
|
|
IMPLEM = 12
|
|
|
|
def do_test(self, metric=faiss.METRIC_L2):
|
|
ds = datasets.SyntheticDataset(32, 750, 200, 100)
|
|
|
|
index = faiss.index_factory(ds.d, "IVF32,PQ16x4np", metric)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
index.nprobe = 4
|
|
|
|
# find a reasonable radius
|
|
D, I = index.search(ds.get_queries(), 10)
|
|
radius = np.median(D[:, -1])
|
|
lims1, D1, I1 = index.range_search(ds.get_queries(), radius)
|
|
|
|
index2 = faiss.IndexIVFPQFastScan(index)
|
|
index2.implem = self.IMPLEM
|
|
lims2, D2, I2 = index2.range_search(ds.get_queries(), radius)
|
|
|
|
nmiss = 0
|
|
nextra = 0
|
|
|
|
for i in range(ds.nq):
|
|
ref = set(I1[lims1[i]: lims1[i + 1]])
|
|
new = set(I2[lims2[i]: lims2[i + 1]])
|
|
nmiss += len(ref - new)
|
|
nextra += len(new - ref)
|
|
|
|
# need some tolerance because the look-up tables are quantized
|
|
self.assertLess(nmiss, 10)
|
|
self.assertLess(nextra, 10)
|
|
|
|
def test_L2(self):
|
|
self.do_test()
|
|
|
|
def test_IP(self):
|
|
self.do_test(metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
class TestRangeSearchImplem10(TestRangeSearchImplem12):
|
|
IMPLEM = 10
|
|
|
|
|
|
class TestRangeSearchImplem110(TestRangeSearchImplem12):
|
|
IMPLEM = 110
|