faiss/tests/test_hnsw.cpp

637 lines
19 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <gtest/gtest.h>
#include <cstddef>
#include <limits>
#include <random>
#include <unordered_set>
#include <vector>
#include <faiss/IndexHNSW.h>
#include <faiss/impl/HNSW.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/utils/random.h>
int reference_pop_min(faiss::HNSW::MinimaxHeap& heap, float* vmin_out) {
assert(heap.k > 0);
// returns min. This is an O(n) operation
int i = heap.k - 1;
while (i >= 0) {
if (heap.ids[i] != -1)
break;
i--;
}
if (i == -1)
return -1;
int imin = i;
float vmin = heap.dis[i];
i--;
while (i >= 0) {
if (heap.ids[i] != -1 && heap.dis[i] < vmin) {
vmin = heap.dis[i];
imin = i;
}
i--;
}
if (vmin_out)
*vmin_out = vmin;
int ret = heap.ids[imin];
heap.ids[imin] = -1;
--heap.nvalid;
return ret;
}
void test_popmin(int heap_size, int amount_to_put) {
// create a heap
faiss::HNSW::MinimaxHeap mm_heap(heap_size);
using storage_idx_t = faiss::HNSW::storage_idx_t;
std::default_random_engine rng(123 + heap_size * amount_to_put);
std::uniform_int_distribution<storage_idx_t> u(0, 65536);
std::uniform_real_distribution<float> uf(0, 1);
// generate random unique indices
std::unordered_set<storage_idx_t> indices;
while (indices.size() < amount_to_put) {
const storage_idx_t index = u(rng);
indices.insert(index);
}
// put ones into the heap
for (const auto index : indices) {
float distance = uf(rng);
if (distance >= 0.7f) {
// add infinity values from time to time
distance = std::numeric_limits<float>::infinity();
}
mm_heap.push(index, distance);
}
// clone the heap
faiss::HNSW::MinimaxHeap cloned_mm_heap = mm_heap;
// takes ones out one by one
while (mm_heap.size() > 0) {
// compare heaps
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
// use the reference pop_min for the cloned heap
float cloned_vmin_dis = std::numeric_limits<float>::quiet_NaN();
storage_idx_t cloned_vmin_idx =
reference_pop_min(cloned_mm_heap, &cloned_vmin_dis);
float vmin_dis = std::numeric_limits<float>::quiet_NaN();
storage_idx_t vmin_idx = mm_heap.pop_min(&vmin_dis);
// compare returns
ASSERT_EQ(vmin_dis, cloned_vmin_dis);
ASSERT_EQ(vmin_idx, cloned_vmin_idx);
}
// compare heaps again
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
}
void test_popmin_identical_distances(
int heap_size,
int amount_to_put,
const float distance) {
// create a heap
faiss::HNSW::MinimaxHeap mm_heap(heap_size);
using storage_idx_t = faiss::HNSW::storage_idx_t;
std::default_random_engine rng(123 + heap_size * amount_to_put);
std::uniform_int_distribution<storage_idx_t> u(0, 65536);
// generate random unique indices
std::unordered_set<storage_idx_t> indices;
while (indices.size() < amount_to_put) {
const storage_idx_t index = u(rng);
indices.insert(index);
}
// put ones into the heap
for (const auto index : indices) {
mm_heap.push(index, distance);
}
// clone the heap
faiss::HNSW::MinimaxHeap cloned_mm_heap = mm_heap;
// takes ones out one by one
while (mm_heap.size() > 0) {
// compare heaps
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
// use the reference pop_min for the cloned heap
float cloned_vmin_dis = std::numeric_limits<float>::quiet_NaN();
storage_idx_t cloned_vmin_idx =
reference_pop_min(cloned_mm_heap, &cloned_vmin_dis);
float vmin_dis = std::numeric_limits<float>::quiet_NaN();
storage_idx_t vmin_idx = mm_heap.pop_min(&vmin_dis);
// compare returns
ASSERT_EQ(vmin_dis, cloned_vmin_dis);
ASSERT_EQ(vmin_idx, cloned_vmin_idx);
}
// compare heaps again
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
}
TEST(HNSW, Test_popmin) {
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32, 64, 128};
for (const size_t size : sizes) {
for (size_t amount = size; amount > 0; amount /= 2) {
test_popmin(size, amount);
}
}
}
TEST(HNSW, Test_popmin_identical_distances) {
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32};
for (const size_t size : sizes) {
for (size_t amount = size; amount > 0; amount /= 2) {
test_popmin_identical_distances(size, amount, 1.0f);
}
}
}
TEST(HNSW, Test_popmin_infinite_distances) {
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32};
for (const size_t size : sizes) {
for (size_t amount = size; amount > 0; amount /= 2) {
test_popmin_identical_distances(
size, amount, std::numeric_limits<float>::infinity());
}
}
}
class HNSWTest : public testing::Test {
protected:
HNSWTest() {
xb = std::make_unique<std::vector<float>>(d * nb);
xb->reserve(d * nb);
faiss::float_rand(xb->data(), d * nb, 12345);
index = std::make_unique<faiss::IndexHNSWFlat>(d, M);
index->add(nb, xb->data());
xq = std::unique_ptr<std::vector<float>>(
new std::vector<float>(d * nq));
xq->reserve(d * nq);
faiss::float_rand(xq->data(), d * nq, 12345);
dis = std::unique_ptr<faiss::DistanceComputer>(
index->storage->get_distance_computer());
dis->set_query(xq->data() + 0 * index->d);
}
const int d = 64;
const int nb = 2000;
const int M = 4;
const int nq = 10;
const int k = 10;
std::unique_ptr<std::vector<float>> xb;
std::unique_ptr<std::vector<float>> xq;
std::unique_ptr<faiss::DistanceComputer> dis;
std::unique_ptr<faiss::IndexHNSWFlat> index;
};
/** Do a BFS on the candidates list */
int reference_search_from_candidates(
const faiss::HNSW& hnsw,
faiss::DistanceComputer& qdis,
faiss::ResultHandler<faiss::HNSW::C>& res,
faiss::HNSW::MinimaxHeap& candidates,
faiss::VisitedTable& vt,
faiss::HNSWStats& stats,
int level,
int nres_in,
const faiss::SearchParametersHNSW* params) {
int nres = nres_in;
int ndis = 0;
// can be overridden by search params
bool do_dis_check = params ? params->check_relative_distance
: hnsw.check_relative_distance;
int efSearch = params ? params->efSearch : hnsw.efSearch;
const faiss::IDSelector* sel = params ? params->sel : nullptr;
faiss::HNSW::C::T threshold = res.threshold;
for (int i = 0; i < candidates.size(); i++) {
faiss::idx_t v1 = candidates.ids[i];
float d = candidates.dis[i];
FAISS_ASSERT(v1 >= 0);
if (!sel || sel->is_member(v1)) {
if (d < threshold) {
if (res.add_result(d, v1)) {
threshold = res.threshold;
}
}
}
vt.set(v1);
}
int nstep = 0;
while (candidates.size() > 0) {
float d0 = 0;
int v0 = candidates.pop_min(&d0);
if (do_dis_check) {
// tricky stopping condition: there are more that ef
// distances that are processed already that are smaller
// than d0
int n_dis_below = candidates.count_below(d0);
if (n_dis_below >= efSearch) {
break;
}
}
size_t begin, end;
hnsw.neighbor_range(v0, level, &begin, &end);
// a reference version
for (size_t j = begin; j < end; j++) {
int v1 = hnsw.neighbors[j];
if (v1 < 0)
break;
if (vt.get(v1)) {
continue;
}
vt.set(v1);
ndis++;
float d = qdis(v1);
if (!sel || sel->is_member(v1)) {
if (d < threshold) {
if (res.add_result(d, v1)) {
threshold = res.threshold;
nres += 1;
}
}
}
candidates.push(v1, d);
}
nstep++;
if (!do_dis_check && nstep > efSearch) {
break;
}
}
if (level == 0) {
stats.n1++;
if (candidates.size() == 0) {
stats.n2++;
}
stats.ndis += ndis;
stats.nhops += nstep;
}
return nres;
}
faiss::HNSWStats reference_greedy_update_nearest(
const faiss::HNSW& hnsw,
faiss::DistanceComputer& qdis,
int level,
faiss::HNSW::storage_idx_t& nearest,
float& d_nearest) {
faiss::HNSWStats stats;
for (;;) {
faiss::HNSW::storage_idx_t prev_nearest = nearest;
size_t begin, end;
hnsw.neighbor_range(nearest, level, &begin, &end);
size_t ndis = 0;
for (size_t i = begin; i < end; i++) {
faiss::HNSW::storage_idx_t v = hnsw.neighbors[i];
if (v < 0)
break;
ndis += 1;
float dis = qdis(v);
if (dis < d_nearest) {
nearest = v;
d_nearest = dis;
}
}
// update stats
stats.ndis += ndis;
stats.nhops += 1;
if (nearest == prev_nearest) {
return stats;
}
}
}
std::priority_queue<faiss::HNSW::Node> reference_search_from_candidate_unbounded(
const faiss::HNSW& hnsw,
const faiss::HNSW::Node& node,
faiss::DistanceComputer& qdis,
int ef,
faiss::VisitedTable* vt,
faiss::HNSWStats& stats) {
int ndis = 0;
std::priority_queue<faiss::HNSW::Node> top_candidates;
std::priority_queue<
faiss::HNSW::Node,
std::vector<faiss::HNSW::Node>,
std::greater<faiss::HNSW::Node>>
candidates;
top_candidates.push(node);
candidates.push(node);
vt->set(node.second);
while (!candidates.empty()) {
float d0;
faiss::HNSW::storage_idx_t v0;
std::tie(d0, v0) = candidates.top();
if (d0 > top_candidates.top().first) {
break;
}
candidates.pop();
size_t begin, end;
hnsw.neighbor_range(v0, 0, &begin, &end);
for (size_t j = begin; j < end; ++j) {
int v1 = hnsw.neighbors[j];
if (v1 < 0) {
break;
}
if (vt->get(v1)) {
continue;
}
vt->set(v1);
float d1 = qdis(v1);
++ndis;
if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
candidates.emplace(d1, v1);
top_candidates.emplace(d1, v1);
if (top_candidates.size() > ef) {
top_candidates.pop();
}
}
}
stats.nhops += 1;
}
++stats.n1;
if (candidates.size() == 0) {
++stats.n2;
}
stats.ndis += ndis;
return top_candidates;
}
TEST_F(HNSWTest, TEST_search_from_candidate_unbounded) {
omp_set_num_threads(1);
auto nearest = index->hnsw.entry_point;
float d_nearest = (*dis)(nearest);
auto node = faiss::HNSW::Node(d_nearest, nearest);
faiss::VisitedTable vt(index->ntotal);
faiss::HNSWStats stats;
// actual version
auto top_candidates = faiss::search_from_candidate_unbounded(
index->hnsw, node, *dis, k, &vt, stats);
auto reference_nearest = index->hnsw.entry_point;
float reference_d_nearest = (*dis)(nearest);
auto reference_node =
faiss::HNSW::Node(reference_d_nearest, reference_nearest);
faiss::VisitedTable reference_vt(index->ntotal);
faiss::HNSWStats reference_stats;
// reference version
auto reference_top_candidates = reference_search_from_candidate_unbounded(
index->hnsw,
reference_node,
*dis,
k,
&reference_vt,
reference_stats);
EXPECT_EQ(stats.ndis, reference_stats.ndis);
EXPECT_EQ(stats.nhops, reference_stats.nhops);
EXPECT_EQ(stats.n1, reference_stats.n1);
EXPECT_EQ(stats.n2, reference_stats.n2);
EXPECT_EQ(top_candidates.size(), reference_top_candidates.size());
}
TEST_F(HNSWTest, TEST_greedy_update_nearest) {
omp_set_num_threads(1);
auto nearest = index->hnsw.entry_point;
float d_nearest = (*dis)(nearest);
auto reference_nearest = index->hnsw.entry_point;
float reference_d_nearest = (*dis)(reference_nearest);
// actual version
auto stats = faiss::greedy_update_nearest(
index->hnsw, *dis, 0, nearest, d_nearest);
// reference version
auto reference_stats = reference_greedy_update_nearest(
index->hnsw, *dis, 0, reference_nearest, reference_d_nearest);
EXPECT_EQ(stats.ndis, reference_stats.ndis);
EXPECT_EQ(stats.nhops, reference_stats.nhops);
EXPECT_EQ(stats.n1, reference_stats.n1);
EXPECT_EQ(stats.n2, reference_stats.n2);
EXPECT_NEAR(d_nearest, reference_d_nearest, 0.01);
EXPECT_EQ(nearest, reference_nearest);
}
TEST_F(HNSWTest, TEST_search_from_candidates) {
omp_set_num_threads(1);
std::vector<faiss::idx_t> I(k * nq);
std::vector<float> D(k * nq);
std::vector<faiss::idx_t> reference_I(k * nq);
std::vector<float> reference_D(k * nq);
using RH = faiss::HeapBlockResultHandler<faiss::HNSW::C>;
faiss::VisitedTable vt(index->ntotal);
faiss::VisitedTable reference_vt(index->ntotal);
int num_candidates = 10;
faiss::HNSW::MinimaxHeap candidates(num_candidates);
faiss::HNSW::MinimaxHeap reference_candidates(num_candidates);
for (int i = 0; i < num_candidates; i++) {
vt.set(i);
reference_vt.set(i);
candidates.push(i, (*dis)(i));
reference_candidates.push(i, (*dis)(i));
}
faiss::HNSWStats stats;
RH bres(nq, D.data(), I.data(), k);
faiss::HeapBlockResultHandler<faiss::HNSW::C>::SingleResultHandler res(
bres);
res.begin(0);
faiss::search_from_candidates(
index->hnsw, *dis, res, candidates, vt, stats, 0, 0, nullptr);
res.end();
faiss::HNSWStats reference_stats;
RH reference_bres(nq, reference_D.data(), reference_I.data(), k);
faiss::HeapBlockResultHandler<faiss::HNSW::C>::SingleResultHandler
reference_res(reference_bres);
reference_res.begin(0);
reference_search_from_candidates(
index->hnsw,
*dis,
reference_res,
reference_candidates,
reference_vt,
reference_stats,
0,
0,
nullptr);
reference_res.end();
for (int i = 0; i < nq; i++) {
for (int j = 0; j < k; j++) {
EXPECT_NEAR(I[i * k + j], reference_I[i * k + j], 0.1);
EXPECT_NEAR(D[i * k + j], reference_D[i * k + j], 0.1);
}
}
EXPECT_EQ(reference_stats.ndis, stats.ndis);
EXPECT_EQ(reference_stats.nhops, stats.nhops);
EXPECT_EQ(reference_stats.n1, stats.n1);
EXPECT_EQ(reference_stats.n2, stats.n2);
}
TEST_F(HNSWTest, TEST_search_neighbors_to_add) {
omp_set_num_threads(1);
faiss::VisitedTable vt(index->ntotal);
faiss::VisitedTable reference_vt(index->ntotal);
std::priority_queue<faiss::HNSW::NodeDistCloser> link_targets;
std::priority_queue<faiss::HNSW::NodeDistCloser> reference_link_targets;
faiss::search_neighbors_to_add(
index->hnsw,
*dis,
link_targets,
index->hnsw.entry_point,
(*dis)(index->hnsw.entry_point),
index->hnsw.max_level,
vt,
false);
faiss::search_neighbors_to_add(
index->hnsw,
*dis,
reference_link_targets,
index->hnsw.entry_point,
(*dis)(index->hnsw.entry_point),
index->hnsw.max_level,
reference_vt,
true);
EXPECT_EQ(link_targets.size(), reference_link_targets.size());
while (!link_targets.empty()) {
auto val = link_targets.top();
auto reference_val = reference_link_targets.top();
EXPECT_EQ(val.d, reference_val.d);
EXPECT_EQ(val.id, reference_val.id);
link_targets.pop();
reference_link_targets.pop();
}
}
TEST_F(HNSWTest, TEST_nb_neighbors_bound) {
omp_set_num_threads(1);
EXPECT_EQ(index->hnsw.nb_neighbors(0), 8);
EXPECT_EQ(index->hnsw.nb_neighbors(1), 4);
EXPECT_EQ(index->hnsw.nb_neighbors(2), 4);
EXPECT_EQ(index->hnsw.nb_neighbors(3), 4);
// picking a large number to trigger an exception based on checking bounds
EXPECT_THROW(index->hnsw.nb_neighbors(100), faiss::FaissException);
}
TEST_F(HNSWTest, TEST_search_level_0) {
omp_set_num_threads(1);
std::vector<faiss::idx_t> I(k * nq);
std::vector<float> D(k * nq);
using RH = faiss::HeapBlockResultHandler<faiss::HNSW::C>;
RH bres1(nq, D.data(), I.data(), k);
faiss::HeapBlockResultHandler<faiss::HNSW::C>::SingleResultHandler res1(
bres1);
RH bres2(nq, D.data(), I.data(), k);
faiss::HeapBlockResultHandler<faiss::HNSW::C>::SingleResultHandler res2(
bres2);
faiss::HNSWStats stats1, stats2;
faiss::VisitedTable vt1(index->ntotal);
faiss::VisitedTable vt2(index->ntotal);
auto nprobe = 5;
const faiss::HNSW::storage_idx_t values[] = {1, 2, 3, 4, 5};
const faiss::HNSW::storage_idx_t* nearest_i = values;
const float distances[] = {0.1, 0.2, 0.3, 0.4, 0.5};
const float* nearest_d = distances;
// search_type == 1
res1.begin(0);
index->hnsw.search_level_0(
*dis, res1, nprobe, nearest_i, nearest_d, 1, stats1, vt1, nullptr);
res1.end();
// search_type == 2
res2.begin(0);
index->hnsw.search_level_0(
*dis, res2, nprobe, nearest_i, nearest_d, 2, stats2, vt2, nullptr);
res2.end();
// search_type 1 calls search_from_candidates in a loop nprobe times.
// search_type 2 pushes the candidates and just calls search_from_candidates
// once, so those stats will be much less.
EXPECT_GT(stats1.ndis, stats2.ndis);
EXPECT_GT(stats1.nhops, stats2.nhops);
EXPECT_GT(stats1.n1, stats2.n1);
EXPECT_GT(stats1.n2, stats2.n2);
}