/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include using namespace faiss; DEFINE_uint32(d, 128, "dimension"); DEFINE_uint32(n, 2000, "dimension"); DEFINE_uint32(iterations, 20, "iterations"); static void bench_reconstruction_error( benchmark::State& state, ScalarQuantizer::QuantizerType type, int d, int n) { std::vector x(d * n); float_rand(x.data(), d * n, 12345); // make sure it's idempotent ScalarQuantizer sq(d, type); sq.train(n, x.data()); size_t code_size = sq.code_size; state.counters["code_size"] = sq.code_size; // encode std::vector codes(code_size * n); sq.compute_codes(x.data(), codes.data(), n); // decode std::vector x2(d * n); sq.decode(codes.data(), x2.data(), n); state.counters["sql2_recons_error"] = fvec_L2sqr(x.data(), x2.data(), n * d) / n; // encode again std::vector codes2(code_size * n); sq.compute_codes(x2.data(), codes2.data(), n); size_t ndiff = 0; for (size_t i = 0; i < codes.size(); i++) { if (codes[i] != codes2[i]) ndiff++; } state.counters["ndiff_for_idempotence"] = ndiff; state.counters["code_size_two"] = codes.size(); } int main(int argc, char** argv) { benchmark::Initialize(&argc, argv); gflags::AllowCommandLineReparsing(); gflags::ParseCommandLineFlags(&argc, &argv, true); int iterations = FLAGS_iterations; int d = FLAGS_d; int n = FLAGS_n; auto benchs = ::perf_tests::sq_types(); for (auto& [bench_name, quantizer_type] : benchs) { benchmark::RegisterBenchmark( bench_name.c_str(), bench_reconstruction_error, quantizer_type, d, n) ->Iterations(iterations); } benchmark::RunSpecifiedBenchmarks(); benchmark::Shutdown(); }