# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np

import faiss
import unittest

from faiss.contrib import datasets
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks

###########################################################
# Reference implementation of encoding with beam search
###########################################################

faiss.omp_set_num_threads(4)

def pairwise_distances(a, b):
    anorms = (a ** 2).sum(1)
    bnorms = (b ** 2).sum(1)
    return anorms.reshape(-1, 1) + bnorms - 2 * a @ b.T


def beam_search_encode_step_ref(cent, residuals, codes, L):
    """ Reference beam search implementation
    encodes a residual table.
    """
    K, d = cent.shape
    n, beam_size, d2 = residuals.shape
    assert d == d2
    n2, beam_size_2, m = codes.shape
    assert n2 == n and beam_size_2 == beam_size

    # compute all possible new residuals
    cent_distances = pairwise_distances(residuals.reshape(n * beam_size, d), cent)
    cent_distances = cent_distances.reshape(n, beam_size, K)

    # TODO write in vector form

    if beam_size * K <= L:
        # then keep all the results
        new_beam_size = beam_size * K
        new_codes = np.zeros((n, beam_size, K, m + 1), dtype=int)
        new_residuals = np.zeros((n, beam_size, K, d), dtype='float32')
        for i in range(n):
            new_codes[i, :, :, :-1] = codes[i]
            new_codes[i, :, :, -1] = np.arange(K)
            new_residuals[i] = residuals[i].reshape(1, d) - cent.reshape(K, d)
        new_codes = new_codes.reshape(n, new_beam_size, m + 1)
        new_residuals = new_residuals.reshape(n, new_beam_size, d)
        new_distances = cent_distances.reshape(n, new_beam_size)
    else:
        # keep top-L results
        new_beam_size = L
        new_codes = np.zeros((n, L, m + 1), dtype=int)
        new_residuals = np.zeros((n, L, d), dtype='float32')
        new_distances = np.zeros((n, L), dtype='float32')
        for i in range(n):
            cd = cent_distances[i].ravel()
            jl = np.argsort(cd)[:L]    # TODO argpartition
            js = jl // K     # input beam index
            ls = jl % K      # centroid index
            new_codes[i, :, :-1] = codes[i, js, :]
            new_codes[i, :, -1] = ls
            new_residuals[i, :, :] = residuals[i, js, :] - cent[ls, :]
            new_distances[i, :] = cd[jl]

    return new_codes, new_residuals, new_distances


def beam_search_encode_step(cent, residuals, codes, L, assign_index=None):
    """ Wrapper of the C++ function with the same interface """
    K, d = cent.shape
    n, beam_size, d2 = residuals.shape
    assert d == d2
    n2, beam_size_2, m = codes.shape
    assert n2 == n and beam_size_2 == beam_size

    assert L <= beam_size * K

    new_codes = np.zeros((n, L, m + 1), dtype='int32')
    new_residuals = np.zeros((n, L, d), dtype='float32')
    new_distances = np.zeros((n, L), dtype='float32')

    sp = faiss.swig_ptr
    codes = np.ascontiguousarray(codes, dtype='int32')
    faiss.beam_search_encode_step(
        d, K, sp(cent), n, beam_size, sp(residuals),
        m, sp(codes), L, sp(new_codes), sp(new_residuals), sp(new_distances),
        assign_index
    )

    return new_codes, new_residuals, new_distances


def beam_search_encoding_ref(centroids, x, L):
    """
    Perform encoding of vectors x with a beam search of size L
    """
    n, d = x.shape
    beam_size = 1
    codes = np.zeros((n, beam_size, 0), dtype=int)
    residuals = x.reshape((n, beam_size, d))
    distances = (x ** 2).sum(1).reshape(n, beam_size)

    for cent in centroids:
        codes, residuals, distances = beam_search_encode_step_ref(
            cent, residuals, codes, L)

    return (codes, residuals, distances)


###########################################################
# Unittests for basic routines
###########################################################


class TestBeamSearch(unittest.TestCase):

    def do_test(self, K=70, L=10, use_assign_index=False):
        """ compare C++ beam search with reference python implementation """
        d = 32
        n = 500
        L = 10 # beam size

        rs = np.random.RandomState(123)
        x = rs.rand(n, d).astype('float32')

        cent = rs.rand(K, d).astype('float32')

        # first quant step --> input beam size is 1
        codes = np.zeros((n, 1, 0), dtype=int)
        residuals = x.reshape(n, 1, d)

        assign_index = faiss.IndexFlatL2(d) if use_assign_index else None

        ref_codes, ref_residuals, ref_distances = beam_search_encode_step_ref(
            cent, residuals, codes, L
        )

        new_codes, new_residuals, new_distances = beam_search_encode_step(
            cent, residuals, codes, L, assign_index
        )

        np.testing.assert_array_equal(new_codes, ref_codes)
        np.testing.assert_array_equal(new_residuals, ref_residuals)
        np.testing.assert_allclose(new_distances, ref_distances, rtol=1e-5)

        # second quant step:
        K = 50
        cent = rs.rand(K, d).astype('float32')

        codes, residuals = ref_codes, ref_residuals

        ref_codes, ref_residuals, ref_distances = beam_search_encode_step_ref(
            cent, residuals, codes, L
        )

        new_codes, new_residuals, new_distances = beam_search_encode_step(
            cent, residuals, codes, L
        )

        np.testing.assert_array_equal(new_codes, ref_codes)
        np.testing.assert_array_equal(new_residuals, ref_residuals)
        np.testing.assert_allclose(new_distances, ref_distances, rtol=1e-5)

    def test_beam_search(self):
        self.do_test()

    def test_beam_search_assign_index(self):
        self.do_test(use_assign_index=True)

    def test_small_beam(self):
        self.do_test(L=1)

    def test_small_beam_2(self):
        self.do_test(L=2)


