# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Binary indexes (de)serialization"""

from __future__ import absolute_import, division, print_function, unicode_literals

import numpy as np
import unittest
import faiss
import os
import tempfile

def make_binary_dataset(d, nb, nt, nq):
    assert d % 8 == 0
    x = np.random.randint(256, size=(nb + nq + nt, int(d / 8))).astype('uint8')
    return x[:nt], x[nt:-nq], x[-nq:]


class TestBinaryFlat(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        unittest.TestCase.__init__(self, *args, **kwargs)
        d = 32
        nt = 0
        nb = 1500
        nq = 500

        (_, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)

    def test_flat(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryFlat(d)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        fd, tmpnam = tempfile.mkstemp()
        os.close(fd)
        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            D2, I2 = index2.search(self.xq, 3)

            assert (I2 == I).all()
            assert (D2 == D).all()

        finally:
            os.remove(tmpnam)


class TestBinaryIVF(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        unittest.TestCase.__init__(self, *args, **kwargs)
        d = 32
        nt = 200
        nb = 1500
        nq = 500

        (self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)

    def test_ivf_flat(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5    # quiet warning
        index.nprobe = 4
        index.train(self.xt)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        fd, tmpnam = tempfile.mkstemp()
        os.close(fd)
        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            D2, I2 = index2.search(self.xq, 3)

            assert (I2 == I).all()
            assert (D2 == D).all()

        finally:
            os.remove(tmpnam)


class TestObjectOwnership(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        unittest.TestCase.__init__(self, *args, **kwargs)
        d = 32
        nt = 200
        nb = 1500
        nq = 500

        (self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)

    def test_read_index_ownership(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryFlat(d)
        index.add(self.xb)

        fd, tmpnam = tempfile.mkstemp()
        os.close(fd)
        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            assert index2.thisown
        finally:
            os.remove(tmpnam)


class TestBinaryFromFloat(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        unittest.TestCase.__init__(self, *args, **kwargs)
        d = 32
        nt = 200
        nb = 1500
        nq = 500

        (self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)

    def test_binary_from_float(self):
        d = self.xq.shape[1] * 8

        float_index = faiss.IndexHNSWFlat(d, 16)
        index = faiss.IndexBinaryFromFloat(float_index)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        fd, tmpnam = tempfile.mkstemp()
        os.close(fd)
        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            D2, I2 = index2.search(self.xq, 3)

            assert (I2 == I).all()
            assert (D2 == D).all()

        finally:
            os.remove(tmpnam)


class TestBinaryHNSW(unittest.TestCase):

    def __init__(self, *args, **kwargs):
        unittest.TestCase.__init__(self, *args, **kwargs)
        d = 32
        nt = 200
        nb = 1500
        nq = 500

        (self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)

    def test_hnsw(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryHNSW(d)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        fd, tmpnam = tempfile.mkstemp()
        os.close(fd)
        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            D2, I2 = index2.search(self.xq, 3)

            assert (I2 == I).all()
            assert (D2 == D).all()

        finally:
            os.remove(tmpnam)

    def test_ivf_hnsw(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryHNSW(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5    # quiet warning
        index.nprobe = 4
        index.train(self.xt)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        fd, tmpnam = tempfile.mkstemp()
        os.close(fd)
        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            D2, I2 = index2.search(self.xq, 3)

            assert (I2 == I).all()
            assert (D2 == D).all()

        finally:
            os.remove(tmpnam)


if __name__ == '__main__':
    unittest.main()