mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Summary: This diff streamlines the code that collects results for brute force distance computations for the L2 / IP and range search / knn search combinations. It introduces a `ResultHandler` template class that abstracts what happens with the computed distances and ids. In addition to the heap result handler and the range search result handler, it introduces a reservoir result handler that improves the search speed for large k (>=100). Benchmark results (https://fb.quip.com/y0g1ACLEqJXx#OCaACA2Gm45) show that on small datasets (10k) search is 10-50% faster (improvements are larger for small k). There is room for improvement in the reservoir implementation, whose implementation is quite naive currently, but the diff is already useful in its current form. Experiments on precomputed db vector norms for L2 distance computations were not very concluding performance-wise, so the implementation is removed from IndexFlatL2. This diff also removes IndexL2BaseShift, which was never used. Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1502 Test Plan: ``` buck test //faiss/tests/:test_product_quantizer buck test //faiss/tests/:test_index -- TestIndexFlat ``` Reviewed By: wickedfoo Differential Revision: D24705464 Pulled By: mdouze fbshipit-source-id: 270e10b19f3c89ed7b607ec30549aca0ac5027fe
145 lines
4.1 KiB
Python
145 lines
4.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import numpy as np
|
|
|
|
import faiss
|
|
import unittest
|
|
|
|
|
|
|
|
class TestProductQuantizer(unittest.TestCase):
|
|
|
|
def test_pq(self):
|
|
d = 64
|
|
n = 2000
|
|
cs = 4
|
|
np.random.seed(123)
|
|
x = np.random.random(size=(n, d)).astype('float32')
|
|
pq = faiss.ProductQuantizer(d, cs, 8)
|
|
pq.train(x)
|
|
codes = pq.compute_codes(x)
|
|
x2 = pq.decode(codes)
|
|
diff = ((x - x2)**2).sum()
|
|
|
|
# print("diff=", diff)
|
|
# diff= 4418.0562
|
|
self.assertGreater(5000, diff)
|
|
|
|
pq10 = faiss.ProductQuantizer(d, cs, 10)
|
|
assert pq10.code_size == 5
|
|
pq10.verbose = True
|
|
pq10.cp.verbose = True
|
|
pq10.train(x)
|
|
codes = pq10.compute_codes(x)
|
|
|
|
x10 = pq10.decode(codes)
|
|
diff10 = ((x - x10)**2).sum()
|
|
self.assertGreater(diff, diff10)
|
|
|
|
def do_test_codec(self, nbit):
|
|
pq = faiss.ProductQuantizer(16, 2, nbit)
|
|
|
|
# simulate training
|
|
rs = np.random.RandomState(123)
|
|
centroids = rs.rand(2, 1 << nbit, 8).astype('float32')
|
|
faiss.copy_array_to_vector(centroids.ravel(), pq.centroids)
|
|
|
|
idx = rs.randint(1 << nbit, size=(100, 2))
|
|
# can be encoded exactly
|
|
x = np.hstack((
|
|
centroids[0, idx[:, 0]],
|
|
centroids[1, idx[:, 1]]
|
|
))
|
|
|
|
# encode / decode
|
|
codes = pq.compute_codes(x)
|
|
xr = pq.decode(codes)
|
|
assert np.all(xr == x)
|
|
|
|
# encode w/ external index
|
|
assign_index = faiss.IndexFlatL2(8)
|
|
pq.assign_index = assign_index
|
|
codes2 = np.empty((100, pq.code_size), dtype='uint8')
|
|
pq.compute_codes_with_assign_index(
|
|
faiss.swig_ptr(x), faiss.swig_ptr(codes2), 100)
|
|
assert np.all(codes == codes2)
|
|
|
|
def test_codec(self):
|
|
for i in range(16):
|
|
print("Testing nbits=%d" % (i + 1))
|
|
self.do_test_codec(i + 1)
|
|
|
|
|
|
class TestPQTables(unittest.TestCase):
|
|
|
|
def do_test(self, d, dsub, nbit=8, metric=None):
|
|
if metric is None:
|
|
self.do_test(d, dsub, nbit, faiss.METRIC_INNER_PRODUCT)
|
|
self.do_test(d, dsub, nbit, faiss.METRIC_L2)
|
|
return
|
|
|
|
M = d // dsub
|
|
pq = faiss.ProductQuantizer(d, M, nbit)
|
|
pq.train(faiss.randn((max(1000, pq.ksub * 50), d), 123))
|
|
|
|
centroids = faiss.vector_to_array(pq.centroids)
|
|
centroids = centroids.reshape(pq.M, pq.ksub, pq.dsub)
|
|
|
|
nx = 100
|
|
x = faiss.randn((nx, d), 555)
|
|
|
|
ref_tab = np.zeros((nx, M, pq.ksub), "float32")
|
|
|
|
# computation of tables in numpy
|
|
for sq in range(M):
|
|
i0, i1 = sq * dsub, (sq + 1) * dsub
|
|
xsub = x[:, i0:i1]
|
|
centsq = centroids[sq, :, :]
|
|
if metric == faiss.METRIC_INNER_PRODUCT:
|
|
ref_tab[:, sq, :] = xsub @ centsq.T
|
|
elif metric == faiss.METRIC_L2:
|
|
xsub3 = xsub.reshape(nx, 1, dsub)
|
|
cent3 = centsq.reshape(1, pq.ksub, dsub)
|
|
ref_tab[:, sq, :] = ((xsub3 - cent3) ** 2).sum(2)
|
|
else:
|
|
assert False
|
|
|
|
sp = faiss.swig_ptr
|
|
|
|
new_tab = np.zeros((nx, M, pq.ksub), "float32")
|
|
if metric == faiss.METRIC_INNER_PRODUCT:
|
|
pq.compute_inner_prod_tables(nx, sp(x), sp(new_tab))
|
|
elif metric == faiss.METRIC_L2:
|
|
pq.compute_distance_tables(nx, sp(x), sp(new_tab))
|
|
else:
|
|
assert False
|
|
|
|
np.testing.assert_array_almost_equal(ref_tab, new_tab, decimal=5)
|
|
|
|
def test_dsub2(self):
|
|
self.do_test(16, 2)
|
|
|
|
def test_dsub5(self):
|
|
self.do_test(20, 5)
|
|
|
|
def test_dsub2_odd(self):
|
|
self.do_test(18, 2)
|
|
|
|
def test_dsub4(self):
|
|
self.do_test(32, 4)
|
|
|
|
def test_dsub4_odd(self):
|
|
self.do_test(36, 4)
|
|
|
|
# too slow
|
|
#def test_12bit(self):
|
|
# self.do_test(32, 4, nbit=12)
|
|
|
|
def test_4bit(self):
|
|
self.do_test(32, 4, nbit=4)
|