def eval_codec(q, xb):
    codes = q.compute_codes(xb)
    decoded = q.decode(codes)
    return ((xb - decoded) ** 2).sum()


class TestResidualQuantizer(unittest.TestCase):

    def test_training(self):
        """check that the error is in the same ballpark as PQ """
        ds = datasets.SyntheticDataset(32, 3000, 1000, 0)

        xt = ds.get_train()
        xb = ds.get_database()

        rq = faiss.ResidualQuantizer(ds.d, 4, 6)
        rq.verbose
        rq.verbose = True
        #
        rq.train_type = faiss.ResidualQuantizer.Train_default
        rq.cp.verbose
        # rq.cp.verbose = True
        rq.train(xt)
        err_rq = eval_codec(rq, xb)

        pq = faiss.ProductQuantizer(ds.d, 4, 6)
        pq.train(xt)
        err_pq = eval_codec(pq, xb)

        # in practice RQ is often better than PQ but it does not the case here, so just check
        # that we are within some factor.
        # print(err_pq, err_rq)
        self.assertLess(err_rq, err_pq * 1.2)

    def test_beam_size(self):
        """ check that a larger beam gives a lower error """
        ds = datasets.SyntheticDataset(32, 3000, 1000, 0)

        xt = ds.get_train()
        xb = ds.get_database()

        rq0 = faiss.ResidualQuantizer(ds.d, 4, 6)
        rq0.train_type = faiss.ResidualQuantizer.Train_default
        rq0.max_beam_size = 2
        rq0.train(xt)
        err_rq0 = eval_codec(rq0, xb)

        rq1 = faiss.ResidualQuantizer(ds.d, 4, 6)
        rq1.train_type = faiss.ResidualQuantizer.Train_default
        rq1.max_beam_size = 10
        rq1.train(xt)
        err_rq1 = eval_codec(rq1, xb)

        self.assertLess(err_rq1, err_rq0)

    def test_training_with_limited_mem(self):
        """ make sure a different batch size gives the same result"""
        ds = datasets.SyntheticDataset(32, 3000, 1000, 0)

        xt = ds.get_train()

        rq0 = faiss.ResidualQuantizer(ds.d, 4, 6)
        rq0.train_type = faiss.ResidualQuantizer.Train_default
        rq0.max_beam_size = 5
        # rq0.verbose = True
        rq0.train(xt)
        cb0 = get_additive_quantizer_codebooks(rq0)

        rq1 = faiss.ResidualQuantizer(ds.d, 4, 6)
        rq1.train_type = faiss.ResidualQuantizer.Train_default
        rq1.max_beam_size = 5
        rq1.max_mem_distances
        rq1.max_mem_distances = 3000 * ds.d * 4 * 3
        # rq1.verbose = True
        rq1.train(xt)
        cb1 = get_additive_quantizer_codebooks(rq1)

        for c0, c1 in zip(cb0, cb1):
            self.assertTrue(np.all(c0 == c1))

    def test_clipping(self):
        """ verify that a clipped residual quantizer gives the same
        code prefix + suffix as the full RQ """
        ds = datasets.SyntheticDataset(32, 1000, 100, 0)

        rq = faiss.ResidualQuantizer(ds.d, 5, 4)
        rq.train_type = faiss.ResidualQuantizer.Train_default
        rq.max_beam_size = 5
        rq.train(ds.get_train())

        rq.max_beam_size = 1   # is not he same for a large beam size
        codes = rq.compute_codes(ds.get_database())

        rq2 = faiss.ResidualQuantizer(ds.d, 2, 4)
        rq2.initialize_from(rq)
        self.assertEqual(rq2.M, 2)
        # verify that the beginning of the codes are the same
        codes2 = rq2.compute_codes(ds.get_database())

        rq3 = faiss.ResidualQuantizer(ds.d, 3, 4)
        rq3.initialize_from(rq, 2)
        self.assertEqual(rq3.M, 3)
        codes3 = rq3.compute_codes(ds.get_database() - rq2.decode(codes2))

        # verify that prefixes are the same
        for i in range(ds.nb):
            br = faiss.BitstringReader(faiss.swig_ptr(codes[i]), rq.code_size)
            br2 = faiss.BitstringReader(faiss.swig_ptr(codes2[i]), rq2.code_size)
            self.assertEqual(br.read(rq2.tot_bits), br2.read(rq2.tot_bits))
            br3 = faiss.BitstringReader(faiss.swig_ptr(codes3[i]), rq3.code_size)
            self.assertEqual(br.read(rq3.tot_bits), br3.read(rq3.tot_bits))

###########################################################
# Test index, index factory sa_encode / sa_decode
###########################################################

def unpack_codes(rq, packed_codes):
    nbits = faiss.vector_to_array(rq.nbits)
    if np.all(nbits == 8):
        return packed_codes.astype("uint32")
    nbits = [int(x) for x in nbits]
    nb = len(nbits)
    n, code_size = packed_codes.shape
    codes = np.zeros((n, nb), dtype="uint32")
    for i in range(n):
        br = faiss.BitstringReader(faiss.swig_ptr(packed_codes[i]), code_size)
        for j, nbi in enumerate(nbits):
            codes[i, j] = br.read(nbi)
    return codes

