2024-10-22 09:46:48 -07:00
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2020-10-13 11:14:25 -07:00
|
|
|
#
|
|
|
|
# This source code is licensed under the MIT license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
2023-03-14 11:11:50 -07:00
|
|
|
import os
|
2024-11-19 12:09:40 -08:00
|
|
|
import platform
|
2024-04-03 10:36:56 -07:00
|
|
|
import shutil
|
2023-06-14 07:58:44 -07:00
|
|
|
import tempfile
|
2024-11-19 12:09:40 -08:00
|
|
|
import unittest
|
|
|
|
from contextlib import contextmanager
|
2020-08-03 22:15:02 +02:00
|
|
|
|
2024-11-19 12:09:40 -08:00
|
|
|
import faiss
|
|
|
|
import numpy as np
|
Migration off defaults to conda-forge channel (#4126)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4126
Good resource on overriding channels to make sure we aren't using `defaults`:https://stackoverflow.com/questions/67695893/how-do-i-completely-purge-and-disable-the-default-channel-in-anaconda-and-switch
Explanation of changes:
-
- changed to miniforge from miniconda: this ensures we only pull in from conda-defaults when creating the environment
- architecture: ARM64 and aarch64 are the same thing. But there is no miniforge package for ARM64, so we need to make it check for aarch64 instead. However, mac breaks this rule, and does have macOS-arm64! So there is a conditional for mac to use arm64. https://github.com/conda-forge/miniforge/releases/
- action.yml mkl 2022.2.1 change: conda-forge and defaults have completely different dependencies. Defaults required intel-openmp, but now on conda-forge, mkl 2023.1 or higher requires llvm-openmp >=14.0.6, but this is incompatible with the pytorch build <2.5 which requires llvm-openmp<14.0. We would need to upgrade Python to 3.12 first, upgrade Pytorch build, then upgrade this mkl. (The meta.yaml changes are the ones that narrow it to 2022.2.1 during `conda build faiss`.) So, this has just been changed to 2022.2.1.
- mkl now requires _openmp_mutex of type "llvm" instead of "gnu": prior non-cuVS builds all used gnu, because intel-openmp from anaconda defaults channel does not require llvm-openmp. Now we need to remove the gnu one which is automatically pulled in during miniconda setup, and only keep the llvm version of _openmp_mutex.
- liblief: The above changes tried to pull in liblief 0.15. This results in an error like `AttributeError: module 'lief._lief.ELF' has no attribute 'ELF_CLASS'`. When I checked passing PR builds on defaults, they use lief 0.12, so I pinned that one for Python 3.9 3.10 3.11. For Python 3.12, we need lief 0.14 or higher.
- gcc_linux-64 =11.2 for faiss-gpu on cudatoolkit-11.2: GPU builds kept trying to reference 11.2 when 14.2 was installed. I couldn't figure out why, or how to point it to the 14.2 installed on the host. Current nightly builds still reference 11.2, so I gave up and pinned 11.2 to keep it the same. Moving to 14.2 will take some more investigation.
- meta.yaml mkl 2023.0 vs 2023.1 with python versions: 3.9, 3.10, and 3.11 pass with 2023.0, but python 3.12 needs mkl 2023.1 or higher. Otherwise we get:
```
INTEL MKL ERROR: $PREFIX/lib/python3.12/site-packages/faiss/../../.././libmkl_def.so.2: undefined symbol: mkl_sparse_optimize_bsr_trsm_i8.
Intel MKL FATAL ERROR: Cannot load libmkl_def.so.2.
```
so the solution was to put a bunch of conditions in in faiss/meta.yaml.
We should be able to use Jinja macros to reduce duplication but it requires some investigation. It was failing: https://github.com/facebookresearch/faiss/actions/runs/12915187334/job/36016477707?pr=4126 (paste of logs here: P1716887936). This can be a future BE task.
Macro example (the `-` signs remove whitespace lines before and after)
```
{% macro inclmkldevel() %}
{%- if PY_VER == '3.9' or PY_VER == '3.10' or PY_VER == '3.11' -%}
- mkl-devel =2023.0 # [x86_64]
- liblief =0.12.3 # [not win]
- python_abi <3.12
{%- elif PY_VER == '3.12' %}
- mkl-devel >=2023.2.0 # [x86_64]
- liblief =0.15.1 # [not win]
- python_abi =3.12
{% endif -%}
{% endmacro %}
```
The python_abi was required to be pinned inside these conditions because otherwise several builds got this error:
```
File "/Users/runner/miniconda3/lib/python3.12/site-packages/conda_build/utils.py", line 1919, in insert_variant_versions
matches = [regex.match(pkg) for pkg in reqs]
^^^^^^^^^^^^^^^^
TypeError: expected string or bytes-like object, got 'list'
```
Unit test notes:
-
- test_gpu_basics.py: GPU residual quantizer: Debugged extensively with Matthijs. The problem is in the C++ -> Python conversion. The C++ side prints the right values, but when getting it back to Python, it is filled with junk data. It is only reproducible on CUDA 11.4.4 after switching channels. It is likely a compiler problem. We discussed, and resolved to create a C++ side unit test (so this diff creates TestGpuResidualQuantizer) to verify the functionality and disable the Python unit test, but leave it in the codebase with a comment. Matthijs made extensive notes in https://docs.google.com/document/d/1MjMdOpPgx-MArdrYJZCaQlRqlrhSj5Y1Z9lTyiab8jc/edit?usp=sharing .
- test_contrib.py: this now hangs forever and times out the runner for Windows on Python 3.12. I have it skipping now.
- test_mem_leak.cpp seems flaky. It sometimes fails, then passes with rerun.
Unfixed issues:
-
- I noticed sometimes downloads will fail with the text like below. It passes on re-run.
```
libgomp-14.2.0-h77fa898_1.conda extraction failed
Warning: error libmamba Error when extracting package: Could not chdir info/recipe/parent/patches/0005-Hardcode-HAVE_ALIGNED_ALLOC-1-in-libstdc-v3-configur.patch
error libmamba Error when extracting package: Could not chdir info/recipe/parent/patches/0005-Hardcode-HAVE_ALIGNED_ALLOC-1-in-libstdc-v3-configur.patch
Warning: Found incorrect download: libgomp. Aborting
Found incorrect download: libgomp. Aborting
Warning:
```
Green build and tests for both build pull request and nightlies: https://github.com/facebookresearch/faiss/actions/runs/12956402963/job/36148818361
Reviewed By: asadoughi
Differential Revision: D68043874
fbshipit-source-id: b105a1e3e6272763ad9daab7fc6f05a79f01c9e2
2025-01-27 14:49:18 -08:00
|
|
|
import sys
|
2020-08-27 19:18:12 -07:00
|
|
|
|
2021-05-09 22:29:28 -07:00
|
|
|
from common_faiss_tests import get_dataset_2
|
2024-11-19 12:09:40 -08:00
|
|
|
|
|
|
|
from faiss.contrib import (
|
|
|
|
big_batch_search,
|
|
|
|
clustering,
|
|
|
|
datasets,
|
|
|
|
evaluation,
|
|
|
|
inspect_tools,
|
|
|
|
ivf_tools,
|
|
|
|
)
|
|
|
|
from faiss.contrib.exhaustive_search import (
|
|
|
|
exponential_query_iterator,
|
|
|
|
knn,
|
|
|
|
knn_ground_truth,
|
|
|
|
range_ground_truth,
|
|
|
|
range_search_max_results,
|
|
|
|
)
|
|
|
|
from faiss.contrib.ondisk import merge_ondisk
|
2020-12-17 17:15:54 -08:00
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
|
2020-08-03 22:15:02 +02:00
|
|
|
class TestComputeGT(unittest.TestCase):
|
|
|
|
|
2025-01-21 11:45:07 -08:00
|
|
|
def do_test_compute_GT(self, metric=faiss.METRIC_L2, ngpu=0):
|
2020-08-03 22:15:02 +02:00
|
|
|
d = 64
|
|
|
|
xt, xb, xq = get_dataset_2(d, 0, 10000, 100)
|
|
|
|
|
2022-02-07 19:35:22 -08:00
|
|
|
index = faiss.IndexFlat(d, metric)
|
2020-08-03 22:15:02 +02:00
|
|
|
index.add(xb)
|
|
|
|
Dref, Iref = index.search(xq, 10)
|
|
|
|
|
|
|
|
# iterator function on the matrix
|
|
|
|
|
|
|
|
def matrix_iterator(xb, bs):
|
|
|
|
for i0 in range(0, xb.shape[0], bs):
|
|
|
|
yield xb[i0:i0 + bs]
|
|
|
|
|
2022-12-09 08:53:13 -08:00
|
|
|
Dnew, Inew = knn_ground_truth(
|
2025-01-21 11:45:07 -08:00
|
|
|
xq, matrix_iterator(xb, 1000), 10, metric, ngpu=ngpu)
|
2020-08-03 22:15:02 +02:00
|
|
|
|
|
|
|
np.testing.assert_array_equal(Iref, Inew)
|
2020-09-30 11:12:46 -07:00
|
|
|
# decimal = 4 required when run on GPU
|
|
|
|
np.testing.assert_almost_equal(Dref, Dnew, decimal=4)
|
2020-08-27 19:18:12 -07:00
|
|
|
|
2022-02-07 19:35:22 -08:00
|
|
|
def test_compute_GT(self):
|
|
|
|
self.do_test_compute_GT()
|
|
|
|
|
|
|
|
def test_compute_GT_ip(self):
|
|
|
|
self.do_test_compute_GT(faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
2025-01-21 11:45:07 -08:00
|
|
|
def test_compute_GT_gpu(self):
|
|
|
|
self.do_test_compute_GT(ngpu=-1)
|
|
|
|
|
|
|
|
def test_compute_GT_ip_gpu(self):
|
|
|
|
self.do_test_compute_GT(faiss.METRIC_INNER_PRODUCT, ngpu=-1)
|
|
|
|
|
2020-08-27 19:18:12 -07:00
|
|
|
|
|
|
|
class TestDatasets(unittest.TestCase):
|
|
|
|
"""here we test only the synthetic dataset. Datasets that require
|
|
|
|
disk or manifold access are in
|
|
|
|
//deeplearning/projects/faiss-forge/test_faiss_datasets/:test_faiss_datasets
|
|
|
|
"""
|
|
|
|
|
|
|
|
def test_synthetic(self):
|
2020-10-20 03:44:56 -07:00
|
|
|
ds = datasets.SyntheticDataset(32, 1000, 2000, 10)
|
2020-08-27 19:18:12 -07:00
|
|
|
xq = ds.get_queries()
|
|
|
|
self.assertEqual(xq.shape, (10, 32))
|
|
|
|
xb = ds.get_database()
|
|
|
|
self.assertEqual(xb.shape, (2000, 32))
|
|
|
|
ds.check_sizes()
|
|
|
|
|
2020-10-20 03:44:56 -07:00
|
|
|
def test_synthetic_ip(self):
|
|
|
|
ds = datasets.SyntheticDataset(32, 1000, 2000, 10, "IP")
|
|
|
|
index = faiss.IndexFlatIP(32)
|
|
|
|
index.add(ds.get_database())
|
|
|
|
np.testing.assert_array_equal(
|
|
|
|
ds.get_groundtruth(100),
|
|
|
|
index.search(ds.get_queries(), 100)[1]
|
|
|
|
)
|
|
|
|
|
2020-08-27 19:18:12 -07:00
|
|
|
def test_synthetic_iterator(self):
|
2020-10-20 03:44:56 -07:00
|
|
|
ds = datasets.SyntheticDataset(32, 1000, 2000, 10)
|
2020-08-27 19:18:12 -07:00
|
|
|
xb = ds.get_database()
|
|
|
|
xb2 = []
|
|
|
|
for xbi in ds.database_iterator():
|
|
|
|
xb2.append(xbi)
|
|
|
|
xb2 = np.vstack(xb2)
|
|
|
|
np.testing.assert_array_equal(xb, xb2)
|
2020-10-09 07:55:44 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TestExhaustiveSearch(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_knn_cpu(self):
|
|
|
|
xb = np.random.rand(200, 32).astype('float32')
|
|
|
|
xq = np.random.rand(100, 32).astype('float32')
|
|
|
|
|
|
|
|
index = faiss.IndexFlatL2(32)
|
|
|
|
index.add(xb)
|
|
|
|
Dref, Iref = index.search(xq, 10)
|
|
|
|
|
|
|
|
Dnew, Inew = knn(xq, xb, 10)
|
|
|
|
|
|
|
|
assert np.all(Inew == Iref)
|
|
|
|
assert np.allclose(Dref, Dnew)
|
|
|
|
|
|
|
|
index = faiss.IndexFlatIP(32)
|
|
|
|
index.add(xb)
|
|
|
|
Dref, Iref = index.search(xq, 10)
|
|
|
|
|
2021-02-03 12:18:28 -08:00
|
|
|
Dnew, Inew = knn(xq, xb, 10, metric=faiss.METRIC_INNER_PRODUCT)
|
2020-10-09 07:55:44 -07:00
|
|
|
|
|
|
|
assert np.all(Inew == Iref)
|
|
|
|
assert np.allclose(Dref, Dnew)
|
2020-12-03 10:04:50 -08:00
|
|
|
|
2020-12-17 17:15:54 -08:00
|
|
|
def do_test_range(self, metric):
|
|
|
|
ds = datasets.SyntheticDataset(32, 0, 1000, 10)
|
|
|
|
xq = ds.get_queries()
|
|
|
|
xb = ds.get_database()
|
2021-02-03 12:18:28 -08:00
|
|
|
D, I = faiss.knn(xq, xb, 10, metric=metric)
|
2020-12-17 17:15:54 -08:00
|
|
|
threshold = float(D[:, -1].mean())
|
|
|
|
|
|
|
|
index = faiss.IndexFlat(32, metric)
|
|
|
|
index.add(xb)
|
|
|
|
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
|
|
|
|
|
|
|
|
new_lims, new_D, new_I = range_ground_truth(
|
|
|
|
xq, ds.database_iterator(bs=100), threshold, ngpu=0,
|
|
|
|
metric_type=metric)
|
|
|
|
|
2023-05-16 00:27:53 -07:00
|
|
|
evaluation.check_ref_range_results(
|
2020-12-17 17:15:54 -08:00
|
|
|
ref_lims, ref_D, ref_I,
|
|
|
|
new_lims, new_D, new_I
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_range_L2(self):
|
|
|
|
self.do_test_range(faiss.METRIC_L2)
|
|
|
|
|
|
|
|
def test_range_IP(self):
|
|
|
|
self.do_test_range(faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
|
|
def test_query_iterator(self, metric=faiss.METRIC_L2):
|
|
|
|
ds = datasets.SyntheticDataset(32, 0, 1000, 1000)
|
|
|
|
xq = ds.get_queries()
|
|
|
|
xb = ds.get_database()
|
2021-02-03 12:18:28 -08:00
|
|
|
D, I = faiss.knn(xq, xb, 10, metric=metric)
|
2020-12-17 17:15:54 -08:00
|
|
|
threshold = float(D[:, -1].mean())
|
|
|
|
|
|
|
|
index = faiss.IndexFlat(32, metric)
|
|
|
|
index.add(xb)
|
|
|
|
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
|
|
|
|
|
|
|
|
def matrix_iterator(xb, bs):
|
|
|
|
for i0 in range(0, xb.shape[0], bs):
|
|
|
|
yield xb[i0:i0 + bs]
|
|
|
|
|
|
|
|
# check repro OK
|
|
|
|
_, new_lims, new_D, new_I = range_search_max_results(
|
2021-07-01 16:06:59 -07:00
|
|
|
index, matrix_iterator(xq, 100), threshold, max_results=1e10)
|
2020-12-17 17:15:54 -08:00
|
|
|
|
2023-05-16 00:27:53 -07:00
|
|
|
evaluation.check_ref_range_results(
|
2020-12-17 17:15:54 -08:00
|
|
|
ref_lims, ref_D, ref_I,
|
|
|
|
new_lims, new_D, new_I
|
|
|
|
)
|
|
|
|
|
|
|
|
max_res = ref_lims[-1] // 2
|
|
|
|
|
|
|
|
new_threshold, new_lims, new_D, new_I = range_search_max_results(
|
|
|
|
index, matrix_iterator(xq, 100), threshold, max_results=max_res)
|
|
|
|
|
|
|
|
self.assertLessEqual(new_lims[-1], max_res)
|
|
|
|
|
|
|
|
ref_lims, ref_D, ref_I = index.range_search(xq, new_threshold)
|
|
|
|
|
2023-05-16 00:27:53 -07:00
|
|
|
evaluation.check_ref_range_results(
|
2020-12-17 17:15:54 -08:00
|
|
|
ref_lims, ref_D, ref_I,
|
|
|
|
new_lims, new_D, new_I
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-12-03 10:04:50 -08:00
|
|
|
class TestInspect(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_LinearTransform(self):
|
|
|
|
# training data
|
|
|
|
xt = np.random.rand(1000, 20).astype('float32')
|
|
|
|
# test data
|
|
|
|
x = np.random.rand(10, 20).astype('float32')
|
|
|
|
# make the PCA matrix
|
|
|
|
pca = faiss.PCAMatrix(20, 10)
|
|
|
|
pca.train(xt)
|
|
|
|
# apply it to test data
|
|
|
|
yref = pca.apply_py(x)
|
|
|
|
|
|
|
|
A, b = inspect_tools.get_LinearTransform_matrix(pca)
|
|
|
|
|
|
|
|
# verify
|
|
|
|
ynew = x @ A.T + b
|
|
|
|
np.testing.assert_array_almost_equal(yref, ynew)
|
2020-12-17 17:15:54 -08:00
|
|
|
|
2021-12-07 01:27:41 -08:00
|
|
|
def test_IndexFlat(self):
|
|
|
|
xb = np.random.rand(13, 20).astype('float32')
|
|
|
|
index = faiss.IndexFlatL2(20)
|
|
|
|
index.add(xb)
|
|
|
|
np.testing.assert_array_equal(
|
|
|
|
xb, inspect_tools.get_flat_data(index)
|
|
|
|
)
|
|
|
|
|
2023-05-26 02:59:01 -07:00
|
|
|
def test_make_LT(self):
|
|
|
|
rs = np.random.RandomState(123)
|
|
|
|
X = rs.rand(13, 20).astype('float32')
|
|
|
|
A = rs.rand(5, 20).astype('float32')
|
|
|
|
b = rs.rand(5).astype('float32')
|
|
|
|
Yref = X @ A.T + b
|
|
|
|
lt = inspect_tools.make_LinearTransform_matrix(A, b)
|
|
|
|
Ynew = lt.apply(X)
|
2023-08-16 09:30:41 -07:00
|
|
|
np.testing.assert_allclose(Yref, Ynew, rtol=1e-06)
|
2023-05-26 02:59:01 -07:00
|
|
|
|
2023-08-04 06:55:24 -07:00
|
|
|
def test_NSG_neighbors(self):
|
|
|
|
# FIXME number of elements to add should be >> 100
|
|
|
|
ds = datasets.SyntheticDataset(32, 0, 200, 10)
|
|
|
|
index = faiss.index_factory(ds.d, "NSG")
|
|
|
|
index.add(ds.get_database())
|
|
|
|
neighbors = inspect_tools.get_NSG_neighbors(index.nsg)
|
|
|
|
# neighbors should be either valid indexes or -1
|
|
|
|
np.testing.assert_array_less(-2, neighbors)
|
|
|
|
np.testing.assert_array_less(neighbors, ds.nb)
|
|
|
|
|
2020-12-17 17:15:54 -08:00
|
|
|
|
|
|
|
class TestRangeEval(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_precision_recall(self):
|
|
|
|
Iref = [
|
|
|
|
[1, 2, 3],
|
|
|
|
[5, 6],
|
|
|
|
[],
|
|
|
|
[]
|
|
|
|
]
|
|
|
|
Inew = [
|
|
|
|
[1, 2],
|
|
|
|
[6, 7],
|
|
|
|
[1],
|
|
|
|
[]
|
|
|
|
]
|
|
|
|
|
|
|
|
lims_ref = np.cumsum([0] + [len(x) for x in Iref])
|
|
|
|
Iref = np.hstack(Iref)
|
|
|
|
lims_new = np.cumsum([0] + [len(x) for x in Inew])
|
|
|
|
Inew = np.hstack(Inew)
|
|
|
|
|
|
|
|
precision, recall = evaluation.range_PR(lims_ref, Iref, lims_new, Inew)
|
|
|
|
|
|
|
|
self.assertEqual(precision, 0.6)
|
|
|
|
self.assertEqual(recall, 0.6)
|
|
|
|
|
|
|
|
def test_PR_multiple(self):
|
|
|
|
metric = faiss.METRIC_L2
|
|
|
|
ds = datasets.SyntheticDataset(32, 1000, 1000, 10)
|
|
|
|
xq = ds.get_queries()
|
|
|
|
xb = ds.get_database()
|
|
|
|
|
|
|
|
# good for ~10k results
|
|
|
|
threshold = 15
|
|
|
|
|
|
|
|
index = faiss.IndexFlat(32, metric)
|
|
|
|
index.add(xb)
|
|
|
|
ref_lims, ref_D, ref_I = index.range_search(xq, threshold)
|
|
|
|
|
|
|
|
# now make a slightly suboptimal index
|
|
|
|
index2 = faiss.index_factory(32, "PCA16,Flat")
|
|
|
|
index2.train(ds.get_train())
|
|
|
|
index2.add(xb)
|
|
|
|
|
|
|
|
# PCA reduces distances so will have more results
|
|
|
|
new_lims, new_D, new_I = index2.range_search(xq, threshold)
|
|
|
|
|
|
|
|
all_thr = np.array([5.0, 10.0, 12.0, 15.0])
|
|
|
|
for mode in "overall", "average":
|
|
|
|
ref_precisions = np.zeros_like(all_thr)
|
|
|
|
ref_recalls = np.zeros_like(all_thr)
|
|
|
|
|
|
|
|
for i, thr in enumerate(all_thr):
|
|
|
|
|
|
|
|
lims2, _, I2 = evaluation.filter_range_results(
|
|
|
|
new_lims, new_D, new_I, thr)
|
|
|
|
|
|
|
|
prec, recall = evaluation.range_PR(
|
|
|
|
ref_lims, ref_I, lims2, I2, mode=mode)
|
|
|
|
|
|
|
|
ref_precisions[i] = prec
|
|
|
|
ref_recalls[i] = recall
|
|
|
|
|
|
|
|
precisions, recalls = evaluation.range_PR_multiple_thresholds(
|
|
|
|
ref_lims, ref_I,
|
|
|
|
new_lims, new_D, new_I, all_thr,
|
|
|
|
mode=mode
|
|
|
|
)
|
|
|
|
|
|
|
|
np.testing.assert_array_almost_equal(ref_precisions, precisions)
|
|
|
|
np.testing.assert_array_almost_equal(ref_recalls, recalls)
|
2021-02-25 11:37:44 -08:00
|
|
|
|
|
|
|
|
|
|
|
class TestPreassigned(unittest.TestCase):
|
|
|
|
|
2024-01-30 09:20:07 -08:00
|
|
|
def test_index_pretransformed(self):
|
|
|
|
|
|
|
|
ds = datasets.SyntheticDataset(128, 2000, 2000, 200)
|
|
|
|
xt = ds.get_train()
|
|
|
|
xq = ds.get_queries()
|
|
|
|
xb = ds.get_database()
|
|
|
|
index = faiss.index_factory(128, 'PCA64,IVF64,PQ4np')
|
|
|
|
index.train(xt)
|
|
|
|
index.add(xb)
|
|
|
|
index_downcasted = faiss.extract_index_ivf(index)
|
|
|
|
index_downcasted.nprobe = 10
|
|
|
|
xq_trans = index.chain.at(0).apply_py(xq)
|
|
|
|
D_ref, I_ref = index.search(xq, 4)
|
|
|
|
|
|
|
|
quantizer = index_downcasted.quantizer
|
|
|
|
Dq, Iq = quantizer.search(xq_trans, index_downcasted.nprobe)
|
|
|
|
D, I = ivf_tools.search_preassigned(index, xq, 4, Iq, Dq)
|
|
|
|
np.testing.assert_almost_equal(D_ref, D, decimal=4)
|
|
|
|
np.testing.assert_array_equal(I_ref, I)
|
|
|
|
|
2021-02-25 11:37:44 -08:00
|
|
|
def test_float(self):
|
|
|
|
ds = datasets.SyntheticDataset(128, 2000, 2000, 200)
|
|
|
|
|
|
|
|
d = ds.d
|
|
|
|
xt = ds.get_train()
|
|
|
|
xq = ds.get_queries()
|
|
|
|
xb = ds.get_database()
|
|
|
|
|
|
|
|
# define alternative quantizer on the 20 first dims of vectors
|
|
|
|
km = faiss.Kmeans(20, 50)
|
|
|
|
km.train(xt[:, :20].copy())
|
|
|
|
alt_quantizer = km.index
|
|
|
|
|
|
|
|
index = faiss.index_factory(d, "IVF50,PQ16np")
|
|
|
|
index.by_residual = False
|
|
|
|
|
|
|
|
# (optional) fake coarse quantizer
|
|
|
|
fake_centroids = np.zeros((index.nlist, index.d), dtype="float32")
|
|
|
|
index.quantizer.add(fake_centroids)
|
|
|
|
|
|
|
|
# train the PQ part
|
|
|
|
index.train(xt)
|
|
|
|
|
|
|
|
# add elements xb
|
|
|
|
a = alt_quantizer.search(xb[:, :20].copy(), 1)[1].ravel()
|
|
|
|
ivf_tools.add_preassigned(index, xb, a)
|
|
|
|
|
2022-12-09 08:53:13 -08:00
|
|
|
# search elements xq, increase nprobe, check 4 first results w/
|
|
|
|
# groundtruth
|
2021-02-25 11:37:44 -08:00
|
|
|
prev_inter_perf = 0
|
|
|
|
for nprobe in 1, 10, 20:
|
|
|
|
|
|
|
|
index.nprobe = nprobe
|
|
|
|
a = alt_quantizer.search(xq[:, :20].copy(), index.nprobe)[1]
|
|
|
|
D, I = ivf_tools.search_preassigned(index, xq, 4, a)
|
2022-12-09 08:53:13 -08:00
|
|
|
inter_perf = faiss.eval_intersection(
|
|
|
|
I, ds.get_groundtruth()[:, :4])
|
2021-02-25 11:37:44 -08:00
|
|
|
self.assertTrue(inter_perf >= prev_inter_perf)
|
|
|
|
prev_inter_perf = inter_perf
|
|
|
|
|
|
|
|
# test range search
|
|
|
|
|
|
|
|
index.nprobe = 20
|
|
|
|
|
|
|
|
a = alt_quantizer.search(xq[:, :20].copy(), index.nprobe)[1]
|
|
|
|
|
|
|
|
# just to find a reasonable radius
|
|
|
|
D, I = ivf_tools.search_preassigned(index, xq, 4, a)
|
|
|
|
radius = D.max() * 1.01
|
|
|
|
|
|
|
|
lims, DR, IR = ivf_tools.range_search_preassigned(index, xq, radius, a)
|
|
|
|
|
2022-12-09 08:53:13 -08:00
|
|
|
# with that radius the k-NN results are a subset of the range search
|
|
|
|
# results
|
2021-02-25 11:37:44 -08:00
|
|
|
for q in range(len(xq)):
|
|
|
|
l0, l1 = lims[q], lims[q + 1]
|
|
|
|
self.assertTrue(set(I[q]) <= set(IR[l0:l1]))
|
|
|
|
|
Migration off defaults to conda-forge channel (#4126)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4126
Good resource on overriding channels to make sure we aren't using `defaults`:https://stackoverflow.com/questions/67695893/how-do-i-completely-purge-and-disable-the-default-channel-in-anaconda-and-switch
Explanation of changes:
-
- changed to miniforge from miniconda: this ensures we only pull in from conda-defaults when creating the environment
- architecture: ARM64 and aarch64 are the same thing. But there is no miniforge package for ARM64, so we need to make it check for aarch64 instead. However, mac breaks this rule, and does have macOS-arm64! So there is a conditional for mac to use arm64. https://github.com/conda-forge/miniforge/releases/
- action.yml mkl 2022.2.1 change: conda-forge and defaults have completely different dependencies. Defaults required intel-openmp, but now on conda-forge, mkl 2023.1 or higher requires llvm-openmp >=14.0.6, but this is incompatible with the pytorch build <2.5 which requires llvm-openmp<14.0. We would need to upgrade Python to 3.12 first, upgrade Pytorch build, then upgrade this mkl. (The meta.yaml changes are the ones that narrow it to 2022.2.1 during `conda build faiss`.) So, this has just been changed to 2022.2.1.
- mkl now requires _openmp_mutex of type "llvm" instead of "gnu": prior non-cuVS builds all used gnu, because intel-openmp from anaconda defaults channel does not require llvm-openmp. Now we need to remove the gnu one which is automatically pulled in during miniconda setup, and only keep the llvm version of _openmp_mutex.
- liblief: The above changes tried to pull in liblief 0.15. This results in an error like `AttributeError: module 'lief._lief.ELF' has no attribute 'ELF_CLASS'`. When I checked passing PR builds on defaults, they use lief 0.12, so I pinned that one for Python 3.9 3.10 3.11. For Python 3.12, we need lief 0.14 or higher.
- gcc_linux-64 =11.2 for faiss-gpu on cudatoolkit-11.2: GPU builds kept trying to reference 11.2 when 14.2 was installed. I couldn't figure out why, or how to point it to the 14.2 installed on the host. Current nightly builds still reference 11.2, so I gave up and pinned 11.2 to keep it the same. Moving to 14.2 will take some more investigation.
- meta.yaml mkl 2023.0 vs 2023.1 with python versions: 3.9, 3.10, and 3.11 pass with 2023.0, but python 3.12 needs mkl 2023.1 or higher. Otherwise we get:
```
INTEL MKL ERROR: $PREFIX/lib/python3.12/site-packages/faiss/../../.././libmkl_def.so.2: undefined symbol: mkl_sparse_optimize_bsr_trsm_i8.
Intel MKL FATAL ERROR: Cannot load libmkl_def.so.2.
```
so the solution was to put a bunch of conditions in in faiss/meta.yaml.
We should be able to use Jinja macros to reduce duplication but it requires some investigation. It was failing: https://github.com/facebookresearch/faiss/actions/runs/12915187334/job/36016477707?pr=4126 (paste of logs here: P1716887936). This can be a future BE task.
Macro example (the `-` signs remove whitespace lines before and after)
```
{% macro inclmkldevel() %}
{%- if PY_VER == '3.9' or PY_VER == '3.10' or PY_VER == '3.11' -%}
- mkl-devel =2023.0 # [x86_64]
- liblief =0.12.3 # [not win]
- python_abi <3.12
{%- elif PY_VER == '3.12' %}
- mkl-devel >=2023.2.0 # [x86_64]
- liblief =0.15.1 # [not win]
- python_abi =3.12
{% endif -%}
{% endmacro %}
```
The python_abi was required to be pinned inside these conditions because otherwise several builds got this error:
```
File "/Users/runner/miniconda3/lib/python3.12/site-packages/conda_build/utils.py", line 1919, in insert_variant_versions
matches = [regex.match(pkg) for pkg in reqs]
^^^^^^^^^^^^^^^^
TypeError: expected string or bytes-like object, got 'list'
```
Unit test notes:
-
- test_gpu_basics.py: GPU residual quantizer: Debugged extensively with Matthijs. The problem is in the C++ -> Python conversion. The C++ side prints the right values, but when getting it back to Python, it is filled with junk data. It is only reproducible on CUDA 11.4.4 after switching channels. It is likely a compiler problem. We discussed, and resolved to create a C++ side unit test (so this diff creates TestGpuResidualQuantizer) to verify the functionality and disable the Python unit test, but leave it in the codebase with a comment. Matthijs made extensive notes in https://docs.google.com/document/d/1MjMdOpPgx-MArdrYJZCaQlRqlrhSj5Y1Z9lTyiab8jc/edit?usp=sharing .
- test_contrib.py: this now hangs forever and times out the runner for Windows on Python 3.12. I have it skipping now.
- test_mem_leak.cpp seems flaky. It sometimes fails, then passes with rerun.
Unfixed issues:
-
- I noticed sometimes downloads will fail with the text like below. It passes on re-run.
```
libgomp-14.2.0-h77fa898_1.conda extraction failed
Warning: error libmamba Error when extracting package: Could not chdir info/recipe/parent/patches/0005-Hardcode-HAVE_ALIGNED_ALLOC-1-in-libstdc-v3-configur.patch
error libmamba Error when extracting package: Could not chdir info/recipe/parent/patches/0005-Hardcode-HAVE_ALIGNED_ALLOC-1-in-libstdc-v3-configur.patch
Warning: Found incorrect download: libgomp. Aborting
Found incorrect download: libgomp. Aborting
Warning:
```
Green build and tests for both build pull request and nightlies: https://github.com/facebookresearch/faiss/actions/runs/12956402963/job/36148818361
Reviewed By: asadoughi
Differential Revision: D68043874
fbshipit-source-id: b105a1e3e6272763ad9daab7fc6f05a79f01c9e2
2025-01-27 14:49:18 -08:00
|
|
|
@unittest.skipIf(
|
|
|
|
platform.system() == 'Windows'
|
|
|
|
and sys.version_info[0] == 3
|
|
|
|
and sys.version_info[1] == 12,
|
|
|
|
'test_binary hangs for Windows on Python 3.12.'
|
|
|
|
)
|
2021-02-25 11:37:44 -08:00
|
|
|
def test_binary(self):
|
|
|
|
ds = datasets.SyntheticDataset(128, 2000, 2000, 200)
|
|
|
|
|
|
|
|
d = ds.d
|
|
|
|
xt = ds.get_train()
|
|
|
|
xq = ds.get_queries()
|
|
|
|
xb = ds.get_database()
|
|
|
|
|
2022-12-09 08:53:13 -08:00
|
|
|
# define alternative quantizer on the 20 first dims of vectors
|
|
|
|
# (will be in float)
|
2021-02-25 11:37:44 -08:00
|
|
|
km = faiss.Kmeans(20, 50)
|
|
|
|
km.train(xt[:, :20].copy())
|
|
|
|
alt_quantizer = km.index
|
|
|
|
|
|
|
|
binarizer = faiss.index_factory(d, "ITQ,LSHt")
|
|
|
|
binarizer.train(xt)
|
|
|
|
|
|
|
|
xb_bin = binarizer.sa_encode(xb)
|
|
|
|
xq_bin = binarizer.sa_encode(xq)
|
|
|
|
|
|
|
|
index = faiss.index_binary_factory(d, "BIVF200")
|
|
|
|
|
|
|
|
fake_centroids = np.zeros((index.nlist, index.d // 8), dtype="uint8")
|
|
|
|
index.quantizer.add(fake_centroids)
|
|
|
|
index.is_trained = True
|
|
|
|
|
|
|
|
# add elements xb
|
|
|
|
a = alt_quantizer.search(xb[:, :20].copy(), 1)[1].ravel()
|
|
|
|
ivf_tools.add_preassigned(index, xb_bin, a)
|
|
|
|
|
2022-04-20 03:03:38 -07:00
|
|
|
# recompute GT in binary
|
|
|
|
k = 15
|
|
|
|
ib = faiss.IndexBinaryFlat(128)
|
|
|
|
ib.add(xb_bin)
|
|
|
|
Dgt, Igt = ib.search(xq_bin, k)
|
|
|
|
|
2022-12-09 08:53:13 -08:00
|
|
|
# search elements xq, increase nprobe, check 4 first results w/
|
|
|
|
# groundtruth
|
2021-02-25 11:37:44 -08:00
|
|
|
prev_inter_perf = 0
|
|
|
|
for nprobe in 1, 10, 20:
|
|
|
|
|
|
|
|
index.nprobe = nprobe
|
|
|
|
a = alt_quantizer.search(xq[:, :20].copy(), index.nprobe)[1]
|
2022-04-20 03:03:38 -07:00
|
|
|
D, I = ivf_tools.search_preassigned(index, xq_bin, k, a)
|
|
|
|
inter_perf = faiss.eval_intersection(I, Igt)
|
|
|
|
self.assertGreaterEqual(inter_perf, prev_inter_perf)
|
2021-02-25 11:37:44 -08:00
|
|
|
prev_inter_perf = inter_perf
|
|
|
|
|
|
|
|
# test range search
|
|
|
|
|
|
|
|
index.nprobe = 20
|
|
|
|
|
|
|
|
a = alt_quantizer.search(xq[:, :20].copy(), index.nprobe)[1]
|
|
|
|
|
|
|
|
# just to find a reasonable radius
|
|
|
|
D, I = ivf_tools.search_preassigned(index, xq_bin, 4, a)
|
|
|
|
radius = int(D.max() + 1)
|
|
|
|
|
2022-12-09 08:53:13 -08:00
|
|
|
lims, DR, IR = ivf_tools.range_search_preassigned(
|
|
|
|
index, xq_bin, radius, a)
|
2021-02-25 11:37:44 -08:00
|
|
|
|
2022-12-09 08:53:13 -08:00
|
|
|
# with that radius the k-NN results are a subset of the range
|
|
|
|
# search results
|
2021-02-25 11:37:44 -08:00
|
|
|
for q in range(len(xq)):
|
|
|
|
l0, l1 = lims[q], lims[q + 1]
|
|
|
|
self.assertTrue(set(I[q]) <= set(IR[l0:l1]))
|
2021-07-01 16:06:59 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TestRangeSearchMaxResults(unittest.TestCase):
|
|
|
|
|
|
|
|
def do_test(self, metric_type):
|
|
|
|
ds = datasets.SyntheticDataset(32, 0, 1000, 200)
|
|
|
|
index = faiss.IndexFlat(ds.d, metric_type)
|
|
|
|
index.add(ds.get_database())
|
|
|
|
|
|
|
|
# find a reasonable radius
|
|
|
|
D, _ = index.search(ds.get_queries(), 10)
|
|
|
|
radius0 = float(np.median(D[:, -1]))
|
|
|
|
|
|
|
|
# baseline = search with that radius
|
|
|
|
lims_ref, Dref, Iref = index.range_search(ds.get_queries(), radius0)
|
|
|
|
|
2022-12-09 08:53:13 -08:00
|
|
|
# now see if using just the total number of results, we can get back
|
|
|
|
# the same result table
|
2021-07-01 16:06:59 -07:00
|
|
|
query_iterator = exponential_query_iterator(ds.get_queries())
|
|
|
|
|
|
|
|
init_radius = 1e10 if metric_type == faiss.METRIC_L2 else -1e10
|
|
|
|
radius1, lims_new, Dnew, Inew = range_search_max_results(
|
2022-12-09 08:53:13 -08:00
|
|
|
index, query_iterator, init_radius,
|
|
|
|
min_results=Dref.size, clip_to_min=True
|
2021-07-01 16:06:59 -07:00
|
|
|
)
|
|
|
|
|
2023-05-16 00:27:53 -07:00
|
|
|
evaluation.check_ref_range_results(
|
2021-07-01 16:06:59 -07:00
|
|
|
lims_ref, Dref, Iref,
|
|
|
|
lims_new, Dnew, Inew
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_L2(self):
|
|
|
|
self.do_test(faiss.METRIC_L2)
|
|
|
|
|
|
|
|
def test_IP(self):
|
|
|
|
self.do_test(faiss.METRIC_INNER_PRODUCT)
|
2022-02-28 14:18:47 -08:00
|
|
|
|
2023-06-19 06:05:14 -07:00
|
|
|
def test_binary(self):
|
|
|
|
ds = datasets.SyntheticDataset(64, 1000, 1000, 200)
|
|
|
|
tobinary = faiss.index_factory(ds.d, "LSHrt")
|
|
|
|
tobinary.train(ds.get_train())
|
|
|
|
index = faiss.IndexBinaryFlat(ds.d)
|
|
|
|
xb = tobinary.sa_encode(ds.get_database())
|
|
|
|
xq = tobinary.sa_encode(ds.get_queries())
|
|
|
|
index.add(xb)
|
|
|
|
|
|
|
|
# find a reasonable radius
|
|
|
|
D, _ = index.search(xq, 10)
|
|
|
|
radius0 = int(np.median(D[:, -1]))
|
|
|
|
|
|
|
|
# baseline = search with that radius
|
|
|
|
lims_ref, Dref, Iref = index.range_search(xq, radius0)
|
|
|
|
|
|
|
|
# now see if using just the total number of results, we can get back
|
|
|
|
# the same result table
|
|
|
|
query_iterator = exponential_query_iterator(xq)
|
|
|
|
|
|
|
|
radius1, lims_new, Dnew, Inew = range_search_max_results(
|
|
|
|
index, query_iterator, ds.d // 2,
|
|
|
|
min_results=Dref.size, clip_to_min=True
|
|
|
|
)
|
|
|
|
|
|
|
|
evaluation.check_ref_range_results(
|
|
|
|
lims_ref, Dref, Iref,
|
|
|
|
lims_new, Dnew, Inew
|
|
|
|
)
|
|
|
|
|
2022-02-28 14:18:47 -08:00
|
|
|
|
|
|
|
class TestClustering(unittest.TestCase):
|
|
|
|
|
2024-09-20 09:15:27 -07:00
|
|
|
def test_python_kmeans(self):
|
|
|
|
""" Test the python implementation of kmeans """
|
|
|
|
ds = datasets.SyntheticDataset(32, 10000, 0, 0)
|
|
|
|
x = ds.get_train()
|
|
|
|
|
|
|
|
# bad distribution to stress-test split code
|
|
|
|
xt = x[:10000].copy()
|
|
|
|
xt[:5000] = x[0]
|
|
|
|
|
|
|
|
km_ref = faiss.Kmeans(ds.d, 100, niter=10)
|
|
|
|
km_ref.train(xt)
|
|
|
|
err = faiss.knn(xt, km_ref.centroids, 1)[0].sum()
|
|
|
|
|
|
|
|
data = clustering.DatasetAssign(xt)
|
|
|
|
centroids = clustering.kmeans(100, data, 10)
|
|
|
|
err2 = faiss.knn(xt, centroids, 1)[0].sum()
|
|
|
|
|
|
|
|
# err=33498.332 err2=33380.477
|
|
|
|
self.assertLess(err2, err * 1.1)
|
|
|
|
|
2022-02-28 14:18:47 -08:00
|
|
|
def test_2level(self):
|
|
|
|
" verify that 2-level clustering is not too sub-optimal "
|
|
|
|
ds = datasets.SyntheticDataset(32, 10000, 0, 0)
|
|
|
|
xt = ds.get_train()
|
|
|
|
km_ref = faiss.Kmeans(ds.d, 100)
|
|
|
|
km_ref.train(xt)
|
|
|
|
err = faiss.knn(xt, km_ref.centroids, 1)[0].sum()
|
|
|
|
|
2023-03-28 07:23:30 -07:00
|
|
|
centroids2, _ = clustering.two_level_clustering(xt, 10, 100)
|
2022-02-28 14:18:47 -08:00
|
|
|
err2 = faiss.knn(xt, centroids2, 1)[0].sum()
|
|
|
|
|
|
|
|
self.assertLess(err2, err * 1.1)
|
|
|
|
|
|
|
|
def test_ivf_train_2level(self):
|
|
|
|
" check 2-level clustering with IVF training "
|
|
|
|
ds = datasets.SyntheticDataset(32, 10000, 1000, 200)
|
|
|
|
index = faiss.index_factory(ds.d, "PCA16,IVF100,SQ8")
|
|
|
|
faiss.extract_index_ivf(index).nprobe = 10
|
|
|
|
index.train(ds.get_train())
|
|
|
|
index.add(ds.get_database())
|
|
|
|
Dref, Iref = index.search(ds.get_queries(), 1)
|
|
|
|
|
|
|
|
index = faiss.index_factory(ds.d, "PCA16,IVF100,SQ8")
|
|
|
|
faiss.extract_index_ivf(index).nprobe = 10
|
2022-12-09 08:53:13 -08:00
|
|
|
clustering.train_ivf_index_with_2level(
|
2023-03-28 07:23:30 -07:00
|
|
|
index, ds.get_train(), verbose=True, rebalance=False)
|
2022-02-28 14:18:47 -08:00
|
|
|
index.add(ds.get_database())
|
|
|
|
Dnew, Inew = index.search(ds.get_queries(), 1)
|
|
|
|
|
|
|
|
# normally 47 / 200 differences
|
|
|
|
ndiff = (Iref != Inew).sum()
|
2024-12-23 08:56:26 -08:00
|
|
|
self.assertLess(ndiff, 53)
|
2022-12-09 08:53:13 -08:00
|
|
|
|
|
|
|
class TestBigBatchSearch(unittest.TestCase):
|
|
|
|
|
|
|
|
def do_test(self, factory_string, metric=faiss.METRIC_L2):
|
|
|
|
# ds = datasets.SyntheticDataset(32, 2000, 4000, 1000)
|
|
|
|
ds = datasets.SyntheticDataset(32, 2000, 400, 500)
|
|
|
|
k = 10
|
|
|
|
index = faiss.index_factory(ds.d, factory_string, metric)
|
|
|
|
assert index.metric_type == metric
|
|
|
|
index.train(ds.get_train())
|
|
|
|
index.add(ds.get_database())
|
|
|
|
index.nprobe = 5
|
|
|
|
Dref, Iref = index.search(ds.get_queries(), k)
|
|
|
|
# faiss.omp_set_num_threads(1)
|
|
|
|
for method in ("pairwise_distances", "knn_function", "index"):
|
2023-06-14 07:58:44 -07:00
|
|
|
for threaded in 0, 1, 2:
|
2023-05-04 09:59:06 -07:00
|
|
|
Dnew, Inew = big_batch_search.big_batch_search(
|
2022-12-09 08:53:13 -08:00
|
|
|
index, ds.get_queries(),
|
|
|
|
k, method=method,
|
|
|
|
threaded=threaded
|
|
|
|
)
|
|
|
|
self.assertLess((Inew != Iref).sum() / Iref.size, 1e-4)
|
|
|
|
np.testing.assert_almost_equal(Dnew, Dref, decimal=4)
|
|
|
|
|
|
|
|
def test_Flat(self):
|
|
|
|
self.do_test("IVF64,Flat")
|
|
|
|
|
|
|
|
def test_Flat_IP(self):
|
|
|
|
self.do_test("IVF64,Flat", metric=faiss.METRIC_INNER_PRODUCT)
|
|
|
|
|
|
|
|
def test_PQ(self):
|
|
|
|
self.do_test("IVF64,PQ4np")
|
|
|
|
|
|
|
|
def test_SQ(self):
|
|
|
|
self.do_test("IVF64,SQ8")
|
2023-03-14 11:11:50 -07:00
|
|
|
|
|
|
|
def test_checkpoint(self):
|
|
|
|
ds = datasets.SyntheticDataset(32, 2000, 400, 500)
|
|
|
|
k = 10
|
|
|
|
index = faiss.index_factory(ds.d, "IVF64,SQ8")
|
|
|
|
index.train(ds.get_train())
|
|
|
|
index.add(ds.get_database())
|
|
|
|
index.nprobe = 5
|
|
|
|
Dref, Iref = index.search(ds.get_queries(), k)
|
|
|
|
|
2023-06-14 07:58:44 -07:00
|
|
|
checkpoint = tempfile.mktemp()
|
2023-03-14 11:11:50 -07:00
|
|
|
try:
|
|
|
|
# First big batch search
|
|
|
|
try:
|
2023-05-04 09:59:06 -07:00
|
|
|
Dnew, Inew = big_batch_search.big_batch_search(
|
2023-03-14 11:11:50 -07:00
|
|
|
index, ds.get_queries(),
|
|
|
|
k, method="knn_function",
|
2023-06-14 07:58:44 -07:00
|
|
|
threaded=2,
|
|
|
|
checkpoint=checkpoint, checkpoint_freq=0.1,
|
2023-03-14 11:11:50 -07:00
|
|
|
crash_at=20
|
|
|
|
)
|
|
|
|
except ZeroDivisionError:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
self.assertFalse("should have crashed")
|
|
|
|
# Second big batch search
|
2023-05-04 09:59:06 -07:00
|
|
|
Dnew, Inew = big_batch_search.big_batch_search(
|
2023-03-14 11:11:50 -07:00
|
|
|
index, ds.get_queries(),
|
|
|
|
k, method="knn_function",
|
2023-06-14 07:58:44 -07:00
|
|
|
threaded=2,
|
|
|
|
checkpoint=checkpoint, checkpoint_freq=5
|
2023-03-14 11:11:50 -07:00
|
|
|
)
|
|
|
|
self.assertLess((Inew != Iref).sum() / Iref.size, 1e-4)
|
|
|
|
np.testing.assert_almost_equal(Dnew, Dref, decimal=4)
|
|
|
|
finally:
|
|
|
|
if os.path.exists(checkpoint):
|
|
|
|
os.unlink(checkpoint)
|
2023-05-04 09:59:06 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TestInvlistSort(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_sort(self):
|
|
|
|
""" make sure that the search results do not change
|
|
|
|
after sorting the inverted lists """
|
|
|
|
ds = datasets.SyntheticDataset(32, 2000, 200, 20)
|
|
|
|
index = faiss.index_factory(ds.d, "IVF50,SQ8")
|
|
|
|
index.train(ds.get_train())
|
|
|
|
index.add(ds.get_database())
|
|
|
|
index.nprobe = 5
|
|
|
|
Dref, Iref = index.search(ds.get_queries(), 5)
|
|
|
|
|
|
|
|
ivf_tools.sort_invlists_by_size(index)
|
|
|
|
list_sizes = ivf_tools.get_invlist_sizes(index.invlists)
|
|
|
|
assert np.all(list_sizes[1:] >= list_sizes[:-1])
|
|
|
|
|
|
|
|
Dnew, Inew = index.search(ds.get_queries(), 5)
|
|
|
|
np.testing.assert_equal(Dnew, Dref)
|
|
|
|
np.testing.assert_equal(Inew, Iref)
|
|
|
|
|
|
|
|
def test_hnsw_permute(self):
|
2024-11-19 12:09:40 -08:00
|
|
|
"""
|
|
|
|
make sure HNSW permutation works
|
|
|
|
(useful when used as coarse quantizer)
|
|
|
|
"""
|
2023-05-04 09:59:06 -07:00
|
|
|
ds = datasets.SyntheticDataset(32, 0, 1000, 50)
|
|
|
|
index = faiss.index_factory(ds.d, "HNSW32,Flat")
|
|
|
|
index.add(ds.get_database())
|
|
|
|
Dref, Iref = index.search(ds.get_queries(), 5)
|
|
|
|
rs = np.random.RandomState(1234)
|
|
|
|
perm = rs.permutation(index.ntotal)
|
|
|
|
index.permute_entries(perm)
|
|
|
|
Dnew, Inew = index.search(ds.get_queries(), 5)
|
|
|
|
np.testing.assert_equal(Dnew, Dref)
|
|
|
|
Inew_remap = perm[Inew]
|
|
|
|
np.testing.assert_equal(Inew_remap, Iref)
|
2023-07-19 10:05:46 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TestCodeSet(unittest.TestCase):
|
|
|
|
|
|
|
|
def test_code_set(self):
|
|
|
|
""" CodeSet and np.unique should produce the same output """
|
|
|
|
d = 8
|
|
|
|
n = 1000 # > 256 and using only 0 or 1 so there must be duplicates
|
|
|
|
codes = np.random.randint(0, 2, (n, d), dtype=np.uint8)
|
|
|
|
s = faiss.CodeSet(d)
|
|
|
|
inserted = s.insert(codes)
|
|
|
|
np.testing.assert_equal(
|
|
|
|
np.sort(np.unique(codes, axis=0), axis=None),
|
|
|
|
np.sort(codes[inserted], axis=None))
|
2024-04-03 10:36:56 -07:00
|
|
|
|
|
|
|
|
2024-11-19 12:09:40 -08:00
|
|
|
@unittest.skipIf(
|
|
|
|
platform.system() == 'Windows',
|
|
|
|
'OnDiskInvertedLists is unsupported on Windows.'
|
|
|
|
)
|
2024-04-03 10:36:56 -07:00
|
|
|
class TestMerge(unittest.TestCase):
|
|
|
|
@contextmanager
|
|
|
|
def temp_directory(self):
|
|
|
|
temp_dir = tempfile.mkdtemp()
|
|
|
|
try:
|
|
|
|
yield temp_dir
|
|
|
|
finally:
|
|
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
|
|
def do_test_ondisk_merge(self, shift_ids=False):
|
|
|
|
with self.temp_directory() as tmpdir:
|
|
|
|
# only train and add index to disk without adding elements.
|
|
|
|
# this will create empty inverted lists.
|
|
|
|
ds = datasets.SyntheticDataset(32, 2000, 200, 20)
|
|
|
|
index = faiss.index_factory(ds.d, "IVF32,Flat")
|
|
|
|
index.train(ds.get_train())
|
|
|
|
faiss.write_index(index, tmpdir + "/trained.index")
|
|
|
|
|
|
|
|
# create 4 shards and add elements to them
|
|
|
|
ns = 4 # number of shards
|
|
|
|
|
|
|
|
for bno in range(ns):
|
|
|
|
index = faiss.read_index(tmpdir + "/trained.index")
|
|
|
|
i0, i1 = int(bno * ds.nb / ns), int((bno + 1) * ds.nb / ns)
|
|
|
|
if shift_ids:
|
|
|
|
index.add_with_ids(ds.xb[i0:i1], np.arange(0, ds.nb / ns))
|
|
|
|
else:
|
|
|
|
index.add_with_ids(ds.xb[i0:i1], np.arange(i0, i1))
|
|
|
|
faiss.write_index(index, tmpdir + "/block_%d.index" % bno)
|
|
|
|
|
|
|
|
# construct the output index and merge them on disk
|
|
|
|
index = faiss.read_index(tmpdir + "/trained.index")
|
|
|
|
block_fnames = [tmpdir + "/block_%d.index" % bno for bno in range(4)]
|
|
|
|
|
|
|
|
merge_ondisk(
|
|
|
|
index, block_fnames, tmpdir + "/merged_index.ivfdata", shift_ids
|
|
|
|
)
|
|
|
|
faiss.write_index(index, tmpdir + "/populated.index")
|
|
|
|
|
|
|
|
# perform a search from index on disk
|
|
|
|
index = faiss.read_index(tmpdir + "/populated.index")
|
|
|
|
index.nprobe = 5
|
|
|
|
D, I = index.search(ds.xq, 5)
|
|
|
|
|
|
|
|
# ground-truth
|
|
|
|
gtI = ds.get_groundtruth(5)
|
|
|
|
|
|
|
|
recall_at_1 = (I[:, :1] == gtI[:, :1]).sum() / float(ds.xq.shape[0])
|
|
|
|
self.assertGreaterEqual(recall_at_1, 0.5)
|
|
|
|
|
|
|
|
def test_ondisk_merge(self):
|
|
|
|
self.do_test_ondisk_merge()
|
|
|
|
|
|
|
|
def test_ondisk_merge_with_shift_ids(self):
|
|
|
|
# verified that recall is same for test_ondisk_merge and
|
|
|
|
self.do_test_ondisk_merge(True)
|