# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""this is a basic test script for simple indices work"""

import os
import sys
import numpy as np
import unittest
import faiss

from common_faiss_tests import compare_binary_result_lists, make_binary_dataset



def binary_to_float(x):
    n, d = x.shape
    x8 = x.reshape(n * d, -1)
    c8 = 2 * ((x8 >> np.arange(8)) & 1).astype('int8') - 1
    return c8.astype('float32').reshape(n, d * 8)


def binary_dis(x, y):
    return sum(faiss.popcount64(int(xi ^ yi)) for xi, yi in zip(x, y))


class TestBinaryPQ(unittest.TestCase):
    """ Use a PQ that mimicks a binary encoder """

    def test_encode_to_binary(self):
        d = 256
        nt = 256
        nb = 1500
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nt, nb, nq)
        pq = faiss.ProductQuantizer(d, int(d / 8), 8)

        centroids = binary_to_float(
            np.tile(np.arange(256), int(d / 8)).astype('uint8').reshape(-1, 1))

        faiss.copy_array_to_vector(centroids.ravel(), pq.centroids)
        pq.is_trained = True

        codes = pq.compute_codes(binary_to_float(xb))

        assert np.all(codes == xb)

        indexpq = faiss.IndexPQ(d, int(d / 8), 8)
        indexpq.pq = pq
        indexpq.is_trained = True

        indexpq.add(binary_to_float(xb))
        D, I = indexpq.search(binary_to_float(xq), 3)

        for i in range(nq):
            for j, dj in zip(I[i], D[i]):
                ref_dis = binary_dis(xq[i], xb[j])
                assert 4 * ref_dis == dj

        nlist = 32
        quantizer = faiss.IndexFlatL2(d)
        # pretext class for training
        iflat = faiss.IndexIVFFlat(quantizer, d, nlist)
        iflat.train(binary_to_float(xt))

        indexivfpq = faiss.IndexIVFPQ(quantizer, d, nlist, int(d / 8), 8)

        indexivfpq.pq = pq
        indexivfpq.is_trained = True
        indexivfpq.by_residual = False

        indexivfpq.add(binary_to_float(xb))
        indexivfpq.nprobe = 4

        D, I = indexivfpq.search(binary_to_float(xq), 3)

        for i in range(nq):
            for j, dj in zip(I[i], D[i]):
                ref_dis = binary_dis(xq[i], xb[j])
                assert 4 * ref_dis == dj