def retrain_AQ_codebook(index, xt):
    """ reference implementation of codebook retraining """
    rq = index.rq

    codes_packed = index.sa_encode(xt)
    n, code_size = codes_packed.shape

    x_decoded = index.sa_decode(codes_packed)
    MSE = ((xt - x_decoded) ** 2).sum() / n
    # print(f"Initial MSE on training set: {MSE:g}")

    codes = unpack_codes(index.rq, codes_packed)
    # print("ref codes", codes[0])
    codebook_offsets = faiss.vector_to_array(rq.codebook_offsets)

    # build sparse code matrix (represented as a dense matrix)
    C = np.zeros((n, rq.total_codebook_size))

    for i in range(n):
        C[i][codes[i] + codebook_offsets[:-1]] = 1

    # import pdb; pdb.set_trace()
    # import scipy
    # B, residuals, rank, singvals = np.linalg.lstsq(C, xt, rcond=None)
    if True:
        B, residuals, rank, singvals = np.linalg.lstsq(C, xt, rcond=None)
    else:
        import scipy.linalg
        import pdb; pdb.set_trace()
        B, residuals, rank, singvals = scipy.linalg.lstsq(C, xt, )

    MSE = ((C @ B - xt) ** 2).sum() / n
    # print(f"MSE after retrainining: {MSE:g}")

    # replace codebook
    # faiss.copy_array_to_vector(B.astype('float32').ravel(), index.rq.codebooks)
    # update codebook tables
    # index.rq.compute_codebook_tables()

    return C, B


class TestIndexResidualQuantizer(unittest.TestCase):

    def test_io(self):
        ds = datasets.SyntheticDataset(32, 1000, 100, 0)

        xt = ds.get_train()
        xb = ds.get_database()

        ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
        ir.rq.train_type = faiss.ResidualQuantizer.Train_default
        ir.train(xt)
        ref_codes = ir.sa_encode(xb)

        b = faiss.serialize_index(ir)
        ir2 = faiss.deserialize_index(b)
        codes2 = ir2.sa_encode(xb)

        np.testing.assert_array_equal(ref_codes, codes2)

    def test_equiv_rq(self):
        """
        make sure it is equivalent to search a RQ and to search an IVF
        with RCQ + RQ with the same codebooks.
        """
        ds = datasets.SyntheticDataset(32, 3000, 1000, 50)

        # make a flat RQ
        iflat = faiss.IndexResidualQuantizer(ds.d, 5, 4)
        iflat.rq.train_type = faiss.ResidualQuantizer.Train_default
        iflat.train(ds.get_train())
        iflat.add(ds.get_database())

        # ref search result
        Dref, Iref = iflat.search(ds.get_queries(), 10)

        # get its codebooks + encoded version of the dataset
        codebooks = get_additive_quantizer_codebooks(iflat.rq)
        codes = faiss.vector_to_array(iflat.codes).reshape(-1, iflat.code_size)

        # make an IVF with 2x4 + 3x4 = 5x4 bits
        ivf = faiss.index_factory(ds.d, "IVF256(RCQ2x4),RQ3x4")

        # initialize the codebooks
        rcq = faiss.downcast_index(ivf.quantizer)
        faiss.copy_array_to_vector(
            np.vstack(codebooks[:rcq.rq.M]).ravel(),
            rcq.rq.codebooks
        )
        rcq.rq.is_trained = True
        # translation of AdditiveCoarseQuantizer::train
        rcq.ntotal = 1 << rcq.rq.tot_bits
        rcq.centroid_norms.resize(rcq.ntotal)
        rcq.rq.compute_centroid_norms(rcq.centroid_norms.data())
        rcq.is_trained = True

        faiss.copy_array_to_vector(
            np.vstack(codebooks[rcq.rq.M:]).ravel(),
            ivf.rq.codebooks
        )
        ivf.rq.is_trained = True
        ivf.is_trained = True

        # add the codes (this works because 2x4 is a multiple of 8 bits)
        ivf.add_sa_codes(codes)

        # perform exhaustive search
        ivf.nprobe = ivf.nlist

        Dnew, Inew = ivf.search(ds.get_queries(), 10)

        np.testing.assert_array_equal(Iref, Inew)
        np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)

    def test_factory(self):
        index = faiss.index_factory(5, "RQ2x16_3x8_6x4")

        np.testing.assert_array_equal(
            faiss.vector_to_array(index.rq.nbits),
            np.array([16, 16, 8, 8, 8, 4, 4, 4, 4, 4, 4])
        )

    def test_factory_norm(self):
        index = faiss.index_factory(5, "RQ8x8_Nqint8")
        self.assertEqual(
            index.rq.search_type,
            faiss.AdditiveQuantizer.ST_norm_qint8)


    def test_search_decompress(self):
        ds = datasets.SyntheticDataset(32, 1000, 1000, 100)

        xt = ds.get_train()
        xb = ds.get_database()

        ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
        ir.rq.train_type = faiss.ResidualQuantizer.Train_default
        ir.train(xt)
        ir.add(xb)

        D, I = ir.search(ds.get_queries(), 10)
        gt = ds.get_groundtruth()

        recalls = {
            rank: (I[:, :rank] == gt[:, :1]).sum() / len(gt)
            for rank in [1, 10, 100]
        }
        # recalls are {1: 0.05, 10: 0.37, 100: 0.37}
        self.assertGreater(recalls[10], 0.35)

    def test_reestimate_codebook(self):
        ds = datasets.SyntheticDataset(32, 1000, 1000, 100)

        xt = ds.get_train()
        xb = ds.get_database()

        ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
        ir.train(xt)

        # ir.rq.verbose = True
        xb_decoded = ir.sa_decode(ir.sa_encode(xb))
        err_before = ((xb - xb_decoded) ** 2).sum()

        # test manual call of retrain_AQ_codebook

        ref_C, ref_codebook = retrain_AQ_codebook(ir, xb)
        ir.rq.retrain_AQ_codebook(len(xb), faiss.swig_ptr(xb))

        xb_decoded = ir.sa_decode(ir.sa_encode(xb))
        err_after = ((xb - xb_decoded) ** 2).sum()

        # ref run: 8347.857 vs. 7710.014
        self.assertGreater(err_before, err_after * 1.05)

    def test_reestimate_codebook_2(self):
        ds = datasets.SyntheticDataset(32, 1000, 0, 0)
        xt = ds.get_train()

        ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
        ir.rq.train_type = 0
        ir.train(xt)

        xt_decoded = ir.sa_decode(ir.sa_encode(xt))
        err_before = ((xt - xt_decoded) ** 2).sum()

        ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
        ir.rq.train_type = faiss.ResidualQuantizer.Train_refine_codebook
        ir.train(xt)

        xt_decoded = ir.sa_decode(ir.sa_encode(xt))
        err_after_refined = ((xt - xt_decoded) ** 2).sum()

        # print(err_before, err_after_refined)
        # ref run 7474.98 / 7006.1777
        self.assertGreater(err_before, err_after_refined * 1.06)





