143 lines
4.3 KiB
Python
143 lines
4.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
# a few tests of the swig wrapper
|
|
|
|
import unittest
|
|
import faiss
|
|
import numpy as np
|
|
|
|
|
|
class TestSWIGWrap(unittest.TestCase):
|
|
""" various regressions with the SWIG wrapper """
|
|
|
|
def test_size_t_ptr(self):
|
|
# issue 1064
|
|
index = faiss.IndexHNSWFlat(10, 32)
|
|
|
|
hnsw = index.hnsw
|
|
index.add(np.random.rand(100, 10).astype('float32'))
|
|
be = np.empty(2, 'uint64')
|
|
hnsw.neighbor_range(23, 0, faiss.swig_ptr(be), faiss.swig_ptr(be[1:]))
|
|
|
|
def test_id_map_at(self):
|
|
# issue 1020
|
|
n_features = 100
|
|
feature_dims = 10
|
|
|
|
features = np.random.random((n_features, feature_dims)).astype(np.float32)
|
|
idx = np.arange(n_features).astype(np.int64)
|
|
|
|
index = faiss.IndexFlatL2(feature_dims)
|
|
index = faiss.IndexIDMap2(index)
|
|
index.add_with_ids(features, idx)
|
|
|
|
[index.id_map.at(int(i)) for i in range(index.ntotal)]
|
|
|
|
def test_downcast_Refine(self):
|
|
|
|
index = faiss.IndexRefineFlat(
|
|
faiss.IndexScalarQuantizer(10, faiss.ScalarQuantizer.QT_8bit)
|
|
)
|
|
|
|
# serialize and deserialize
|
|
index2 = faiss.deserialize_index(
|
|
faiss.serialize_index(index)
|
|
)
|
|
|
|
assert isinstance(index2, faiss.IndexRefineFlat)
|
|
|
|
def do_test_array_type(self, dtype):
|
|
""" tests swig_ptr and rev_swig_ptr for this type of array """
|
|
a = np.arange(12).astype(dtype)
|
|
ptr = faiss.swig_ptr(a)
|
|
a2 = faiss.rev_swig_ptr(ptr, 12)
|
|
np.testing.assert_array_equal(a, a2)
|
|
|
|
def test_all_array_types(self):
|
|
self.do_test_array_type('float32')
|
|
self.do_test_array_type('float64')
|
|
self.do_test_array_type('int8')
|
|
self.do_test_array_type('uint8')
|
|
self.do_test_array_type('int16')
|
|
self.do_test_array_type('uint16')
|
|
self.do_test_array_type('int32')
|
|
self.do_test_array_type('uint32')
|
|
self.do_test_array_type('int64')
|
|
self.do_test_array_type('uint64')
|
|
|
|
def test_int64(self):
|
|
# see https://github.com/facebookresearch/faiss/issues/1529
|
|
v = faiss.Int64Vector()
|
|
|
|
for i in range(10):
|
|
v.push_back(i)
|
|
a = faiss.vector_to_array(v)
|
|
assert a.dtype == 'int64'
|
|
np.testing.assert_array_equal(a, np.arange(10, dtype='int64'))
|
|
|
|
# check if it works in an IDMap
|
|
idx = faiss.IndexIDMap(faiss.IndexFlatL2(32))
|
|
idx.add_with_ids(
|
|
np.random.rand(10, 32).astype('float32'),
|
|
np.random.randint(1000, size=10, dtype='int64')
|
|
)
|
|
faiss.vector_to_array(idx.id_map)
|
|
|
|
def test_asan(self):
|
|
# this test should fail with ASAN
|
|
index = faiss.IndexFlatL2(32)
|
|
index.this.own(False) # this is a mem leak, should be catched by ASAN
|
|
|
|
def test_SWIG_version(self):
|
|
self.assertLess(faiss.swig_version(), 0x050000)
|
|
|
|
|
|
class TestRevSwigPtr(unittest.TestCase):
|
|
|
|
def test_rev_swig_ptr(self):
|
|
|
|
index = faiss.IndexFlatL2(4)
|
|
xb0 = np.vstack([
|
|
i * 10 + np.array([1, 2, 3, 4], dtype='float32')
|
|
for i in range(5)])
|
|
index.add(xb0)
|
|
xb = faiss.rev_swig_ptr(index.get_xb(), 4 * 5).reshape(5, 4)
|
|
self.assertEqual(np.abs(xb0 - xb).sum(), 0)
|
|
|
|
|
|
class TestException(unittest.TestCase):
|
|
|
|
def test_exception(self):
|
|
|
|
index = faiss.IndexFlatL2(10)
|
|
|
|
a = np.zeros((5, 10), dtype='float32')
|
|
b = np.zeros(5, dtype='int64')
|
|
|
|
# an unsupported operation for IndexFlat
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
index.add_with_ids, a, b
|
|
)
|
|
# assert 'add_with_ids not implemented' in str(e)
|
|
|
|
def test_exception_2(self):
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
faiss.index_factory, 12, 'IVF256,Flat,PQ8'
|
|
)
|
|
# assert 'could not parse' in str(e)
|
|
|
|
|
|
@unittest.skipIf(faiss.swig_version() < 0x040000, "swig < 4 does not support Doxygen comments")
|
|
class TestDoxygen(unittest.TestCase):
|
|
|
|
def test_doxygen_comments(self):
|
|
maxheap_array = faiss.float_maxheap_array_t()
|
|
|
|
self.assertTrue("a template structure for a set of [min|max]-heaps"
|
|
in maxheap_array.__doc__)
|