Add heap_replace_top to simplify heap_pop + heap_push (#1597)

Summary:
Signed-off-by: shengjun.li <shengjun.li@zilliz.com>

Add heap_replace_top to simplify heap_pop + heap_push

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1597

Test Plan:
OMP_NUM_THREADS=1 buck run mode/opt //faiss/benchs/:bench_heap_replace
OMP_NUM_THREADS=8 buck run mode/opt //faiss/benchs/:bench_heap_replace

Reviewed By: beauby

Differential Revision: D25943140

Pulled By: mdouze

fbshipit-source-id: 66fe67779dd281a7753f597542c2e797ba0d7df5
pull/1647/head
shengjun.li 2021-01-20 11:26:13 -08:00 committed by Facebook GitHub Bot
parent a2791322d9
commit 908812266c
20 changed files with 219 additions and 60 deletions

View File

@ -0,0 +1,136 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstdio>
#include <omp.h>
#include <faiss/utils/utils.h>
#include <faiss/utils/random.h>
#include <faiss/utils/Heap.h>
#include <faiss/impl/FaissAssert.h>
using namespace faiss;
void addn_default(
size_t n, size_t k,
const float *x, int64_t *heap_ids, float * heap_val)
{
for (size_t i = 0; i < k; i++) {
minheap_push(i + 1, heap_val, heap_ids, x[i], i);
}
for (size_t i = k; i < n; i++) {
if (x[i] > heap_val[0]) {
minheap_pop(k, heap_val, heap_ids);
minheap_push(k, heap_val, heap_ids, x[i], i);
}
}
minheap_reorder(k, heap_val, heap_ids);
}
void addn_replace(
size_t n, size_t k,
const float *x, int64_t *heap_ids, float * heap_val)
{
for (size_t i = 0; i < k; i++) {
minheap_push(i + 1, heap_val, heap_ids, x[i], i);
}
for (size_t i = k; i < n; i++) {
if (x[i] > heap_val[0]) {
minheap_replace_top(k, heap_val, heap_ids, x[i], i);
}
}
minheap_reorder(k, heap_val, heap_ids);
}
void addn_func(
size_t n, size_t k,
const float *x, int64_t *heap_ids, float * heap_val)
{
minheap_heapify(k, heap_val, heap_ids);
minheap_addn(k, heap_val, heap_ids, x, nullptr, n);
minheap_reorder(k, heap_val, heap_ids);
}
int main() {
size_t n = 10 * 1000 * 1000;
std::vector<size_t> ks({20, 50, 100, 200, 500, 1000, 2000, 5000});
std::vector<float> x(n);
float_randn(x.data(), n, 12345);
int nrun = 100;
for(size_t k: ks) {
printf("benchmark with k=%zd n=%zd nrun=%d\n", k, n, nrun);
FAISS_THROW_IF_NOT(k < n);
double tot_t1 = 0, tot_t2 = 0, tot_t3 = 0;
#pragma omp parallel reduction(+: tot_t1, tot_t2, tot_t3)
{
std::vector<float> heap_dis(k);
std::vector<float> heap_dis_2(k);
std::vector<float> heap_dis_3(k);
std::vector<int64_t> heap_ids(k);
std::vector<int64_t> heap_ids_2(k);
std::vector<int64_t> heap_ids_3(k);
#pragma omp for
for (int run = 0; run < nrun; run++) {
double t0, t1, t2, t3;
t0 = getmillisecs();
// default implem
addn_default(n, k, x.data(), heap_ids.data(), heap_dis.data());
t1 = getmillisecs();
// new implem from Zilliz
addn_replace(n, k, x.data(), heap_ids_2.data(), heap_dis_2.data());
t2 = getmillisecs();
// with addn
addn_func(n, k, x.data(), heap_ids_3.data(), heap_dis_3.data());
t3 = getmillisecs();
tot_t1 += t1 - t0;
tot_t2 += t2 - t1;
tot_t3 += t3 - t2;
}
for (size_t i = 0; i < k; i++) {
FAISS_THROW_IF_NOT_FMT(
heap_ids[i] == heap_ids_2[i],
"i=%ld (%ld, %g) != (%ld, %g)",
i, size_t(heap_ids[i]), heap_dis[i],
size_t(heap_ids_2[i]), heap_dis_2[i]);
FAISS_THROW_IF_NOT(heap_dis[i] == heap_dis_2[i]);
}
for (size_t i = 0; i < k; i++) {
FAISS_THROW_IF_NOT(heap_ids[i] == heap_ids_3[i]);
FAISS_THROW_IF_NOT(heap_dis[i] == heap_dis_3[i]);
}
}
printf("default implem: %.3f ms\n", tot_t1 / nrun);
printf("replace implem: %.3f ms\n", tot_t2 / nrun);
printf("addn implem: %.3f ms\n", tot_t3 / nrun);
}
return 0;
}

View File

@ -136,8 +136,7 @@ struct KnnSearchResults {
inline void add (float dis, idx_t id) {
if (dis < heap_sim[0]) {
heap_pop<C> (k, heap_sim, heap_ids);
heap_push<C> (k, heap_sim, heap_ids, dis, id);
heap_replace_top<C> (k, heap_sim, heap_ids, dis, id);
}
}

View File

@ -319,9 +319,8 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
for (size_t j = 0; j < n; j++) {
uint32_t dis = hc.hamming (codes);
if (dis < simi[0]) {
heap_pop<C> (k, simi, idxi);
idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
heap_push<C> (k, simi, idxi, dis, id);
heap_replace_top<C> (k, simi, idxi, dis, id);
nup++;
}
codes += code_size;

View File

@ -1003,8 +1003,7 @@ int search_from_candidates_2(const HNSW & hnsw,
if (nres < k) {
faiss::maxheap_push (++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_pop (nres--, D, I);
faiss::maxheap_push (++nres, D, I, d, v1);
faiss::maxheap_replace_top (nres, D, I, d, v1);
}
}
vt.visited[v1] = vt.visno + 1;

View File

@ -159,9 +159,8 @@ struct IVFFlatScanner: InvertedListScanner {
float dis = metric == METRIC_INNER_PRODUCT ?
fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
if (C::cmp (simi[0], dis)) {
heap_pop<C> (k, simi, idxi);
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
heap_push<C> (k, simi, idxi, dis, id);
heap_replace_top<C> (k, simi, idxi, dis, id);
nup++;
}
}

View File

@ -828,9 +828,8 @@ struct KnnSearchResults {
inline void add (idx_t j, float dis) {
if (C::cmp (heap_sim[0], dis)) {
heap_pop<C> (k, heap_sim, heap_ids);
idx_t id = ids ? ids[j] : lo_build (key, j);
heap_push<C> (k, heap_sim, heap_ids, dis, id);
heap_replace_top<C> (k, heap_sim, heap_ids, dis, id);
nup++;
}
}

View File

@ -172,9 +172,8 @@ void IndexIVFPQR::search_preassigned (
float dis = fvec_L2sqr (residual_1, residual_2, d);
if (dis < heap_sim[0]) {
maxheap_pop (k, heap_sim, heap_ids);
idx_t id_or_pair = store_pairs ? sl : id;
maxheap_push (k, heap_sim, heap_ids, dis, id_or_pair);
maxheap_replace_top (k, heap_sim, heap_ids, dis, id_or_pair);
}
n_refine ++;
}

View File

@ -269,9 +269,8 @@ struct IVFScanner: InvertedListScanner {
float dis = hc.hamming (codes);
if (dis < simi [0]) {
maxheap_pop (k, simi, idxi);
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
maxheap_push (k, simi, idxi, dis, id);
maxheap_replace_top (k, simi, idxi, dis, id);
nup++;
}
codes += code_size;

View File

@ -346,8 +346,7 @@ static size_t polysemous_inner_loop (
}
if (dis < heap_dis[0]) {
maxheap_pop (k, heap_dis, heap_ids);
maxheap_push (k, heap_dis, heap_ids, dis, bi);
maxheap_replace_top (k, heap_dis, heap_ids, dis, bi);
}
}
b_code += code_size;

View File

@ -539,8 +539,7 @@ int HNSW::search_from_candidates(
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_pop(nres--, D, I);
faiss::maxheap_push(++nres, D, I, d, v1);
faiss::maxheap_replace_top(nres, D, I, d, v1);
}
vt.set(v1);
}
@ -578,8 +577,7 @@ int HNSW::search_from_candidates(
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_pop(nres--, D, I);
faiss::maxheap_push(++nres, D, I, d, v1);
faiss::maxheap_replace_top(nres, D, I, d, v1);
}
candidates.push(v1, d);
}

View File

@ -63,8 +63,7 @@ void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
}
if (C::cmp (heap_dis[0], dis)) {
heap_pop<C> (k, heap_dis, heap_ids);
heap_push<C> (k, heap_dis, heap_ids, dis, j);
heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
}
}
}
@ -89,8 +88,7 @@ void pq_estimators_from_tables_M4 (const CT * codes,
dis += dt[*codes++];
if (C::cmp (heap_dis[0], dis)) {
heap_pop<C> (k, heap_dis, heap_ids);
heap_push<C> (k, heap_dis, heap_ids, dis, j);
heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
}
}
}
@ -132,8 +130,7 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
dt += ksub;
}
if (C::cmp (heap_dis[0], dis)) {
heap_pop<C> (k, heap_dis, heap_ids);
heap_push<C> (k, heap_dis, heap_ids, dis, j);
heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
}
}
}
@ -163,8 +160,7 @@ static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
}
if (C::cmp(heap_dis[0], dis)) {
heap_pop<C>(k, heap_dis, heap_ids);
heap_push<C>(k, heap_dis, heap_ids, dis, j);
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
}
}
}
@ -762,8 +758,7 @@ void ProductQuantizer::search_sdc (const uint8_t * qcodes,
tab += ksub * ksub;
}
if (dis < heap_dis[0]) {
maxheap_pop (k, heap_dis, heap_ids);
maxheap_push (k, heap_dis, heap_ids, dis, j);
maxheap_replace_top (k, heap_dis, heap_ids, dis, j);
}
bcode += code_size;
}

