Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
test_ivfpq_codec.cpp
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 #include <cstdio>
11 #include <cstdlib>
12 
13 #include <gtest/gtest.h>
14 
15 #include <faiss/IndexIVFPQ.h>
16 #include <faiss/IndexFlat.h>
17 #include <faiss/utils.h>
18 
19 
20 // dimension of the vectors to index
21 int d = 64;
22 
23 // size of the database we plan to index
24 size_t nb = 8000;
25 
26 
27 double eval_codec_error (long ncentroids, long m, const std::vector<float> &v)
28 {
29  faiss::IndexFlatL2 coarse_quantizer (d);
30  faiss::IndexIVFPQ index (&coarse_quantizer, d,
31  ncentroids, m, 8);
32  index.pq.cp.niter = 10; // speed up train
33  index.train (nb, v.data());
34 
35  // encode and decode to compute reconstruction error
36 
37  std::vector<long> keys (nb);
38  std::vector<uint8_t> codes (nb * m);
39  index.encode_multiple (nb, keys.data(), v.data(), codes.data(), true);
40 
41  std::vector<float> v2 (nb * d);
42  index.decode_multiple (nb, keys.data(), codes.data(), v2.data());
43 
44  return faiss::fvec_L2sqr (v.data(), v2.data(), nb * d);
45 }
46 
47 
48 
49 TEST(IVFPQ, codec) {
50 
51  std::vector <float> database (nb * d);
52  for (size_t i = 0; i < nb * d; i++) {
53  database[i] = drand48();
54  }
55 
56  double err0 = eval_codec_error(16, 8, database);
57 
58  // should be more accurate as there are more coarse centroids
59  double err1 = eval_codec_error(128, 8, database);
60  EXPECT_GT(err0, err1);
61 
62  // should be more accurate as there are more PQ codes
63  double err2 = eval_codec_error(16, 16, database);
64  EXPECT_GT(err0, err2);
65 }
float fvec_L2sqr(const float *x, const float *y, size_t d)
Squared L2 distance between two vectors.
Definition: utils.cpp:432