faiss/tests/test_product_quantizer.py

148 lines
4.2 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function
import numpy as np
import faiss
import unittest
class TestProductQuantizer(unittest.TestCase):
def test_pq(self):
d = 64
n = 2000
cs = 4
np.random.seed(123)
x = np.random.random(size=(n, d)).astype('float32')
pq = faiss.ProductQuantizer(d, cs, 8)
pq.train(x)
codes = pq.compute_codes(x)
x2 = pq.decode(codes)
diff = ((x - x2)**2).sum()
# print("diff=", diff)
# diff= 4418.0562
self.assertGreater(5000, diff)
pq10 = faiss.ProductQuantizer(d, cs, 10)
assert pq10.code_size == 5
pq10.verbose = True
pq10.cp.verbose = True
pq10.train(x)
codes = pq10.compute_codes(x)
x10 = pq10.decode(codes)
diff10 = ((x - x10)**2).sum()
self.assertGreater(diff, diff10)
def do_test_codec(self, nbit):
pq = faiss.ProductQuantizer(16, 2, nbit)
# simulate training
rs = np.random.RandomState(123)
centroids = rs.rand(2, 1 << nbit, 8).astype('float32')
faiss.copy_array_to_vector(centroids.ravel(), pq.centroids)
idx = rs.randint(1 << nbit, size=(100, 2))
# can be encoded exactly
x = np.hstack((
centroids[0, idx[:, 0]],
centroids[1, idx[:, 1]]
))
# encode / decode
codes = pq.compute_codes(x)
xr = pq.decode(codes)
assert np.all(xr == x)
# encode w/ external index
assign_index = faiss.IndexFlatL2(8)
pq.assign_index = assign_index
codes2 = np.empty((100, pq.code_size), dtype='uint8')
pq.compute_codes_with_assign_index(
faiss.swig_ptr(x), faiss.swig_ptr(codes2), 100)
assert np.all(codes == codes2)
def test_codec(self):
for i in range(16):
print("Testing nbits=%d" % (i + 1))
self.do_test_codec(i + 1)
class TestPQTables(unittest.TestCase):
def do_test(self, d, dsub, nbit=8, metric=None):
if metric is None:
self.do_test(d, dsub, nbit, faiss.METRIC_INNER_PRODUCT)
self.do_test(d, dsub, nbit, faiss.METRIC_L2)
return
# faiss.cvar.distance_compute_blas_threshold = 1000000
M = d // dsub
pq = faiss.ProductQuantizer(d, M, nbit)
xt = faiss.randn((max(1000, pq.ksub * 50), d), 123)
pq.cp.niter = 4 # to avoid timeouts in tests
pq.train(xt)
centroids = faiss.vector_to_array(pq.centroids)
centroids = centroids.reshape(pq.M, pq.ksub, pq.dsub)
nx = 100
x = faiss.randn((nx, d), 555)
ref_tab = np.zeros((nx, M, pq.ksub), "float32")
# computation of tables in numpy
for sq in range(M):
i0, i1 = sq * dsub, (sq + 1) * dsub
xsub = x[:, i0:i1]
centsq = centroids[sq, :, :]
if metric == faiss.METRIC_INNER_PRODUCT:
ref_tab[:, sq, :] = xsub @ centsq.T
elif metric == faiss.METRIC_L2:
xsub3 = xsub.reshape(nx, 1, dsub)
cent3 = centsq.reshape(1, pq.ksub, dsub)
ref_tab[:, sq, :] = ((xsub3 - cent3) ** 2).sum(2)
else:
assert False
sp = faiss.swig_ptr
new_tab = np.zeros((nx, M, pq.ksub), "float32")
if metric == faiss.METRIC_INNER_PRODUCT:
pq.compute_inner_prod_tables(nx, sp(x), sp(new_tab))
elif metric == faiss.METRIC_L2:
pq.compute_distance_tables(nx, sp(x), sp(new_tab))
else:
assert False
np.testing.assert_array_almost_equal(ref_tab, new_tab, decimal=5)
def test_dsub2(self):
self.do_test(16, 2)
def test_dsub5(self):
self.do_test(20, 5)
def test_dsub2_odd(self):
self.do_test(18, 2)
def test_dsub4(self):
self.do_test(32, 4)
def test_dsub4_odd(self):
self.do_test(36, 4)
# too slow
#def test_12bit(self):
# self.do_test(32, 4, nbit=12)
def test_4bit(self):
self.do_test(32, 4, nbit=4)