10 #include "../../IndexIVFFlat.h"
11 #include "../../IndexIVFPQ.h"
12 #include "../../IndexFlat.h"
13 #include "../../index_io.h"
14 #include "../test/TestUtils.h"
16 #include <gflags/gflags.h>
19 DEFINE_bool(ivfpq,
false,
"use IVFPQ encoding");
20 DEFINE_int32(codes, 4,
"number of PQ codes per vector");
21 DEFINE_int32(bits_per_code, 8,
"number of bits per PQ code");
24 DEFINE_bool(l2,
true,
"use L2 metric (versus IP metric)");
25 DEFINE_bool(ivfflat,
false,
"use IVF flat encoding");
28 DEFINE_string(out,
"/home/jhj/local/index.out",
"index file for output");
29 DEFINE_int32(dim, 128,
"vector dimension");
30 DEFINE_int32(num_coarse, 100,
"number of coarse centroids");
31 DEFINE_int32(num, 100000,
"total database size");
32 DEFINE_int32(num_train, -1,
"number of database vecs to train on");
35 void fillAndSave(T& index,
int numTrain,
int num,
int dim) {
36 auto trainVecs = faiss::gpu::randVecs(numTrain, dim);
37 index.train(numTrain, trainVecs.data());
39 constexpr
int kAddChunk = 1000000;
41 for (
int i = 0; i < num; i += kAddChunk) {
42 int numRemaining = (num - i) < kAddChunk ? (num - i) : kAddChunk;
43 auto vecs = faiss::gpu::randVecs(numRemaining, dim);
45 printf(
"adding at %d: %d\n", i, numRemaining);
46 index.add(numRemaining, vecs.data());
49 faiss::write_index(&index, FLAGS_out.c_str());
52 int main(
int argc,
char** argv) {
53 gflags::ParseCommandLineFlags(&argc, &argv,
true);
56 if ((FLAGS_ivfpq && FLAGS_ivfflat) ||
57 (!FLAGS_ivfpq && !FLAGS_ivfflat)) {
58 printf(
"must specify either ivfpq or ivfflat\n");
63 auto numCentroids = FLAGS_num_coarse;
65 auto numTrain = FLAGS_num_train;
66 numTrain = numTrain == -1 ? std::max((num / 4), 1) : numTrain;
67 numTrain = std::min(num, numTrain);
72 FLAGS_codes, FLAGS_bits_per_code);
75 printf(
"IVFPQ: codes %d bits per code %d\n",
76 FLAGS_codes, FLAGS_bits_per_code);
77 printf(
"Lists: %d\n", numCentroids);
78 printf(
"Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain);
79 printf(
"output file: %s\n", FLAGS_out.c_str());
81 fillAndSave(index, numTrain, num, dim);
82 }
else if (FLAGS_ivfflat) {
91 FLAGS_l2 ? faiss::METRIC_L2 :
92 faiss::METRIC_INNER_PRODUCT);
94 printf(
"IVFFlat: metric %s\n", FLAGS_l2 ?
"L2" :
"IP");
95 printf(
"Lists: %d\n", numCentroids);
96 printf(
"Database: dim %d num vecs %d trained on %d\n", dim, num, numTrain);
97 printf(
"output file: %s\n", FLAGS_out.c_str());
99 fillAndSave(index, numTrain, num, dim);