Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_ivfpq_indexing.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved
10 
11 #include <cstdio>
12 #include <cstdlib>
13 
14 #include <gtest/gtest.h>
15 
16 #include <faiss/IndexIVFPQ.h>
17 #include <faiss/IndexFlat.h>
18 #include <faiss/index_io.h>
19 
20 TEST(IVFPQ, accuracy) {
21 
22  // dimension of the vectors to index
23  int d = 64;
24 
25  // size of the database we plan to index
26  size_t nb = 1000;
27 
28  // make a set of nt training vectors in the unit cube
29  // (could be the database)
30  size_t nt = 1500;
31 
32  // make the index object and train it
33  faiss::IndexFlatL2 coarse_quantizer (d);
34 
35  // a reasonable number of cetroids to index nb vectors
36  int ncentroids = 25;
37 
38  faiss::IndexIVFPQ index (&coarse_quantizer, d,
39  ncentroids, 16, 8);
40 
41  // index that gives the ground-truth
42  faiss::IndexFlatL2 index_gt (d);
43 
44  srand48 (35);
45 
46  { // training
47 
48  std::vector <float> trainvecs (nt * d);
49  for (size_t i = 0; i < nt * d; i++) {
50  trainvecs[i] = drand48();
51  }
52  index.verbose = true;
53  index.train (nt, trainvecs.data());
54  }
55 
56  { // populating the database
57 
58  std::vector <float> database (nb * d);
59  for (size_t i = 0; i < nb * d; i++) {
60  database[i] = drand48();
61  }
62 
63  index.add (nb, database.data());
64  index_gt.add (nb, database.data());
65  }
66 
67  int nq = 200;
68  int n_ok;
69 
70  { // searching the database
71 
72  std::vector <float> queries (nq * d);
73  for (size_t i = 0; i < nq * d; i++) {
74  queries[i] = drand48();
75  }
76 
77  std::vector<faiss::Index::idx_t> gt_nns (nq);
78  std::vector<float> gt_dis (nq);
79 
80  index_gt.search (nq, queries.data(), 1,
81  gt_dis.data(), gt_nns.data());
82 
83  index.nprobe = 5;
84  int k = 5;
85  std::vector<faiss::Index::idx_t> nns (k * nq);
86  std::vector<float> dis (k * nq);
87 
88  index.search (nq, queries.data(), k, dis.data(), nns.data());
89 
90  n_ok = 0;
91  for (int q = 0; q < nq; q++) {
92 
93  for (int i = 0; i < k; i++)
94  if (nns[q * k + i] == gt_nns[q])
95  n_ok++;
96  }
97  EXPECT_GT(n_ok, nq * 0.4);
98  }
99 
100 }