Add heap_replace_top to simplify heap_pop + heap_push (#1597)
Summary: Signed-off-by: shengjun.li <shengjun.li@zilliz.com> Add heap_replace_top to simplify heap_pop + heap_push Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1597 Test Plan: OMP_NUM_THREADS=1 buck run mode/opt //faiss/benchs/:bench_heap_replace OMP_NUM_THREADS=8 buck run mode/opt //faiss/benchs/:bench_heap_replace Reviewed By: beauby Differential Revision: D25943140 Pulled By: mdouze fbshipit-source-id: 66fe67779dd281a7753f597542c2e797ba0d7df5pull/1647/head
parent
a2791322d9
commit
908812266c
|
@ -0,0 +1,136 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
|
||||
#include <cstdio>
|
||||
#include <omp.h>
|
||||
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/utils/random.h>
|
||||
#include <faiss/utils/Heap.h>
|
||||
#include <faiss/impl/FaissAssert.h>
|
||||
|
||||
using namespace faiss;
|
||||
|
||||
void addn_default(
|
||||
size_t n, size_t k,
|
||||
const float *x, int64_t *heap_ids, float * heap_val)
|
||||
{
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
minheap_push(i + 1, heap_val, heap_ids, x[i], i);
|
||||
}
|
||||
|
||||
for (size_t i = k; i < n; i++) {
|
||||
if (x[i] > heap_val[0]) {
|
||||
minheap_pop(k, heap_val, heap_ids);
|
||||
minheap_push(k, heap_val, heap_ids, x[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
minheap_reorder(k, heap_val, heap_ids);
|
||||
}
|
||||
|
||||
void addn_replace(
|
||||
size_t n, size_t k,
|
||||
const float *x, int64_t *heap_ids, float * heap_val)
|
||||
{
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
minheap_push(i + 1, heap_val, heap_ids, x[i], i);
|
||||
}
|
||||
|
||||
for (size_t i = k; i < n; i++) {
|
||||
if (x[i] > heap_val[0]) {
|
||||
minheap_replace_top(k, heap_val, heap_ids, x[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
minheap_reorder(k, heap_val, heap_ids);
|
||||
}
|
||||
|
||||
void addn_func(
|
||||
size_t n, size_t k,
|
||||
const float *x, int64_t *heap_ids, float * heap_val)
|
||||
{
|
||||
minheap_heapify(k, heap_val, heap_ids);
|
||||
|
||||
minheap_addn(k, heap_val, heap_ids, x, nullptr, n);
|
||||
|
||||
minheap_reorder(k, heap_val, heap_ids);
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
|
||||
size_t n = 10 * 1000 * 1000;
|
||||
|
||||
std::vector<size_t> ks({20, 50, 100, 200, 500, 1000, 2000, 5000});
|
||||
|
||||
std::vector<float> x(n);
|
||||
float_randn(x.data(), n, 12345);
|
||||
|
||||
int nrun = 100;
|
||||
for(size_t k: ks) {
|
||||
printf("benchmark with k=%zd n=%zd nrun=%d\n", k, n, nrun);
|
||||
FAISS_THROW_IF_NOT(k < n);
|
||||
|
||||
double tot_t1 = 0, tot_t2 = 0, tot_t3 = 0;
|
||||
#pragma omp parallel reduction(+: tot_t1, tot_t2, tot_t3)
|
||||
{
|
||||
|
||||
std::vector<float> heap_dis(k);
|
||||
std::vector<float> heap_dis_2(k);
|
||||
std::vector<float> heap_dis_3(k);
|
||||
|
||||
std::vector<int64_t> heap_ids(k);
|
||||
std::vector<int64_t> heap_ids_2(k);
|
||||
std::vector<int64_t> heap_ids_3(k);
|
||||
|
||||
#pragma omp for
|
||||
for (int run = 0; run < nrun; run++) {
|
||||
|
||||
double t0, t1, t2, t3;
|
||||
|
||||
t0 = getmillisecs();
|
||||
|
||||
// default implem
|
||||
addn_default(n, k, x.data(), heap_ids.data(), heap_dis.data());
|
||||
t1 = getmillisecs();
|
||||
|
||||
// new implem from Zilliz
|
||||
addn_replace(n, k, x.data(), heap_ids_2.data(), heap_dis_2.data());
|
||||
t2 = getmillisecs();
|
||||
|
||||
// with addn
|
||||
addn_func(n, k, x.data(), heap_ids_3.data(), heap_dis_3.data());
|
||||
t3 = getmillisecs();
|
||||
|
||||
tot_t1 += t1 - t0;
|
||||
tot_t2 += t2 - t1;
|
||||
tot_t3 += t3 - t2;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
FAISS_THROW_IF_NOT_FMT(
|
||||
heap_ids[i] == heap_ids_2[i],
|
||||
"i=%ld (%ld, %g) != (%ld, %g)",
|
||||
i, size_t(heap_ids[i]), heap_dis[i],
|
||||
size_t(heap_ids_2[i]), heap_dis_2[i]);
|
||||
FAISS_THROW_IF_NOT(heap_dis[i] == heap_dis_2[i]);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < k; i++) {
|
||||
FAISS_THROW_IF_NOT(heap_ids[i] == heap_ids_3[i]);
|
||||
FAISS_THROW_IF_NOT(heap_dis[i] == heap_dis_3[i]);
|
||||
}
|
||||
}
|
||||
printf("default implem: %.3f ms\n", tot_t1 / nrun);
|
||||
printf("replace implem: %.3f ms\n", tot_t2 / nrun);
|
||||
printf("addn implem: %.3f ms\n", tot_t3 / nrun);
|
||||
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -136,8 +136,7 @@ struct KnnSearchResults {
|
|||
|
||||
inline void add (float dis, idx_t id) {
|
||||
if (dis < heap_sim[0]) {
|
||||
heap_pop<C> (k, heap_sim, heap_ids);
|
||||
heap_push<C> (k, heap_sim, heap_ids, dis, id);
|
||||
heap_replace_top<C> (k, heap_sim, heap_ids, dis, id);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -319,9 +319,8 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|||
for (size_t j = 0; j < n; j++) {
|
||||
uint32_t dis = hc.hamming (codes);
|
||||
if (dis < simi[0]) {
|
||||
heap_pop<C> (k, simi, idxi);
|
||||
idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
||||
heap_push<C> (k, simi, idxi, dis, id);
|
||||
heap_replace_top<C> (k, simi, idxi, dis, id);
|
||||
nup++;
|
||||
}
|
||||
codes += code_size;
|
||||
|
|
|
@ -1003,8 +1003,7 @@ int search_from_candidates_2(const HNSW & hnsw,
|
|||
if (nres < k) {
|
||||
faiss::maxheap_push (++nres, D, I, d, v1);
|
||||
} else if (d < D[0]) {
|
||||
faiss::maxheap_pop (nres--, D, I);
|
||||
faiss::maxheap_push (++nres, D, I, d, v1);
|
||||
faiss::maxheap_replace_top (nres, D, I, d, v1);
|
||||
}
|
||||
}
|
||||
vt.visited[v1] = vt.visno + 1;
|
||||
|
|
|
@ -159,9 +159,8 @@ struct IVFFlatScanner: InvertedListScanner {
|
|||
float dis = metric == METRIC_INNER_PRODUCT ?
|
||||
fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
|
||||
if (C::cmp (simi[0], dis)) {
|
||||
heap_pop<C> (k, simi, idxi);
|
||||
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
||||
heap_push<C> (k, simi, idxi, dis, id);
|
||||
heap_replace_top<C> (k, simi, idxi, dis, id);
|
||||
nup++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -828,9 +828,8 @@ struct KnnSearchResults {
|
|||
|
||||
inline void add (idx_t j, float dis) {
|
||||
if (C::cmp (heap_sim[0], dis)) {
|
||||
heap_pop<C> (k, heap_sim, heap_ids);
|
||||
idx_t id = ids ? ids[j] : lo_build (key, j);
|
||||
heap_push<C> (k, heap_sim, heap_ids, dis, id);
|
||||
heap_replace_top<C> (k, heap_sim, heap_ids, dis, id);
|
||||
nup++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -172,9 +172,8 @@ void IndexIVFPQR::search_preassigned (
|
|||
float dis = fvec_L2sqr (residual_1, residual_2, d);
|
||||
|
||||
if (dis < heap_sim[0]) {
|
||||
maxheap_pop (k, heap_sim, heap_ids);
|
||||
idx_t id_or_pair = store_pairs ? sl : id;
|
||||
maxheap_push (k, heap_sim, heap_ids, dis, id_or_pair);
|
||||
maxheap_replace_top (k, heap_sim, heap_ids, dis, id_or_pair);
|
||||
}
|
||||
n_refine ++;
|
||||
}
|
||||
|
|
|
@ -269,9 +269,8 @@ struct IVFScanner: InvertedListScanner {
|
|||
float dis = hc.hamming (codes);
|
||||
|
||||
if (dis < simi [0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
||||
maxheap_push (k, simi, idxi, dis, id);
|
||||
maxheap_replace_top (k, simi, idxi, dis, id);
|
||||
nup++;
|
||||
}
|
||||
codes += code_size;
|
||||
|
|
|
@ -346,8 +346,7 @@ static size_t polysemous_inner_loop (
|
|||
}
|
||||
|
||||
if (dis < heap_dis[0]) {
|
||||
maxheap_pop (k, heap_dis, heap_ids);
|
||||
maxheap_push (k, heap_dis, heap_ids, dis, bi);
|
||||
maxheap_replace_top (k, heap_dis, heap_ids, dis, bi);
|
||||
}
|
||||
}
|
||||
b_code += code_size;
|
||||
|
|
|
@ -539,8 +539,7 @@ int HNSW::search_from_candidates(
|
|||
if (nres < k) {
|
||||
faiss::maxheap_push(++nres, D, I, d, v1);
|
||||
} else if (d < D[0]) {
|
||||
faiss::maxheap_pop(nres--, D, I);
|
||||
faiss::maxheap_push(++nres, D, I, d, v1);
|
||||
faiss::maxheap_replace_top(nres, D, I, d, v1);
|
||||
}
|
||||
vt.set(v1);
|
||||
}
|
||||
|
@ -578,8 +577,7 @@ int HNSW::search_from_candidates(
|
|||
if (nres < k) {
|
||||
faiss::maxheap_push(++nres, D, I, d, v1);
|
||||
} else if (d < D[0]) {
|
||||
faiss::maxheap_pop(nres--, D, I);
|
||||
faiss::maxheap_push(++nres, D, I, d, v1);
|
||||
faiss::maxheap_replace_top(nres, D, I, d, v1);
|
||||
}
|
||||
candidates.push(v1, d);
|
||||
}
|
||||
|
|
|
@ -63,8 +63,7 @@ void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
|
|||
}
|
||||
|
||||
if (C::cmp (heap_dis[0], dis)) {
|
||||
heap_pop<C> (k, heap_dis, heap_ids);
|
||||
heap_push<C> (k, heap_dis, heap_ids, dis, j);
|
||||
heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -89,8 +88,7 @@ void pq_estimators_from_tables_M4 (const CT * codes,
|
|||
dis += dt[*codes++];
|
||||
|
||||
if (C::cmp (heap_dis[0], dis)) {
|
||||
heap_pop<C> (k, heap_dis, heap_ids);
|
||||
heap_push<C> (k, heap_dis, heap_ids, dis, j);
|
||||
heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -132,8 +130,7 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
|
|||
dt += ksub;
|
||||
}
|
||||
if (C::cmp (heap_dis[0], dis)) {
|
||||
heap_pop<C> (k, heap_dis, heap_ids);
|
||||
heap_push<C> (k, heap_dis, heap_ids, dis, j);
|
||||
heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -163,8 +160,7 @@ static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
|
|||
}
|
||||
|
||||
if (C::cmp(heap_dis[0], dis)) {
|
||||
heap_pop<C>(k, heap_dis, heap_ids);
|
||||
heap_push<C>(k, heap_dis, heap_ids, dis, j);
|
||||
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -762,8 +758,7 @@ void ProductQuantizer::search_sdc (const uint8_t * qcodes,
|
|||
tab += ksub * ksub;
|
||||
}
|
||||
if (dis < heap_dis[0]) {
|
||||
maxheap_pop (k, heap_dis, heap_ids);
|
||||
maxheap_push (k, heap_dis, heap_ids, dis, j);
|
||||
maxheap_replace_top (k, heap_dis, heap_ids, dis, j);
|
||||
}
|
||||
bcode += code_size;
|
||||
}
|
||||
|
|
|
@ -75,8 +75,7 @@ struct HeapResultHandler {
|
|||
/// add one result for query i
|
||||
void add_result(T dis, TI idx) {
|
||||
if (C::cmp(heap_dis[0], dis)) {
|
||||
heap_pop<C>(k, heap_dis, heap_ids);
|
||||
heap_push<C>(k, heap_dis, heap_ids, dis, idx);
|
||||
heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
|
||||
thresh = heap_dis[0];
|
||||
}
|
||||
}
|
||||
|
@ -113,8 +112,7 @@ struct HeapResultHandler {
|
|||
for (size_t j = j0; j < j1; j++) {
|
||||
T dis = *dis_tab++;
|
||||
if (C::cmp(thresh, dis)) {
|
||||
heap_pop<C>(k, heap_dis, heap_ids);
|
||||
heap_push<C>(k, heap_dis, heap_ids, dis, j);
|
||||
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
||||
thresh = heap_dis[0];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1430,9 +1430,8 @@ struct IVFSQScannerIP: InvertedListScanner {
|
|||
float accu = accu0 + dc.query_to_code (codes);
|
||||
|
||||
if (accu > simi [0]) {
|
||||
minheap_pop (k, simi, idxi);
|
||||
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
||||
minheap_push (k, simi, idxi, accu, id);
|
||||
minheap_replace_top (k, simi, idxi, accu, id);
|
||||
nup++;
|
||||
}
|
||||
codes += code_size;
|
||||
|
@ -1518,9 +1517,8 @@ struct IVFSQScannerL2: InvertedListScanner {
|
|||
float dis = dc.query_to_code (codes);
|
||||
|
||||
if (dis < simi [0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
||||
maxheap_push (k, simi, idxi, dis, id);
|
||||
maxheap_replace_top (k, simi, idxi, dis, id);
|
||||
nup++;
|
||||
}
|
||||
codes += code_size;
|
||||
|
|
|
@ -46,8 +46,7 @@ void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
|
|||
for (size_t j = 0; j < nj; j++) {
|
||||
T ip = ip_line [j];
|
||||
if (C::cmp(simi[0], ip)) {
|
||||
heap_pop<C> (k, simi, idxi);
|
||||
heap_push<C> (k, simi, idxi, ip, j + j0);
|
||||
heap_replace_top<C> (k, simi, idxi, ip, j + j0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -74,8 +73,7 @@ void HeapArray<C>::addn_with_ids (
|
|||
for (size_t j = 0; j < nj; j++) {
|
||||
T ip = ip_line [j];
|
||||
if (C::cmp(simi[0], ip)) {
|
||||
heap_pop<C> (k, simi, idxi);
|
||||
heap_push<C> (k, simi, idxi, ip, id_line [j]);
|
||||
heap_replace_top<C> (k, simi, idxi, ip, id_line [j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -105,6 +105,43 @@ void heap_push (size_t k,
|
|||
|
||||
|
||||
|
||||
/** Replace the top element from the heap defined by bh_val[0..k-1] and
|
||||
* bh_ids[0..k-1].
|
||||
*/
|
||||
template <class C> inline
|
||||
void heap_replace_top (size_t k,
|
||||
typename C::T * bh_val, typename C::TI * bh_ids,
|
||||
typename C::T val, typename C::TI ids)
|
||||
{
|
||||
bh_val--; /* Use 1-based indexing for easier node->child translation */
|
||||
bh_ids--;
|
||||
size_t i = 1, i1, i2;
|
||||
while (1) {
|
||||
i1 = i << 1;
|
||||
i2 = i1 + 1;
|
||||
if (i1 > k)
|
||||
break;
|
||||
if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) {
|
||||
if (C::cmp(val, bh_val[i1]))
|
||||
break;
|
||||
bh_val[i] = bh_val[i1];
|
||||
bh_ids[i] = bh_ids[i1];
|
||||
i = i1;
|
||||
}
|
||||
else {
|
||||
if (C::cmp(val, bh_val[i2]))
|
||||
break;
|
||||
bh_val[i] = bh_val[i2];
|
||||
bh_ids[i] = bh_ids[i2];
|
||||
i = i2;
|
||||
}
|
||||
}
|
||||
bh_val[i] = val;
|
||||
bh_ids[i] = ids;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Partial instanciation for heaps with TI = int64_t */
|
||||
|
||||
template <typename T> inline
|
||||
|
@ -121,6 +158,13 @@ void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
|||
}
|
||||
|
||||
|
||||
template <typename T> inline
|
||||
void minheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
||||
{
|
||||
heap_replace_top<CMin<T, int64_t> > (k, bh_val, bh_ids, val, ids);
|
||||
}
|
||||
|
||||
|
||||
template <typename T> inline
|
||||
void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids)
|
||||
{
|
||||
|
@ -135,6 +179,12 @@ void maxheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
|||
}
|
||||
|
||||
|
||||
template <typename T> inline
|
||||
void maxheap_replace_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
|
||||
{
|
||||
heap_replace_top<CMax<T, int64_t> > (k, bh_val, bh_ids, val, ids);
|
||||
}
|
||||
|
||||
|
||||
/*******************************************************************
|
||||
* Heap initialization
|
||||
|
@ -212,15 +262,13 @@ void heap_addn (size_t k,
|
|||
if (ids)
|
||||
for (i = 0; i < n; i++) {
|
||||
if (C::cmp (bh_val[0], x[i])) {
|
||||
heap_pop<C> (k, bh_val, bh_ids);
|
||||
heap_push<C> (k, bh_val, bh_ids, x[i], ids[i]);
|
||||
heap_replace_top<C> (k, bh_val, bh_ids, x[i], ids[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
for (i = 0; i < n; i++) {
|
||||
if (C::cmp (bh_val[0], x[i])) {
|
||||
heap_pop<C> (k, bh_val, bh_ids);
|
||||
heap_push<C> (k, bh_val, bh_ids, x[i], i);
|
||||
heap_replace_top<C> (k, bh_val, bh_ids, x[i], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -522,8 +522,7 @@ void knn_inner_products_by_idx (const float * x,
|
|||
float ip = fvec_inner_product (x_, y + d * idsi[j], d);
|
||||
|
||||
if (ip > simi[0]) {
|
||||
minheap_pop (k, simi, idxi);
|
||||
minheap_push (k, simi, idxi, ip, idsi[j]);
|
||||
minheap_replace_top (k, simi, idxi, ip, idsi[j]);
|
||||
}
|
||||
}
|
||||
minheap_reorder (k, simi, idxi);
|
||||
|
@ -550,8 +549,7 @@ void knn_L2sqr_by_idx (const float * x,
|
|||
float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
|
||||
|
||||
if (disij < simi[0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
maxheap_push (k, simi, idxi, disij, idsi[j]);
|
||||
maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
|
||||
}
|
||||
}
|
||||
maxheap_reorder (res->k, simi, idxi);
|
||||
|
|
|
@ -176,8 +176,7 @@ void knn_extra_metrics_template (
|
|||
float disij = vd (x_i, y_j);
|
||||
|
||||
if (disij < simi[0]) {
|
||||
maxheap_pop (k, simi, idxi);
|
||||
maxheap_push (k, simi, idxi, disij, j);
|
||||
maxheap_replace_top (k, simi, idxi, disij, j);
|
||||
}
|
||||
y_j += d;
|
||||
}
|
||||
|
|
|
@ -292,8 +292,7 @@ void hammings_knn_hc (
|
|||
for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
|
||||
dis = hc.hamming (bs2_);
|
||||
if (dis < bh_val_[0]) {
|
||||
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
|
||||
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
faiss::maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -391,8 +390,7 @@ void hammings_knn_hc_1 (
|
|||
for (j = 0; j < n2; j++, bs2_+= nwords) {
|
||||
dis = popcount64 (bs1_ ^ *bs2_);
|
||||
if (dis < bh_val_0) {
|
||||
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
|
||||
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
faiss::maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
|
||||
bh_val_0 = bh_val_[0];
|
||||
}
|
||||
}
|
||||
|
@ -818,8 +816,7 @@ static void hamming_dis_inner_loop (
|
|||
int ndiff = hc.hamming (cb);
|
||||
cb += code_size;
|
||||
if (ndiff < bh_val_[0]) {
|
||||
maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
|
||||
maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
|
||||
maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -683,7 +683,7 @@ class TestSpectralHash(unittest.TestCase):
|
|||
key = (nbit, tt, period)
|
||||
|
||||
print('(%d, %s, %g): %d, ' % (nbit, repr(tt), period, ninter))
|
||||
assert abs(ninter - self.ref_results[key]) <= 4
|
||||
assert abs(ninter - self.ref_results[key]) <= 12
|
||||
|
||||
|
||||
class TestRefine(unittest.TestCase):
|
||||
|
|
|
@ -183,7 +183,9 @@ class TestBinaryIVF(unittest.TestCase):
|
|||
index.add(self.xb)
|
||||
Divfflat, _ = index.search(self.xq, 10)
|
||||
|
||||
self.assertEqual((self.Dref == Divfflat).sum(), 4122)
|
||||
# Some centroids are equidistant from the query points.
|
||||
# So the answer will depend on the implementation of the heap.
|
||||
self.assertGreater((self.Dref == Divfflat).sum(), 4100)
|
||||
|
||||
def test_ivf_range(self):
|
||||
d = self.xq.shape[1] * 8
|
||||
|
|
Loading…
Reference in New Issue