diff --git a/benchs/bench_heap_replace.cpp b/benchs/bench_heap_replace.cpp new file mode 100644 index 000000000..2963f5f02 --- /dev/null +++ b/benchs/bench_heap_replace.cpp @@ -0,0 +1,136 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +#include +#include + +#include +#include +#include +#include + +using namespace faiss; + +void addn_default( + size_t n, size_t k, + const float *x, int64_t *heap_ids, float * heap_val) +{ + for (size_t i = 0; i < k; i++) { + minheap_push(i + 1, heap_val, heap_ids, x[i], i); + } + + for (size_t i = k; i < n; i++) { + if (x[i] > heap_val[0]) { + minheap_pop(k, heap_val, heap_ids); + minheap_push(k, heap_val, heap_ids, x[i], i); + } + } + + minheap_reorder(k, heap_val, heap_ids); +} + +void addn_replace( + size_t n, size_t k, + const float *x, int64_t *heap_ids, float * heap_val) +{ + for (size_t i = 0; i < k; i++) { + minheap_push(i + 1, heap_val, heap_ids, x[i], i); + } + + for (size_t i = k; i < n; i++) { + if (x[i] > heap_val[0]) { + minheap_replace_top(k, heap_val, heap_ids, x[i], i); + } + } + + minheap_reorder(k, heap_val, heap_ids); +} + +void addn_func( + size_t n, size_t k, + const float *x, int64_t *heap_ids, float * heap_val) +{ + minheap_heapify(k, heap_val, heap_ids); + + minheap_addn(k, heap_val, heap_ids, x, nullptr, n); + + minheap_reorder(k, heap_val, heap_ids); +} + + +int main() { + + size_t n = 10 * 1000 * 1000; + + std::vector ks({20, 50, 100, 200, 500, 1000, 2000, 5000}); + + std::vector x(n); + float_randn(x.data(), n, 12345); + + int nrun = 100; + for(size_t k: ks) { + printf("benchmark with k=%zd n=%zd nrun=%d\n", k, n, nrun); + FAISS_THROW_IF_NOT(k < n); + + double tot_t1 = 0, tot_t2 = 0, tot_t3 = 0; +#pragma omp parallel reduction(+: tot_t1, tot_t2, tot_t3) + { + + std::vector heap_dis(k); + std::vector heap_dis_2(k); + std::vector heap_dis_3(k); + + std::vector heap_ids(k); + std::vector heap_ids_2(k); + std::vector heap_ids_3(k); + +#pragma omp for + for (int run = 0; run < nrun; run++) { + + double t0, t1, t2, t3; + + t0 = getmillisecs(); + + // default implem + addn_default(n, k, x.data(), heap_ids.data(), heap_dis.data()); + t1 = getmillisecs(); + + // new implem from Zilliz + addn_replace(n, k, x.data(), heap_ids_2.data(), heap_dis_2.data()); + t2 = getmillisecs(); + + // with addn + addn_func(n, k, x.data(), heap_ids_3.data(), heap_dis_3.data()); + t3 = getmillisecs(); + + tot_t1 += t1 - t0; + tot_t2 += t2 - t1; + tot_t3 += t3 - t2; + } + + for (size_t i = 0; i < k; i++) { + FAISS_THROW_IF_NOT_FMT( + heap_ids[i] == heap_ids_2[i], + "i=%ld (%ld, %g) != (%ld, %g)", + i, size_t(heap_ids[i]), heap_dis[i], + size_t(heap_ids_2[i]), heap_dis_2[i]); + FAISS_THROW_IF_NOT(heap_dis[i] == heap_dis_2[i]); + } + + for (size_t i = 0; i < k; i++) { + FAISS_THROW_IF_NOT(heap_ids[i] == heap_ids_3[i]); + FAISS_THROW_IF_NOT(heap_dis[i] == heap_dis_3[i]); + } + } + printf("default implem: %.3f ms\n", tot_t1 / nrun); + printf("replace implem: %.3f ms\n", tot_t2 / nrun); + printf("addn implem: %.3f ms\n", tot_t3 / nrun); + + } + return 0; +} diff --git a/faiss/IndexBinaryHash.cpp b/faiss/IndexBinaryHash.cpp index a0dcf3aa3..691545359 100644 --- a/faiss/IndexBinaryHash.cpp +++ b/faiss/IndexBinaryHash.cpp @@ -136,8 +136,7 @@ struct KnnSearchResults { inline void add (float dis, idx_t id) { if (dis < heap_sim[0]) { - heap_pop (k, heap_sim, heap_ids); - heap_push (k, heap_sim, heap_ids, dis, id); + heap_replace_top (k, heap_sim, heap_ids, dis, id); } } diff --git a/faiss/IndexBinaryIVF.cpp b/faiss/IndexBinaryIVF.cpp index 612db9eb8..3486a0d87 100644 --- a/faiss/IndexBinaryIVF.cpp +++ b/faiss/IndexBinaryIVF.cpp @@ -319,9 +319,8 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner { for (size_t j = 0; j < n; j++) { uint32_t dis = hc.hamming (codes); if (dis < simi[0]) { - heap_pop (k, simi, idxi); idx_t id = store_pairs ? lo_build(list_no, j) : ids[j]; - heap_push (k, simi, idxi, dis, id); + heap_replace_top (k, simi, idxi, dis, id); nup++; } codes += code_size; diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 1a6edbdfa..7944c8178 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -1003,8 +1003,7 @@ int search_from_candidates_2(const HNSW & hnsw, if (nres < k) { faiss::maxheap_push (++nres, D, I, d, v1); } else if (d < D[0]) { - faiss::maxheap_pop (nres--, D, I); - faiss::maxheap_push (++nres, D, I, d, v1); + faiss::maxheap_replace_top (nres, D, I, d, v1); } } vt.visited[v1] = vt.visno + 1; diff --git a/faiss/IndexIVFFlat.cpp b/faiss/IndexIVFFlat.cpp index 80ecba593..bdd29db7c 100644 --- a/faiss/IndexIVFFlat.cpp +++ b/faiss/IndexIVFFlat.cpp @@ -159,9 +159,8 @@ struct IVFFlatScanner: InvertedListScanner { float dis = metric == METRIC_INNER_PRODUCT ? fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d); if (C::cmp (simi[0], dis)) { - heap_pop (k, simi, idxi); int64_t id = store_pairs ? lo_build (list_no, j) : ids[j]; - heap_push (k, simi, idxi, dis, id); + heap_replace_top (k, simi, idxi, dis, id); nup++; } } diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index a66455b49..cf2e0fc78 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -828,9 +828,8 @@ struct KnnSearchResults { inline void add (idx_t j, float dis) { if (C::cmp (heap_sim[0], dis)) { - heap_pop (k, heap_sim, heap_ids); idx_t id = ids ? ids[j] : lo_build (key, j); - heap_push (k, heap_sim, heap_ids, dis, id); + heap_replace_top (k, heap_sim, heap_ids, dis, id); nup++; } } diff --git a/faiss/IndexIVFPQR.cpp b/faiss/IndexIVFPQR.cpp index 04259a875..a87688397 100644 --- a/faiss/IndexIVFPQR.cpp +++ b/faiss/IndexIVFPQR.cpp @@ -172,9 +172,8 @@ void IndexIVFPQR::search_preassigned ( float dis = fvec_L2sqr (residual_1, residual_2, d); if (dis < heap_sim[0]) { - maxheap_pop (k, heap_sim, heap_ids); idx_t id_or_pair = store_pairs ? sl : id; - maxheap_push (k, heap_sim, heap_ids, dis, id_or_pair); + maxheap_replace_top (k, heap_sim, heap_ids, dis, id_or_pair); } n_refine ++; } diff --git a/faiss/IndexIVFSpectralHash.cpp b/faiss/IndexIVFSpectralHash.cpp index 9f42b98b5..6f9175296 100644 --- a/faiss/IndexIVFSpectralHash.cpp +++ b/faiss/IndexIVFSpectralHash.cpp @@ -269,9 +269,8 @@ struct IVFScanner: InvertedListScanner { float dis = hc.hamming (codes); if (dis < simi [0]) { - maxheap_pop (k, simi, idxi); int64_t id = store_pairs ? lo_build (list_no, j) : ids[j]; - maxheap_push (k, simi, idxi, dis, id); + maxheap_replace_top (k, simi, idxi, dis, id); nup++; } codes += code_size; diff --git a/faiss/IndexPQ.cpp b/faiss/IndexPQ.cpp index fdb57a6d9..4a1eb4a2f 100644 --- a/faiss/IndexPQ.cpp +++ b/faiss/IndexPQ.cpp @@ -346,8 +346,7 @@ static size_t polysemous_inner_loop ( } if (dis < heap_dis[0]) { - maxheap_pop (k, heap_dis, heap_ids); - maxheap_push (k, heap_dis, heap_ids, dis, bi); + maxheap_replace_top (k, heap_dis, heap_ids, dis, bi); } } b_code += code_size; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index b92aaba3f..6c9bcf009 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -539,8 +539,7 @@ int HNSW::search_from_candidates( if (nres < k) { faiss::maxheap_push(++nres, D, I, d, v1); } else if (d < D[0]) { - faiss::maxheap_pop(nres--, D, I); - faiss::maxheap_push(++nres, D, I, d, v1); + faiss::maxheap_replace_top(nres, D, I, d, v1); } vt.set(v1); } @@ -578,8 +577,7 @@ int HNSW::search_from_candidates( if (nres < k) { faiss::maxheap_push(++nres, D, I, d, v1); } else if (d < D[0]) { - faiss::maxheap_pop(nres--, D, I); - faiss::maxheap_push(++nres, D, I, d, v1); + faiss::maxheap_replace_top(nres, D, I, d, v1); } candidates.push(v1, d); } diff --git a/faiss/impl/ProductQuantizer.cpp b/faiss/impl/ProductQuantizer.cpp index 692cc1fb6..c66fa8193 100644 --- a/faiss/impl/ProductQuantizer.cpp +++ b/faiss/impl/ProductQuantizer.cpp @@ -63,8 +63,7 @@ void pq_estimators_from_tables_Mmul4 (int M, const CT * codes, } if (C::cmp (heap_dis[0], dis)) { - heap_pop (k, heap_dis, heap_ids); - heap_push (k, heap_dis, heap_ids, dis, j); + heap_replace_top (k, heap_dis, heap_ids, dis, j); } } } @@ -89,8 +88,7 @@ void pq_estimators_from_tables_M4 (const CT * codes, dis += dt[*codes++]; if (C::cmp (heap_dis[0], dis)) { - heap_pop (k, heap_dis, heap_ids); - heap_push (k, heap_dis, heap_ids, dis, j); + heap_replace_top (k, heap_dis, heap_ids, dis, j); } } } @@ -132,8 +130,7 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq, dt += ksub; } if (C::cmp (heap_dis[0], dis)) { - heap_pop (k, heap_dis, heap_ids); - heap_push (k, heap_dis, heap_ids, dis, j); + heap_replace_top (k, heap_dis, heap_ids, dis, j); } } } @@ -163,8 +160,7 @@ static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq, } if (C::cmp(heap_dis[0], dis)) { - heap_pop(k, heap_dis, heap_ids); - heap_push(k, heap_dis, heap_ids, dis, j); + heap_replace_top(k, heap_dis, heap_ids, dis, j); } } } @@ -762,8 +758,7 @@ void ProductQuantizer::search_sdc (const uint8_t * qcodes, tab += ksub * ksub; } if (dis < heap_dis[0]) { - maxheap_pop (k, heap_dis, heap_ids); - maxheap_push (k, heap_dis, heap_ids, dis, j); + maxheap_replace_top (k, heap_dis, heap_ids, dis, j); } bcode += code_size; } diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index a3765ce6a..e87dc364c 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -75,8 +75,7 @@ struct HeapResultHandler { /// add one result for query i void add_result(T dis, TI idx) { if (C::cmp(heap_dis[0], dis)) { - heap_pop(k, heap_dis, heap_ids); - heap_push(k, heap_dis, heap_ids, dis, idx); + heap_replace_top(k, heap_dis, heap_ids, dis, idx); thresh = heap_dis[0]; } } @@ -113,8 +112,7 @@ struct HeapResultHandler { for (size_t j = j0; j < j1; j++) { T dis = *dis_tab++; if (C::cmp(thresh, dis)) { - heap_pop(k, heap_dis, heap_ids); - heap_push(k, heap_dis, heap_ids, dis, j); + heap_replace_top(k, heap_dis, heap_ids, dis, j); thresh = heap_dis[0]; } } diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 704f59738..6a6ca3b5a 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -1430,9 +1430,8 @@ struct IVFSQScannerIP: InvertedListScanner { float accu = accu0 + dc.query_to_code (codes); if (accu > simi [0]) { - minheap_pop (k, simi, idxi); int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - minheap_push (k, simi, idxi, accu, id); + minheap_replace_top (k, simi, idxi, accu, id); nup++; } codes += code_size; @@ -1518,9 +1517,8 @@ struct IVFSQScannerL2: InvertedListScanner { float dis = dc.query_to_code (codes); if (dis < simi [0]) { - maxheap_pop (k, simi, idxi); int64_t id = store_pairs ? (list_no << 32 | j) : ids[j]; - maxheap_push (k, simi, idxi, dis, id); + maxheap_replace_top (k, simi, idxi, dis, id); nup++; } codes += code_size; diff --git a/faiss/utils/Heap.cpp b/faiss/utils/Heap.cpp index b281c14ee..8e70891a8 100644 --- a/faiss/utils/Heap.cpp +++ b/faiss/utils/Heap.cpp @@ -46,8 +46,7 @@ void HeapArray::addn (size_t nj, const T *vin, TI j0, for (size_t j = 0; j < nj; j++) { T ip = ip_line [j]; if (C::cmp(simi[0], ip)) { - heap_pop (k, simi, idxi); - heap_push (k, simi, idxi, ip, j + j0); + heap_replace_top (k, simi, idxi, ip, j + j0); } } } @@ -74,8 +73,7 @@ void HeapArray::addn_with_ids ( for (size_t j = 0; j < nj; j++) { T ip = ip_line [j]; if (C::cmp(simi[0], ip)) { - heap_pop (k, simi, idxi); - heap_push (k, simi, idxi, ip, id_line [j]); + heap_replace_top (k, simi, idxi, ip, id_line [j]); } } } diff --git a/faiss/utils/Heap.h b/faiss/utils/Heap.h index 3cef9532f..7f90567c7 100644 --- a/faiss/utils/Heap.h +++ b/faiss/utils/Heap.h @@ -105,6 +105,43 @@ void heap_push (size_t k, +/** Replace the top element from the heap defined by bh_val[0..k-1] and + * bh_ids[0..k-1]. + */ +template inline +void heap_replace_top (size_t k, + typename C::T * bh_val, typename C::TI * bh_ids, + typename C::T val, typename C::TI ids) +{ + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + size_t i = 1, i1, i2; + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) + break; + if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) { + if (C::cmp(val, bh_val[i1])) + break; + bh_val[i] = bh_val[i1]; + bh_ids[i] = bh_ids[i1]; + i = i1; + } + else { + if (C::cmp(val, bh_val[i2])) + break; + bh_val[i] = bh_val[i2]; + bh_ids[i] = bh_ids[i2]; + i = i2; + } + } + bh_val[i] = val; + bh_ids[i] = ids; +} + + + /* Partial instanciation for heaps with TI = int64_t */ template inline @@ -121,6 +158,13 @@ void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) } +template inline +void minheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_replace_top > (k, bh_val, bh_ids, val, ids); +} + + template inline void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids) { @@ -135,6 +179,12 @@ void maxheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) } +template inline +void maxheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_replace_top > (k, bh_val, bh_ids, val, ids); +} + /******************************************************************* * Heap initialization @@ -212,15 +262,13 @@ void heap_addn (size_t k, if (ids) for (i = 0; i < n; i++) { if (C::cmp (bh_val[0], x[i])) { - heap_pop (k, bh_val, bh_ids); - heap_push (k, bh_val, bh_ids, x[i], ids[i]); + heap_replace_top (k, bh_val, bh_ids, x[i], ids[i]); } } else for (i = 0; i < n; i++) { if (C::cmp (bh_val[0], x[i])) { - heap_pop (k, bh_val, bh_ids); - heap_push (k, bh_val, bh_ids, x[i], i); + heap_replace_top (k, bh_val, bh_ids, x[i], i); } } } diff --git a/faiss/utils/distances.cpp b/faiss/utils/distances.cpp index 21e78037d..1b88b2b34 100644 --- a/faiss/utils/distances.cpp +++ b/faiss/utils/distances.cpp @@ -522,8 +522,7 @@ void knn_inner_products_by_idx (const float * x, float ip = fvec_inner_product (x_, y + d * idsi[j], d); if (ip > simi[0]) { - minheap_pop (k, simi, idxi); - minheap_push (k, simi, idxi, ip, idsi[j]); + minheap_replace_top (k, simi, idxi, ip, idsi[j]); } } minheap_reorder (k, simi, idxi); @@ -550,8 +549,7 @@ void knn_L2sqr_by_idx (const float * x, float disij = fvec_L2sqr (x_, y + d * idsi[j], d); if (disij < simi[0]) { - maxheap_pop (k, simi, idxi); - maxheap_push (k, simi, idxi, disij, idsi[j]); + maxheap_replace_top (k, simi, idxi, disij, idsi[j]); } } maxheap_reorder (res->k, simi, idxi); diff --git a/faiss/utils/extra_distances.cpp b/faiss/utils/extra_distances.cpp index 5e5e23d22..295dc47e2 100644 --- a/faiss/utils/extra_distances.cpp +++ b/faiss/utils/extra_distances.cpp @@ -176,8 +176,7 @@ void knn_extra_metrics_template ( float disij = vd (x_i, y_j); if (disij < simi[0]) { - maxheap_pop (k, simi, idxi); - maxheap_push (k, simi, idxi, disij, j); + maxheap_replace_top (k, simi, idxi, disij, j); } y_j += d; } diff --git a/faiss/utils/hamming.cpp b/faiss/utils/hamming.cpp index bca8a895e..44c5147a5 100644 --- a/faiss/utils/hamming.cpp +++ b/faiss/utils/hamming.cpp @@ -292,8 +292,7 @@ void hammings_knn_hc ( for (j = j0; j < j1; j++, bs2_+= bytes_per_code) { dis = hc.hamming (bs2_); if (dis < bh_val_[0]) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, dis, j); + faiss::maxheap_replace_top (k, bh_val_, bh_ids_, dis, j); } } } @@ -391,8 +390,7 @@ void hammings_knn_hc_1 ( for (j = 0; j < n2; j++, bs2_+= nwords) { dis = popcount64 (bs1_ ^ *bs2_); if (dis < bh_val_0) { - faiss::maxheap_pop (k, bh_val_, bh_ids_); - faiss::maxheap_push (k, bh_val_, bh_ids_, dis, j); + faiss::maxheap_replace_top (k, bh_val_, bh_ids_, dis, j); bh_val_0 = bh_val_[0]; } } @@ -818,8 +816,7 @@ static void hamming_dis_inner_loop ( int ndiff = hc.hamming (cb); cb += code_size; if (ndiff < bh_val_[0]) { - maxheap_pop (k, bh_val_, bh_ids_); - maxheap_push (k, bh_val_, bh_ids_, ndiff, j); + maxheap_replace_top (k, bh_val_, bh_ids_, ndiff, j); } } } diff --git a/tests/test_index_accuracy.py b/tests/test_index_accuracy.py index 59388d412..5f3b557ec 100644 --- a/tests/test_index_accuracy.py +++ b/tests/test_index_accuracy.py @@ -683,7 +683,7 @@ class TestSpectralHash(unittest.TestCase): key = (nbit, tt, period) print('(%d, %s, %g): %d, ' % (nbit, repr(tt), period, ninter)) - assert abs(ninter - self.ref_results[key]) <= 4 + assert abs(ninter - self.ref_results[key]) <= 12 class TestRefine(unittest.TestCase): diff --git a/tests/test_index_binary.py b/tests/test_index_binary.py index 8b4294a43..5227d2dde 100644 --- a/tests/test_index_binary.py +++ b/tests/test_index_binary.py @@ -183,7 +183,9 @@ class TestBinaryIVF(unittest.TestCase): index.add(self.xb) Divfflat, _ = index.search(self.xq, 10) - self.assertEqual((self.Dref == Divfflat).sum(), 4122) + # Some centroids are equidistant from the query points. + # So the answer will depend on the implementation of the heap. + self.assertGreater((self.Dref == Divfflat).sum(), 4100) def test_ivf_range(self): d = self.xq.shape[1] * 8