###########################################################
# As a coarse quantizer
###########################################################

class TestIVFResidualCoarseQuantizer(unittest.TestCase):

    def test_IVF_resiudal(self):
        ds = datasets.SyntheticDataset(32, 3000, 1000, 100)

        xt = ds.get_train()
        xb = ds.get_database()

        gt = ds.get_groundtruth(1)

        # RQ 2x6 = 12 bits = 4096 centroids
        quantizer = faiss.ResidualCoarseQuantizer(ds.d, 2, 6)
        rq = quantizer.rq
        rq.train_type = faiss.ResidualQuantizer.Train_default
        index = faiss.IndexIVFFlat(quantizer, ds.d, 1 << rq.tot_bits)
        index.quantizer_trains_alone
        index.quantizer_trains_alone = True

        index.train(xt)
        index.add(xb)

        # make sure that increasing the nprobe increases accuracy

        index.nprobe = 10
        D, I = index.search(ds.get_queries(), 10)
        r10 = (I == gt[None, :]).sum() / ds.nq

        index.nprobe = 40
        D, I = index.search(ds.get_queries(), 10)
        r40 = (I == gt[None, :]).sum() / ds.nq

        self.assertGreater(r40, r10)

        # make sure that decreasing beam factor decreases accuracy
        quantizer.beam_factor
        quantizer.beam_factor = 1.0
        index.nprobe = 10
        D, I = index.search(ds.get_queries(), 10)
        r10_narrow_beam = (I == gt[None, :]).sum() / ds.nq

        self.assertGreater(r10, r10_narrow_beam)

    def test_factory(self):
        ds = datasets.SyntheticDataset(16, 500, 1000, 100)

        index = faiss.index_factory(ds.d, "IVF1024(RCQ2x5),Flat")
        index.train(ds.get_train())
        index.add(ds.get_database())

        Dref, Iref = index.search(ds.get_queries(), 10)

        b = faiss.serialize_index(index)
        index2 = faiss.deserialize_index(b)

        Dnew, Inew = index2.search(ds.get_queries(), 10)

        np.testing.assert_equal(Dref, Dnew)
        np.testing.assert_equal(Iref, Inew)

    def test_ivfsq(self):
        ds = datasets.SyntheticDataset(32, 3000, 1000, 100)

        xt = ds.get_train()
        xb = ds.get_database()

        gt = ds.get_groundtruth(1)

        # RQ 2x5 = 10 bits = 1024 centroids
        index = faiss.index_factory(ds.d, "IVF1024(RCQ2x5),SQ8")
        quantizer = faiss.downcast_index(index.quantizer)
        rq = quantizer.rq
        rq.train_type = faiss.ResidualQuantizer.Train_default

        index.train(xt)
        index.add(xb)

        # make sure that increasing the nprobe increases accuracy

        index.nprobe = 10
        D, I = index.search(ds.get_queries(), 10)
        r10 = (I == gt[None, :]).sum() / ds.nq

        index.nprobe = 40
        D, I = index.search(ds.get_queries(), 10)
        r40 = (I == gt[None, :]).sum() / ds.nq

        self.assertGreater(r40, r10)

    def test_rcq_LUT(self):
        ds = datasets.SyntheticDataset(32, 3000, 1000, 100)

        xt = ds.get_train()
        xb = ds.get_database()

        # RQ 2x5 = 10 bits = 1024 centroids
        index = faiss.index_factory(ds.d, "IVF1024(RCQ2x5),SQ8")

        quantizer = faiss.downcast_index(index.quantizer)
        rq = quantizer.rq
        rq.train_type = faiss.ResidualQuantizer.Train_default

        index.train(xt)
        index.add(xb)
        index.nprobe = 10

        # set exact centroids as coarse quantizer
        all_centroids = quantizer.reconstruct_n(0, quantizer.ntotal)
        q2 = faiss.IndexFlatL2(32)
        q2.add(all_centroids)
        index.quantizer = q2
        Dref, Iref = index.search(ds.get_queries(), 10)
        index.quantizer = quantizer

        # search with LUT
        quantizer.set_beam_factor(-1)
        Dnew, Inew = index.search(ds.get_queries(), 10)

        np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
        np.testing.assert_array_equal(Iref, Inew)

        # check i/o
        CDref, CIref = quantizer.search(ds.get_queries(), 10)
        quantizer2 = faiss.deserialize_index(faiss.serialize_index(quantizer))
        quantizer2.search(ds.get_queries(), 10)
        CDnew, CInew = quantizer2.search(ds.get_queries(), 10)
        np.testing.assert_array_almost_equal(CDref, CDnew, decimal=5)
        np.testing.assert_array_equal(CIref, CInew)

        # check that you can load the index without computing the tables
        quantizer.set_beam_factor(2.0)
        self.assertNotEqual(quantizer.rq.codebook_cross_products.size(), 0)
        quantizer3 = faiss.deserialize_index(
            faiss.serialize_index(quantizer), faiss.IO_FLAG_SKIP_PRECOMPUTE_TABLE)
        self.assertEqual(quantizer3.rq.codebook_cross_products.size(), 0)
        CD3, CI3 = quantizer3.search(ds.get_queries(), 10)


