15 #include <gtest/gtest.h>
17 #include <faiss/IndexIVF.h>
18 #include <faiss/AutoTune.h>
19 #include <faiss/VectorTransform.h>
20 #include <faiss/IVFlib.h>
41 std::vector<float> make_data(
size_t n)
43 std::vector <float> database (n * d);
44 for (
size_t i = 0; i < n * d; i++) {
45 database[i] = drand48();
50 std::unique_ptr<faiss::Index> make_index(
const char *index_type,
51 const std::vector<float> & x) {
53 auto index = std::unique_ptr<faiss::Index> (
55 index->train(nb, x.data());
56 index->add(nb, x.data());
64 bool test_search_centroid(
const char *index_key) {
65 std::vector<float> xb = make_data(nb);
66 auto index = make_index(index_key, xb);
72 std::vector<idx_t> centroid_ids (nb);
73 faiss::ivflib::search_centroid(
74 index.get(), xb.data(), nb, centroid_ids.data());
79 for(
int i = 0; i < nb; i++) {
81 int list_no = centroid_ids[i];
82 int list_size = ivf->invlists->list_size (list_no);
83 auto * list = ivf->invlists->get_ids (list_no);
85 for(
int j = 0; j < list_size; j++) {
91 if(!found)
return false;
96 int test_search_and_return_centroids(
const char *index_key) {
97 std::vector<float> xb = make_data(nb);
98 auto index = make_index(index_key, xb);
100 std::vector<idx_t> centroid_ids (nb);
101 faiss::ivflib::search_centroid(index.get(), xb.data(),
102 nb, centroid_ids.data());
105 faiss::ivflib::extract_index_ivf (index.get());
108 std::vector<float> xq = make_data(nq);
114 std::vector<idx_t> refI (nq * k);
115 std::vector<float> refD (nq * k);
116 index->search (nq, xq.data(), k, refD.data(), refI.data());
120 std::vector<idx_t> newI (nq * k);
121 std::vector<float> newD (nq * k);
123 std::vector<idx_t> query_centroid_ids (nq);
124 std::vector<idx_t> result_centroid_ids (nq * k);
126 faiss::ivflib::search_and_return_centroids(index.get(),
128 newD.data(), newI.data(),
129 query_centroid_ids.data(),
130 result_centroid_ids.data());
141 for(
int i = 0; i < nq * k; i++) {
142 int list_no = result_centroid_ids[i];
143 int result_no = newI[i];
145 if (result_no < 0)
continue;
149 int list_size = ivf->invlists->list_size (list_no);
150 auto * list = ivf->invlists->get_ids (list_no);
152 for(
int j = 0; j < list_size; j++) {
153 if (list[j] == result_no) {
170 TEST(test_search_centroid, IVFFlat) {
171 bool ok = test_search_centroid(
"IVF32,Flat");
175 TEST(test_search_centroid, PCAIVFFlat) {
176 bool ok = test_search_centroid(
"PCA16,IVF32,Flat");
180 TEST(test_search_and_return_centroids, IVFFlat) {
181 int err = test_search_and_return_centroids(
"IVF32,Flat");
186 TEST(test_search_and_return_centroids, PCAIVFFlat) {
187 int err = test_search_and_return_centroids(
"PCA16,IVF32,Flat");
size_t nprobe
number of probes at query time
long idx_t
all indices are this type
Index * index_factory(int d, const char *description_in, MetricType metric)