9 #include "../../IndexIVFFlat.h"
10 #include "../../IndexIVFPQ.h"
11 #include "../../IndexFlat.h"
12 #include "../../index_io.h"
13 #include "../test/TestUtils.h"
15 #include <gflags/gflags.h>
18 DEFINE_bool(ivfpq,
false,
"use IVFPQ encoding");
19 DEFINE_int32(codes, 4,
"number of PQ codes per vector");
20 DEFINE_int32(bits_per_code, 8,
"number of bits per PQ code");
23 DEFINE_bool(l2,
true,
"use L2 metric (versus IP metric)");
24 DEFINE_bool(ivfflat,
false,
"use IVF flat encoding");
27 DEFINE_string(out,
"/home/jhj/local/index.out",
"index file for output");
28 DEFINE_int32(dim, 128,
"vector dimension");
29 DEFINE_int32(num_coarse, 100,
"number of coarse centroids");
30 DEFINE_int32(num, 100000,
"total database size");
31 DEFINE_int32(num_train, -1,
"number of database vecs to train on");
34 void fillAndSave(T& index,
int numTrain,
int num,
int dim) {
35 auto trainVecs = faiss::gpu::randVecs(numTrain, dim);
36 index.train(numTrain, trainVecs.data());
38 constexpr
int kAddChunk = 1000000;
40 for (
int i = 0; i < num; i += kAddChunk) {
41 int numRemaining = (num - i) < kAddChunk ? (num - i) : kAddChunk;
42 auto vecs = faiss::gpu::randVecs(numRemaining, dim);
44 printf(
"adding at %d: %d\n", i, numRemaining);
45 index.add(numRemaining, vecs.data());
48 faiss::write_index(&index, FLAGS_out.c_str());
51 int main(
int argc,
char** argv) {
52 gflags::ParseCommandLineFlags(&argc, &argv,
true);
55 if ((FLAGS_ivfpq && FLAGS_ivfflat) ||
56 (!FLAGS_ivfpq && !FLAGS_ivfflat)) {
57 printf(
"must specify either ivfpq or ivfflat\n");
62 auto numCentroids = FLAGS_num_coarse;
64 auto numTrain = FLAGS_num_train;
65 numTrain = numTrain == -1 ? std::max((num / 4), 1) : numTrain;
66 numTrain = std::min(num, numTrain);
71 FLAGS_codes, FLAGS_bits_per_code);
74 printf(
"IVFPQ: codes %d bits per code %d\n",
75 FLAGS_codes, FLAGS_bits_per_code);
76 printf(
"Lists: %d\n", numCentroids);
77 printf(
"Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain);
78 printf(
"output file: %s\n", FLAGS_out.c_str());
80 fillAndSave(index, numTrain, num, dim);
81 }
else if (FLAGS_ivfflat) {
90 FLAGS_l2 ? faiss::METRIC_L2 :
91 faiss::METRIC_INNER_PRODUCT);
93 printf(
"IVFFlat: metric %s\n", FLAGS_l2 ?
"L2" :
"IP");
94 printf(
"Lists: %d\n", numCentroids);
95 printf(
"Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain);
96 printf(
"output file: %s\n", FLAGS_out.c_str());
98 fillAndSave(index, numTrain, num, dim);