faiss/tests/test_residual_quantizer.py
Matthijs Douze 07a874d5b1 Post-training refinement of residual quantizer codebooks (#2166)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2166

RQ training is done progressively from one quantizer to the next, maintaining a current set of codes and quantization centroids.
However, for RQ as for any additive quantizer, there is a closed form solution for the centroids that minimizes the quantization error for fixed codes.
This diff offers the option to estimate that codebook at the end of the optimization. It performs this estimation iteratively, ie. several rounds of code computation - codebook refinement are performed.

A pure python implementation + results is here:
https://github.com/fairinternal/faiss_improvements/blob/dbcc746/decoder/refine_aq_codebook.ipynb

Reviewed By: wickedfoo

Differential Revision: D33309409

fbshipit-source-id: 55c13425292e73a1b05f00e90f4dcfdc8b3549e8
2022-01-05 00:59:16 -08:00

1100 lines
35 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import faiss
import unittest
from faiss.contrib import datasets
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks
###########################################################
# Reference implementation of encoding with beam search
###########################################################
def pairwise_distances(a, b):
anorms = (a ** 2).sum(1)
bnorms = (b ** 2).sum(1)
return anorms.reshape(-1, 1) + bnorms - 2 * a @ b.T
def beam_search_encode_step_ref(cent, residuals, codes, L):
""" Reference beam search implementation
encodes a residual table.
"""
K, d = cent.shape
n, beam_size, d2 = residuals.shape
assert d == d2
n2, beam_size_2, m = codes.shape
assert n2 == n and beam_size_2 == beam_size
# compute all possible new residuals
cent_distances = pairwise_distances(residuals.reshape(n * beam_size, d), cent)
cent_distances = cent_distances.reshape(n, beam_size, K)
# TODO write in vector form
if beam_size * K <= L:
# then keep all the results
new_beam_size = beam_size * K
new_codes = np.zeros((n, beam_size, K, m + 1), dtype=int)
new_residuals = np.zeros((n, beam_size, K, d), dtype='float32')
for i in range(n):
new_codes[i, :, :, :-1] = codes[i]
new_codes[i, :, :, -1] = np.arange(K)
new_residuals[i] = residuals[i].reshape(1, d) - cent.reshape(K, d)
new_codes = new_codes.reshape(n, new_beam_size, m + 1)
new_residuals = new_residuals.reshape(n, new_beam_size, d)
new_distances = cent_distances.reshape(n, new_beam_size)
else:
# keep top-L results
new_beam_size = L
new_codes = np.zeros((n, L, m + 1), dtype=int)
new_residuals = np.zeros((n, L, d), dtype='float32')
new_distances = np.zeros((n, L), dtype='float32')
for i in range(n):
cd = cent_distances[i].ravel()
jl = np.argsort(cd)[:L] # TODO argpartition
js = jl // K # input beam index
ls = jl % K # centroid index
new_codes[i, :, :-1] = codes[i, js, :]
new_codes[i, :, -1] = ls
new_residuals[i, :, :] = residuals[i, js, :] - cent[ls, :]
new_distances[i, :] = cd[jl]
return new_codes, new_residuals, new_distances
def beam_search_encode_step(cent, residuals, codes, L, assign_index=None):
""" Wrapper of the C++ function with the same interface """
K, d = cent.shape
n, beam_size, d2 = residuals.shape
assert d == d2
n2, beam_size_2, m = codes.shape
assert n2 == n and beam_size_2 == beam_size
assert L <= beam_size * K
new_codes = np.zeros((n, L, m + 1), dtype='int32')
new_residuals = np.zeros((n, L, d), dtype='float32')
new_distances = np.zeros((n, L), dtype='float32')
sp = faiss.swig_ptr
codes = np.ascontiguousarray(codes, dtype='int32')
faiss.beam_search_encode_step(
d, K, sp(cent), n, beam_size, sp(residuals),
m, sp(codes), L, sp(new_codes), sp(new_residuals), sp(new_distances),
assign_index
)
return new_codes, new_residuals, new_distances
def beam_search_encoding_ref(centroids, x, L):
"""
Perform encoding of vectors x with a beam search of size L
"""
n, d = x.shape
beam_size = 1
codes = np.zeros((n, beam_size, 0), dtype=int)
residuals = x.reshape((n, beam_size, d))
distances = (x ** 2).sum(1).reshape(n, beam_size)
for cent in centroids:
codes, residuals, distances = beam_search_encode_step_ref(
cent, residuals, codes, L)
return (codes, residuals, distances)
###########################################################
# Unittests for basic routines
###########################################################
class TestBeamSearch(unittest.TestCase):
def do_test(self, K=70, L=10, use_assign_index=False):
""" compare C++ beam search with reference python implementation """
d = 32
n = 500
L = 10 # beam size
rs = np.random.RandomState(123)
x = rs.rand(n, d).astype('float32')
cent = rs.rand(K, d).astype('float32')
# first quant step --> input beam size is 1
codes = np.zeros((n, 1, 0), dtype=int)
residuals = x.reshape(n, 1, d)
assign_index = faiss.IndexFlatL2(d) if use_assign_index else None
ref_codes, ref_residuals, ref_distances = beam_search_encode_step_ref(
cent, residuals, codes, L
)
new_codes, new_residuals, new_distances = beam_search_encode_step(
cent, residuals, codes, L, assign_index
)
np.testing.assert_array_equal(new_codes, ref_codes)
np.testing.assert_array_equal(new_residuals, ref_residuals)
np.testing.assert_allclose(new_distances, ref_distances, rtol=1e-5)
# second quant step:
K = 50
cent = rs.rand(K, d).astype('float32')
codes, residuals = ref_codes, ref_residuals
ref_codes, ref_residuals, ref_distances = beam_search_encode_step_ref(
cent, residuals, codes, L
)
new_codes, new_residuals, new_distances = beam_search_encode_step(
cent, residuals, codes, L
)
np.testing.assert_array_equal(new_codes, ref_codes)
np.testing.assert_array_equal(new_residuals, ref_residuals)
np.testing.assert_allclose(new_distances, ref_distances, rtol=1e-5)
def test_beam_search(self):
self.do_test()
def test_beam_search_assign_index(self):
self.do_test(use_assign_index=True)
def test_small_beam(self):
self.do_test(L=1)
def test_small_beam_2(self):
self.do_test(L=2)
def eval_codec(q, xb):
codes = q.compute_codes(xb)
decoded = q.decode(codes)
return ((xb - decoded) ** 2).sum()
class TestResidualQuantizer(unittest.TestCase):
def test_training(self):
"""check that the error is in the same ballpark as PQ """
ds = datasets.SyntheticDataset(32, 3000, 1000, 0)
xt = ds.get_train()
xb = ds.get_database()
rq = faiss.ResidualQuantizer(ds.d, 4, 6)
rq.verbose
rq.verbose = True
#
rq.train_type = faiss.ResidualQuantizer.Train_default
rq.cp.verbose
# rq.cp.verbose = True
rq.train(xt)
err_rq = eval_codec(rq, xb)
pq = faiss.ProductQuantizer(ds.d, 4, 6)
pq.train(xt)
err_pq = eval_codec(pq, xb)
# in practice RQ is often better than PQ but it does not the case here, so just check
# that we are within some factor.
print(err_pq, err_rq)
self.assertLess(err_rq, err_pq * 1.2)
def test_beam_size(self):
""" check that a larger beam gives a lower error """
ds = datasets.SyntheticDataset(32, 3000, 1000, 0)
xt = ds.get_train()
xb = ds.get_database()
rq0 = faiss.ResidualQuantizer(ds.d, 4, 6)
rq0.train_type = faiss.ResidualQuantizer.Train_default
rq0.max_beam_size = 2
rq0.train(xt)
err_rq0 = eval_codec(rq0, xb)
rq1 = faiss.ResidualQuantizer(ds.d, 4, 6)
rq1.train_type = faiss.ResidualQuantizer.Train_default
rq1.max_beam_size = 10
rq1.train(xt)
err_rq1 = eval_codec(rq1, xb)
self.assertLess(err_rq1, err_rq0)
def test_training_with_limited_mem(self):
""" make sure a different batch size gives the same result"""
ds = datasets.SyntheticDataset(32, 3000, 1000, 0)
xt = ds.get_train()
rq0 = faiss.ResidualQuantizer(ds.d, 4, 6)
rq0.train_type = faiss.ResidualQuantizer.Train_default
rq0.max_beam_size = 5
# rq0.verbose = True
rq0.train(xt)
cb0 = get_additive_quantizer_codebooks(rq0)
rq1 = faiss.ResidualQuantizer(ds.d, 4, 6)
rq1.train_type = faiss.ResidualQuantizer.Train_default
rq1.max_beam_size = 5
rq1.max_mem_distances
rq1.max_mem_distances = 3000 * ds.d * 4 * 3
# rq1.verbose = True
rq1.train(xt)
cb1 = get_additive_quantizer_codebooks(rq1)
for c0, c1 in zip(cb0, cb1):
self.assertTrue(np.all(c0 == c1))
###########################################################
# Test index, index factory sa_encode / sa_decode
###########################################################
def unpack_codes(rq, packed_codes):
nbits = faiss.vector_to_array(rq.nbits)
if np.all(nbits == 8):
return packed_codes.astype("uint32")
nbits = [int(x) for x in nbits]
nb = len(nbits)
n, code_size = packed_codes.shape
codes = np.zeros((n, nb), dtype="uint32")
for i in range(n):
br = faiss.BitstringReader(faiss.swig_ptr(packed_codes[i]), code_size)
for j, nbi in enumerate(nbits):
codes[i, j] = br.read(nbi)
return codes
def retrain_AQ_codebook(index, xt):
""" reference implementation of codebook retraining """
rq = index.rq
codes_packed = index.sa_encode(xt)
n, code_size = codes_packed.shape
x_decoded = index.sa_decode(codes_packed)
MSE = ((xt - x_decoded) ** 2).sum() / n
print(f"Initial MSE on training set: {MSE:g}")
codes = unpack_codes(index.rq, codes_packed)
print("ref codes", codes[0])
codebook_offsets = faiss.vector_to_array(rq.codebook_offsets)
# build sparse code matrix (represented as a dense matrix)
C = np.zeros((n, rq.total_codebook_size))
for i in range(n):
C[i][codes[i] + codebook_offsets[:-1]] = 1
# import pdb; pdb.set_trace()
# import scipy
# B, residuals, rank, singvals = np.linalg.lstsq(C, xt, rcond=None)
if True:
B, residuals, rank, singvals = np.linalg.lstsq(C, xt, rcond=None)
else:
import scipy.linalg
import pdb; pdb.set_trace()
B, residuals, rank, singvals = scipy.linalg.lstsq(C, xt, )
MSE = ((C @ B - xt) ** 2).sum() / n
print(f"MSE after retrainining: {MSE:g}")
# replace codebook
# faiss.copy_array_to_vector(B.astype('float32').ravel(), index.rq.codebooks)
# update codebook tables
# index.rq.compute_codebook_tables()
return C, B
class TestIndexResidualQuantizer(unittest.TestCase):
def test_io(self):
ds = datasets.SyntheticDataset(32, 1000, 100, 0)
xt = ds.get_train()
xb = ds.get_database()
ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
ir.rq.train_type = faiss.ResidualQuantizer.Train_default
ir.train(xt)
ref_codes = ir.sa_encode(xb)
b = faiss.serialize_index(ir)
ir2 = faiss.deserialize_index(b)
codes2 = ir2.sa_encode(xb)
np.testing.assert_array_equal(ref_codes, codes2)
def test_equiv_rq(self):
"""
make sure it is equivalent to search a RQ and to search an IVF
with RCQ + RQ with the same codebooks.
"""
ds = datasets.SyntheticDataset(32, 3000, 1000, 50)
# make a flat RQ
iflat = faiss.IndexResidualQuantizer(ds.d, 5, 4)
iflat.rq.train_type = faiss.ResidualQuantizer.Train_default
iflat.train(ds.get_train())
iflat.add(ds.get_database())
# ref search result
Dref, Iref = iflat.search(ds.get_queries(), 10)
# get its codebooks + encoded version of the dataset
codebooks = get_additive_quantizer_codebooks(iflat.rq)
codes = faiss.vector_to_array(iflat.codes).reshape(-1, iflat.code_size)
# make an IVF with 2x4 + 3x4 = 5x4 bits
ivf = faiss.index_factory(ds.d, "IVF256(RCQ2x4),RQ3x4")
# initialize the codebooks
rcq = faiss.downcast_index(ivf.quantizer)
faiss.copy_array_to_vector(
np.vstack(codebooks[:rcq.rq.M]).ravel(),
rcq.rq.codebooks
)
rcq.rq.is_trained = True
# translation of AdditiveCoarseQuantizer::train
rcq.ntotal = 1 << rcq.rq.tot_bits
rcq.centroid_norms.resize(rcq.ntotal)
rcq.rq.compute_centroid_norms(rcq.centroid_norms.data())
rcq.is_trained = True
faiss.copy_array_to_vector(
np.vstack(codebooks[rcq.rq.M:]).ravel(),
ivf.rq.codebooks
)
ivf.rq.is_trained = True
ivf.is_trained = True
# add the codes (this works because 2x4 is a multiple of 8 bits)
ivf.add_sa_codes(codes)
# perform exhaustive search
ivf.nprobe = ivf.nlist
Dnew, Inew = ivf.search(ds.get_queries(), 10)
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
def test_factory(self):
index = faiss.index_factory(5, "RQ2x16_3x8_6x4")
np.testing.assert_array_equal(
faiss.vector_to_array(index.rq.nbits),
np.array([16, 16, 8, 8, 8, 4, 4, 4, 4, 4, 4])
)
def test_factory_norm(self):
index = faiss.index_factory(5, "RQ8x8_Nqint8")
self.assertEqual(
index.rq.search_type,
faiss.AdditiveQuantizer.ST_norm_qint8)
def test_search_decompress(self):
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
xt = ds.get_train()
xb = ds.get_database()
ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
ir.rq.train_type = faiss.ResidualQuantizer.Train_default
ir.train(xt)
ir.add(xb)
D, I = ir.search(ds.get_queries(), 10)
gt = ds.get_groundtruth()
recalls = {
rank: (I[:, :rank] == gt[:, :1]).sum() / len(gt)
for rank in [1, 10, 100]
}
# recalls are {1: 0.05, 10: 0.37, 100: 0.37}
self.assertGreater(recalls[10], 0.35)
def test_reestimate_codebook(self):
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
xt = ds.get_train()
xb = ds.get_database()
ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
ir.train(xt)
# ir.rq.verbose = True
xb_decoded = ir.sa_decode(ir.sa_encode(xb))
err_before = ((xb - xb_decoded) ** 2).sum()
# test manual call of retrain_AQ_codebook
ref_C, ref_codebook = retrain_AQ_codebook(ir, xb)
ir.rq.retrain_AQ_codebook(len(xb), faiss.swig_ptr(xb))
xb_decoded = ir.sa_decode(ir.sa_encode(xb))
err_after = ((xb - xb_decoded) ** 2).sum()
# ref run: 8347.857 vs. 7710.014
self.assertGreater(err_before, err_after * 1.05)
def test_reestimate_codebook_2(self):
ds = datasets.SyntheticDataset(32, 1000, 0, 0)
xt = ds.get_train()
ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
ir.rq.train_type = 0
ir.train(xt)
xt_decoded = ir.sa_decode(ir.sa_encode(xt))
err_before = ((xt - xt_decoded) ** 2).sum()
ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
ir.rq.train_type = faiss.ResidualQuantizer.Train_refine_codebook
ir.train(xt)
xt_decoded = ir.sa_decode(ir.sa_encode(xt))
err_after_refined = ((xt - xt_decoded) ** 2).sum()
print(err_before, err_after_refined)
# ref run 7474.98 / 7006.1777
self.assertGreater(err_before, err_after_refined * 1.06)
###########################################################
# As a coarse quantizer
###########################################################
class TestIVFResidualCoarseQuantizer(unittest.TestCase):
def test_IVF_resiudal(self):
ds = datasets.SyntheticDataset(32, 3000, 1000, 100)
xt = ds.get_train()
xb = ds.get_database()
gt = ds.get_groundtruth(1)
# RQ 2x6 = 12 bits = 4096 centroids
quantizer = faiss.ResidualCoarseQuantizer(ds.d, 2, 6)
rq = quantizer.rq
rq.train_type = faiss.ResidualQuantizer.Train_default
index = faiss.IndexIVFFlat(quantizer, ds.d, 1 << rq.tot_bits)
index.quantizer_trains_alone
index.quantizer_trains_alone = True
index.train(xt)
index.add(xb)
# make sure that increasing the nprobe increases accuracy
index.nprobe = 10
D, I = index.search(ds.get_queries(), 10)
r10 = (I == gt[None, :]).sum() / ds.nq
index.nprobe = 40
D, I = index.search(ds.get_queries(), 10)
r40 = (I == gt[None, :]).sum() / ds.nq
self.assertGreater(r40, r10)
# make sure that decreasing beam factor decreases accuracy
quantizer.beam_factor
quantizer.beam_factor = 1.0
index.nprobe = 10
D, I = index.search(ds.get_queries(), 10)
r10_narrow_beam = (I == gt[None, :]).sum() / ds.nq
self.assertGreater(r10, r10_narrow_beam)
def test_factory(self):
ds = datasets.SyntheticDataset(16, 500, 1000, 100)
index = faiss.index_factory(ds.d, "IVF1024(RCQ2x5),Flat")
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 10)
b = faiss.serialize_index(index)
index2 = faiss.deserialize_index(b)
Dnew, Inew = index2.search(ds.get_queries(), 10)
np.testing.assert_equal(Dref, Dnew)
np.testing.assert_equal(Iref, Inew)
def test_ivfsq(self):
ds = datasets.SyntheticDataset(32, 3000, 1000, 100)
xt = ds.get_train()
xb = ds.get_database()
gt = ds.get_groundtruth(1)
# RQ 2x5 = 10 bits = 1024 centroids
index = faiss.index_factory(ds.d, "IVF1024(RCQ2x5),SQ8")
quantizer = faiss.downcast_index(index.quantizer)
rq = quantizer.rq
rq.train_type = faiss.ResidualQuantizer.Train_default
index.train(xt)
index.add(xb)
# make sure that increasing the nprobe increases accuracy
index.nprobe = 10
D, I = index.search(ds.get_queries(), 10)
r10 = (I == gt[None, :]).sum() / ds.nq
index.nprobe = 40
D, I = index.search(ds.get_queries(), 10)
r40 = (I == gt[None, :]).sum() / ds.nq
self.assertGreater(r40, r10)
def test_rcq_LUT(self):
ds = datasets.SyntheticDataset(32, 3000, 1000, 100)
xt = ds.get_train()
xb = ds.get_database()
# RQ 2x5 = 10 bits = 1024 centroids
index = faiss.index_factory(ds.d, "IVF1024(RCQ2x5),SQ8")
quantizer = faiss.downcast_index(index.quantizer)
rq = quantizer.rq
rq.train_type = faiss.ResidualQuantizer.Train_default
index.train(xt)
index.add(xb)
index.nprobe = 10
# set exact centroids as coarse quantizer
all_centroids = quantizer.reconstruct_n(0, quantizer.ntotal)
q2 = faiss.IndexFlatL2(32)
q2.add(all_centroids)
index.quantizer = q2
Dref, Iref = index.search(ds.get_queries(), 10)
index.quantizer = quantizer
# search with LUT
quantizer.set_beam_factor(-1)
Dnew, Inew = index.search(ds.get_queries(), 10)
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
np.testing.assert_array_equal(Iref, Inew)
# check i/o
CDref, CIref = quantizer.search(ds.get_queries(), 10)
quantizer2 = faiss.deserialize_index(faiss.serialize_index(quantizer))
quantizer2.search(ds.get_queries(), 10)
CDnew, CInew = quantizer2.search(ds.get_queries(), 10)
np.testing.assert_array_almost_equal(CDref, CDnew, decimal=5)
np.testing.assert_array_equal(CIref, CInew)
###########################################################
# Test search with LUTs
###########################################################
class TestAdditiveQuantizerWithLUT(unittest.TestCase):
def test_RCQ_knn(self):
ds = datasets.SyntheticDataset(32, 1000, 0, 123)
xt = ds.get_train()
xq = ds.get_queries()
# RQ 3+4+5 = 12 bits = 4096 centroids
rcq = faiss.index_factory(ds.d, "RCQ1x3_1x4_1x5")
rcq.train(xt)
aq = rcq.rq
cents = rcq.reconstruct_n(0, rcq.ntotal)
sp = faiss.swig_ptr
# test norms computation
norms_ref = (cents ** 2).sum(1)
norms = np.zeros(1 << aq.tot_bits, dtype="float32")
aq.compute_centroid_norms(sp(norms))
np.testing.assert_array_almost_equal(norms, norms_ref, decimal=5)
# test IP search
Dref, Iref = faiss.knn(
xq, cents, 10,
metric=faiss.METRIC_INNER_PRODUCT
)
Dnew = np.zeros_like(Dref)
Inew = np.zeros_like(Iref)
aq.knn_centroids_inner_product(len(xq), sp(xq), 10, sp(Dnew), sp(Inew))
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
np.testing.assert_array_equal(Iref, Inew)
# test L2 search
Dref, Iref = faiss.knn(xq, cents, 10, metric=faiss.METRIC_L2)
Dnew = np.zeros_like(Dref)
Inew = np.zeros_like(Iref)
aq.knn_centroids_L2(len(xq), sp(xq), 10, sp(Dnew), sp(Inew), sp(norms))
np.testing.assert_array_equal(Iref, Inew)
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
class TestIndexResidualQuantizerSearch(unittest.TestCase):
def test_search_IP(self):
ds = datasets.SyntheticDataset(32, 1000, 200, 100)
xt = ds.get_train()
xb = ds.get_database()
xq = ds.get_queries()
ir = faiss.IndexResidualQuantizer(
ds.d, 3, 4, faiss.METRIC_INNER_PRODUCT)
ir.rq.train_type = faiss.ResidualQuantizer.Train_default
ir.train(xt)
ir.add(xb)
Dref, Iref = ir.search(xq, 4)
AQ = faiss.AdditiveQuantizer
ir2 = faiss.IndexResidualQuantizer(
ds.d, 3, 4, faiss.METRIC_INNER_PRODUCT, AQ.ST_LUT_nonorm)
ir2.rq.codebooks = ir.rq.codebooks # fake training
ir2.rq.is_trained = True
ir2.is_trained = True
ir2.add(xb)
D2, I2 = ir2.search(xq, 4)
np.testing.assert_array_equal(Iref, I2)
np.testing.assert_array_almost_equal(Dref, D2, decimal=5)
def test_search_L2(self):
ds = datasets.SyntheticDataset(32, 1000, 200, 100)
xt = ds.get_train()
xb = ds.get_database()
xq = ds.get_queries()
gt = ds.get_groundtruth(10)
ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
ir.rq.train_type = faiss.ResidualQuantizer.Train_default
ir.rq.max_beam_size = 30
ir.train(xt)
# reference run w/ decoding
ir.add(xb)
Dref, Iref = ir.search(xq, 10)
# 388
inter_ref = faiss.eval_intersection(Iref, gt)
AQ = faiss.AdditiveQuantizer
for st in AQ.ST_norm_float, AQ.ST_norm_qint8, AQ.ST_norm_qint4, \
AQ.ST_norm_cqint8, AQ.ST_norm_cqint4:
ir2 = faiss.IndexResidualQuantizer(ds.d, 3, 4, faiss.METRIC_L2, st)
ir2.rq.max_beam_size = 30
ir2.train(xt) # to get the norm bounds
ir2.rq.codebooks = ir.rq.codebooks # fake training
ir2.add(xb)
D2, I2 = ir2.search(xq, 10)
if st == AQ.ST_norm_float:
np.testing.assert_array_almost_equal(Dref, D2, decimal=5)
self.assertLess((Iref != I2).sum(), Iref.size * 0.05)
else:
inter_2 = faiss.eval_intersection(I2, gt)
self.assertGreater(inter_ref, inter_2)
# print(st, inter_ref, inter_2)
###########################################################
# IVF version
###########################################################
class TestIVFResidualQuantizer(unittest.TestCase):
def do_test_accuracy(self, by_residual, st):
ds = datasets.SyntheticDataset(32, 3000, 1000, 100)
quantizer = faiss.IndexFlatL2(ds.d)
index = faiss.IndexIVFResidualQuantizer(
quantizer, ds.d, 100, 3, 4,
faiss.METRIC_L2, st
)
index.by_residual = by_residual
index.rq.train_type
index.rq.train_type = faiss.ResidualQuantizer.Train_default
index.rq.max_beam_size = 30
index.train(ds.get_train())
index.add(ds.get_database())
inters = []
for nprobe in 1, 2, 5, 10, 20, 50:
index.nprobe = nprobe
D, I = index.search(ds.get_queries(), 10)
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
# print(st, "nprobe=", nprobe, "inter=", inter)
inters.append(inter)
# do a little I/O test
index2 = faiss.deserialize_index(faiss.serialize_index(index))
D2, I2 = index2.search(ds.get_queries(), 10)
np.testing.assert_array_equal(I2, I)
np.testing.assert_array_equal(D2, D)
inters = np.array(inters)
if by_residual:
# check that we have increasing intersection measures with
# nprobe
self.assertTrue(np.all(inters[1:] >= inters[:-1]))
else:
self.assertTrue(np.all(inters[1:3] >= inters[:2]))
# check that we have the same result as the flat residual quantizer
iflat = faiss.IndexResidualQuantizer(
ds.d, 3, 4, faiss.METRIC_L2, st)
iflat.rq.train_type
iflat.rq.train_type = faiss.ResidualQuantizer.Train_default
iflat.rq.max_beam_size = 30
iflat.train(ds.get_train())
iflat.rq.codebooks = index.rq.codebooks
iflat.add(ds.get_database())
Dref, Iref = iflat.search(ds.get_queries(), 10)
index.nprobe = 100
D2, I2 = index.search(ds.get_queries(), 10)
np.testing.assert_array_almost_equal(Dref, D2, decimal=5)
# there are many ties because the codes are so short
self.assertLess((Iref != I2).sum(), Iref.size * 0.2)
def test_decompress_no_residual(self):
self.do_test_accuracy(False, faiss.AdditiveQuantizer.ST_decompress)
def test_norm_float_no_residual(self):
self.do_test_accuracy(False, faiss.AdditiveQuantizer.ST_norm_float)
def test_decompress(self):
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_decompress)
def test_norm_float(self):
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_float)
def test_norm_cqint(self):
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint8)
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint4)
def test_factory(self):
index = faiss.index_factory(12, "IVF1024,RQ8x8_Nfloat")
self.assertEqual(index.nlist, 1024)
self.assertEqual(
index.rq.search_type,
faiss.AdditiveQuantizer.ST_norm_float
)
index = faiss.index_factory(12, "IVF1024,RQ8x8_Ncqint8")
self.assertEqual(
index.rq.search_type,
faiss.AdditiveQuantizer.ST_norm_cqint8
)
index = faiss.index_factory(12, "IVF1024,RQ8x8_Ncqint4")
self.assertEqual(
index.rq.search_type,
faiss.AdditiveQuantizer.ST_norm_cqint4
)
def do_test_accuracy_IP(self, by_residual):
ds = datasets.SyntheticDataset(32, 3000, 1000, 100, "IP")
quantizer = faiss.IndexFlatIP(ds.d)
index = faiss.IndexIVFResidualQuantizer(
quantizer, ds.d, 100, 3, 4,
faiss.METRIC_INNER_PRODUCT, faiss.AdditiveQuantizer.ST_decompress
)
index.cp.spherical = True
index.by_residual = by_residual
index.rq.train_type
index.rq.train_type = faiss.ResidualQuantizer.Train_default
index.train(ds.get_train())
index.add(ds.get_database())
inters = []
for nprobe in 1, 2, 5, 10, 20, 50:
index.nprobe = nprobe
index.rq.search_type = faiss.AdditiveQuantizer.ST_decompress
D, I = index.search(ds.get_queries(), 10)
index.rq.search_type = faiss.AdditiveQuantizer.ST_LUT_nonorm
D2, I2 = index.search(ds.get_queries(), 10)
# print(D[:5] - D2[:5])
# print(I[:5])
np.testing.assert_array_almost_equal(D, D2, decimal=5)
# there are many ties because the codes are so short
self.assertLess((I != I2).sum(), I.size * 0.1)
# D2, I2 = index2.search(ds.get_queries(), 10)
# print(D[:5])
# print(D2[:5])
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
# print("nprobe=", nprobe, "inter=", inter)
inters.append(inter)
self.assertTrue(np.all(inters[1:4] >= inters[:3]))
def test_no_residual_IP(self):
self.do_test_accuracy_IP(False)
def test_residual_IP(self):
self.do_test_accuracy_IP(True)
############################################################
# Test functions that use precomputed codebook products
############################################################
def precomp_codebooks(codebooks):
codebook_cross_prods = [
[c1 @ c2.T for c1 in codebooks] for c2 in codebooks
]
cent_norms = [
(c ** 2).sum(1)
for c in codebooks
]
return codebook_cross_prods, cent_norms
############################################################
# Reference imelementation of table-based beam search
############################################################
def beam_search_encode_step_tab(codes, L, distances, codebook_cross_prods_i,
query_cp_i, cent_norms_i):
""" Reference beam search implementation
encodes a residual table.
"""
n, beam_size, m = codes.shape
n2, beam_size_2 = distances.shape
assert n2 == n and beam_size_2 == beam_size
n2, K = query_cp_i.shape
assert n2 == n
K2, = cent_norms_i.shape
assert K == K2
assert len(codebook_cross_prods_i) == m
# n, beam_size, K
new_distances = distances[:, :, None] + cent_norms_i[None, None, :]
new_distances -= 2 * query_cp_i[:, None, :]
dotprods = np.zeros((n, beam_size, K))
for j in range(m):
cb = codebook_cross_prods_i[j]
for i in range(n):
for b in range(beam_size):
dotprods[i, b, :] += cb[codes[i, b, j]]
# print("dps", dotprods[:3, :2, :4])
new_distances += 2 * dotprods
cent_distances = new_distances
# TODO write in vector form
if beam_size * K <= L:
# then keep all the results
new_beam_size = beam_size * K
new_codes = np.zeros((n, beam_size, K, m + 1), dtype=int)
for i in range(n):
new_codes[i, :, :, :-1] = codes[i]
new_codes[i, :, :, -1] = np.arange(K)
new_codes = new_codes.reshape(n, new_beam_size, m + 1)
new_distances = cent_distances.reshape(n, new_beam_size)
else:
# keep top-L results
new_beam_size = L
new_codes = np.zeros((n, L, m + 1), dtype=int)
new_distances = np.zeros((n, L), dtype='float32')
for i in range(n):
cd = cent_distances[i].ravel()
jl = np.argsort(cd)[:L] # TODO argpartition
js = jl // K # input beam index
ls = jl % K # centroid index
new_codes[i, :, :-1] = codes[i, js, :]
new_codes[i, :, -1] = ls
new_distances[i, :] = cd[jl]
return new_codes, new_distances
def beam_search_encoding_tab(codebooks, x, L, precomp, implem="ref"):
"""
Perform encoding of vectors x with a beam search of size L
"""
compare_implem = "ref" in implem and "cpp" in implem
query_cross_prods = [
x @ c.T for c in codebooks
]
M = len(codebooks)
codebook_offsets = np.zeros(M + 1, dtype='uint64')
codebook_offsets[1:] = np.cumsum([len(cb) for cb in codebooks])
codebook_cross_prods, cent_norms = precomp
n, d = x.shape
beam_size = 1
codes = np.zeros((n, beam_size, 0), dtype='int32')
distances = (x ** 2).sum(1).reshape(n, beam_size)
for m, cent in enumerate(codebooks):
if "ref" in implem:
new_codes, new_distances = beam_search_encode_step_tab(
codes, L,
distances, codebook_cross_prods[m][:m],
query_cross_prods[m], cent_norms[m]
)
new_beam_size = codes.shape[1]
if compare_implem:
codes_ref = new_codes
distances_ref = new_distances
if "cpp" in implem:
K = len(cent)
new_beam_size = min(beam_size * K, L)
new_codes = np.zeros((n, new_beam_size, m + 1), dtype='int32')
new_distances = np.zeros((n, new_beam_size), dtype="float32")
if m > 0:
cp = np.vstack(codebook_cross_prods[m][:m])
else:
cp = np.zeros((0, K), dtype='float32')
sp = faiss.swig_ptr
faiss.beam_search_encode_step_tab(
K, n, beam_size,
sp(cp), cp.shape[1],
sp(codebook_offsets),
sp(query_cross_prods[m]), query_cross_prods[m].shape[1],
sp(cent_norms[m]),
m,
sp(codes), sp(distances),
new_beam_size,
sp(new_codes), sp(new_distances)
)
if compare_implem:
np.testing.assert_array_almost_equal(
new_distances, distances_ref, decimal=5)
np.testing.assert_array_equal(
new_codes, codes_ref)
codes = new_codes
distances = new_distances
beam_size = new_beam_size
return (codes, distances)
class TestCrossCodebookComputations(unittest.TestCase):
def test_precomp(self):
ds = datasets.SyntheticDataset(32, 1000, 1000, 0)
# make sure it work with varying nb of bits
nbits = faiss.UInt64Vector()
nbits.push_back(5)
nbits.push_back(6)
nbits.push_back(7)
rq = faiss.ResidualQuantizer(ds.d, nbits)
rq.train_type = faiss.ResidualQuantizer.Train_default
rq.train(ds.get_train())
codebooks = get_additive_quantizer_codebooks(rq)
precomp = precomp_codebooks(codebooks)
codebook_cross_prods_ref, cent_norms_ref = precomp
# check C++ precomp tables
codebook_cross_prods_ref = np.hstack([
np.vstack(c) for c in codebook_cross_prods_ref])
rq.compute_codebook_tables()
codebook_cross_prods = faiss.vector_to_array(
rq.codebook_cross_products)
codebook_cross_prods = codebook_cross_prods.reshape(
rq.total_codebook_size, rq.total_codebook_size)
cent_norms = faiss.vector_to_array(rq.cent_norms)
np.testing.assert_array_almost_equal(
codebook_cross_prods, codebook_cross_prods_ref, decimal=5)
np.testing.assert_array_almost_equal(
np.hstack(cent_norms_ref), cent_norms, decimal=5)
# validate that the python tab-based encoding works
xb = ds.get_database()
ref_codes, _, _ = beam_search_encoding_ref(codebooks, xb, 7)
new_codes, _ = beam_search_encoding_tab(codebooks, xb, 7, precomp)
np.testing.assert_array_equal(ref_codes, new_codes)
# validate the C++ beam_search_encode_step_tab function
beam_search_encoding_tab(codebooks, xb, 7, precomp, implem="ref cpp")
# check implem w/ residuals
n = ref_codes.shape[0]
sp = faiss.swig_ptr
ref_codes_packed = np.zeros((n, rq.code_size), dtype='uint8')
ref_codes_int32 = ref_codes.astype('int32')
rq.pack_codes(
n, sp(ref_codes_int32),
sp(ref_codes_packed), rq.M * ref_codes.shape[1]
)
rq.max_beam_size = 7
codes_ref_residuals = rq.compute_codes(xb)
np.testing.assert_array_equal(ref_codes_packed, codes_ref_residuals)
rq.use_beam_LUT = 1
codes_new = rq.compute_codes(xb)
np.testing.assert_array_equal(codes_ref_residuals, codes_new)