###########################################################
# Test search with LUTs
###########################################################


class TestAdditiveQuantizerWithLUT(unittest.TestCase):

    def test_RCQ_knn(self):
        ds = datasets.SyntheticDataset(32, 1000, 0, 123)
        xt = ds.get_train()
        xq = ds.get_queries()

        # RQ 3+4+5 = 12 bits = 4096 centroids
        rcq = faiss.index_factory(ds.d, "RCQ1x3_1x4_1x5")
        rcq.train(xt)

        aq = rcq.rq

        cents = rcq.reconstruct_n(0, rcq.ntotal)

        sp = faiss.swig_ptr

        # test norms computation

        norms_ref = (cents ** 2).sum(1)
        norms = np.zeros(1 << aq.tot_bits, dtype="float32")
        aq.compute_centroid_norms(sp(norms))

        np.testing.assert_array_almost_equal(norms, norms_ref, decimal=5)

        # test IP search

        Dref, Iref = faiss.knn(
            xq, cents, 10,
            metric=faiss.METRIC_INNER_PRODUCT
        )

        Dnew = np.zeros_like(Dref)
        Inew = np.zeros_like(Iref)

        aq.knn_centroids_inner_product(len(xq), sp(xq), 10, sp(Dnew), sp(Inew))

        np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
        np.testing.assert_array_equal(Iref, Inew)

        # test L2 search

        Dref, Iref = faiss.knn(xq, cents, 10, metric=faiss.METRIC_L2)

        Dnew = np.zeros_like(Dref)
        Inew = np.zeros_like(Iref)

        aq.knn_centroids_L2(len(xq), sp(xq), 10, sp(Dnew), sp(Inew), sp(norms))

        np.testing.assert_array_equal(Iref, Inew)
        np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)


class TestIndexResidualQuantizerSearch(unittest.TestCase):

    def test_search_IP(self):
        ds = datasets.SyntheticDataset(32, 1000, 200, 100)

        xt = ds.get_train()
        xb = ds.get_database()
        xq = ds.get_queries()

        ir = faiss.IndexResidualQuantizer(
            ds.d, 3, 4, faiss.METRIC_INNER_PRODUCT)
        ir.rq.train_type = faiss.ResidualQuantizer.Train_default
        ir.train(xt)

        ir.add(xb)

        Dref, Iref = ir.search(xq, 4)

        AQ = faiss.AdditiveQuantizer
        ir2 = faiss.IndexResidualQuantizer(
            ds.d, 3, 4, faiss.METRIC_INNER_PRODUCT, AQ.ST_LUT_nonorm)
        ir2.rq.codebooks = ir.rq.codebooks    # fake training
        ir2.rq.is_trained = True
        ir2.is_trained = True
        ir2.add(xb)

        D2, I2 = ir2.search(xq, 4)

        np.testing.assert_array_equal(Iref, I2)
        np.testing.assert_array_almost_equal(Dref, D2, decimal=5)

    def test_search_L2(self):
        ds = datasets.SyntheticDataset(32, 1000, 200, 100)

        xt = ds.get_train()
        xb = ds.get_database()
        xq = ds.get_queries()
        gt = ds.get_groundtruth(10)

        ir = faiss.IndexResidualQuantizer(ds.d, 3, 4)
        ir.rq.train_type = faiss.ResidualQuantizer.Train_default
        ir.rq.max_beam_size = 30
        ir.train(xt)

        # reference run w/ decoding
        ir.add(xb)
        Dref, Iref = ir.search(xq, 10)

        # 388
        inter_ref = faiss.eval_intersection(Iref, gt)

        AQ = faiss.AdditiveQuantizer
        for st in AQ.ST_norm_float, AQ.ST_norm_qint8, AQ.ST_norm_qint4, \
                AQ.ST_norm_cqint8, AQ.ST_norm_cqint4:

            ir2 = faiss.IndexResidualQuantizer(ds.d, 3, 4, faiss.METRIC_L2, st)
            ir2.rq.max_beam_size = 30
            ir2.train(xt)   # to get the norm bounds
            ir2.rq.codebooks = ir.rq.codebooks    # fake training
            ir2.add(xb)

            D2, I2 = ir2.search(xq, 10)

            if st == AQ.ST_norm_float:
                np.testing.assert_array_almost_equal(Dref, D2, decimal=5)
                self.assertLess((Iref != I2).sum(), Iref.size * 0.05)
            else:
                inter_2 = faiss.eval_intersection(I2, gt)
                self.assertGreaterEqual(inter_ref, inter_2)
                # print(st, inter_ref, inter_2)


###########################################################
# IVF version
###########################################################


