faiss/perf_tests/bench_no_multithreading_rcq...

66 lines
2.0 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <gflags/gflags.h>
#include <benchmark/benchmark.h>
#include <faiss/IndexAdditiveQuantizer.h> // @manual=//faiss:faiss_no_multithreading
#include <faiss/utils/random.h> // @manual=//faiss:faiss_no_multithreading
using namespace faiss;
DEFINE_uint32(iterations, 20, "iterations");
DEFINE_uint32(nprobe, 1, "nprobe");
DEFINE_uint32(batch_size, 1, "batch_size");
DEFINE_double(beam_factor, 4.0, "beam factor");
static void bench_search(
benchmark::State& state,
int batch_size,
int nprobe,
float beam_factor) {
int d = 512;
int nt = 2 << 15;
std::vector<float> xt(d * nt);
float_rand(xt.data(), d * nt, 12345);
ResidualCoarseQuantizer rq(d, {16, 8});
rq.verbose = false;
rq.train(nt, xt.data());
std::vector<float> xq(d * batch_size);
float_rand(xq.data(), d * batch_size, 12345);
std::vector<float> distances(nprobe * batch_size);
std::vector<int64_t> clusterIndices(nprobe * batch_size);
SearchParametersResidualCoarseQuantizer param;
param.beam_factor = beam_factor;
for (auto _ : state) {
rq.search(
batch_size,
xq.data(),
nprobe,
distances.data(),
clusterIndices.data(),
&param);
}
}
int main(int argc, char** argv) {
benchmark::Initialize(&argc, argv);
gflags::AllowCommandLineReparsing();
gflags::ParseCommandLineFlags(&argc, &argv, true);
int iterations = FLAGS_iterations;
int nprobe = FLAGS_nprobe;
float beam_factor = FLAGS_beam_factor;
int batch_size = FLAGS_batch_size;
benchmark::RegisterBenchmark(
"search", bench_search, batch_size, nprobe, beam_factor)
->Iterations(iterations);
benchmark::RunSpecifiedBenchmarks();
benchmark::Shutdown();
}