11 #include "../../IndexIVF.h"
12 #include "../../IndexIVFPQ.h"
13 #include "../../IndexFlat.h"
14 #include "../../index_io.h"
15 #include "../test/TestUtils.h"
17 #include <gflags/gflags.h>
20 DEFINE_bool(ivfpq,
false,
"use IVFPQ encoding");
21 DEFINE_int32(codes, 4,
"number of PQ codes per vector");
22 DEFINE_int32(bits_per_code, 8,
"number of bits per PQ code");
25 DEFINE_bool(l2,
true,
"use L2 metric (versus IP metric)");
26 DEFINE_bool(ivfflat,
false,
"use IVF flat encoding");
29 DEFINE_string(out,
"/home/jhj/local/index.out",
"index file for output");
30 DEFINE_int32(dim, 128,
"vector dimension");
31 DEFINE_int32(num_coarse, 100,
"number of coarse centroids");
32 DEFINE_int32(num, 100000,
"total database size");
33 DEFINE_int32(num_train, -1,
"number of database vecs to train on");
36 void fillAndSave(T& index,
int numTrain,
int num,
int dim) {
37 auto trainVecs = faiss::gpu::randVecs(numTrain, dim);
38 index.train(numTrain, trainVecs.data());
40 constexpr
int kAddChunk = 1000000;
42 for (
int i = 0; i < num; i += kAddChunk) {
43 int numRemaining = (num - i) < kAddChunk ? (num - i) : kAddChunk;
44 auto vecs = faiss::gpu::randVecs(numRemaining, dim);
46 printf(
"adding at %d: %d\n", i, numRemaining);
47 index.add(numRemaining, vecs.data());
50 faiss::write_index(&index, FLAGS_out.c_str());
53 int main(
int argc,
char** argv) {
54 gflags::ParseCommandLineFlags(&argc, &argv,
true);
57 if ((FLAGS_ivfpq && FLAGS_ivfflat) ||
58 (!FLAGS_ivfpq && !FLAGS_ivfflat)) {
59 printf(
"must specify either ivfpq or ivfflat\n");
64 auto numCentroids = FLAGS_num_coarse;
66 auto numTrain = FLAGS_num_train;
67 numTrain = numTrain == -1 ? std::max((num / 4), 1) : numTrain;
68 numTrain = std::min(num, numTrain);
73 FLAGS_codes, FLAGS_bits_per_code);
76 printf(
"IVFPQ: codes %d bits per code %d\n",
77 FLAGS_codes, FLAGS_bits_per_code);
78 printf(
"Lists: %d\n", numCentroids);
79 printf(
"Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain);
80 printf(
"output file: %s\n", FLAGS_out.c_str());
82 fillAndSave(index, numTrain, num, dim);
83 }
else if (FLAGS_ivfflat) {
92 FLAGS_l2 ? faiss::METRIC_L2 :
93 faiss::METRIC_INNER_PRODUCT);
95 printf(
"IVFFlat: metric %s\n", FLAGS_l2 ?
"L2" :
"IP");
96 printf(
"Lists: %d\n", numCentroids);
97 printf(
"Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain);
98 printf(
"output file: %s\n", FLAGS_out.c_str());
100 fillAndSave(index, numTrain, num, dim);