View File

@ -75,8 +75,7 @@ struct HeapResultHandler {
/// add one result for query i
void add_result(T dis, TI idx) {
if (C::cmp(heap_dis[0], dis)) {
heap_pop<C>(k, heap_dis, heap_ids);
heap_push<C>(k, heap_dis, heap_ids, dis, idx);
heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
thresh = heap_dis[0];
}
}
@ -113,8 +112,7 @@ struct HeapResultHandler {
for (size_t j = j0; j < j1; j++) {
T dis = *dis_tab++;
if (C::cmp(thresh, dis)) {
heap_pop<C>(k, heap_dis, heap_ids);
heap_push<C>(k, heap_dis, heap_ids, dis, j);
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
thresh = heap_dis[0];
}
}

View File

@ -1430,9 +1430,8 @@ struct IVFSQScannerIP: InvertedListScanner {
float accu = accu0 + dc.query_to_code (codes);
if (accu > simi [0]) {
minheap_pop (k, simi, idxi);
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
minheap_push (k, simi, idxi, accu, id);
minheap_replace_top (k, simi, idxi, accu, id);
nup++;
}
codes += code_size;
@ -1518,9 +1517,8 @@ struct IVFSQScannerL2: InvertedListScanner {
float dis = dc.query_to_code (codes);
if (dis < simi [0]) {
maxheap_pop (k, simi, idxi);
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
maxheap_push (k, simi, idxi, dis, id);
maxheap_replace_top (k, simi, idxi, dis, id);
nup++;
}
codes += code_size;

View File

@ -46,8 +46,7 @@ void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
for (size_t j = 0; j < nj; j++) {
T ip = ip_line [j];
if (C::cmp(simi[0], ip)) {
heap_pop<C> (k, simi, idxi);
heap_push<C> (k, simi, idxi, ip, j + j0);
heap_replace_top<C> (k, simi, idxi, ip, j + j0);
}
}
}
@ -74,8 +73,7 @@ void HeapArray<C>::addn_with_ids (
for (size_t j = 0; j < nj; j++) {
T ip = ip_line [j];
if (C::cmp(simi[0], ip)) {
heap_pop<C> (k, simi, idxi);
heap_push<C> (k, simi, idxi, ip, id_line [j]);
heap_replace_top<C> (k, simi, idxi, ip, id_line [j]);
}
}
}

View File

@ -105,6 +105,43 @@ void heap_push (size_t k,
/** Replace the top element from the heap defined by bh_val[0..k-1] and
* bh_ids[0..k-1].
*/
template <class C> inline
void heap_replace_top (size_t k,
typename C::T * bh_val, typename C::TI * bh_ids,
typename C::T val, typename C::TI ids)
{
bh_val--; /* Use 1-based indexing for easier node->child translation */
bh_ids--;
size_t i = 1, i1, i2;
while (1) {
i1 = i << 1;
i2 = i1 + 1;
if (i1 > k)
break;
if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) {
if (C::cmp(val, bh_val[i1]))
break;
bh_val[i] = bh_val[i1];
bh_ids[i] = bh_ids[i1];
i = i1;
}
else {
if (C::cmp(val, bh_val[i2]))
break;
bh_val[i] = bh_val[i2];
bh_ids[i] = bh_ids[i2];
i = i2;
}
}
bh_val[i] = val;
bh_ids[i] = ids;
}
/* Partial instanciation for heaps with TI = int64_t */
template <typename T> inline
@ -121,6 +158,13 @@ void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
}
template <typename T> inline
void minheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
{
heap_replace_top<CMin<T, int64_t> > (k, bh_val, bh_ids, val, ids);
}
template <typename T> inline
void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids)
{
@ -135,6 +179,12 @@ void maxheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
}
template <typename T> inline
void maxheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
{
heap_replace_top<CMax<T, int64_t> > (k, bh_val, bh_ids, val, ids);
}
/*******************************************************************
* Heap initialization
@ -212,15 +262,13 @@ void heap_addn (size_t k,
if (ids)
for (i = 0; i < n; i++) {
if (C::cmp (bh_val[0], x[i])) {
heap_pop<C> (k, bh_val, bh_ids);
heap_push<C> (k, bh_val, bh_ids, x[i], ids[i]);
heap_replace_top<C> (k, bh_val, bh_ids, x[i], ids[i]);
}
}
else
for (i = 0; i < n; i++) {
if (C::cmp (bh_val[0], x[i])) {
heap_pop<C> (k, bh_val, bh_ids);
heap_push<C> (k, bh_val, bh_ids, x[i], i);
heap_replace_top<C> (k, bh_val, bh_ids, x[i], i);
}
}
}

View File

@ -522,8 +522,7 @@ void knn_inner_products_by_idx (const float * x,
float ip = fvec_inner_product (x_, y + d * idsi[j], d);
if (ip > simi[0]) {
minheap_pop (k, simi, idxi);
minheap_push (k, simi, idxi, ip, idsi[j]);
minheap_replace_top (k, simi, idxi, ip, idsi[j]);
}
}
minheap_reorder (k, simi, idxi);
@ -550,8 +549,7 @@ void knn_L2sqr_by_idx (const float * x,
float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
if (disij < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, disij, idsi[j]);
maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
}
}
maxheap_reorder (res->k, simi, idxi);

View File

@ -176,8 +176,7 @@ void knn_extra_metrics_template (
float disij = vd (x_i, y_j);
if (disij < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, disij, j);
maxheap_replace_top (k, simi, idxi, disij, j);
}
y_j += d;
}

View File

@ -292,8 +292,7 @@ void hammings_knn_hc (
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
dis = hc.hamming (bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
faiss::maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
}
}
}
@ -391,8 +390,7 @@ void hammings_knn_hc_1 (
for (j = 0; j < n2; j++, bs2_+= nwords) {
dis = popcount64 (bs1_ ^ *bs2_);
if (dis < bh_val_0) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
faiss::maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
bh_val_0 = bh_val_[0];
}
}
@ -818,8 +816,7 @@ static void hamming_dis_inner_loop (
int ndiff = hc.hamming (cb);
cb += code_size;
if (ndiff < bh_val_[0]) {
maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
}
}
}

View File

@ -683,7 +683,7 @@ class TestSpectralHash(unittest.TestCase):
key = (nbit, tt, period)
print('(%d, %s, %g): %d, ' % (nbit, repr(tt), period, ninter))
assert abs(ninter - self.ref_results[key]) <= 4
assert abs(ninter - self.ref_results[key]) <= 12
class TestRefine(unittest.TestCase):

View File

@ -183,7 +183,9 @@ class TestBinaryIVF(unittest.TestCase):
index.add(self.xb)
Divfflat, _ = index.search(self.xq, 10)
self.assertEqual((self.Dref == Divfflat).sum(), 4122)
# Some centroids are equidistant from the query points.
# So the answer will depend on the implementation of the heap.
self.assertGreater((self.Dref == Divfflat).sum(), 4100)
def test_ivf_range(self):
d = self.xq.shape[1] * 8