Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_pairs_decoding.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdio>
10 #include <cstdlib>
11 
12 #include <memory>
13 #include <vector>
14 
15 #include <gtest/gtest.h>
16 
17 #include <faiss/IndexIVF.h>
18 #include <faiss/AutoTune.h>
19 #include <faiss/VectorTransform.h>
20 
21 
22 /*************************************************************
23  * The functions to test, that can be useful in FANN
24  *************************************************************/
25 
26 /* Returns the cluster the embeddings belong to.
27  *
28  * @param index Index, which should be an IVF index
29  * (otherwise there are no clusters)
30  * @param embeddings object descriptors for which the centroids should be found,
31  * size num_objects * d
32  * @param cebtroid_ids
33  * cluster id each object belongs to, size num_objects
34  */
35 void Search_centroid(faiss::Index *index,
36  const float* embeddings, int num_objects,
37  int64_t* centroid_ids)
38 {
39  const float *x = embeddings;
40  std::unique_ptr<float[]> del;
41  if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
42  x = index_pre->apply_chain(num_objects, x);
43  del.reset((float*)x);
44  index = index_pre->index;
45  }
46  faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
47  assert(index_ivf);
48  index_ivf->quantizer->assign(num_objects, x, centroid_ids);
49 }
50 
51 
52 
53 /* Returns the cluster the embeddings belong to.
54  *
55  * @param index Index, which should be an IVF index
56  * (otherwise there are no clusters)
57  * @param query_centroid_ids
58  * centroid ids corresponding to the query vectors (size n)
59  * @param result_centroid_ids
60  * centroid ids corresponding to the results (size n * k)
61  * other arguments are the same as the standard search function
62  */
63 void search_and_retrun_centroids(faiss::Index *index,
64  size_t n,
65  const float* xin,
66  long k,
67  float *distances,
68  int64_t* labels,
69  int64_t* query_centroid_ids,
70  int64_t* result_centroid_ids)
71 {
72  const float *x = xin;
73  std::unique_ptr<float []> del;
74  if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
75  x = index_pre->apply_chain(n, x);
76  del.reset((float*)x);
77  index = index_pre->index;
78  }
79  faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
80  assert(index_ivf);
81 
82  size_t nprobe = index_ivf->nprobe;
83  std::vector<long> cent_nos (n * nprobe);
84  std::vector<float> cent_dis (n * nprobe);
85  index_ivf->quantizer->search(
86  n, x, nprobe, cent_dis.data(), cent_nos.data());
87 
88  if (query_centroid_ids) {
89  for (size_t i = 0; i < n; i++)
90  query_centroid_ids[i] = cent_nos[i * nprobe];
91  }
92 
93  index_ivf->search_preassigned (n, x, k,
94  cent_nos.data(), cent_dis.data(),
95  distances, labels, true);
96 
97  for (size_t i = 0; i < n * k; i++) {
98  int64_t label = labels[i];
99  if (label < 0) {
100  if (result_centroid_ids)
101  result_centroid_ids[i] = -1;
102  } else {
103  long list_no = label >> 32;
104  long list_index = label & 0xffffffff;
105  if (result_centroid_ids)
106  result_centroid_ids[i] = list_no;
107  labels[i] = index_ivf->ids[list_no][list_index];
108  }
109  }
110 }
111 
112 /*************************************************************
113  * Test utils
114  *************************************************************/
115 
116 // return an IndexIVF that may be embedded in an IndexPreTransform
117 faiss::IndexIVF * get_IndexIVF(faiss::Index *index) {
118  if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
119  index = index_pre->index;
120  }
121  faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
122  bool t = index_ivf != nullptr;
123  assert(index_ivf);
124  return index_ivf;
125 }
126 
127 
128 
129 // dimension of the vectors to index
130 int d = 64;
131 
132 // size of the database we plan to index
133 size_t nb = 8000;
134 
135 // nb of queries
136 size_t nq = 200;
137 
138 std::vector<float> make_data(size_t n)
139 {
140  std::vector <float> database (n * d);
141  for (size_t i = 0; i < n * d; i++) {
142  database[i] = drand48();
143  }
144  return database;
145 }
146 
147 std::unique_ptr<faiss::Index> make_index(const char *index_type,
148  const std::vector<float> & x) {
149 
150  auto index = std::unique_ptr<faiss::Index> (
151  faiss::index_factory(d, index_type));
152  index->train(nb, x.data());
153  index->add(nb, x.data());
154  return index;
155 }
156 
157 /*************************************************************
158  * Test functions for a given index type
159  *************************************************************/
160 
161 bool test_Search_centroid(const char *index_key) {
162  std::vector<float> xb = make_data(nb); // database vectors
163  auto index = make_index(index_key, xb);
164 
165  /* First test: find the centroids associated to the database
166  vectors and make sure that each vector does indeed appear in
167  the inverted list corresponding to its centroid */
168 
169  std::vector<int64_t> centroid_ids (nb);
170  Search_centroid(index.get(), xb.data(), nb, centroid_ids.data());
171 
172  const faiss::IndexIVF * ivf = get_IndexIVF(index.get());
173 
174  for(int i = 0; i < nb; i++) {
175  bool found = false;
176  int list_no = centroid_ids[i];
177  for(int j: ivf->ids[list_no]) {
178  if (j == i) {
179  found = true;
180  break;
181  }
182  }
183  if(!found) return false;
184  }
185  return true;
186 }
187 
188 int test_search_and_return_centroids(const char *index_key) {
189  std::vector<float> xb = make_data(nb); // database vectors
190  auto index = make_index(index_key, xb);
191 
192  std::vector<int64_t> centroid_ids (nb);
193  Search_centroid(index.get(), xb.data(), nb, centroid_ids.data());
194 
195  faiss::IndexIVF * ivf = get_IndexIVF(index.get());
196  ivf->nprobe = 4;
197 
198  std::vector<float> xq = make_data(nq); // database vectors
199 
200  int k = 5;
201 
202  // compute a reference search result
203 
204  std::vector<long> refI (nq * k);
205  std::vector<float> refD (nq * k);
206  index->search (nq, xq.data(), k, refD.data(), refI.data());
207 
208  // compute search result
209 
210  std::vector<long> newI (nq * k);
211  std::vector<float> newD (nq * k);
212 
213  std::vector<int64_t> query_centroid_ids (nq);
214  std::vector<int64_t> result_centroid_ids (nq * k);
215 
216  search_and_retrun_centroids(index.get(),
217  nq, xq.data(), k,
218  newD.data(), newI.data(),
219  query_centroid_ids.data(),
220  result_centroid_ids.data());
221 
222  // first verify that we have the same result as the standard search
223 
224  if (newI != refI) {
225  return 1;
226  }
227 
228  // then check if the result ids are indeed in the inverted list
229  // they are supposed to be in
230 
231  for(int i = 0; i < nq * k; i++) {
232  int list_no = result_centroid_ids[i];
233  int result_no = newI[i];
234 
235  if (result_no < 0) continue;
236 
237  bool found = false;
238 
239  for(int j: ivf->ids[list_no]) {
240  if (j == result_no) {
241  found = true;
242  break;
243  }
244  }
245  if(!found) return 2;
246  }
247  return 0;
248 }
249 
250 /*************************************************************
251  * Test entry points
252  *************************************************************/
253 
254 TEST(test_Search_centroid, IVFFlat) {
255  bool ok = test_Search_centroid("IVF32,Flat");
256  EXPECT_TRUE(ok);
257 }
258 
259 TEST(test_Search_centroid, PCAIVFFlat) {
260  bool ok = test_Search_centroid("PCA16,IVF32,Flat");
261  EXPECT_TRUE(ok);
262 }
263 
264 TEST(test_search_and_return_centroids, IVFFlat) {
265  int err = test_search_and_return_centroids("IVF32,Flat");
266  EXPECT_NE(err, 1);
267  EXPECT_NE(err, 2);
268 }
269 
270 TEST(test_search_and_return_centroids, PCAIVFFlat) {
271  int err = test_search_and_return_centroids("PCA16,IVF32,Flat");
272  EXPECT_NE(err, 1);
273  EXPECT_NE(err, 2);
274 }
virtual void search_preassigned(idx_t n, const float *x, idx_t k, const idx_t *assign, const float *centroid_dis, float *distances, idx_t *labels, bool store_pairs) const =0
size_t nprobe
number of probes at query time
Definition: IndexIVF.h:78
void assign(idx_t n, const float *x, idx_t *labels, idx_t k=1)
Definition: Index.cpp:34
virtual void train(idx_t n, const float *x)
Definition: Index.cpp:23
std::vector< std::vector< long > > ids
Inverted lists for indexes.
Definition: IndexIVF.h:81
virtual void add(idx_t n, const float *x)=0
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
Index * quantizer
quantizer that maps vectors to inverted lists
Definition: IndexIVF.h:33
Index * index_factory(int d, const char *description_in, MetricType metric)
Definition: AutoTune.cpp:687