Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_ivfpq_codec.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdio>
10 #include <cstdlib>
11 
12 #include <gtest/gtest.h>
13 
14 #include <faiss/IndexIVFPQ.h>
15 #include <faiss/IndexFlat.h>
16 #include <faiss/utils.h>
17 
18 
19 // dimension of the vectors to index
20 int d = 64;
21 
22 // size of the database we plan to index
23 size_t nb = 8000;
24 
25 
26 double eval_codec_error (long ncentroids, long m, const std::vector<float> &v)
27 {
28  faiss::IndexFlatL2 coarse_quantizer (d);
29  faiss::IndexIVFPQ index (&coarse_quantizer, d,
30  ncentroids, m, 8);
31  index.pq.cp.niter = 10; // speed up train
32  index.train (nb, v.data());
33 
34  // encode and decode to compute reconstruction error
35 
36  std::vector<long> keys (nb);
37  std::vector<uint8_t> codes (nb * m);
38  index.encode_multiple (nb, keys.data(), v.data(), codes.data(), true);
39 
40  std::vector<float> v2 (nb * d);
41  index.decode_multiple (nb, keys.data(), codes.data(), v2.data());
42 
43  return faiss::fvec_L2sqr (v.data(), v2.data(), nb * d);
44 }
45 
46 
47 
48 TEST(IVFPQ, codec) {
49 
50  std::vector <float> database (nb * d);
51  for (size_t i = 0; i < nb * d; i++) {
52  database[i] = drand48();
53  }
54 
55  double err0 = eval_codec_error(16, 8, database);
56 
57  // should be more accurate as there are more coarse centroids
58  double err1 = eval_codec_error(128, 8, database);
59  EXPECT_GT(err0, err1);
60 
61  // should be more accurate as there are more PQ codes
62  double err2 = eval_codec_error(16, 16, database);
63  EXPECT_GT(err0, err2);
64 }
float fvec_L2sqr(const float *x, const float *y, size_t d)
Squared L2 distance between two vectors.
Definition: utils.cpp:481