faiss/tests/test_index_binary.py

325 lines
9.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#! /usr/bin/env python2
"""this is a basic test script for simple indices work"""
import numpy as np
import unittest
import faiss
def make_binary_dataset(d, nt, nb, nq):
assert d % 8 == 0
rs = np.random.RandomState(123)
x = rs.randint(256, size=(nb + nq + nt, int(d / 8))).astype('uint8')
return x[:nt], x[nt:-nq], x[-nq:]
def binary_to_float(x):
n, d = x.shape
x8 = x.reshape(n * d, -1)
c8 = 2 * ((x8 >> np.arange(8)) & 1).astype('int8') - 1
return c8.astype('float32').reshape(n, d * 8)
def binary_dis(x, y):
return sum(faiss.popcount64(int(xi ^ yi)) for xi, yi in zip(x, y))
class TestBinaryPQ(unittest.TestCase):
""" Use a PQ that mimicks a binary encoder """
def test_encode_to_binary(self):
d = 256
nt = 256
nb = 1500
nq = 500
(xt, xb, xq) = make_binary_dataset(d, nt, nb, nq)
pq = faiss.ProductQuantizer(d, int(d / 8), 8)
centroids = binary_to_float(
np.tile(np.arange(256), int(d / 8)).astype('uint8').reshape(-1, 1))
faiss.copy_array_to_vector(centroids.ravel(), pq.centroids)
pq.is_trained = True
codes = pq.compute_codes(binary_to_float(xb))
assert np.all(codes == xb)
indexpq = faiss.IndexPQ(d, int(d / 8), 8)
indexpq.pq = pq
indexpq.is_trained = True
indexpq.add(binary_to_float(xb))
D, I = indexpq.search(binary_to_float(xq), 3)
for i in range(nq):
for j, dj in zip(I[i], D[i]):
ref_dis = binary_dis(xq[i], xb[j])
assert 4 * ref_dis == dj
nlist = 32
quantizer = faiss.IndexFlatL2(d)
# pretext class for training
iflat = faiss.IndexIVFFlat(quantizer, d, nlist)
iflat.train(binary_to_float(xt))
indexivfpq = faiss.IndexIVFPQ(quantizer, d, nlist, int(d / 8), 8)
indexivfpq.pq = pq
indexivfpq.is_trained = True
indexivfpq.by_residual = False
indexivfpq.add(binary_to_float(xb))
indexivfpq.nprobe = 4
D, I = indexivfpq.search(binary_to_float(xq), 3)
for i in range(nq):
for j, dj in zip(I[i], D[i]):
ref_dis = binary_dis(xq[i], xb[j])
assert 4 * ref_dis == dj
class TestBinaryFlat(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
d = 32
nt = 0
nb = 1500
nq = 500
(_, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq)
def test_flat(self):
d = self.xq.shape[1] * 8
nq = self.xq.shape[0]
index = faiss.IndexBinaryFlat(d)
index.add(self.xb)
D, I = index.search(self.xq, 3)
for i in range(nq):
for j, dj in zip(I[i], D[i]):
ref_dis = binary_dis(self.xq[i], self.xb[j])
assert dj == ref_dis
# test reconstruction
assert np.all(index.reconstruct(12) == self.xb[12])
def test_empty_flat(self):
d = self.xq.shape[1] * 8
index = faiss.IndexBinaryFlat(d)
for use_heap in [True, False]:
index.use_heap = use_heap
Dflat, Iflat = index.search(self.xq, 10)
assert(np.all(Iflat == -1))
assert(np.all(Dflat == 2147483647)) # NOTE(hoss): int32_t max
class TestBinaryIVF(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
d = 32
nt = 200
nb = 1500
nq = 500
(self.xt, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq)
index = faiss.IndexBinaryFlat(d)
index.add(self.xb)
Dref, Iref = index.search(self.xq, 10)
self.Dref = Dref
def test_ivf_flat_exhaustive(self):
d = self.xq.shape[1] * 8
quantizer = faiss.IndexBinaryFlat(d)
index = faiss.IndexBinaryIVF(quantizer, d, 8)
index.cp.min_points_per_centroid = 5 # quiet warning
index.nprobe = 8
index.train(self.xt)
index.add(self.xb)
Divfflat, _ = index.search(self.xq, 10)
np.testing.assert_array_equal(self.Dref, Divfflat)
def test_ivf_flat2(self):
d = self.xq.shape[1] * 8
quantizer = faiss.IndexBinaryFlat(d)
index = faiss.IndexBinaryIVF(quantizer, d, 8)
index.cp.min_points_per_centroid = 5 # quiet warning
index.nprobe = 4
index.train(self.xt)
index.add(self.xb)
Divfflat, _ = index.search(self.xq, 10)
self.assertEqual((self.Dref == Divfflat).sum(), 4122)
def test_ivf_flat_empty(self):
d = self.xq.shape[1] * 8
index = faiss.IndexBinaryIVF(faiss.IndexBinaryFlat(d), d, 8)
index.train(self.xt)
for use_heap in [True, False]:
index.use_heap = use_heap
Divfflat, Iivfflat = index.search(self.xq, 10)
assert(np.all(Iivfflat == -1))
assert(np.all(Divfflat == 2147483647)) # NOTE(hoss): int32_t max
class TestHNSW(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
d = 32
nt = 0
nb = 1500
nq = 500
(_, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq)
def test_hnsw_exact_distances(self):
d = self.xq.shape[1] * 8
nq = self.xq.shape[0]
index = faiss.IndexBinaryHNSW(d, 16)
index.add(self.xb)
Dists, Ids = index.search(self.xq, 3)
for i in range(nq):
for j, dj in zip(Ids[i], Dists[i]):
ref_dis = binary_dis(self.xq[i], self.xb[j])
self.assertEqual(dj, ref_dis)
def test_hnsw(self):
d = self.xq.shape[1] * 8
# NOTE(hoss): Ensure the HNSW construction is deterministic.
nthreads = faiss.omp_get_max_threads()
faiss.omp_set_num_threads(1)
index_hnsw_float = faiss.IndexHNSWFlat(d, 16)
index_hnsw_ref = faiss.IndexBinaryFromFloat(index_hnsw_float)
index_hnsw_bin = faiss.IndexBinaryHNSW(d, 16)
index_hnsw_ref.add(self.xb)
index_hnsw_bin.add(self.xb)
faiss.omp_set_num_threads(nthreads)
Dref, Iref = index_hnsw_ref.search(self.xq, 3)
Dbin, Ibin = index_hnsw_bin.search(self.xq, 3)
self.assertTrue((Dref == Dbin).all())
def compare_binary_result_lists(D1, I1, D2, I2):
"""comparing result lists is difficult because there are many
ties. Here we sort by (distance, index) pairs and ignore the largest
distance of each result. Compatible result lists should pass this."""
assert D1.shape == I1.shape == D2.shape == I2.shape
n, k = D1.shape
ndiff = (D1 != D2).sum()
assert ndiff == 0, '%d differences in distance matrix %s' % (
ndiff, D1.shape)
def normalize_DI(D, I):
norm = I.max() + 1.0
Dr = D.astype('float64') + I / norm
# ignore -1s and elements on last column
Dr[I1 == -1] = 1e20
Dr[D == D[:, -1:]] = 1e20
Dr.sort(axis=1)
return Dr
ndiff = (normalize_DI(D1, I1) != normalize_DI(D2, I2)).sum()
assert ndiff == 0, '%d differences in normalized D matrix' % ndiff
class TestReplicasAndShards(unittest.TestCase):
def test_replicas(self):
d = 32
nq = 100
nb = 200
(_, xb, xq) = make_binary_dataset(d, 0, nb, nq)
index_ref = faiss.IndexBinaryFlat(d)
index_ref.add(xb)
Dref, Iref = index_ref.search(xq, 10)
nrep = 5
index = faiss.IndexBinaryReplicas()
for i in range(nrep):
sub_idx = faiss.IndexBinaryFlat(d)
sub_idx.add(xb)
index.addIndex(sub_idx)
D, I = index.search(xq, 10)
self.assertTrue((Dref == D).all())
self.assertTrue((Iref == I).all())
index2 = faiss.IndexBinaryReplicas()
for i in range(nrep):
sub_idx = faiss.IndexBinaryFlat(d)
index2.addIndex(sub_idx)
index2.add(xb)
D2, I2 = index2.search(xq, 10)
self.assertTrue((Dref == D2).all())
self.assertTrue((Iref == I2).all())
def test_shards(self):
d = 32
nq = 100
nb = 200
(_, xb, xq) = make_binary_dataset(d, 0, nb, nq)
index_ref = faiss.IndexBinaryFlat(d)
index_ref.add(xb)
Dref, Iref = index_ref.search(xq, 10)
nrep = 5
index = faiss.IndexBinaryShards(d)
for i in range(nrep):
sub_idx = faiss.IndexBinaryFlat(d)
sub_idx.add(xb[i * nb // nrep : (i + 1) * nb // nrep])
index.add_shard(sub_idx)
D, I = index.search(xq, 10)
compare_binary_result_lists(Dref, Iref, D, I)
index2 = faiss.IndexBinaryShards(d)
for i in range(nrep):
sub_idx = faiss.IndexBinaryFlat(d)
index2.add_shard(sub_idx)
index2.add(xb)
D2, I2 = index2.search(xq, 10)
compare_binary_result_lists(Dref, Iref, D2, I2)
if __name__ == '__main__':
unittest.main()