class TestIVFResidualQuantizer(unittest.TestCase):

    def do_test_accuracy(self, by_residual, st):
        ds = datasets.SyntheticDataset(32, 3000, 1000, 100)

        quantizer = faiss.IndexFlatL2(ds.d)

        index = faiss.IndexIVFResidualQuantizer(
            quantizer, ds.d, 100, 3, 4,
            faiss.METRIC_L2, st
        )
        index.by_residual = by_residual

        index.rq.train_type
        index.rq.train_type = faiss.ResidualQuantizer.Train_default
        index.rq.max_beam_size = 30

        index.train(ds.get_train())
        index.add(ds.get_database())

        inters = []
        for nprobe in 1, 2, 5, 10, 20, 50:
            index.nprobe = nprobe
            D, I = index.search(ds.get_queries(), 10)
            inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
            # print(st, "nprobe=", nprobe, "inter=", inter)
            inters.append(inter)

        # do a little I/O test
        index2 = faiss.deserialize_index(faiss.serialize_index(index))
        D2, I2 = index2.search(ds.get_queries(), 10)
        np.testing.assert_array_equal(I2, I)
        np.testing.assert_array_equal(D2, D)

        inters = np.array(inters)

        if by_residual:
            # check that we have increasing intersection measures with
            # nprobe
            self.assertTrue(np.all(inters[1:] >= inters[:-1]))
        else:
            self.assertTrue(np.all(inters[1:3] >= inters[:2]))
            # check that we have the same result as the flat residual quantizer
            iflat = faiss.IndexResidualQuantizer(
                ds.d, 3, 4, faiss.METRIC_L2, st)
            iflat.rq.train_type
            iflat.rq.train_type = faiss.ResidualQuantizer.Train_default
            iflat.rq.max_beam_size = 30
            iflat.train(ds.get_train())
            iflat.rq.codebooks = index.rq.codebooks

            iflat.add(ds.get_database())
            Dref, Iref = iflat.search(ds.get_queries(), 10)

            index.nprobe = 100
            D2, I2 = index.search(ds.get_queries(), 10)
            np.testing.assert_array_almost_equal(Dref, D2, decimal=5)
            # there are many ties because the codes are so short
            self.assertLess((Iref != I2).sum(), Iref.size * 0.2)

    def test_decompress_no_residual(self):
        self.do_test_accuracy(False, faiss.AdditiveQuantizer.ST_decompress)

    def test_norm_float_no_residual(self):
        self.do_test_accuracy(False, faiss.AdditiveQuantizer.ST_norm_float)

    def test_decompress(self):
        self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_decompress)

    def test_norm_float(self):
        self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_float)

    def test_norm_cqint(self):
        self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint8)
        self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint4)

    def test_factory(self):
        index = faiss.index_factory(12, "IVF1024,RQ8x8_Nfloat")
        self.assertEqual(index.nlist, 1024)
        self.assertEqual(
            index.rq.search_type,
            faiss.AdditiveQuantizer.ST_norm_float
        )

        index = faiss.index_factory(12, "IVF1024,RQ8x8_Ncqint8")
        self.assertEqual(
            index.rq.search_type,
            faiss.AdditiveQuantizer.ST_norm_cqint8
        )
        index = faiss.index_factory(12, "IVF1024,RQ8x8_Ncqint4")
        self.assertEqual(
            index.rq.search_type,
            faiss.AdditiveQuantizer.ST_norm_cqint4
        )

    def do_test_accuracy_IP(self, by_residual):
        ds = datasets.SyntheticDataset(32, 3000, 1000, 100, "IP")

        quantizer = faiss.IndexFlatIP(ds.d)

        index = faiss.IndexIVFResidualQuantizer(
            quantizer, ds.d, 100, 3, 4,
            faiss.METRIC_INNER_PRODUCT, faiss.AdditiveQuantizer.ST_decompress
        )
        index.cp.spherical = True
        index.by_residual = by_residual

        index.rq.train_type
        index.rq.train_type = faiss.ResidualQuantizer.Train_default
        index.train(ds.get_train())

        index.add(ds.get_database())

        inters = []
        for nprobe in 1, 2, 5, 10, 20, 50:
            index.nprobe = nprobe
            index.rq.search_type = faiss.AdditiveQuantizer.ST_decompress
            D, I = index.search(ds.get_queries(), 10)
            index.rq.search_type = faiss.AdditiveQuantizer.ST_LUT_nonorm
            D2, I2 = index.search(ds.get_queries(), 10)
            # print(D[:5] - D2[:5])
            # print(I[:5])
            np.testing.assert_array_almost_equal(D, D2, decimal=5)
            # there are many ties because the codes are so short
            self.assertLess((I != I2).sum(), I.size * 0.1)

            # D2, I2 = index2.search(ds.get_queries(), 10)
            # print(D[:5])
            # print(D2[:5])

            inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
            # print("nprobe=", nprobe, "inter=", inter)
            inters.append(inter)
        self.assertTrue(np.all(inters[1:4] >= inters[:3]))

    def test_no_residual_IP(self):
        self.do_test_accuracy_IP(False)

    def test_residual_IP(self):
        self.do_test_accuracy_IP(True)


############################################################
# Test functions that use precomputed codebook products
############################################################


def precomp_codebooks(codebooks):

    codebook_cross_prods = [
        [c1 @ c2.T for c1 in codebooks] for c2 in codebooks
    ]
    cent_norms = [
        (c ** 2).sum(1)
        for c in codebooks
    ]
    return codebook_cross_prods, cent_norms


############################################################
# Reference imelementation of table-based beam search
############################################################

def beam_search_encode_step_tab(codes, L, distances, codebook_cross_prods_i,
                                query_cp_i, cent_norms_i):
    """ Reference beam search implementation
    encodes a residual table.
    """
    n, beam_size, m = codes.shape

    n2, beam_size_2 = distances.shape
    assert n2 == n and beam_size_2 == beam_size
    n2, K = query_cp_i.shape
    assert n2 == n
    K2, = cent_norms_i.shape
    assert K == K2
    assert len(codebook_cross_prods_i) == m

    # n, beam_size, K
    new_distances = distances[:, :, None] + cent_norms_i[None, None, :]
    new_distances -= 2 * query_cp_i[:, None, :]

    dotprods = np.zeros((n, beam_size, K))

    for j in range(m):
        cb = codebook_cross_prods_i[j]
        for i in range(n):
            for b in range(beam_size):
                dotprods[i, b, :] += cb[codes[i, b, j]]

    # print("dps", dotprods[:3, :2, :4])

    new_distances += 2 * dotprods
    cent_distances = new_distances

    # TODO write in vector form

    if beam_size * K <= L:
        # then keep all the results
        new_beam_size = beam_size * K
        new_codes = np.zeros((n, beam_size, K, m + 1), dtype=int)
        for i in range(n):
            new_codes[i, :, :, :-1] = codes[i]
            new_codes[i, :, :, -1] = np.arange(K)
        new_codes = new_codes.reshape(n, new_beam_size, m + 1)
        new_distances = cent_distances.reshape(n, new_beam_size)
    else:
        # keep top-L results
        new_beam_size = L
        new_codes = np.zeros((n, L, m + 1), dtype=int)
        new_distances = np.zeros((n, L), dtype='float32')
        for i in range(n):
            cd = cent_distances[i].ravel()
            jl = np.argsort(cd)[:L]    # TODO argpartition
            js = jl // K     # input beam index
            ls = jl % K      # centroid index
            new_codes[i, :, :-1] = codes[i, js, :]
            new_codes[i, :, -1] = ls
            new_distances[i, :] = cd[jl]

    return new_codes, new_distances


