Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_merge.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdio>
10 #include <cstdlib>
11 
12 #include <gtest/gtest.h>
13 
14 #include <faiss/IndexIVFFlat.h>
15 #include <faiss/IndexIVFPQ.h>
16 #include <faiss/IndexFlat.h>
17 #include <faiss/MetaIndexes.h>
18 #include <faiss/FaissAssert.h>
19 #include <faiss/VectorTransform.h>
20 #include <faiss/OnDiskInvertedLists.h>
21 
22 
23 namespace faiss {
24 
25 // Main function to test
26 
27 // Merge index1 into index0. Works on IndexIVF's and IndexIVF's
28 // embedded in a IndexPreTransform
29 
30 void merge_into(Index *index0, Index *index1, bool shift_ids) {
31  FAISS_THROW_IF_NOT (index0->d == index1->d);
32  IndexIVF *ivf0 = dynamic_cast<IndexIVF *>(index0);
33  IndexIVF *ivf1 = dynamic_cast<IndexIVF *>(index1);
34 
35  if (!ivf0) {
36  IndexPreTransform *pt0 = dynamic_cast<IndexPreTransform *>(index0);
37  IndexPreTransform *pt1 = dynamic_cast<IndexPreTransform *>(index1);
38 
39  // minimal sanity check
40  FAISS_THROW_IF_NOT (pt0 && pt1);
41  FAISS_THROW_IF_NOT (pt0->chain.size() == pt1->chain.size());
42  for (int i = 0; i < pt0->chain.size(); i++) {
43  FAISS_THROW_IF_NOT (typeid(pt0->chain[i]) == typeid(pt1->chain[i]));
44  }
45 
46  ivf0 = dynamic_cast<IndexIVF *>(pt0->index);
47  ivf1 = dynamic_cast<IndexIVF *>(pt1->index);
48  }
49 
50  FAISS_THROW_IF_NOT (ivf0);
51  FAISS_THROW_IF_NOT (ivf1);
52 
53  ivf0->merge_from (*ivf1, shift_ids ? ivf0->ntotal : 0);
54 
55  // useful for IndexPreTransform
56  index0->ntotal = ivf0->ntotal;
57  index1->ntotal = ivf1->ntotal;
58 }
59 
60 };
61 
62 
63 struct Tempfilename {
64 
65  static pthread_mutex_t mutex;
66 
67  std::string filename;
68 
69  Tempfilename (const char *prefix = nullptr) {
70  pthread_mutex_lock (&mutex);
71  filename = tempnam (nullptr, prefix);
72  pthread_mutex_unlock (&mutex);
73  }
74 
75  ~Tempfilename () {
76  if (access (filename.c_str(), F_OK)) {
77  unlink (filename.c_str());
78  }
79  }
80 
81  const char *c_str() {
82  return filename.c_str();
83  }
84 
85 };
86 
87 pthread_mutex_t Tempfilename::mutex = PTHREAD_MUTEX_INITIALIZER;
88 
89 
90 // parameters to use for the test
91 int d = 64;
92 size_t nb = 1000;
93 size_t nq = 100;
94 int nindex = 4;
95 int k = 10;
96 int nlist = 40;
97 
98 typedef faiss::Index::idx_t idx_t;
99 
100 struct CommonData {
101 
102  std::vector <float> database;
103  std::vector <float> queries;
104  std::vector<idx_t> ids;
105  faiss::IndexFlatL2 quantizer;
106 
107  CommonData(): database (nb * d), queries (nq * d), ids(nb), quantizer (d) {
108 
109  for (size_t i = 0; i < nb * d; i++) {
110  database[i] = drand48();
111  }
112  for (size_t i = 0; i < nq * d; i++) {
113  queries[i] = drand48();
114  }
115  for (int i = 0; i < nb; i++) {
116  ids[i] = 123 + 456 * i;
117  }
118  { // just to train the quantizer
119  faiss::IndexIVFFlat iflat (&quantizer, d, nlist);
120  iflat.train(nb, database.data());
121  }
122  }
123 };
124 
125 CommonData cd;
126 
127 
128 
129 /// perform a search on shards, then merge and search again and
130 /// compare results.
131 int compare_merged (faiss::IndexShards *index_shards, bool shift_ids,
132  bool standard_merge = true)
133 {
134 
135  std::vector<idx_t> refI(k * nq);
136  std::vector<float> refD(k * nq);
137 
138  index_shards->search(nq, cd.queries.data(), k, refD.data(), refI.data());
139  Tempfilename filename;
140 
141  std::vector<idx_t> newI(k * nq);
142  std::vector<float> newD(k * nq);
143 
144  if (standard_merge) {
145 
146  for (int i = 1; i < nindex; i++) {
147  merge_into(index_shards->at(0), index_shards->at(i), shift_ids);
148  }
149 
150  index_shards->sync_with_shard_indexes();
151  } else {
152  std::vector<const faiss::InvertedLists *> lists;
153  faiss::IndexIVF *index0 = nullptr;
154  size_t ntotal = 0;
155  for (int i = 0; i < nindex; i++) {
156  auto index_ivf = dynamic_cast<faiss::IndexIVF*>(index_shards->at(i));
157  assert (index_ivf);
158  if (i == 0) {
159  index0 = index_ivf;
160  }
161  lists.push_back (index_ivf->invlists);
162  ntotal += index_ivf->ntotal;
163  }
164 
165  auto il = new faiss::OnDiskInvertedLists(
166  index0->nlist, index0->code_size,
167  filename.c_str());
168 
169  il->merge_from(lists.data(), lists.size());
170 
171  index0->replace_invlists(il, true);
172  index0->ntotal = ntotal;
173  }
174  // search only on first index
175  index_shards->at(0)->search(nq, cd.queries.data(),
176  k, newD.data(), newI.data());
177 
178  size_t ndiff = 0;
179  for (size_t i = 0; i < k * nq; i++) {
180  if (refI[i] != newI[i]) {
181  ndiff ++;
182  }
183  }
184  return ndiff;
185 }
186 
187 
188 // test on IVFFlat with implicit numbering
189 TEST(MERGE, merge_flat_no_ids) {
190  faiss::IndexShards index_shards(d);
191  index_shards.own_fields = true;
192  for (int i = 0; i < nindex; i++) {
193  index_shards.add_shard (
194  new faiss::IndexIVFFlat (&cd.quantizer, d, nlist));
195  }
196  EXPECT_TRUE(index_shards.is_trained);
197  index_shards.add(nb, cd.database.data());
198  size_t prev_ntotal = index_shards.ntotal;
199  int ndiff = compare_merged(&index_shards, true);
200  EXPECT_EQ (prev_ntotal, index_shards.ntotal);
201  EXPECT_EQ(0, ndiff);
202 }
203 
204 
205 // test on IVFFlat, explicit ids
206 TEST(MERGE, merge_flat) {
207  faiss::IndexShards index_shards(d, false, false);
208  index_shards.own_fields = true;
209 
210  for (int i = 0; i < nindex; i++) {
211  index_shards.add_shard (
212  new faiss::IndexIVFFlat (&cd.quantizer, d, nlist));
213  }
214 
215  EXPECT_TRUE(index_shards.is_trained);
216  index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data());
217  int ndiff = compare_merged(&index_shards, false);
218  EXPECT_GE(0, ndiff);
219 }
220 
221 // test on IVFFlat and a VectorTransform
222 TEST(MERGE, merge_flat_vt) {
223  faiss::IndexShards index_shards(d, false, false);
224  index_shards.own_fields = true;
225 
226  // here we have to retrain because of the vectorTransform
227  faiss::RandomRotationMatrix rot(d, d);
228  rot.init(1234);
229  faiss::IndexFlatL2 quantizer (d);
230 
231  { // just to train the quantizer
232  faiss::IndexIVFFlat iflat (&quantizer, d, nlist);
233  faiss::IndexPreTransform ipt (&rot, &iflat);
234  ipt.train(nb, cd.database.data());
235  }
236 
237  for (int i = 0; i < nindex; i++) {
239  new faiss::RandomRotationMatrix (rot),
240  new faiss::IndexIVFFlat (&quantizer, d, nlist)
241  );
242  ipt->own_fields = true;
243  index_shards.add_shard (ipt);
244  }
245  EXPECT_TRUE(index_shards.is_trained);
246  index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data());
247  size_t prev_ntotal = index_shards.ntotal;
248  int ndiff = compare_merged(&index_shards, false);
249  EXPECT_EQ (prev_ntotal, index_shards.ntotal);
250  EXPECT_GE(0, ndiff);
251 }
252 
253 
254 // put the merged invfile on disk
255 TEST(MERGE, merge_flat_ondisk) {
256  faiss::IndexShards index_shards(d, false, false);
257  index_shards.own_fields = true;
258  Tempfilename filename;
259 
260  for (int i = 0; i < nindex; i++) {
261  auto ivf = new faiss::IndexIVFFlat (&cd.quantizer, d, nlist);
262  if (i == 0) {
263  auto il = new faiss::OnDiskInvertedLists (
264  ivf->nlist, ivf->code_size,
265  filename.c_str());
266  ivf->replace_invlists(il, true);
267  }
268  index_shards.add_shard (ivf);
269  }
270 
271  EXPECT_TRUE(index_shards.is_trained);
272  index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data());
273  int ndiff = compare_merged(&index_shards, false);
274 
275  EXPECT_EQ(ndiff, 0);
276 }
277 
278 // non use ondisk specific merge
279 TEST(MERGE, merge_flat_ondisk_2) {
280  faiss::IndexShards index_shards(d, false, false);
281  index_shards.own_fields = true;
282 
283  for (int i = 0; i < nindex; i++) {
284  index_shards.add_shard (
285  new faiss::IndexIVFFlat (&cd.quantizer, d, nlist));
286  }
287  EXPECT_TRUE(index_shards.is_trained);
288  index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data());
289  int ndiff = compare_merged(&index_shards, false, false);
290  EXPECT_GE(0, ndiff);
291 }
Randomly rotate a set of vectors.
void add_with_ids(idx_t n, const float *x, const long *xids) override
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
void train(idx_t n, const float *x) override
Trains the quantizer and calls train_residual to train sub-quantizers.
Definition: IndexIVF.cpp:424
void add(idx_t n, const float *x) override
supported only for sub-indices that implement add_with_ids
long idx_t
all indices are this type
Definition: Index.h:62
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:65
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
bool own_fields
! the sub-index
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:69
size_t nlist
number of possible key values
Definition: IndexIVF.h:34
size_t code_size
code size per vector in bytes
Definition: IndexIVF.h:171