mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2255 The `DistanceComputer` object is derived from an Index (obtained with `get_distance_computer()`). It maintains a current query and quickly computes distances from that query to any item in the database. This is useful, eg. for the IndexHNSW and IndexNSG that rely on query-to-point comparisons in the datasets. This diff introduces the `FlatCodesDistanceComputer`, that inherits from `DistanceComputer` for Flat indexes. In addition to the distance-to-item function, it adds a `distance_to_code` that computes the distance from any code to the current query, even if it is not stored in the index. This is implemented for all FlatCode indexes (IndexFlat, IndexPQ, IndexScalarQuantizer and IndexAdditiveQuantizer). In the process, the two classes were extracted to their own header file `impl/DistanceComputer.h` Reviewed By: beauby Differential Revision: D34863609 fbshipit-source-id: 39d8c66475e55c3223c4a6a210827aa48bca292d
67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import numpy as np
|
|
import unittest
|
|
import faiss
|
|
|
|
from faiss.contrib import datasets
|
|
|
|
|
|
class TestDistanceComputer(unittest.TestCase):
|
|
|
|
def do_test(self, factory_string, metric_type=faiss.METRIC_L2):
|
|
ds = datasets.SyntheticDataset(32, 1000, 200, 20)
|
|
|
|
index = faiss.index_factory(32, factory_string, metric_type)
|
|
index.train(ds.get_train())
|
|
index.add(ds.get_database())
|
|
xq = ds.get_queries()
|
|
Dref, Iref = index.search(xq, 10)
|
|
|
|
for is_FlatCodesDistanceComputer in False, True:
|
|
if not is_FlatCodesDistanceComputer:
|
|
dc = index.get_distance_computer()
|
|
else:
|
|
if not isinstance(index, faiss.IndexFlatCodes):
|
|
continue
|
|
dc = index.get_FlatCodesDistanceComputer()
|
|
self.assertTrue(dc.this.own())
|
|
for q in range(ds.nq):
|
|
dc.set_query(faiss.swig_ptr(xq[q]))
|
|
for j in range(10):
|
|
ref_dis = Dref[q, j]
|
|
new_dis = dc(int(Iref[q, j]))
|
|
np.testing.assert_almost_equal(
|
|
new_dis, ref_dis, decimal=5)
|
|
|
|
def test_distance_computer_PQ(self):
|
|
self.do_test("PQ8np")
|
|
|
|
def test_distance_computer_SQ(self):
|
|
self.do_test("SQ8")
|
|
|
|
def test_distance_computer_SQ6(self):
|
|
self.do_test("SQ6")
|
|
|
|
def test_distance_computer_PQbit6(self):
|
|
self.do_test("PQ8x6np")
|
|
|
|
def test_distance_computer_PQbit6_ip(self):
|
|
self.do_test("PQ8x6np", faiss.METRIC_INNER_PRODUCT)
|
|
|
|
def test_distance_computer_VT(self):
|
|
self.do_test("PCA20,SQ8")
|
|
|
|
def test_distance_computer_AQ_decompress(self):
|
|
self.do_test("RQ3x4") # test decompress path
|
|
|
|
def test_distance_computer_AQ_LUT(self):
|
|
self.do_test("RQ3x4_Nqint8") # test LUT path
|
|
|
|
def test_distance_computer_AQ_LUT_IP(self):
|
|
self.do_test("RQ3x4_Nqint8", faiss.METRIC_INNER_PRODUCT)
|