Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_pairs_decoding.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdio>
10 #include <cstdlib>
11 
12 #include <memory>
13 #include <vector>
14 
15 #include <gtest/gtest.h>
16 
17 #include <faiss/IndexIVF.h>
18 #include <faiss/AutoTune.h>
19 #include <faiss/VectorTransform.h>
20 
21 
22 /*************************************************************
23  * The functions to test, that can be useful in FANN
24  *************************************************************/
25 
26 /* Returns the cluster the embeddings belong to.
27  *
28  * @param index Index, which should be an IVF index
29  * (otherwise there are no clusters)
30  * @param embeddings object descriptors for which the centroids should be found,
31  * size num_objects * d
32  * @param cebtroid_ids
33  * cluster id each object belongs to, size num_objects
34  */
35 void Search_centroid(faiss::Index *index,
36  const float* embeddings, int num_objects,
37  int64_t* centroid_ids)
38 {
39  const float *x = embeddings;
40  std::unique_ptr<float[]> del;
41  if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
42  x = index_pre->apply_chain(num_objects, x);
43  del.reset((float*)x);
44  index = index_pre->index;
45  }
46  faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
47  assert(index_ivf);
48  index_ivf->quantizer->assign(num_objects, x, centroid_ids);
49 }
50 
51 
52 
53 /* Returns the cluster the embeddings belong to.
54  *
55  * @param index Index, which should be an IVF index
56  * (otherwise there are no clusters)
57  * @param query_centroid_ids
58  * centroid ids corresponding to the query vectors (size n)
59  * @param result_centroid_ids
60  * centroid ids corresponding to the results (size n * k)
61  * other arguments are the same as the standard search function
62  */
63 void search_and_retrun_centroids(faiss::Index *index,
64  size_t n,
65  const float* xin,
66  long k,
67  float *distances,
68  int64_t* labels,
69  int64_t* query_centroid_ids,
70  int64_t* result_centroid_ids)
71 {
72  const float *x = xin;
73  std::unique_ptr<float []> del;
74  if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
75  x = index_pre->apply_chain(n, x);
76  del.reset((float*)x);
77  index = index_pre->index;
78  }
79  faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
80  assert(index_ivf);
81 
82  size_t nprobe = index_ivf->nprobe;
83  std::vector<long> cent_nos (n * nprobe);
84  std::vector<float> cent_dis (n * nprobe);
85  index_ivf->quantizer->search(
86  n, x, nprobe, cent_dis.data(), cent_nos.data());
87 
88  if (query_centroid_ids) {
89  for (size_t i = 0; i < n; i++)
90  query_centroid_ids[i] = cent_nos[i * nprobe];
91  }
92 
93  index_ivf->search_preassigned (n, x, k,
94  cent_nos.data(), cent_dis.data(),
95  distances, labels, true);
96 
97  for (size_t i = 0; i < n * k; i++) {
98  int64_t label = labels[i];
99  if (label < 0) {
100  if (result_centroid_ids)
101  result_centroid_ids[i] = -1;
102  } else {
103  long list_no = label >> 32;
104  long list_index = label & 0xffffffff;
105  if (result_centroid_ids)
106  result_centroid_ids[i] = list_no;
107  labels[i] = index_ivf->invlists->get_single_id(list_no, list_index);
108  }
109  }
110 }
111 
112 /*************************************************************
113  * Test utils
114  *************************************************************/
115 
116 // return an IndexIVF that may be embedded in an IndexPreTransform
117 faiss::IndexIVF * get_IndexIVF(faiss::Index *index) {
118  if (auto index_pre = dynamic_cast<faiss::IndexPreTransform*>(index)) {
119  index = index_pre->index;
120  }
121  faiss::IndexIVF* index_ivf = dynamic_cast<faiss::IndexIVF*>(index);
122  bool t = index_ivf != nullptr;
123  assert(index_ivf);
124  return index_ivf;
125 }
126 
127 
128 
129 // dimension of the vectors to index
130 int d = 64;
131 
132 // size of the database we plan to index
133 size_t nb = 8000;
134 
135 // nb of queries
136 size_t nq = 200;
137 
138 std::vector<float> make_data(size_t n)
139 {
140  std::vector <float> database (n * d);
141  for (size_t i = 0; i < n * d; i++) {
142  database[i] = drand48();
143  }
144  return database;
145 }
146 
147 std::unique_ptr<faiss::Index> make_index(const char *index_type,
148  const std::vector<float> & x) {
149 
150  auto index = std::unique_ptr<faiss::Index> (
151  faiss::index_factory(d, index_type));
152  index->train(nb, x.data());
153  index->add(nb, x.data());
154  return index;
155 }
156 
157 /*************************************************************
158  * Test functions for a given index type
159  *************************************************************/
160 
161 bool test_Search_centroid(const char *index_key) {
162  std::vector<float> xb = make_data(nb); // database vectors
163  auto index = make_index(index_key, xb);
164 
165  /* First test: find the centroids associated to the database
166  vectors and make sure that each vector does indeed appear in
167  the inverted list corresponding to its centroid */
168 
169  std::vector<int64_t> centroid_ids (nb);
170  Search_centroid(index.get(), xb.data(), nb, centroid_ids.data());
171 
172  const faiss::IndexIVF * ivf = get_IndexIVF(index.get());
173 
174  for(int i = 0; i < nb; i++) {
175  bool found = false;
176  int list_no = centroid_ids[i];
177  int list_size = ivf->invlists->list_size (list_no);
178  auto * list = ivf->invlists->get_ids (list_no);
179 
180  for(int j = 0; j < list_size; j++) {
181  if (list[j] == i) {
182  found = true;
183  break;
184  }
185  }
186  if(!found) return false;
187  }
188  return true;
189 }
190 
191 int test_search_and_return_centroids(const char *index_key) {
192  std::vector<float> xb = make_data(nb); // database vectors
193  auto index = make_index(index_key, xb);
194 
195  std::vector<int64_t> centroid_ids (nb);
196  Search_centroid(index.get(), xb.data(), nb, centroid_ids.data());
197 
198  faiss::IndexIVF * ivf = get_IndexIVF(index.get());
199  ivf->nprobe = 4;
200 
201  std::vector<float> xq = make_data(nq); // database vectors
202 
203  int k = 5;
204 
205  // compute a reference search result
206 
207  std::vector<long> refI (nq * k);
208  std::vector<float> refD (nq * k);
209  index->search (nq, xq.data(), k, refD.data(), refI.data());
210 
211  // compute search result
212 
213  std::vector<long> newI (nq * k);
214  std::vector<float> newD (nq * k);
215 
216  std::vector<int64_t> query_centroid_ids (nq);
217  std::vector<int64_t> result_centroid_ids (nq * k);
218 
219  search_and_retrun_centroids(index.get(),
220  nq, xq.data(), k,
221  newD.data(), newI.data(),
222  query_centroid_ids.data(),
223  result_centroid_ids.data());
224 
225  // first verify that we have the same result as the standard search
226 
227  if (newI != refI) {
228  return 1;
229  }
230 
231  // then check if the result ids are indeed in the inverted list
232  // they are supposed to be in
233 
234  for(int i = 0; i < nq * k; i++) {
235  int list_no = result_centroid_ids[i];
236  int result_no = newI[i];
237 
238  if (result_no < 0) continue;
239 
240  bool found = false;
241 
242  int list_size = ivf->invlists->list_size (list_no);
243  auto * list = ivf->invlists->get_ids (list_no);
244 
245  for(int j = 0; j < list_size; j++) {
246  if (list[j] == result_no) {
247  found = true;
248  break;
249  }
250  }
251  if(!found) return 2;
252  }
253  return 0;
254 }
255 
256 /*************************************************************
257  * Test entry points
258  *************************************************************/
259 
260 TEST(test_Search_centroid, IVFFlat) {
261  bool ok = test_Search_centroid("IVF32,Flat");
262  EXPECT_TRUE(ok);
263 }
264 
265 TEST(test_Search_centroid, PCAIVFFlat) {
266  bool ok = test_Search_centroid("PCA16,IVF32,Flat");
267  EXPECT_TRUE(ok);
268 }
269 
270 TEST(test_search_and_return_centroids, IVFFlat) {
271  int err = test_search_and_return_centroids("IVF32,Flat");
272  EXPECT_NE(err, 1);
273  EXPECT_NE(err, 2);
274 }
275 
276 TEST(test_search_and_return_centroids, PCAIVFFlat) {
277  int err = test_search_and_return_centroids("PCA16,IVF32,Flat");
278  EXPECT_NE(err, 1);
279  EXPECT_NE(err, 2);
280 }
virtual void search_preassigned(idx_t n, const float *x, idx_t k, const idx_t *assign, const float *centroid_dis, float *distances, idx_t *labels, bool store_pairs) const =0
size_t nprobe
number of probes at query time
Definition: IndexIVF.h:173
void assign(idx_t n, const float *x, idx_t *labels, idx_t k=1)
Definition: Index.cpp:34
virtual void train(idx_t n, const float *x)
Definition: Index.cpp:23
virtual idx_t get_single_id(size_t list_no, size_t offset) const
Definition: IndexIVF.cpp:118
virtual void add(idx_t n, const float *x)=0
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
InvertedLists * invlists
Acess to the actual data.
Definition: IndexIVF.h:168
Index * quantizer
quantizer that maps vectors to inverted lists
Definition: IndexIVF.h:33
Index * index_factory(int d, const char *description_in, MetricType metric)
Definition: AutoTune.cpp:688