faiss/benchs/bench_heap_replace.cpp

137 lines
3.7 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstdio>
#include <omp.h>
#include <faiss/utils/utils.h>
#include <faiss/utils/random.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/FaissAssert.h>
using namespace faiss;
void addn_default(
size_t n, size_t k,
const float *x, int64_t *heap_ids, float * heap_val)
{
for (size_t i = 0; i < k; i++) {
minheap_push(i + 1, heap_val, heap_ids, x[i], i);
}
for (size_t i = k; i < n; i++) {
if (x[i] > heap_val[0]) {
minheap_pop(k, heap_val, heap_ids);
minheap_push(k, heap_val, heap_ids, x[i], i);
}
}
minheap_reorder(k, heap_val, heap_ids);
}
void addn_replace(
size_t n, size_t k,
const float *x, int64_t *heap_ids, float * heap_val)
{
for (size_t i = 0; i < k; i++) {
minheap_push(i + 1, heap_val, heap_ids, x[i], i);
}
for (size_t i = k; i < n; i++) {
if (x[i] > heap_val[0]) {
minheap_replace_top(k, heap_val, heap_ids, x[i], i);
}
}
minheap_reorder(k, heap_val, heap_ids);
}
void addn_func(
size_t n, size_t k,
const float *x, int64_t *heap_ids, float * heap_val)
{
minheap_heapify(k, heap_val, heap_ids);
minheap_addn(k, heap_val, heap_ids, x, nullptr, n);
minheap_reorder(k, heap_val, heap_ids);
}
int main() {
size_t n = 10 * 1000 * 1000;
std::vector<size_t> ks({20, 50, 100, 200, 500, 1000, 2000, 5000});
std::vector<float> x(n);
float_randn(x.data(), n, 12345);
int nrun = 100;
for(size_t k: ks) {
printf("benchmark with k=%zd n=%zd nrun=%d\n", k, n, nrun);
FAISS_THROW_IF_NOT(k < n);
double tot_t1 = 0, tot_t2 = 0, tot_t3 = 0;
#pragma omp parallel reduction(+: tot_t1, tot_t2, tot_t3)
{
std::vector<float> heap_dis(k);
std::vector<float> heap_dis_2(k);
std::vector<float> heap_dis_3(k);
std::vector<int64_t> heap_ids(k);
std::vector<int64_t> heap_ids_2(k);
std::vector<int64_t> heap_ids_3(k);
#pragma omp for
for (int run = 0; run < nrun; run++) {
double t0, t1, t2, t3;
t0 = getmillisecs();
// default implem
addn_default(n, k, x.data(), heap_ids.data(), heap_dis.data());
t1 = getmillisecs();
// new implem from Zilliz
addn_replace(n, k, x.data(), heap_ids_2.data(), heap_dis_2.data());
t2 = getmillisecs();
// with addn
addn_func(n, k, x.data(), heap_ids_3.data(), heap_dis_3.data());
t3 = getmillisecs();
tot_t1 += t1 - t0;
tot_t2 += t2 - t1;
tot_t3 += t3 - t2;
}
for (size_t i = 0; i < k; i++) {
FAISS_THROW_IF_NOT_FMT(
heap_ids[i] == heap_ids_2[i],
"i=%ld (%ld, %g) != (%ld, %g)",
i, size_t(heap_ids[i]), heap_dis[i],
size_t(heap_ids_2[i]), heap_dis_2[i]);
FAISS_THROW_IF_NOT(heap_dis[i] == heap_dis_2[i]);
}
for (size_t i = 0; i < k; i++) {
FAISS_THROW_IF_NOT(heap_ids[i] == heap_ids_3[i]);
FAISS_THROW_IF_NOT(heap_dis[i] == heap_dis_3[i]);
}
}
printf("default implem: %.3f ms\n", tot_t1 / nrun);
printf("replace implem: %.3f ms\n", tot_t2 / nrun);
printf("addn implem: %.3f ms\n", tot_t3 / nrun);
}
return 0;
}