faiss/tests/test_refine.py
Matthijs Douze 291353c5a9 Generalize DistanceComputer for flat indexes (#2255)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2255

The `DistanceComputer` object is derived from an Index (obtained with `get_distance_computer()`). It maintains a current query and quickly computes distances from that query to any item in the database. This is useful, eg. for the IndexHNSW and IndexNSG that rely on query-to-point comparisons in the datasets.

This diff introduces the `FlatCodesDistanceComputer`, that inherits from `DistanceComputer` for Flat indexes. In addition to the distance-to-item function, it adds a `distance_to_code` that computes the distance from any code to the current query, even if it is not stored in the index.

This is implemented for all FlatCode indexes (IndexFlat, IndexPQ, IndexScalarQuantizer and IndexAdditiveQuantizer).

In the process, the two classes were extracted to their own header file `impl/DistanceComputer.h`

Reviewed By: beauby

Differential Revision: D34863609

fbshipit-source-id: 39d8c66475e55c3223c4a6a210827aa48bca292d
2022-03-20 23:43:33 -07:00

67 lines
2.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import unittest
import faiss
from faiss.contrib import datasets
class TestDistanceComputer(unittest.TestCase):
def do_test(self, factory_string, metric_type=faiss.METRIC_L2):
ds = datasets.SyntheticDataset(32, 1000, 200, 20)
index = faiss.index_factory(32, factory_string, metric_type)
index.train(ds.get_train())
index.add(ds.get_database())
xq = ds.get_queries()
Dref, Iref = index.search(xq, 10)
for is_FlatCodesDistanceComputer in False, True:
if not is_FlatCodesDistanceComputer:
dc = index.get_distance_computer()
else:
if not isinstance(index, faiss.IndexFlatCodes):
continue
dc = index.get_FlatCodesDistanceComputer()
self.assertTrue(dc.this.own())
for q in range(ds.nq):
dc.set_query(faiss.swig_ptr(xq[q]))
for j in range(10):
ref_dis = Dref[q, j]
new_dis = dc(int(Iref[q, j]))
np.testing.assert_almost_equal(
new_dis, ref_dis, decimal=5)
def test_distance_computer_PQ(self):
self.do_test("PQ8np")
def test_distance_computer_SQ(self):
self.do_test("SQ8")
def test_distance_computer_SQ6(self):
self.do_test("SQ6")
def test_distance_computer_PQbit6(self):
self.do_test("PQ8x6np")
def test_distance_computer_PQbit6_ip(self):
self.do_test("PQ8x6np", faiss.METRIC_INNER_PRODUCT)
def test_distance_computer_VT(self):
self.do_test("PCA20,SQ8")
def test_distance_computer_AQ_decompress(self):
self.do_test("RQ3x4") # test decompress path
def test_distance_computer_AQ_LUT(self):
self.do_test("RQ3x4_Nqint8") # test LUT path
def test_distance_computer_AQ_LUT_IP(self):
self.do_test("RQ3x4_Nqint8", faiss.METRIC_INNER_PRODUCT)