class TestBinaryFlat(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        unittest.TestCase.__init__(self, *args, **kwargs)
        d = 32
        nt = 0
        nb = 1500
        nq = 500

        (_, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq)

    def test_flat(self):
        d = self.xq.shape[1] * 8
        nq = self.xq.shape[0]

        index = faiss.IndexBinaryFlat(d)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        for i in range(nq):
            for j, dj in zip(I[i], D[i]):
                ref_dis = binary_dis(self.xq[i], self.xb[j])
                assert dj == ref_dis

        # test reconstruction
        assert np.all(index.reconstruct(12) == self.xb[12])

    def test_empty_flat(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryFlat(d)

        for use_heap in [True, False]:
            index.use_heap = use_heap
            Dflat, Iflat = index.search(self.xq, 10)

            assert(np.all(Iflat == -1))
            assert(np.all(Dflat == 2147483647)) # NOTE(hoss): int32_t max

    def test_range_search(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryFlat(d)
        index.add(self.xb)
        D, I = index.search(self.xq, 10)
        thresh = int(np.median(D[:, -1]))

        lims, D2, I2 = index.range_search(self.xq, thresh)
        nt1 = nt2 = 0
        for i in range(len(self.xq)):
            range_res = I2[lims[i]:lims[i + 1]]
            if thresh > D[i, -1]:
                self.assertTrue(set(I[i]) <= set(range_res))
                nt1 += 1
            elif thresh < D[i, -1]:
                self.assertTrue(set(range_res) <= set(I[i]))
                nt2 += 1
            # in case of equality we have a problem with ties
        print('nb tests', nt1, nt2)
        # nb tests is actually low...
        self.assertTrue(nt1 > 19 and nt2 > 19)


class TestBinaryIVF(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        unittest.TestCase.__init__(self, *args, **kwargs)
        d = 32
        nt = 200
        nb = 1500
        nq = 500

        (self.xt, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq)
        index = faiss.IndexBinaryFlat(d)
        index.add(self.xb)
        Dref, Iref = index.search(self.xq, 10)
        self.Dref = Dref

    def test_ivf_flat_exhaustive(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5    # quiet warning
        index.nprobe = 8
        index.train(self.xt)
        index.add(self.xb)
        Divfflat, _ = index.search(self.xq, 10)

        np.testing.assert_array_equal(self.Dref, Divfflat)

    def test_ivf_flat2(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5    # quiet warning
        index.nprobe = 4
        index.train(self.xt)
        index.add(self.xb)
        Divfflat, _ = index.search(self.xq, 10)

        # Some centroids are equidistant from the query points.
        # So the answer will depend on the implementation of the heap.
        self.assertGreater((self.Dref == Divfflat).sum(), 4100)

    def test_ivf_range(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5    # quiet warning
        index.nprobe = 4
        index.train(self.xt)
        index.add(self.xb)
        D, I = index.search(self.xq, 10)

        radius = int(np.median(D[:, -1]) + 1)
        Lr, Dr, Ir = index.range_search(self.xq, radius)

        for i in range(len(self.xq)):
            res = Ir[Lr[i]:Lr[i + 1]]
            if D[i, -1] < radius:
                self.assertTrue(set(I[i]) <= set(res))
            else:
                subset = I[i, D[i, :] < radius]
                self.assertTrue(set(subset) == set(res))


    def test_ivf_flat_empty(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryIVF(faiss.IndexBinaryFlat(d), d, 8)
        index.train(self.xt)

        for use_heap in [True, False]:
            index.use_heap = use_heap
            Divfflat, Iivfflat = index.search(self.xq, 10)

            assert(np.all(Iivfflat == -1))
            assert(np.all(Divfflat == 2147483647)) # NOTE(hoss): int32_t max

    def test_ivf_reconstruction(self):
        d = self.xq.shape[1] * 8
        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5    # quiet warning
        index.nprobe = 4
        index.train(self.xt)

        index.add(self.xb)
        index.set_direct_map_type(faiss.DirectMap.Array)

        for i in range(0, len(self.xb), 13):
            np.testing.assert_array_equal(
                index.reconstruct(i),
                self.xb[i]
            )

        # try w/ hashtable
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        rs = np.random.RandomState(123)
        ids = rs.choice(10000, size=len(self.xb), replace=False).astype(np.int64)
        index.add_with_ids(self.xb, ids)
        index.set_direct_map_type(faiss.DirectMap.Hashtable)

        for i in range(0, len(self.xb), 13):
            np.testing.assert_array_equal(
                index.reconstruct(int(ids[i])),
                self.xb[i]
            )

    def test_ivf_nprobe(self):
        """Test in case of nprobe > nlist."""
        d = self.xq.shape[1] * 8
        xt, xb, xq = self.xt, self.xb, self.xq

        # nlist = 10
        index = faiss.index_binary_factory(d, "BIVF10")

        # When nprobe >= nlist, it is equivalent to an IndexFlat.

        index.train(xt)
        index.add(xb)
        index.nprobe = 2048
        k = 5

        # test kNN search
        D, I = index.search(xq, k)

        ref_index = faiss.index_binary_factory(d, "BFlat")
        ref_index.add(xb)
        ref_D, ref_I = ref_index.search(xq, k)

        print(D[0], ref_D[0])
        print(I[0], ref_I[0])
        assert np.all(D == ref_D)
        # assert np.all(I == ref_I)  # id may be different

        # test range search
        thresh = 5   # *squared* distance
        lims, D, I = index.range_search(xq, thresh)
        ref_lims, ref_D, ref_I = ref_index.range_search(xq, thresh)
        assert np.all(lims == ref_lims)
        assert np.all(D == ref_D)
        # assert np.all(I == ref_I)  # id may be different


class TestHNSW(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        unittest.TestCase.__init__(self, *args, **kwargs)
        d = 32
        nt = 0
        nb = 1500
        nq = 500

        (_, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq)

    def test_hnsw_exact_distances(self):
        d = self.xq.shape[1] * 8
        nq = self.xq.shape[0]

        index = faiss.IndexBinaryHNSW(d, 16)
        index.add(self.xb)
        Dists, Ids = index.search(self.xq, 3)

        for i in range(nq):
            for j, dj in zip(Ids[i], Dists[i]):
                ref_dis = binary_dis(self.xq[i], self.xb[j])
                self.assertEqual(dj, ref_dis)

    def test_hnsw(self):
        d = self.xq.shape[1] * 8

        # NOTE(hoss): Ensure the HNSW construction is deterministic.
        nthreads = faiss.omp_get_max_threads()
        faiss.omp_set_num_threads(1)

        index_hnsw_float = faiss.IndexHNSWFlat(d, 16)
        index_hnsw_ref = faiss.IndexBinaryFromFloat(index_hnsw_float)

        index_hnsw_bin = faiss.IndexBinaryHNSW(d, 16)

        index_hnsw_ref.add(self.xb)
        index_hnsw_bin.add(self.xb)

        faiss.omp_set_num_threads(nthreads)

        Dref, Iref = index_hnsw_ref.search(self.xq, 3)
        Dbin, Ibin = index_hnsw_bin.search(self.xq, 3)

        self.assertTrue((Dref == Dbin).all())



class TestReplicasAndShards(unittest.TestCase):

    @unittest.skipIf(os.name == "posix" and os.uname().sysname == "Darwin",
                     "There is a bug in the OpenMP implementation on OSX.")
    def test_replicas(self):
        d = 32
        nq = 100
        nb = 200

        (_, xb, xq) = make_binary_dataset(d, 0, nb, nq)

        index_ref = faiss.IndexBinaryFlat(d)
        index_ref.add(xb)

        Dref, Iref = index_ref.search(xq, 10)

        nrep = 5
        index = faiss.IndexBinaryReplicas()
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            sub_idx.add(xb)
            index.addIndex(sub_idx)

        D, I = index.search(xq, 10)

        self.assertTrue((Dref == D).all())
        self.assertTrue((Iref == I).all())

        index2 = faiss.IndexBinaryReplicas()
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            index2.addIndex(sub_idx)

        index2.add(xb)
        D2, I2 = index2.search(xq, 10)

        self.assertTrue((Dref == D2).all())
        self.assertTrue((Iref == I2).all())

    def test_shards(self):
        d = 32
        nq = 100
        nb = 200

        (_, xb, xq) = make_binary_dataset(d, 0, nb, nq)

        index_ref = faiss.IndexBinaryFlat(d)
        index_ref.add(xb)

        Dref, Iref = index_ref.search(xq, 10)

        nrep = 5
        index = faiss.IndexBinaryShards(d)
        for i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            sub_idx.add(xb[i * nb // nrep : (i + 1) * nb // nrep])
            index.add_shard(sub_idx)

        D, I = index.search(xq, 10)

        compare_binary_result_lists(Dref, Iref, D, I)

        index2 = faiss.IndexBinaryShards(d)
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            index2.add_shard(sub_idx)

        index2.add(xb)
        D2, I2 = index2.search(xq, 10)

        compare_binary_result_lists(Dref, Iref, D2, I2)


if __name__ == '__main__':
    unittest.main()