Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_merge.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdio>
10 #include <cstdlib>
11 
12 #include <gtest/gtest.h>
13 
14 #include <faiss/IndexIVFFlat.h>
15 #include <faiss/IndexIVFPQ.h>
16 #include <faiss/IndexFlat.h>
17 #include <faiss/MetaIndexes.h>
18 #include <faiss/FaissAssert.h>
19 #include <faiss/VectorTransform.h>
20 #include <faiss/OnDiskInvertedLists.h>
21 #include <faiss/IVFlib.h>
22 
23 
24 namespace {
25 
26 
27 struct Tempfilename {
28 
29  static pthread_mutex_t mutex;
30 
31  std::string filename;
32 
33  Tempfilename (const char *prefix = nullptr) {
34  pthread_mutex_lock (&mutex);
35  char *cfname = tempnam (nullptr, prefix);
36  filename = cfname;
37  free(cfname);
38  pthread_mutex_unlock (&mutex);
39  }
40 
41  ~Tempfilename () {
42  if (access (filename.c_str(), F_OK)) {
43  unlink (filename.c_str());
44  }
45  }
46 
47  const char *c_str() {
48  return filename.c_str();
49  }
50 
51 };
52 
53 pthread_mutex_t Tempfilename::mutex = PTHREAD_MUTEX_INITIALIZER;
54 
55 
56 // parameters to use for the test
57 int d = 64;
58 size_t nb = 1000;
59 size_t nq = 100;
60 int nindex = 4;
61 int k = 10;
62 int nlist = 40;
63 
64 typedef faiss::Index::idx_t idx_t;
65 
66 struct CommonData {
67 
68  std::vector <float> database;
69  std::vector <float> queries;
70  std::vector<idx_t> ids;
71  faiss::IndexFlatL2 quantizer;
72 
73  CommonData(): database (nb * d), queries (nq * d), ids(nb), quantizer (d) {
74 
75  for (size_t i = 0; i < nb * d; i++) {
76  database[i] = drand48();
77  }
78  for (size_t i = 0; i < nq * d; i++) {
79  queries[i] = drand48();
80  }
81  for (int i = 0; i < nb; i++) {
82  ids[i] = 123 + 456 * i;
83  }
84  { // just to train the quantizer
85  faiss::IndexIVFFlat iflat (&quantizer, d, nlist);
86  iflat.train(nb, database.data());
87  }
88  }
89 };
90 
91 CommonData cd;
92 
93 /// perform a search on shards, then merge and search again and
94 /// compare results.
95 int compare_merged (faiss::IndexShards *index_shards, bool shift_ids,
96  bool standard_merge = true)
97 {
98 
99  std::vector<idx_t> refI(k * nq);
100  std::vector<float> refD(k * nq);
101 
102  index_shards->search(nq, cd.queries.data(), k, refD.data(), refI.data());
103  Tempfilename filename;
104 
105  std::vector<idx_t> newI(k * nq);
106  std::vector<float> newD(k * nq);
107 
108  if (standard_merge) {
109 
110  for (int i = 1; i < nindex; i++) {
111  faiss::ivflib::merge_into(
112  index_shards->at(0), index_shards->at(i),
113  shift_ids);
114  }
115 
116  index_shards->sync_with_shard_indexes();
117  } else {
118  std::vector<const faiss::InvertedLists *> lists;
119  faiss::IndexIVF *index0 = nullptr;
120  size_t ntotal = 0;
121  for (int i = 0; i < nindex; i++) {
122  auto index_ivf = dynamic_cast<faiss::IndexIVF*>(index_shards->at(i));
123  assert (index_ivf);
124  if (i == 0) {
125  index0 = index_ivf;
126  }
127  lists.push_back (index_ivf->invlists);
128  ntotal += index_ivf->ntotal;
129  }
130 
131  auto il = new faiss::OnDiskInvertedLists(
132  index0->nlist, index0->code_size,
133  filename.c_str());
134 
135  il->merge_from(lists.data(), lists.size());
136 
137  index0->replace_invlists(il, true);
138  index0->ntotal = ntotal;
139  }
140  // search only on first index
141  index_shards->at(0)->search(nq, cd.queries.data(),
142  k, newD.data(), newI.data());
143 
144  size_t ndiff = 0;
145  for (size_t i = 0; i < k * nq; i++) {
146  if (refI[i] != newI[i]) {
147  ndiff ++;
148  }
149  }
150  return ndiff;
151 }
152 
153 } // namespace
154 
155 
156 // test on IVFFlat with implicit numbering
157 TEST(MERGE, merge_flat_no_ids) {
158  faiss::IndexShards index_shards(d);
159  index_shards.own_fields = true;
160  for (int i = 0; i < nindex; i++) {
161  index_shards.add_shard (
162  new faiss::IndexIVFFlat (&cd.quantizer, d, nlist));
163  }
164  EXPECT_TRUE(index_shards.is_trained);
165  index_shards.add(nb, cd.database.data());
166  size_t prev_ntotal = index_shards.ntotal;
167  int ndiff = compare_merged(&index_shards, true);
168  EXPECT_EQ (prev_ntotal, index_shards.ntotal);
169  EXPECT_EQ(0, ndiff);
170 }
171 
172 
173 // test on IVFFlat, explicit ids
174 TEST(MERGE, merge_flat) {
175  faiss::IndexShards index_shards(d, false, false);
176  index_shards.own_fields = true;
177 
178  for (int i = 0; i < nindex; i++) {
179  index_shards.add_shard (
180  new faiss::IndexIVFFlat (&cd.quantizer, d, nlist));
181  }
182 
183  EXPECT_TRUE(index_shards.is_trained);
184  index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data());
185  int ndiff = compare_merged(&index_shards, false);
186  EXPECT_GE(0, ndiff);
187 }
188 
189 // test on IVFFlat and a VectorTransform
190 TEST(MERGE, merge_flat_vt) {
191  faiss::IndexShards index_shards(d, false, false);
192  index_shards.own_fields = true;
193 
194  // here we have to retrain because of the vectorTransform
195  faiss::RandomRotationMatrix rot(d, d);
196  rot.init(1234);
197  faiss::IndexFlatL2 quantizer (d);
198 
199  { // just to train the quantizer
200  faiss::IndexIVFFlat iflat (&quantizer, d, nlist);
201  faiss::IndexPreTransform ipt (&rot, &iflat);
202  ipt.train(nb, cd.database.data());
203  }
204 
205  for (int i = 0; i < nindex; i++) {
207  new faiss::RandomRotationMatrix (rot),
208  new faiss::IndexIVFFlat (&quantizer, d, nlist)
209  );
210  ipt->own_fields = true;
211  index_shards.add_shard (ipt);
212  }
213  EXPECT_TRUE(index_shards.is_trained);
214  index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data());
215  size_t prev_ntotal = index_shards.ntotal;
216  int ndiff = compare_merged(&index_shards, false);
217  EXPECT_EQ (prev_ntotal, index_shards.ntotal);
218  EXPECT_GE(0, ndiff);
219 }
220 
221 
222 // put the merged invfile on disk
223 TEST(MERGE, merge_flat_ondisk) {
224  faiss::IndexShards index_shards(d, false, false);
225  index_shards.own_fields = true;
226  Tempfilename filename;
227 
228  for (int i = 0; i < nindex; i++) {
229  auto ivf = new faiss::IndexIVFFlat (&cd.quantizer, d, nlist);
230  if (i == 0) {
231  auto il = new faiss::OnDiskInvertedLists (
232  ivf->nlist, ivf->code_size,
233  filename.c_str());
234  ivf->replace_invlists(il, true);
235  }
236  index_shards.add_shard (ivf);
237  }
238 
239  EXPECT_TRUE(index_shards.is_trained);
240  index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data());
241  int ndiff = compare_merged(&index_shards, false);
242 
243  EXPECT_EQ(ndiff, 0);
244 }
245 
246 // now use ondisk specific merge
247 TEST(MERGE, merge_flat_ondisk_2) {
248  faiss::IndexShards index_shards(d, false, false);
249  index_shards.own_fields = true;
250 
251  for (int i = 0; i < nindex; i++) {
252  index_shards.add_shard (
253  new faiss::IndexIVFFlat (&cd.quantizer, d, nlist));
254  }
255  EXPECT_TRUE(index_shards.is_trained);
256  index_shards.add_with_ids(nb, cd.database.data(), cd.ids.data());
257  int ndiff = compare_merged(&index_shards, false, false);
258  EXPECT_GE(0, ndiff);
259 }
Randomly rotate a set of vectors.
void add_with_ids(idx_t n, const float *x, const long *xids) override
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
void add(idx_t n, const float *x) override
supported only for sub-indices that implement add_with_ids
long idx_t
all indices are this type
Definition: Index.h:64
void replace_invlists(InvertedLists *il, bool own=false)
replace the inverted lists, old one is deallocated if own_invlists
Definition: IndexIVF.cpp:486
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
bool own_fields
! the sub-index
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:71
size_t nlist
number of possible key values
Definition: IndexIVF.h:34
size_t code_size
code size per vector in bytes
Definition: IndexIVF.h:96