220 lines
5.3 KiB
Python
220 lines
5.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Binary indexes (de)serialization"""
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import numpy as np
|
|
import unittest
|
|
import faiss
|
|
import os
|
|
import tempfile
|
|
|
|
def make_binary_dataset(d, nb, nt, nq):
|
|
assert d % 8 == 0
|
|
x = np.random.randint(256, size=(nb + nq + nt, int(d / 8))).astype('uint8')
|
|
return x[:nt], x[nt:-nq], x[-nq:]
|
|
|
|
|
|
class TestBinaryFlat(unittest.TestCase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
unittest.TestCase.__init__(self, *args, **kwargs)
|
|
d = 32
|
|
nt = 0
|
|
nb = 1500
|
|
nq = 500
|
|
|
|
(_, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
|
|
|
|
def test_flat(self):
|
|
d = self.xq.shape[1] * 8
|
|
|
|
index = faiss.IndexBinaryFlat(d)
|
|
index.add(self.xb)
|
|
D, I = index.search(self.xq, 3)
|
|
|
|
fd, tmpnam = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index_binary(index, tmpnam)
|
|
|
|
index2 = faiss.read_index_binary(tmpnam)
|
|
|
|
D2, I2 = index2.search(self.xq, 3)
|
|
|
|
assert (I2 == I).all()
|
|
assert (D2 == D).all()
|
|
|
|
finally:
|
|
os.remove(tmpnam)
|
|
|
|
|
|
class TestBinaryIVF(unittest.TestCase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
unittest.TestCase.__init__(self, *args, **kwargs)
|
|
d = 32
|
|
nt = 200
|
|
nb = 1500
|
|
nq = 500
|
|
|
|
(self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
|
|
|
|
def test_ivf_flat(self):
|
|
d = self.xq.shape[1] * 8
|
|
|
|
quantizer = faiss.IndexBinaryFlat(d)
|
|
index = faiss.IndexBinaryIVF(quantizer, d, 8)
|
|
index.cp.min_points_per_centroid = 5 # quiet warning
|
|
index.nprobe = 4
|
|
index.train(self.xt)
|
|
index.add(self.xb)
|
|
D, I = index.search(self.xq, 3)
|
|
|
|
fd, tmpnam = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index_binary(index, tmpnam)
|
|
|
|
index2 = faiss.read_index_binary(tmpnam)
|
|
|
|
D2, I2 = index2.search(self.xq, 3)
|
|
|
|
assert (I2 == I).all()
|
|
assert (D2 == D).all()
|
|
|
|
finally:
|
|
os.remove(tmpnam)
|
|
|
|
|
|
class TestObjectOwnership(unittest.TestCase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
unittest.TestCase.__init__(self, *args, **kwargs)
|
|
d = 32
|
|
nt = 200
|
|
nb = 1500
|
|
nq = 500
|
|
|
|
(self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
|
|
|
|
def test_read_index_ownership(self):
|
|
d = self.xq.shape[1] * 8
|
|
|
|
index = faiss.IndexBinaryFlat(d)
|
|
index.add(self.xb)
|
|
|
|
fd, tmpnam = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index_binary(index, tmpnam)
|
|
|
|
index2 = faiss.read_index_binary(tmpnam)
|
|
|
|
assert index2.thisown
|
|
finally:
|
|
os.remove(tmpnam)
|
|
|
|
|
|
class TestBinaryFromFloat(unittest.TestCase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
unittest.TestCase.__init__(self, *args, **kwargs)
|
|
d = 32
|
|
nt = 200
|
|
nb = 1500
|
|
nq = 500
|
|
|
|
(self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
|
|
|
|
def test_binary_from_float(self):
|
|
d = self.xq.shape[1] * 8
|
|
|
|
float_index = faiss.IndexHNSWFlat(d, 16)
|
|
index = faiss.IndexBinaryFromFloat(float_index)
|
|
index.add(self.xb)
|
|
D, I = index.search(self.xq, 3)
|
|
|
|
fd, tmpnam = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index_binary(index, tmpnam)
|
|
|
|
index2 = faiss.read_index_binary(tmpnam)
|
|
|
|
D2, I2 = index2.search(self.xq, 3)
|
|
|
|
assert (I2 == I).all()
|
|
assert (D2 == D).all()
|
|
|
|
finally:
|
|
os.remove(tmpnam)
|
|
|
|
|
|
class TestBinaryHNSW(unittest.TestCase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
unittest.TestCase.__init__(self, *args, **kwargs)
|
|
d = 32
|
|
nt = 200
|
|
nb = 1500
|
|
nq = 500
|
|
|
|
(self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
|
|
|
|
def test_hnsw(self):
|
|
d = self.xq.shape[1] * 8
|
|
|
|
index = faiss.IndexBinaryHNSW(d)
|
|
index.add(self.xb)
|
|
D, I = index.search(self.xq, 3)
|
|
|
|
fd, tmpnam = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index_binary(index, tmpnam)
|
|
|
|
index2 = faiss.read_index_binary(tmpnam)
|
|
|
|
D2, I2 = index2.search(self.xq, 3)
|
|
|
|
assert (I2 == I).all()
|
|
assert (D2 == D).all()
|
|
|
|
finally:
|
|
os.remove(tmpnam)
|
|
|
|
def test_ivf_hnsw(self):
|
|
d = self.xq.shape[1] * 8
|
|
|
|
quantizer = faiss.IndexBinaryHNSW(d)
|
|
index = faiss.IndexBinaryIVF(quantizer, d, 8)
|
|
index.cp.min_points_per_centroid = 5 # quiet warning
|
|
index.nprobe = 4
|
|
index.train(self.xt)
|
|
index.add(self.xb)
|
|
D, I = index.search(self.xq, 3)
|
|
|
|
fd, tmpnam = tempfile.mkstemp()
|
|
os.close(fd)
|
|
try:
|
|
faiss.write_index_binary(index, tmpnam)
|
|
|
|
index2 = faiss.read_index_binary(tmpnam)
|
|
|
|
D2, I2 = index2.search(self.xq, 3)
|
|
|
|
assert (I2 == I).all()
|
|
assert (D2 == D).all()
|
|
|
|
finally:
|
|
os.remove(tmpnam)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|