def beam_search_encoding_tab(codebooks, x, L, precomp, implem="ref"):
    """
    Perform encoding of vectors x with a beam search of size L
    """
    compare_implem = "ref" in implem and "cpp" in implem

    query_cross_prods = [
        x @ c.T for c in codebooks
    ]

    M = len(codebooks)
    codebook_offsets = np.zeros(M + 1, dtype='uint64')
    codebook_offsets[1:] = np.cumsum([len(cb) for cb in codebooks])
    codebook_cross_prods, cent_norms = precomp
    n, d = x.shape
    beam_size = 1
    codes = np.zeros((n, beam_size, 0), dtype='int32')
    distances = (x ** 2).sum(1).reshape(n, beam_size)

    for m, cent in enumerate(codebooks):

        if "ref" in implem:
            new_codes, new_distances = beam_search_encode_step_tab(
                codes, L,
                distances, codebook_cross_prods[m][:m],
                query_cross_prods[m], cent_norms[m]
            )
            new_beam_size = codes.shape[1]

        if compare_implem:
            codes_ref = new_codes
            distances_ref = new_distances

        if "cpp" in implem:
            K = len(cent)
            new_beam_size = min(beam_size * K, L)
            new_codes = np.zeros((n, new_beam_size, m + 1), dtype='int32')
            new_distances = np.zeros((n, new_beam_size), dtype="float32")
            if m > 0:
                cp = np.vstack(codebook_cross_prods[m][:m])
            else:
                cp = np.zeros((0, K), dtype='float32')

            sp = faiss.swig_ptr
            faiss.beam_search_encode_step_tab(
                K, n, beam_size,
                sp(cp), cp.shape[1],
                sp(codebook_offsets),
                sp(query_cross_prods[m]), query_cross_prods[m].shape[1],
                sp(cent_norms[m]),
                m,
                sp(codes), sp(distances),
                new_beam_size,
                sp(new_codes), sp(new_distances)
            )

        if compare_implem:
            np.testing.assert_array_almost_equal(
                new_distances, distances_ref, decimal=5)
            np.testing.assert_array_equal(
                new_codes, codes_ref)

        codes = new_codes
        distances = new_distances
        beam_size = new_beam_size

    return (codes, distances)


class TestCrossCodebookComputations(unittest.TestCase):

    def test_precomp(self):
        ds = datasets.SyntheticDataset(32, 1000, 1000, 0)

        # make sure it work with varying nb of bits
        nbits = faiss.UInt64Vector()
        nbits.push_back(5)
        nbits.push_back(6)
        nbits.push_back(7)

        rq = faiss.ResidualQuantizer(ds.d, nbits)
        rq.train_type = faiss.ResidualQuantizer.Train_default
        rq.train(ds.get_train())

        codebooks = get_additive_quantizer_codebooks(rq)
        precomp = precomp_codebooks(codebooks)
        codebook_cross_prods_ref, cent_norms_ref = precomp

        # check C++ precomp tables
        codebook_cross_prods_ref = np.hstack([
            np.vstack(c) for c in codebook_cross_prods_ref])

        rq.compute_codebook_tables()
        codebook_cross_prods = faiss.vector_to_array(
            rq.codebook_cross_products)
        codebook_cross_prods = codebook_cross_prods.reshape(
            rq.total_codebook_size, rq.total_codebook_size)
        cent_norms = faiss.vector_to_array(rq.cent_norms)

        np.testing.assert_array_almost_equal(
            codebook_cross_prods, codebook_cross_prods_ref, decimal=5)
        np.testing.assert_array_almost_equal(
            np.hstack(cent_norms_ref), cent_norms, decimal=5)

        # validate that the python tab-based encoding works
        xb = ds.get_database()
        ref_codes, _, _ = beam_search_encoding_ref(codebooks, xb, 7)
        new_codes, _ = beam_search_encoding_tab(codebooks, xb, 7, precomp)
        np.testing.assert_array_equal(ref_codes, new_codes)

        # validate the C++ beam_search_encode_step_tab function
        beam_search_encoding_tab(codebooks, xb, 7, precomp, implem="ref cpp")

        # check implem w/ residuals
        n = ref_codes.shape[0]
        sp = faiss.swig_ptr
        ref_codes_packed = np.zeros((n, rq.code_size), dtype='uint8')
        ref_codes_int32 = ref_codes.astype('int32')
        rq.pack_codes(
            n, sp(ref_codes_int32),
            sp(ref_codes_packed), rq.M * ref_codes.shape[1]
        )

        rq.max_beam_size = 7
        codes_ref_residuals = rq.compute_codes(xb)
        np.testing.assert_array_equal(ref_codes_packed, codes_ref_residuals)

        rq.use_beam_LUT = 1
        codes_new = rq.compute_codes(xb)
        np.testing.assert_array_equal(codes_ref_residuals, codes_new)


