faiss/perf_tests/bench_no_multithreading_rcq_search.cpp
Facebook Community Bot c8d1474fc5
Re-sync with internal repository (#3885)
The internal and external repositories are out of sync. This Pull Request attempts to brings them back in sync by patching the GitHub repository. Please carefully review this patch. You must disable ShipIt for your project in order to merge this pull request. DO NOT IMPORT this pull request. Instead, merge it directly on GitHub using the MERGE BUTTON. Re-enable ShipIt after merging.
2024-09-24 07:49:31 -07:00

66 lines
2.0 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <gflags/gflags.h>
#include <benchmark/benchmark.h>
#include <faiss/IndexAdditiveQuantizer.h> // @manual=//faiss:faiss_no_multithreading
#include <faiss/utils/random.h> // @manual=//faiss:faiss_no_multithreading
using namespace faiss;
DEFINE_uint32(iterations, 20, "iterations");
DEFINE_uint32(nprobe, 1, "nprobe");
DEFINE_uint32(batch_size, 1, "batch_size");
DEFINE_double(beam_factor, 4.0, "beam factor");
static void bench_search(
benchmark::State& state,
int batch_size,
int nprobe,
float beam_factor) {
int d = 512;
int nt = 2 << 15;
std::vector<float> xt(d * nt);
float_rand(xt.data(), d * nt, 12345);
ResidualCoarseQuantizer rq(d, {16, 8});
rq.verbose = false;
rq.train(nt, xt.data());
std::vector<float> xq(d * batch_size);
float_rand(xq.data(), d * batch_size, 12345);
std::vector<float> distances(nprobe * batch_size);
std::vector<int64_t> clusterIndices(nprobe * batch_size);
SearchParametersResidualCoarseQuantizer param;
param.beam_factor = beam_factor;
for (auto _ : state) {
rq.search(
batch_size,
xq.data(),
nprobe,
distances.data(),
clusterIndices.data(),
&param);
}
}
int main(int argc, char** argv) {
benchmark::Initialize(&argc, argv);
gflags::AllowCommandLineReparsing();
gflags::ParseCommandLineFlags(&argc, &argv, true);
int iterations = FLAGS_iterations;
int nprobe = FLAGS_nprobe;
float beam_factor = FLAGS_beam_factor;
int batch_size = FLAGS_batch_size;
benchmark::RegisterBenchmark(
"search", bench_search, batch_size, nprobe, beam_factor)
->Iterations(iterations);
benchmark::RunSpecifiedBenchmarks();
benchmark::Shutdown();
}