mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Summary: The tests TestPQTables are very slow in dev mode with BLAS. This seems to be due to the training operation of the PQ. However, since it does not matter if the training is accurate or not, we can just reduce the nb of training iterations from the default 25 to 4. Still unclear why this happens, because the runtime is spent in BLAS, which should be independend of mode/opt or mode/dev. Reviewed By: wickedfoo Differential Revision: D24783752 fbshipit-source-id: 38077709eb9a6432210c11c3040765e139353ae8
148 lines
4.2 KiB
Python
148 lines
4.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import numpy as np
|
|
|
|
import faiss
|
|
import unittest
|
|
|
|
|
|
|
|
class TestProductQuantizer(unittest.TestCase):
|
|
|
|
def test_pq(self):
|
|
d = 64
|
|
n = 2000
|
|
cs = 4
|
|
np.random.seed(123)
|
|
x = np.random.random(size=(n, d)).astype('float32')
|
|
pq = faiss.ProductQuantizer(d, cs, 8)
|
|
pq.train(x)
|
|
codes = pq.compute_codes(x)
|
|
x2 = pq.decode(codes)
|
|
diff = ((x - x2)**2).sum()
|
|
|
|
# print("diff=", diff)
|
|
# diff= 4418.0562
|
|
self.assertGreater(5000, diff)
|
|
|
|
pq10 = faiss.ProductQuantizer(d, cs, 10)
|
|
assert pq10.code_size == 5
|
|
pq10.verbose = True
|
|
pq10.cp.verbose = True
|
|
pq10.train(x)
|
|
codes = pq10.compute_codes(x)
|
|
|
|
x10 = pq10.decode(codes)
|
|
diff10 = ((x - x10)**2).sum()
|
|
self.assertGreater(diff, diff10)
|
|
|
|
def do_test_codec(self, nbit):
|
|
pq = faiss.ProductQuantizer(16, 2, nbit)
|
|
|
|
# simulate training
|
|
rs = np.random.RandomState(123)
|
|
centroids = rs.rand(2, 1 << nbit, 8).astype('float32')
|
|
faiss.copy_array_to_vector(centroids.ravel(), pq.centroids)
|
|
|
|
idx = rs.randint(1 << nbit, size=(100, 2))
|
|
# can be encoded exactly
|
|
x = np.hstack((
|
|
centroids[0, idx[:, 0]],
|
|
centroids[1, idx[:, 1]]
|
|
))
|
|
|
|
# encode / decode
|
|
codes = pq.compute_codes(x)
|
|
xr = pq.decode(codes)
|
|
assert np.all(xr == x)
|
|
|
|
# encode w/ external index
|
|
assign_index = faiss.IndexFlatL2(8)
|
|
pq.assign_index = assign_index
|
|
codes2 = np.empty((100, pq.code_size), dtype='uint8')
|
|
pq.compute_codes_with_assign_index(
|
|
faiss.swig_ptr(x), faiss.swig_ptr(codes2), 100)
|
|
assert np.all(codes == codes2)
|
|
|
|
def test_codec(self):
|
|
for i in range(16):
|
|
print("Testing nbits=%d" % (i + 1))
|
|
self.do_test_codec(i + 1)
|
|
|
|
|
|
class TestPQTables(unittest.TestCase):
|
|
|
|
def do_test(self, d, dsub, nbit=8, metric=None):
|
|
if metric is None:
|
|
self.do_test(d, dsub, nbit, faiss.METRIC_INNER_PRODUCT)
|
|
self.do_test(d, dsub, nbit, faiss.METRIC_L2)
|
|
return
|
|
# faiss.cvar.distance_compute_blas_threshold = 1000000
|
|
|
|
M = d // dsub
|
|
pq = faiss.ProductQuantizer(d, M, nbit)
|
|
xt = faiss.randn((max(1000, pq.ksub * 50), d), 123)
|
|
pq.cp.niter = 4 # to avoid timeouts in tests
|
|
pq.train(xt)
|
|
|
|
centroids = faiss.vector_to_array(pq.centroids)
|
|
centroids = centroids.reshape(pq.M, pq.ksub, pq.dsub)
|
|
|
|
nx = 100
|
|
x = faiss.randn((nx, d), 555)
|
|
|
|
ref_tab = np.zeros((nx, M, pq.ksub), "float32")
|
|
|
|
# computation of tables in numpy
|
|
for sq in range(M):
|
|
i0, i1 = sq * dsub, (sq + 1) * dsub
|
|
xsub = x[:, i0:i1]
|
|
centsq = centroids[sq, :, :]
|
|
if metric == faiss.METRIC_INNER_PRODUCT:
|
|
ref_tab[:, sq, :] = xsub @ centsq.T
|
|
elif metric == faiss.METRIC_L2:
|
|
xsub3 = xsub.reshape(nx, 1, dsub)
|
|
cent3 = centsq.reshape(1, pq.ksub, dsub)
|
|
ref_tab[:, sq, :] = ((xsub3 - cent3) ** 2).sum(2)
|
|
else:
|
|
assert False
|
|
|
|
sp = faiss.swig_ptr
|
|
|
|
new_tab = np.zeros((nx, M, pq.ksub), "float32")
|
|
if metric == faiss.METRIC_INNER_PRODUCT:
|
|
pq.compute_inner_prod_tables(nx, sp(x), sp(new_tab))
|
|
elif metric == faiss.METRIC_L2:
|
|
pq.compute_distance_tables(nx, sp(x), sp(new_tab))
|
|
else:
|
|
assert False
|
|
|
|
np.testing.assert_array_almost_equal(ref_tab, new_tab, decimal=5)
|
|
|
|
def test_dsub2(self):
|
|
self.do_test(16, 2)
|
|
|
|
def test_dsub5(self):
|
|
self.do_test(20, 5)
|
|
|
|
def test_dsub2_odd(self):
|
|
self.do_test(18, 2)
|
|
|
|
def test_dsub4(self):
|
|
self.do_test(32, 4)
|
|
|
|
def test_dsub4_odd(self):
|
|
self.do_test(36, 4)
|
|
|
|
# too slow
|
|
#def test_12bit(self):
|
|
# self.do_test(32, 4, nbit=12)
|
|
|
|
def test_4bit(self):
|
|
self.do_test(32, 4, nbit=4)
|