Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_pairs_decoding.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdio>
10 #include <cstdlib>
11 
12 #include <memory>
13 #include <vector>
14 
15 #include <gtest/gtest.h>
16 
17 #include <faiss/IndexIVF.h>
18 #include <faiss/AutoTune.h>
19 #include <faiss/VectorTransform.h>
20 #include <faiss/IVFlib.h>
21 
22 
23 namespace {
24 
25 typedef faiss::Index::idx_t idx_t;
26 
27 /*************************************************************
28  * Test utils
29  *************************************************************/
30 
31 
32 // dimension of the vectors to index
33 int d = 64;
34 
35 // size of the database we plan to index
36 size_t nb = 8000;
37 
38 // nb of queries
39 size_t nq = 200;
40 
41 std::vector<float> make_data(size_t n)
42 {
43  std::vector <float> database (n * d);
44  for (size_t i = 0; i < n * d; i++) {
45  database[i] = drand48();
46  }
47  return database;
48 }
49 
50 std::unique_ptr<faiss::Index> make_index(const char *index_type,
51  const std::vector<float> & x) {
52 
53  auto index = std::unique_ptr<faiss::Index> (
54  faiss::index_factory(d, index_type));
55  index->train(nb, x.data());
56  index->add(nb, x.data());
57  return index;
58 }
59 
60 /*************************************************************
61  * Test functions for a given index type
62  *************************************************************/
63 
64 bool test_search_centroid(const char *index_key) {
65  std::vector<float> xb = make_data(nb); // database vectors
66  auto index = make_index(index_key, xb);
67 
68  /* First test: find the centroids associated to the database
69  vectors and make sure that each vector does indeed appear in
70  the inverted list corresponding to its centroid */
71 
72  std::vector<idx_t> centroid_ids (nb);
73  faiss::ivflib::search_centroid(
74  index.get(), xb.data(), nb, centroid_ids.data());
75 
76  const faiss::IndexIVF * ivf = faiss::ivflib::extract_index_ivf
77  (index.get());
78 
79  for(int i = 0; i < nb; i++) {
80  bool found = false;
81  int list_no = centroid_ids[i];
82  int list_size = ivf->invlists->list_size (list_no);
83  auto * list = ivf->invlists->get_ids (list_no);
84 
85  for(int j = 0; j < list_size; j++) {
86  if (list[j] == i) {
87  found = true;
88  break;
89  }
90  }
91  if(!found) return false;
92  }
93  return true;
94 }
95 
96 int test_search_and_return_centroids(const char *index_key) {
97  std::vector<float> xb = make_data(nb); // database vectors
98  auto index = make_index(index_key, xb);
99 
100  std::vector<idx_t> centroid_ids (nb);
101  faiss::ivflib::search_centroid(index.get(), xb.data(),
102  nb, centroid_ids.data());
103 
104  faiss::IndexIVF * ivf =
105  faiss::ivflib::extract_index_ivf (index.get());
106  ivf->nprobe = 4;
107 
108  std::vector<float> xq = make_data(nq); // database vectors
109 
110  int k = 5;
111 
112  // compute a reference search result
113 
114  std::vector<idx_t> refI (nq * k);
115  std::vector<float> refD (nq * k);
116  index->search (nq, xq.data(), k, refD.data(), refI.data());
117 
118  // compute search result
119 
120  std::vector<idx_t> newI (nq * k);
121  std::vector<float> newD (nq * k);
122 
123  std::vector<idx_t> query_centroid_ids (nq);
124  std::vector<idx_t> result_centroid_ids (nq * k);
125 
126  faiss::ivflib::search_and_return_centroids(index.get(),
127  nq, xq.data(), k,
128  newD.data(), newI.data(),
129  query_centroid_ids.data(),
130  result_centroid_ids.data());
131 
132  // first verify that we have the same result as the standard search
133 
134  if (newI != refI) {
135  return 1;
136  }
137 
138  // then check if the result ids are indeed in the inverted list
139  // they are supposed to be in
140 
141  for(int i = 0; i < nq * k; i++) {
142  int list_no = result_centroid_ids[i];
143  int result_no = newI[i];
144 
145  if (result_no < 0) continue;
146 
147  bool found = false;
148 
149  int list_size = ivf->invlists->list_size (list_no);
150  auto * list = ivf->invlists->get_ids (list_no);
151 
152  for(int j = 0; j < list_size; j++) {
153  if (list[j] == result_no) {
154  found = true;
155  break;
156  }
157  }
158  if(!found) return 2;
159  }
160  return 0;
161 }
162 
163 } // namespace
164 
165 
166 /*************************************************************
167  * Test entry points
168  *************************************************************/
169 
170 TEST(test_search_centroid, IVFFlat) {
171  bool ok = test_search_centroid("IVF32,Flat");
172  EXPECT_TRUE(ok);
173 }
174 
175 TEST(test_search_centroid, PCAIVFFlat) {
176  bool ok = test_search_centroid("PCA16,IVF32,Flat");
177  EXPECT_TRUE(ok);
178 }
179 
180 TEST(test_search_and_return_centroids, IVFFlat) {
181  int err = test_search_and_return_centroids("IVF32,Flat");
182  EXPECT_NE(err, 1);
183  EXPECT_NE(err, 2);
184 }
185 
186 TEST(test_search_and_return_centroids, PCAIVFFlat) {
187  int err = test_search_and_return_centroids("PCA16,IVF32,Flat");
188  EXPECT_NE(err, 1);
189  EXPECT_NE(err, 2);
190 }
size_t nprobe
number of probes at query time
Definition: IndexIVF.h:98
long idx_t
all indices are this type
Definition: Index.h:64
Index * index_factory(int d, const char *description_in, MetricType metric)
Definition: AutoTune.cpp:722