Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_params_override.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdio>
10 #include <cstdlib>
11 
12 #include <memory>
13 #include <vector>
14 
15 #include <gtest/gtest.h>
16 
17 #include <faiss/IndexIVF.h>
18 #include <faiss/IndexBinaryIVF.h>
19 #include <faiss/AutoTune.h>
20 #include <faiss/IVFlib.h>
21 
22 using namespace faiss;
23 
24 namespace {
25 
26 typedef Index::idx_t idx_t;
27 
28 
29 // dimension of the vectors to index
30 int d = 32;
31 
32 // size of the database we plan to index
33 size_t nb = 1000;
34 
35 // nb of queries
36 size_t nq = 200;
37 
38 
39 
40 std::vector<float> make_data(size_t n)
41 {
42  std::vector <float> database (n * d);
43  for (size_t i = 0; i < n * d; i++) {
44  database[i] = drand48();
45  }
46  return database;
47 }
48 
49 std::unique_ptr<Index> make_index(const char *index_type,
50  MetricType metric,
51  const std::vector<float> & x)
52 {
53  std::unique_ptr<Index> index(index_factory(d, index_type, metric));
54  index->train(nb, x.data());
55  index->add(nb, x.data());
56  return index;
57 }
58 
59 std::vector<idx_t> search_index(Index *index, const float *xq) {
60  int k = 10;
61  std::vector<idx_t> I(k * nq);
62  std::vector<float> D(k * nq);
63  index->search (nq, xq, k, D.data(), I.data());
64  return I;
65 }
66 
67 std::vector<idx_t> search_index_with_params(
68  Index *index, const float *xq, IVFSearchParameters *params) {
69  int k = 10;
70  std::vector<idx_t> I(k * nq);
71  std::vector<float> D(k * nq);
72  ivflib::search_with_parameters (index, nq, xq, k,
73  D.data(), I.data(), params);
74  return I;
75 }
76 
77 
78 
79 
80 /*************************************************************
81  * Test functions for a given index type
82  *************************************************************/
83 
84 int test_params_override (const char *index_key, MetricType metric) {
85  std::vector<float> xb = make_data(nb); // database vectors
86  auto index = make_index(index_key, metric, xb);
87  //index->train(nb, xb.data());
88  // index->add(nb, xb.data());
89  std::vector<float> xq = make_data(nq);
90  ParameterSpace ps;
91  ps.set_index_parameter(index.get(), "nprobe", 2);
92  auto res2ref = search_index(index.get(), xq.data());
93  ps.set_index_parameter(index.get(), "nprobe", 9);
94  auto res9ref = search_index(index.get(), xq.data());
95  ps.set_index_parameter(index.get(), "nprobe", 1);
96 
97  IVFSearchParameters params;
98  params.max_codes = 0;
99  params.nprobe = 2;
100  auto res2new = search_index_with_params(index.get(), xq.data(), &params);
101  params.nprobe = 9;
102  auto res9new = search_index_with_params(index.get(), xq.data(), &params);
103 
104  if (res2ref != res2new)
105  return 2;
106 
107  if (res9ref != res9new)
108  return 9;
109 
110  return 0;
111 }
112 
113 
114 } // namespace
115 
116 
117 /*************************************************************
118  * Test entry points
119  *************************************************************/
120 
121 TEST(TPO, IVFFlat) {
122  int err1 = test_params_override ("IVF32,Flat", METRIC_L2);
123  EXPECT_EQ(err1, 0);
124  int err2 = test_params_override ("IVF32,Flat", METRIC_INNER_PRODUCT);
125  EXPECT_EQ(err2, 0);
126 }
127 
128 TEST(TPO, IVFPQ) {
129  int err1 = test_params_override ("IVF32,PQ8np", METRIC_L2);
130  EXPECT_EQ(err1, 0);
131  int err2 = test_params_override ("IVF32,PQ8np", METRIC_INNER_PRODUCT);
132  EXPECT_EQ(err2, 0);
133 }
134 
135 TEST(TPO, IVFSQ) {
136  int err1 = test_params_override ("IVF32,SQ8", METRIC_L2);
137  EXPECT_EQ(err1, 0);
138  int err2 = test_params_override ("IVF32,SQ8", METRIC_INNER_PRODUCT);
139  EXPECT_EQ(err2, 0);
140 }
141 
142 TEST(TPO, IVFFlatPP) {
143  int err1 = test_params_override ("PCA16,IVF32,SQ8", METRIC_L2);
144  EXPECT_EQ(err1, 0);
145  int err2 = test_params_override ("PCA16,IVF32,SQ8", METRIC_INNER_PRODUCT);
146  EXPECT_EQ(err2, 0);
147 }
148 
149 
150 
151 /*************************************************************
152  * Same for binary indexes
153  *************************************************************/
154 
155 
156 std::vector<uint8_t> make_data_binary(size_t n) {
157  std::vector <uint8_t> database (n * d / 8);
158  for (size_t i = 0; i < n * d / 8; i++) {
159  database[i] = lrand48();
160  }
161  return database;
162 }
163 
164 std::unique_ptr<IndexBinaryIVF> make_index(const char *index_type,
165  const std::vector<uint8_t> & x)
166 {
167 
168  auto index = std::unique_ptr<IndexBinaryIVF>
169  (dynamic_cast<IndexBinaryIVF*>(index_binary_factory (d, index_type)));
170  index->train(nb, x.data());
171  index->add(nb, x.data());
172  return index;
173 }
174 
175 std::vector<idx_t> search_index(IndexBinaryIVF *index, const uint8_t *xq) {
176  int k = 10;
177  std::vector<idx_t> I(k * nq);
178  std::vector<int32_t> D(k * nq);
179  index->search (nq, xq, k, D.data(), I.data());
180  return I;
181 }
182 
183 std::vector<idx_t> search_index_with_params(
184  IndexBinaryIVF *index, const uint8_t *xq, IVFSearchParameters *params) {
185  int k = 10;
186  std::vector<idx_t> I(k * nq);
187  std::vector<int32_t> D(k * nq);
188 
189  std::vector<idx_t> Iq(params->nprobe * nq);
190  std::vector<int32_t> Dq(params->nprobe * nq);
191 
192  index->quantizer->search(nq, xq, params->nprobe,
193  Dq.data(), Iq.data());
194  index->search_preassigned(nq, xq, k, Iq.data(), Dq.data(),
195  D.data(), I.data(),
196  false, params);
197  return I;
198 }
199 
200 int test_params_override_binary (const char *index_key) {
201  std::vector<uint8_t> xb = make_data_binary(nb); // database vectors
202  auto index = make_index (index_key, xb);
203  index->train(nb, xb.data());
204  index->add(nb, xb.data());
205  std::vector<uint8_t> xq = make_data_binary(nq);
206  index->nprobe = 2;
207  auto res2ref = search_index(index.get(), xq.data());
208  index->nprobe = 9;
209  auto res9ref = search_index(index.get(), xq.data());
210  index->nprobe = 1;
211 
212  IVFSearchParameters params;
213  params.max_codes = 0;
214  params.nprobe = 2;
215  auto res2new = search_index_with_params(index.get(), xq.data(), &params);
216  params.nprobe = 9;
217  auto res9new = search_index_with_params(index.get(), xq.data(), &params);
218 
219  if (res2ref != res2new)
220  return 2;
221 
222  if (res9ref != res9new)
223  return 9;
224 
225  return 0;
226 }
227 
228 TEST(TPOB, IVF) {
229  int err1 = test_params_override_binary ("BIVF32");
230  EXPECT_EQ(err1, 0);
231 }
void train(idx_t n, const float *x) override
virtual void search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels) const =0
size_t nprobe
number of probes at query time
void search_preassigned(idx_t n, const uint8_t *x, idx_t k, const idx_t *assign, const int32_t *centroid_dis, int32_t *distances, idx_t *labels, bool store_pairs, const IVFSearchParameters *params=nullptr) const
virtual void train(idx_t n, const float *x)
Definition: Index.cpp:24
IndexBinary * quantizer
quantizer that maps vectors to inverted lists
void train(idx_t n, const uint8_t *x) override
Trains the quantizer and calls train_residual to train sub-quantizers.
virtual void add(idx_t n, const float *x)=0
void add(idx_t n, const float *x) override
supported only for sub-indices that implement add_with_ids
long idx_t
all indices are this type
Definition: Index.h:64
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
size_t nprobe
number of probes at query time
Definition: IndexIVF.h:62
void add(idx_t n, const uint8_t *x) override
Quantizes x and calls add_with_key.
virtual void set_index_parameter(Index *index, const std::string &name, double val) const
set one of the parameters
Definition: AutoTune.cpp:452
virtual void search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels) const override
size_t max_codes
max nb of codes to visit to do a query
Definition: IndexIVF.h:63
Index * index_factory(int d, const char *description_in, MetricType metric)
Definition: AutoTune.cpp:722
MetricType
Some algorithms support both an inner product version and a L2 search version.
Definition: Index.h:45