Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
WriteIndex.cpp
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "../../IndexIVF.h"
13 #include "../../IndexIVFPQ.h"
14 #include "../../IndexFlat.h"
15 #include "../../index_io.h"
16 #include "../test/TestUtils.h"
17 #include <vector>
18 #include <gflags/gflags.h>
19 
20 // For IVFPQ:
21 DEFINE_bool(ivfpq, false, "use IVFPQ encoding");
22 DEFINE_int32(codes, 4, "number of PQ codes per vector");
23 DEFINE_int32(bits_per_code, 8, "number of bits per PQ code");
24 
25 // For IVFFlat:
26 DEFINE_bool(l2, true, "use L2 metric (versus IP metric)");
27 DEFINE_bool(ivfflat, false, "use IVF flat encoding");
28 
29 // For both:
30 DEFINE_string(out, "/home/jhj/local/index.out", "index file for output");
31 DEFINE_int32(dim, 128, "vector dimension");
32 DEFINE_int32(num_coarse, 100, "number of coarse centroids");
33 DEFINE_int32(num, 100000, "total database size");
34 DEFINE_int32(num_train, -1, "number of database vecs to train on");
35 
36 template <typename T>
37 void fillAndSave(T& index, int numTrain, int num, int dim) {
38  auto trainVecs = faiss::gpu::randVecs(numTrain, dim);
39  index.train(numTrain, trainVecs.data());
40 
41  constexpr int kAddChunk = 1000000;
42 
43  for (int i = 0; i < num; i += kAddChunk) {
44  int numRemaining = (num - i) < kAddChunk ? (num - i) : kAddChunk;
45  auto vecs = faiss::gpu::randVecs(numRemaining, dim);
46 
47  printf("adding at %d: %d\n", i, numRemaining);
48  index.add(numRemaining, vecs.data());
49  }
50 
51  faiss::write_index(&index, FLAGS_out.c_str());
52 }
53 
54 int main(int argc, char** argv) {
55  google::ParseCommandLineFlags(&argc, &argv, true);
56 
57  // Either ivfpq or ivfflat must be set
58  if ((FLAGS_ivfpq && FLAGS_ivfflat) ||
59  (!FLAGS_ivfpq && !FLAGS_ivfflat)) {
60  printf("must specify either ivfpq or ivfflat\n");
61  return 1;
62  }
63 
64  auto dim = FLAGS_dim;
65  auto numCentroids = FLAGS_num_coarse;
66  auto num = FLAGS_num;
67  auto numTrain = FLAGS_num_train;
68  numTrain = numTrain == -1 ? std::max((num / 4), 1) : numTrain;
69  numTrain = std::min(num, numTrain);
70 
71  if (FLAGS_ivfpq) {
72  faiss::IndexFlatL2 quantizer(dim);
73  faiss::IndexIVFPQ index(&quantizer, dim, numCentroids,
74  FLAGS_codes, FLAGS_bits_per_code);
75  index.verbose = true;
76 
77  printf("IVFPQ: codes %d bits per code %d\n",
78  FLAGS_codes, FLAGS_bits_per_code);
79  printf("Lists: %d\n", numCentroids);
80  printf("Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain);
81  printf("output file: %s\n", FLAGS_out.c_str());
82 
83  fillAndSave(index, numTrain, num, dim);
84  } else if (FLAGS_ivfflat) {
85  faiss::IndexFlatL2 quantizerL2(dim);
86  faiss::IndexFlatIP quantizerIP(dim);
87 
88  faiss::IndexFlat* quantizer = FLAGS_l2 ?
89  (faiss::IndexFlat*) &quantizerL2 :
90  (faiss::IndexFlat*) &quantizerIP;
91 
92  faiss::IndexIVFFlat index(quantizer, dim, numCentroids,
93  FLAGS_l2 ? faiss::METRIC_L2 :
94  faiss::METRIC_INNER_PRODUCT);
95 
96  printf("IVFFlat: metric %s\n", FLAGS_l2 ? "L2" : "IP");
97  printf("Lists: %d\n", numCentroids);
98  printf("Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain);
99  printf("output file: %s\n", FLAGS_out.c_str());
100 
101  fillAndSave(index, numTrain, num, dim);
102  }
103 
104  return 0;
105 }