Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_ivfpq_indexing.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include <cstdio>
11 #include <cstdlib>
12 
13 #include <gtest/gtest.h>
14 
15 #include <faiss/IndexIVFPQ.h>
16 #include <faiss/IndexFlat.h>
17 #include <faiss/index_io.h>
18 
19 TEST(IVFPQ, accuracy) {
20 
21  // dimension of the vectors to index
22  int d = 64;
23 
24  // size of the database we plan to index
25  size_t nb = 1000;
26 
27  // make a set of nt training vectors in the unit cube
28  // (could be the database)
29  size_t nt = 1500;
30 
31  // make the index object and train it
32  faiss::IndexFlatL2 coarse_quantizer (d);
33 
34  // a reasonable number of cetroids to index nb vectors
35  int ncentroids = 25;
36 
37  faiss::IndexIVFPQ index (&coarse_quantizer, d,
38  ncentroids, 16, 8);
39 
40  // index that gives the ground-truth
41  faiss::IndexFlatL2 index_gt (d);
42 
43  srand48 (35);
44 
45  { // training
46 
47  std::vector <float> trainvecs (nt * d);
48  for (size_t i = 0; i < nt * d; i++) {
49  trainvecs[i] = drand48();
50  }
51  index.verbose = true;
52  index.train (nt, trainvecs.data());
53  }
54 
55  { // populating the database
56 
57  std::vector <float> database (nb * d);
58  for (size_t i = 0; i < nb * d; i++) {
59  database[i] = drand48();
60  }
61 
62  index.add (nb, database.data());
63  index_gt.add (nb, database.data());
64  }
65 
66  int nq = 200;
67  int n_ok;
68 
69  { // searching the database
70 
71  std::vector <float> queries (nq * d);
72  for (size_t i = 0; i < nq * d; i++) {
73  queries[i] = drand48();
74  }
75 
76  std::vector<faiss::Index::idx_t> gt_nns (nq);
77  std::vector<float> gt_dis (nq);
78 
79  index_gt.search (nq, queries.data(), 1,
80  gt_dis.data(), gt_nns.data());
81 
82  index.nprobe = 5;
83  int k = 5;
84  std::vector<faiss::Index::idx_t> nns (k * nq);
85  std::vector<float> dis (k * nq);
86 
87  index.search (nq, queries.data(), k, dis.data(), nns.data());
88 
89  n_ok = 0;
90  for (int q = 0; q < nq; q++) {
91 
92  for (int i = 0; i < k; i++)
93  if (nns[q * k + i] == gt_nns[q])
94  n_ok++;
95  }
96  EXPECT_GT(n_ok, nq * 0.4);
97  }
98 
99 }