faiss/tests/test_binary_io.py
Lucas Hosseini a8118acbc5
Facebook sync (May 2019) + relicense (#838)
Changelog:

- changed license: BSD+Patents -> MIT
- propagates exceptions raised in sub-indexes of IndexShards and IndexReplicas
- support for searching several inverted lists in parallel (parallel_mode != 0)
- better support for PQ codes where nbit != 8 or 16
- IVFSpectralHash implementation: spectral hash codes inside an IVF
- 6-bit per component scalar quantizer (4 and 8 bit were already supported)
- combinations of inverted lists: HStackInvertedLists and VStackInvertedLists
- configurable number of threads for OnDiskInvertedLists prefetching (including 0=no prefetch)
- more test and demo code compatible with Python 3 (print with parentheses)
- refactored benchmark code: data loading is now in a single file
2019-05-28 16:17:22 +02:00

218 lines
5.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#! /usr/bin/env python2
"""Binary indexes (de)serialization"""
import numpy as np
import unittest
import faiss
import os
import tempfile
def make_binary_dataset(d, nb, nt, nq):
assert d % 8 == 0
x = np.random.randint(256, size=(nb + nq + nt, int(d / 8))).astype('uint8')
return x[:nt], x[nt:-nq], x[-nq:]
class TestBinaryFlat(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
d = 32
nt = 0
nb = 1500
nq = 500
(_, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
def test_flat(self):
d = self.xq.shape[1] * 8
index = faiss.IndexBinaryFlat(d)
index.add(self.xb)
D, I = index.search(self.xq, 3)
_, tmpnam = tempfile.mkstemp()
try:
faiss.write_index_binary(index, tmpnam)
index2 = faiss.read_index_binary(tmpnam)
D2, I2 = index2.search(self.xq, 3)
assert (I2 == I).all()
assert (D2 == D).all()
finally:
os.remove(tmpnam)
class TestBinaryIVF(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
d = 32
nt = 200
nb = 1500
nq = 500
(self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
def test_ivf_flat(self):
d = self.xq.shape[1] * 8
quantizer = faiss.IndexBinaryFlat(d)
index = faiss.IndexBinaryIVF(quantizer, d, 8)
index.cp.min_points_per_centroid = 5 # quiet warning
index.nprobe = 4
index.train(self.xt)
index.add(self.xb)
D, I = index.search(self.xq, 3)
_, tmpnam = tempfile.mkstemp()
try:
faiss.write_index_binary(index, tmpnam)
index2 = faiss.read_index_binary(tmpnam)
D2, I2 = index2.search(self.xq, 3)
assert (I2 == I).all()
assert (D2 == D).all()
finally:
os.remove(tmpnam)
class TestObjectOwnership(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
d = 32
nt = 200
nb = 1500
nq = 500
(self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
def test_read_index_ownership(self):
d = self.xq.shape[1] * 8
index = faiss.IndexBinaryFlat(d)
index.add(self.xb)
_, tmpnam = tempfile.mkstemp()
try:
faiss.write_index_binary(index, tmpnam)
index2 = faiss.read_index_binary(tmpnam)
assert index2.thisown
finally:
os.remove(tmpnam)
class TestBinaryFromFloat(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
d = 32
nt = 200
nb = 1500
nq = 500
(self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
def test_binary_from_float(self):
d = self.xq.shape[1] * 8
float_index = faiss.IndexHNSWFlat(d, 16)
index = faiss.IndexBinaryFromFloat(float_index)
index.add(self.xb)
D, I = index.search(self.xq, 3)
_, tmpnam = tempfile.mkstemp()
try:
faiss.write_index_binary(index, tmpnam)
index2 = faiss.read_index_binary(tmpnam)
D2, I2 = index2.search(self.xq, 3)
assert (I2 == I).all()
assert (D2 == D).all()
finally:
os.remove(tmpnam)
class TestBinaryHNSW(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
d = 32
nt = 200
nb = 1500
nq = 500
(self.xt, self.xb, self.xq) = make_binary_dataset(d, nb, nt, nq)
def test_hnsw(self):
d = self.xq.shape[1] * 8
index = faiss.IndexBinaryHNSW(d)
index.add(self.xb)
D, I = index.search(self.xq, 3)
_, tmpnam = tempfile.mkstemp()
try:
faiss.write_index_binary(index, tmpnam)
index2 = faiss.read_index_binary(tmpnam)
D2, I2 = index2.search(self.xq, 3)
assert (I2 == I).all()
assert (D2 == D).all()
finally:
os.remove(tmpnam)
def test_ivf_hnsw(self):
d = self.xq.shape[1] * 8
quantizer = faiss.IndexBinaryHNSW(d)
index = faiss.IndexBinaryIVF(quantizer, d, 8)
index.cp.min_points_per_centroid = 5 # quiet warning
index.nprobe = 4
index.train(self.xt)
index.add(self.xb)
D, I = index.search(self.xq, 3)
_, tmpnam = tempfile.mkstemp()
try:
faiss.write_index_binary(index, tmpnam)
index2 = faiss.read_index_binary(tmpnam)
D2, I2 = index2.search(self.xq, 3)
assert (I2 == I).all()
assert (D2 == D).all()
finally:
os.remove(tmpnam)
if __name__ == '__main__':
unittest.main()