Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_ivfpq_indexing.cpp
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved
11 
12 #include <cstdio>
13 #include <cstdlib>
14 
15 #include <gtest/gtest.h>
16 
17 #include <faiss/IndexIVFPQ.h>
18 #include <faiss/IndexFlat.h>
19 #include <faiss/index_io.h>
20 
21 TEST(IVFPQ, accuracy) {
22 
23  // dimension of the vectors to index
24  int d = 64;
25 
26  // size of the database we plan to index
27  size_t nb = 1000;
28 
29  // make a set of nt training vectors in the unit cube
30  // (could be the database)
31  size_t nt = 1500;
32 
33  // make the index object and train it
34  faiss::IndexFlatL2 coarse_quantizer (d);
35 
36  // a reasonable number of cetroids to index nb vectors
37  int ncentroids = 25;
38 
39  faiss::IndexIVFPQ index (&coarse_quantizer, d,
40  ncentroids, 16, 8);
41 
42  // index that gives the ground-truth
43  faiss::IndexFlatL2 index_gt (d);
44 
45  srand48 (35);
46 
47  { // training
48 
49  std::vector <float> trainvecs (nt * d);
50  for (size_t i = 0; i < nt * d; i++) {
51  trainvecs[i] = drand48();
52  }
53  index.verbose = true;
54  index.train (nt, trainvecs.data());
55  }
56 
57  { // populating the database
58 
59  std::vector <float> database (nb * d);
60  for (size_t i = 0; i < nb * d; i++) {
61  database[i] = drand48();
62  }
63 
64  index.add (nb, database.data());
65  index_gt.add (nb, database.data());
66  }
67 
68  int nq = 200;
69  int n_ok;
70 
71  { // searching the database
72 
73  std::vector <float> queries (nq * d);
74  for (size_t i = 0; i < nq * d; i++) {
75  queries[i] = drand48();
76  }
77 
78  std::vector<faiss::Index::idx_t> gt_nns (nq);
79  std::vector<float> gt_dis (nq);
80 
81  index_gt.search (nq, queries.data(), 1,
82  gt_dis.data(), gt_nns.data());
83 
84  index.nprobe = 5;
85  int k = 5;
86  std::vector<faiss::Index::idx_t> nns (k * nq);
87  std::vector<float> dis (k * nq);
88 
89  index.search (nq, queries.data(), k, dis.data(), nns.data());
90 
91  n_ok = 0;
92  for (int q = 0; q < nq; q++) {
93 
94  for (int i = 0; i < k; i++)
95  if (nns[q * k + i] == gt_nns[q])
96  n_ok++;
97  }
98  EXPECT_GT(n_ok, nq * 0.4);
99  }
100 
101 }