15 #include <gtest/gtest.h>
17 #include <faiss/IndexIVF.h>
18 #include <faiss/AutoTune.h>
19 #include <faiss/VectorTransform.h>
36 const float* embeddings,
int num_objects,
37 int64_t* centroid_ids)
39 const float *x = embeddings;
40 std::unique_ptr<float[]> del;
41 if (
auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
42 x = index_pre->apply_chain(num_objects, x);
44 index = index_pre->index;
69 int64_t* query_centroid_ids,
70 int64_t* result_centroid_ids)
73 std::unique_ptr<float []> del;
74 if (
auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
75 x = index_pre->apply_chain(n, x);
77 index = index_pre->index;
82 size_t nprobe = index_ivf->
nprobe;
83 std::vector<long> cent_nos (n * nprobe);
84 std::vector<float> cent_dis (n * nprobe);
86 n, x, nprobe, cent_dis.data(), cent_nos.data());
88 if (query_centroid_ids) {
89 for (
size_t i = 0; i < n; i++)
90 query_centroid_ids[i] = cent_nos[i * nprobe];
94 cent_nos.data(), cent_dis.data(),
95 distances, labels,
true);
97 for (
size_t i = 0; i < n * k; i++) {
98 int64_t label = labels[i];
100 if (result_centroid_ids)
101 result_centroid_ids[i] = -1;
103 long list_no = label >> 32;
104 long list_index = label & 0xffffffff;
105 if (result_centroid_ids)
106 result_centroid_ids[i] = list_no;
107 labels[i] = index_ivf->
ids[list_no][list_index];
118 if (
auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
119 index = index_pre->index;
122 bool t = index_ivf !=
nullptr;
138 std::vector<float> make_data(
size_t n)
140 std::vector <float> database (n * d);
141 for (
size_t i = 0; i < n * d; i++) {
142 database[i] = drand48();
147 std::unique_ptr<faiss::Index> make_index(
const char *index_type,
148 const std::vector<float> & x) {
150 auto index = std::unique_ptr<faiss::Index> (
152 index->
train(nb, x.data());
153 index->
add(nb, x.data());
161 bool test_Search_centroid(
const char *index_key) {
162 std::vector<float> xb = make_data(nb);
163 auto index = make_index(index_key, xb);
169 std::vector<int64_t> centroid_ids (nb);
170 Search_centroid(index.get(), xb.data(), nb, centroid_ids.data());
174 for(
int i = 0; i < nb; i++) {
176 int list_no = centroid_ids[i];
177 for(
int j: ivf->ids[list_no]) {
183 if(!found)
return false;
188 int test_search_and_return_centroids(
const char *index_key) {
189 std::vector<float> xb = make_data(nb);
190 auto index = make_index(index_key, xb);
192 std::vector<int64_t> centroid_ids (nb);
193 Search_centroid(index.get(), xb.data(), nb, centroid_ids.data());
198 std::vector<float> xq = make_data(nq);
204 std::vector<long> refI (nq * k);
205 std::vector<float> refD (nq * k);
206 index->
search (nq, xq.data(), k, refD.data(), refI.data());
210 std::vector<long> newI (nq * k);
211 std::vector<float> newD (nq * k);
213 std::vector<int64_t> query_centroid_ids (nq);
214 std::vector<int64_t> result_centroid_ids (nq * k);
216 search_and_retrun_centroids(index.get(),
218 newD.data(), newI.data(),
219 query_centroid_ids.data(),
220 result_centroid_ids.data());
231 for(
int i = 0; i < nq * k; i++) {
232 int list_no = result_centroid_ids[i];
233 int result_no = newI[i];
235 if (result_no < 0)
continue;
239 for(
int j: ivf->ids[list_no]) {
240 if (j == result_no) {
254 TEST(test_Search_centroid, IVFFlat) {
255 bool ok = test_Search_centroid(
"IVF32,Flat");
259 TEST(test_Search_centroid, PCAIVFFlat) {
260 bool ok = test_Search_centroid(
"PCA16,IVF32,Flat");
264 TEST(test_search_and_return_centroids, IVFFlat) {
265 int err = test_search_and_return_centroids(
"IVF32,Flat");
270 TEST(test_search_and_return_centroids, PCAIVFFlat) {
271 int err = test_search_and_return_centroids(
"PCA16,IVF32,Flat");
virtual void search_preassigned(idx_t n, const float *x, idx_t k, const idx_t *assign, const float *centroid_dis, float *distances, idx_t *labels, bool store_pairs) const =0
size_t nprobe
number of probes at query time
void assign(idx_t n, const float *x, idx_t *labels, idx_t k=1)
virtual void train(idx_t n, const float *x)
std::vector< std::vector< long > > ids
Inverted lists for indexes.
virtual void add(idx_t n, const float *x)=0
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
Index * quantizer
quantizer that maps vectors to inverted lists
Index * index_factory(int d, const char *description_in, MetricType metric)