class TestProductResidualQuantizer(unittest.TestCase):

    def test_codec(self):
        """check that the error is in the same ballpark as PQ."""
        ds = datasets.SyntheticDataset(64, 3000, 3000, 0)

        xt = ds.get_train()
        xb = ds.get_database()

        nsplits = 2
        Msub = 2
        nbits = 4

        prq = faiss.ProductResidualQuantizer(ds.d, nsplits, Msub, nbits)
        prq.train(xt)
        err_prq = eval_codec(prq, xb)

        pq = faiss.ProductQuantizer(ds.d, nsplits * Msub, nbits)
        pq.train(xt)
        err_pq = eval_codec(pq, xb)

        # print(err_prq, err_pq)
        self.assertLess(err_prq, err_pq)

    def test_with_rq(self):
        """compare with RQ when nsplits = 1"""
        ds = datasets.SyntheticDataset(32, 3000, 3000, 0)

        xt = ds.get_train()
        xb = ds.get_database()

        M = 4
        nbits = 4

        prq = faiss.ProductResidualQuantizer(ds.d, 1, M, nbits)
        prq.train(xt)
        err_prq = eval_codec(prq, xb)

        rq = faiss.ResidualQuantizer(ds.d, M, nbits)
        rq.train(xt)
        err_rq = eval_codec(rq, xb)

        # print(err_prq, err_rq)
        self.assertEqual(err_prq, err_rq)


class TestIndexProductResidualQuantizer(unittest.TestCase):

    def test_accuracy1(self):
        """check that the error is in the same ballpark as RQ."""
        recall1 = self.eval_index_accuracy("PRQ4x3x5_Nqint8")
        recall2 = self.eval_index_accuracy("RQ12x5_Nqint8")
        self.assertGreaterEqual(recall1 * 1.1, recall2)  # 657 vs 665

    def test_accuracy2(self):
        """when nsplits = 1, PRQ should be the same as RQ"""
        recall1 = self.eval_index_accuracy("PRQ1x3x5_Nqint8")
        recall2 = self.eval_index_accuracy("RQ3x5_Nqint8")
        self.assertEqual(recall1, recall2)

    def eval_index_accuracy(self, index_key):
        ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
        index = faiss.index_factory(ds.d, index_key)

        index.train(ds.get_train())
        index.add(ds.get_database())
        D, I = index.search(ds.get_queries(), 10)
        inter = faiss.eval_intersection(I, ds.get_groundtruth(10))

        # do a little I/O test
        index2 = faiss.deserialize_index(faiss.serialize_index(index))
        D2, I2 = index2.search(ds.get_queries(), 10)
        np.testing.assert_array_equal(I2, I)
        np.testing.assert_array_equal(D2, D)

        return inter

    def test_factory(self):
        AQ = faiss.AdditiveQuantizer
        ns, Msub, nbits = 2, 4, 8
        index = faiss.index_factory(64, f"PRQ{ns}x{Msub}x{nbits}_Nqint8")
        assert isinstance(index, faiss.IndexProductResidualQuantizer)
        self.assertEqual(index.prq.nsplits, ns)
        self.assertEqual(index.prq.subquantizer(0).M, Msub)
        self.assertEqual(index.prq.subquantizer(0).nbits.at(0), nbits)
        self.assertEqual(index.prq.search_type, AQ.ST_norm_qint8)

        code_size = (ns * Msub * nbits + 7) // 8 + 1
        self.assertEqual(index.prq.code_size, code_size)


class TestIndexIVFProductResidualQuantizer(unittest.TestCase):

    def eval_index_accuracy(self, factory_key):
        ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
        index = faiss.index_factory(ds.d, factory_key)

        index.train(ds.get_train())
        index.add(ds.get_database())

        inters = []
        for nprobe in 1, 2, 5, 10, 20, 50:
            index.nprobe = nprobe
            D, I = index.search(ds.get_queries(), 10)
            inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
            inters.append(inter)

        inters = np.array(inters)
        # 1.05: test relaxed for OSX on ARM
        self.assertTrue(np.all(inters[1:] * 1.05 >= inters[:-1]))

        # do a little I/O test
        index2 = faiss.deserialize_index(faiss.serialize_index(index))
        D2, I2 = index2.search(ds.get_queries(), 10)
        np.testing.assert_array_equal(I2, I)
        np.testing.assert_array_equal(D2, D)

        return inter

    def test_index_accuracy(self):
        self.eval_index_accuracy("IVF100,PRQ2x2x5_Nqint8")

    def test_index_accuracy2(self):
        """check that the error is in the same ballpark as RQ."""
        inter1 = self.eval_index_accuracy("IVF100,PRQ2x2x5_Nqint8")
        inter2 = self.eval_index_accuracy("IVF100,RQ4x5_Nqint8")
        # print(inter1, inter2)  # 392 vs 374
        self.assertGreaterEqual(inter1 * 1.1, inter2)

    def test_factory(self):
        AQ = faiss.AdditiveQuantizer
        ns, Msub, nbits = 2, 4, 8
        index = faiss.index_factory(64, f"IVF100,PRQ{ns}x{Msub}x{nbits}_Nqint8")
        assert isinstance(index, faiss.IndexIVFProductResidualQuantizer)
        self.assertEqual(index.nlist, 100)
        self.assertEqual(index.prq.nsplits, ns)
        self.assertEqual(index.prq.subquantizer(0).M, Msub)
        self.assertEqual(index.prq.subquantizer(0).nbits.at(0), nbits)
        self.assertEqual(index.prq.search_type, AQ.ST_norm_qint8)

        code_size = (ns * Msub * nbits + 7) // 8 + 1
        self.assertEqual(index.prq.code_size, code_size)