1154 lines
32 KiB
C++
1154 lines
32 KiB
C++
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// -*- c++ -*-
|
|
|
|
#include <faiss/impl/HNSW.h>
|
|
|
|
#include <string>
|
|
|
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
#include <faiss/impl/DistanceComputer.h>
|
|
#include <faiss/impl/IDSelector.h>
|
|
#include <faiss/utils/prefetch.h>
|
|
|
|
#include <faiss/impl/platform_macros.h>
|
|
|
|
#ifdef __AVX2__
|
|
#include <immintrin.h>
|
|
|
|
#include <limits>
|
|
#include <type_traits>
|
|
#endif
|
|
|
|
namespace faiss {
|
|
|
|
/**************************************************************
|
|
* HNSW structure implementation
|
|
**************************************************************/
|
|
|
|
int HNSW::nb_neighbors(int layer_no) const {
|
|
return cum_nneighbor_per_level[layer_no + 1] -
|
|
cum_nneighbor_per_level[layer_no];
|
|
}
|
|
|
|
void HNSW::set_nb_neighbors(int level_no, int n) {
|
|
FAISS_THROW_IF_NOT(levels.size() == 0);
|
|
int cur_n = nb_neighbors(level_no);
|
|
for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
|
|
cum_nneighbor_per_level[i] += n - cur_n;
|
|
}
|
|
}
|
|
|
|
int HNSW::cum_nb_neighbors(int layer_no) const {
|
|
return cum_nneighbor_per_level[layer_no];
|
|
}
|
|
|
|
void HNSW::neighbor_range(idx_t no, int layer_no, size_t* begin, size_t* end)
|
|
const {
|
|
size_t o = offsets[no];
|
|
*begin = o + cum_nb_neighbors(layer_no);
|
|
*end = o + cum_nb_neighbors(layer_no + 1);
|
|
}
|
|
|
|
HNSW::HNSW(int M) : rng(12345) {
|
|
set_default_probas(M, 1.0 / log(M));
|
|
offsets.push_back(0);
|
|
}
|
|
|
|
int HNSW::random_level() {
|
|
double f = rng.rand_float();
|
|
// could be a bit faster with bissection
|
|
for (int level = 0; level < assign_probas.size(); level++) {
|
|
if (f < assign_probas[level]) {
|
|
return level;
|
|
}
|
|
f -= assign_probas[level];
|
|
}
|
|
// happens with exponentially low probability
|
|
return assign_probas.size() - 1;
|
|
}
|
|
|
|
void HNSW::set_default_probas(int M, float levelMult) {
|
|
int nn = 0;
|
|
cum_nneighbor_per_level.push_back(0);
|
|
for (int level = 0;; level++) {
|
|
float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
|
|
if (proba < 1e-9)
|
|
break;
|
|
assign_probas.push_back(proba);
|
|
nn += level == 0 ? M * 2 : M;
|
|
cum_nneighbor_per_level.push_back(nn);
|
|
}
|
|
}
|
|
|
|
void HNSW::clear_neighbor_tables(int level) {
|
|
for (int i = 0; i < levels.size(); i++) {
|
|
size_t begin, end;
|
|
neighbor_range(i, level, &begin, &end);
|
|
for (size_t j = begin; j < end; j++) {
|
|
neighbors[j] = -1;
|
|
}
|
|
}
|
|
}
|
|
|
|
void HNSW::reset() {
|
|
max_level = -1;
|
|
entry_point = -1;
|
|
offsets.clear();
|
|
offsets.push_back(0);
|
|
levels.clear();
|
|
neighbors.clear();
|
|
}
|
|
|
|
void HNSW::print_neighbor_stats(int level) const {
|
|
FAISS_THROW_IF_NOT(level < cum_nneighbor_per_level.size());
|
|
printf("stats on level %d, max %d neighbors per vertex:\n",
|
|
level,
|
|
nb_neighbors(level));
|
|
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
|
|
#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
|
|
reduction(+: tot_reciprocal) reduction(+: n_node)
|
|
for (int i = 0; i < levels.size(); i++) {
|
|
if (levels[i] > level) {
|
|
n_node++;
|
|
size_t begin, end;
|
|
neighbor_range(i, level, &begin, &end);
|
|
std::unordered_set<int> neighset;
|
|
for (size_t j = begin; j < end; j++) {
|
|
if (neighbors[j] < 0)
|
|
break;
|
|
neighset.insert(neighbors[j]);
|
|
}
|
|
int n_neigh = neighset.size();
|
|
int n_common = 0;
|
|
int n_reciprocal = 0;
|
|
for (size_t j = begin; j < end; j++) {
|
|
storage_idx_t i2 = neighbors[j];
|
|
if (i2 < 0)
|
|
break;
|
|
FAISS_ASSERT(i2 != i);
|
|
size_t begin2, end2;
|
|
neighbor_range(i2, level, &begin2, &end2);
|
|
for (size_t j2 = begin2; j2 < end2; j2++) {
|
|
storage_idx_t i3 = neighbors[j2];
|
|
if (i3 < 0)
|
|
break;
|
|
if (i3 == i) {
|
|
n_reciprocal++;
|
|
continue;
|
|
}
|
|
if (neighset.count(i3)) {
|
|
neighset.erase(i3);
|
|
n_common++;
|
|
}
|
|
}
|
|
}
|
|
tot_neigh += n_neigh;
|
|
tot_common += n_common;
|
|
tot_reciprocal += n_reciprocal;
|
|
}
|
|
}
|
|
float normalizer = n_node;
|
|
printf(" nb of nodes at that level %zd\n", n_node);
|
|
printf(" neighbors per node: %.2f (%zd)\n",
|
|
tot_neigh / normalizer,
|
|
tot_neigh);
|
|
printf(" nb of reciprocal neighbors: %.2f\n",
|
|
tot_reciprocal / normalizer);
|
|
printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
|
|
tot_common / normalizer,
|
|
tot_common);
|
|
}
|
|
|
|
void HNSW::fill_with_random_links(size_t n) {
|
|
int max_level = prepare_level_tab(n);
|
|
RandomGenerator rng2(456);
|
|
|
|
for (int level = max_level - 1; level >= 0; --level) {
|
|
std::vector<int> elts;
|
|
for (int i = 0; i < n; i++) {
|
|
if (levels[i] > level) {
|
|
elts.push_back(i);
|
|
}
|
|
}
|
|
printf("linking %zd elements in level %d\n", elts.size(), level);
|
|
|
|
if (elts.size() == 1)
|
|
continue;
|
|
|
|
for (int ii = 0; ii < elts.size(); ii++) {
|
|
int i = elts[ii];
|
|
size_t begin, end;
|
|
neighbor_range(i, 0, &begin, &end);
|
|
for (size_t j = begin; j < end; j++) {
|
|
int other = 0;
|
|
do {
|
|
other = elts[rng2.rand_int(elts.size())];
|
|
} while (other == i);
|
|
|
|
neighbors[j] = other;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
|
|
size_t n0 = offsets.size() - 1;
|
|
|
|
if (preset_levels) {
|
|
FAISS_ASSERT(n0 + n == levels.size());
|
|
} else {
|
|
FAISS_ASSERT(n0 == levels.size());
|
|
for (int i = 0; i < n; i++) {
|
|
int pt_level = random_level();
|
|
levels.push_back(pt_level + 1);
|
|
}
|
|
}
|
|
|
|
int max_level = 0;
|
|
for (int i = 0; i < n; i++) {
|
|
int pt_level = levels[i + n0] - 1;
|
|
if (pt_level > max_level)
|
|
max_level = pt_level;
|
|
offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
|
|
neighbors.resize(offsets.back(), -1);
|
|
}
|
|
|
|
return max_level;
|
|
}
|
|
|
|
/** Enumerate vertices from farthest to nearest from query, keep a
|
|
* neighbor only if there is no previous neighbor that is closer to
|
|
* that vertex than the query.
|
|
*/
|
|
void HNSW::shrink_neighbor_list(
|
|
DistanceComputer& qdis,
|
|
std::priority_queue<NodeDistFarther>& input,
|
|
std::vector<NodeDistFarther>& output,
|
|
int max_size) {
|
|
while (input.size() > 0) {
|
|
NodeDistFarther v1 = input.top();
|
|
input.pop();
|
|
float dist_v1_q = v1.d;
|
|
|
|
bool good = true;
|
|
for (NodeDistFarther v2 : output) {
|
|
float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
|
|
|
|
if (dist_v1_v2 < dist_v1_q) {
|
|
good = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (good) {
|
|
output.push_back(v1);
|
|
if (output.size() >= max_size) {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
using storage_idx_t = HNSW::storage_idx_t;
|
|
using NodeDistCloser = HNSW::NodeDistCloser;
|
|
using NodeDistFarther = HNSW::NodeDistFarther;
|
|
|
|
/**************************************************************
|
|
* Addition subroutines
|
|
**************************************************************/
|
|
|
|
/// remove neighbors from the list to make it smaller than max_size
|
|
void shrink_neighbor_list(
|
|
DistanceComputer& qdis,
|
|
std::priority_queue<NodeDistCloser>& resultSet1,
|
|
int max_size) {
|
|
if (resultSet1.size() < max_size) {
|
|
return;
|
|
}
|
|
std::priority_queue<NodeDistFarther> resultSet;
|
|
std::vector<NodeDistFarther> returnlist;
|
|
|
|
while (resultSet1.size() > 0) {
|
|
resultSet.emplace(resultSet1.top().d, resultSet1.top().id);
|
|
resultSet1.pop();
|
|
}
|
|
|
|
HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size);
|
|
|
|
for (NodeDistFarther curen2 : returnlist) {
|
|
resultSet1.emplace(curen2.d, curen2.id);
|
|
}
|
|
}
|
|
|
|
/// add a link between two elements, possibly shrinking the list
|
|
/// of links to make room for it.
|
|
void add_link(
|
|
HNSW& hnsw,
|
|
DistanceComputer& qdis,
|
|
storage_idx_t src,
|
|
storage_idx_t dest,
|
|
int level) {
|
|
size_t begin, end;
|
|
hnsw.neighbor_range(src, level, &begin, &end);
|
|
if (hnsw.neighbors[end - 1] == -1) {
|
|
// there is enough room, find a slot to add it
|
|
size_t i = end;
|
|
while (i > begin) {
|
|
if (hnsw.neighbors[i - 1] != -1)
|
|
break;
|
|
i--;
|
|
}
|
|
hnsw.neighbors[i] = dest;
|
|
return;
|
|
}
|
|
|
|
// otherwise we let them fight out which to keep
|
|
|
|
// copy to resultSet...
|
|
std::priority_queue<NodeDistCloser> resultSet;
|
|
resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
|
|
for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
|
|
storage_idx_t neigh = hnsw.neighbors[i];
|
|
resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
|
|
}
|
|
|
|
shrink_neighbor_list(qdis, resultSet, end - begin);
|
|
|
|
// ...and back
|
|
size_t i = begin;
|
|
while (resultSet.size()) {
|
|
hnsw.neighbors[i++] = resultSet.top().id;
|
|
resultSet.pop();
|
|
}
|
|
// they may have shrunk more than just by 1 element
|
|
while (i < end) {
|
|
hnsw.neighbors[i++] = -1;
|
|
}
|
|
}
|
|
|
|
/// search neighbors on a single level, starting from an entry point
|
|
void search_neighbors_to_add(
|
|
HNSW& hnsw,
|
|
DistanceComputer& qdis,
|
|
std::priority_queue<NodeDistCloser>& results,
|
|
int entry_point,
|
|
float d_entry_point,
|
|
int level,
|
|
VisitedTable& vt) {
|
|
// top is nearest candidate
|
|
std::priority_queue<NodeDistFarther> candidates;
|
|
|
|
NodeDistFarther ev(d_entry_point, entry_point);
|
|
candidates.push(ev);
|
|
results.emplace(d_entry_point, entry_point);
|
|
vt.set(entry_point);
|
|
|
|
while (!candidates.empty()) {
|
|
// get nearest
|
|
const NodeDistFarther& currEv = candidates.top();
|
|
|
|
if (currEv.d > results.top().d) {
|
|
break;
|
|
}
|
|
int currNode = currEv.id;
|
|
candidates.pop();
|
|
|
|
// loop over neighbors
|
|
size_t begin, end;
|
|
hnsw.neighbor_range(currNode, level, &begin, &end);
|
|
for (size_t i = begin; i < end; i++) {
|
|
storage_idx_t nodeId = hnsw.neighbors[i];
|
|
if (nodeId < 0)
|
|
break;
|
|
if (vt.get(nodeId))
|
|
continue;
|
|
vt.set(nodeId);
|
|
|
|
float dis = qdis(nodeId);
|
|
NodeDistFarther evE1(dis, nodeId);
|
|
|
|
if (results.size() < hnsw.efConstruction || results.top().d > dis) {
|
|
results.emplace(dis, nodeId);
|
|
candidates.emplace(dis, nodeId);
|
|
if (results.size() > hnsw.efConstruction) {
|
|
results.pop();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
vt.advance();
|
|
}
|
|
|
|
/**************************************************************
|
|
* Searching subroutines
|
|
**************************************************************/
|
|
|
|
/// greedily update a nearest vector at a given level
|
|
void greedy_update_nearest(
|
|
const HNSW& hnsw,
|
|
DistanceComputer& qdis,
|
|
int level,
|
|
storage_idx_t& nearest,
|
|
float& d_nearest) {
|
|
for (;;) {
|
|
storage_idx_t prev_nearest = nearest;
|
|
|
|
size_t begin, end;
|
|
hnsw.neighbor_range(nearest, level, &begin, &end);
|
|
for (size_t i = begin; i < end; i++) {
|
|
storage_idx_t v = hnsw.neighbors[i];
|
|
if (v < 0)
|
|
break;
|
|
float dis = qdis(v);
|
|
if (dis < d_nearest) {
|
|
nearest = v;
|
|
d_nearest = dis;
|
|
}
|
|
}
|
|
if (nearest == prev_nearest) {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
/// Finds neighbors and builds links with them, starting from an entry
|
|
/// point. The own neighbor list is assumed to be locked.
|
|
void HNSW::add_links_starting_from(
|
|
DistanceComputer& ptdis,
|
|
storage_idx_t pt_id,
|
|
storage_idx_t nearest,
|
|
float d_nearest,
|
|
int level,
|
|
omp_lock_t* locks,
|
|
VisitedTable& vt) {
|
|
std::priority_queue<NodeDistCloser> link_targets;
|
|
|
|
search_neighbors_to_add(
|
|
*this, ptdis, link_targets, nearest, d_nearest, level, vt);
|
|
|
|
// but we can afford only this many neighbors
|
|
int M = nb_neighbors(level);
|
|
|
|
::faiss::shrink_neighbor_list(ptdis, link_targets, M);
|
|
|
|
std::vector<storage_idx_t> neighbors;
|
|
neighbors.reserve(link_targets.size());
|
|
while (!link_targets.empty()) {
|
|
storage_idx_t other_id = link_targets.top().id;
|
|
add_link(*this, ptdis, pt_id, other_id, level);
|
|
neighbors.push_back(other_id);
|
|
link_targets.pop();
|
|
}
|
|
|
|
omp_unset_lock(&locks[pt_id]);
|
|
for (storage_idx_t other_id : neighbors) {
|
|
omp_set_lock(&locks[other_id]);
|
|
add_link(*this, ptdis, other_id, pt_id, level);
|
|
omp_unset_lock(&locks[other_id]);
|
|
}
|
|
omp_set_lock(&locks[pt_id]);
|
|
}
|
|
|
|
/**************************************************************
|
|
* Building, parallel
|
|
**************************************************************/
|
|
|
|
void HNSW::add_with_locks(
|
|
DistanceComputer& ptdis,
|
|
int pt_level,
|
|
int pt_id,
|
|
std::vector<omp_lock_t>& locks,
|
|
VisitedTable& vt) {
|
|
// greedy search on upper levels
|
|
|
|
storage_idx_t nearest;
|
|
#pragma omp critical
|
|
{
|
|
nearest = entry_point;
|
|
|
|
if (nearest == -1) {
|
|
max_level = pt_level;
|
|
entry_point = pt_id;
|
|
}
|
|
}
|
|
|
|
if (nearest < 0) {
|
|
return;
|
|
}
|
|
|
|
omp_set_lock(&locks[pt_id]);
|
|
|
|
int level = max_level; // level at which we start adding neighbors
|
|
float d_nearest = ptdis(nearest);
|
|
|
|
for (; level > pt_level; level--) {
|
|
greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
|
|
}
|
|
|
|
for (; level >= 0; level--) {
|
|
add_links_starting_from(
|
|
ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt);
|
|
}
|
|
|
|
omp_unset_lock(&locks[pt_id]);
|
|
|
|
if (pt_level > max_level) {
|
|
max_level = pt_level;
|
|
entry_point = pt_id;
|
|
}
|
|
}
|
|
|
|
/**************************************************************
|
|
* Searching
|
|
**************************************************************/
|
|
|
|
namespace {
|
|
|
|
using MinimaxHeap = HNSW::MinimaxHeap;
|
|
using Node = HNSW::Node;
|
|
/** Do a BFS on the candidates list */
|
|
|
|
int search_from_candidates(
|
|
const HNSW& hnsw,
|
|
DistanceComputer& qdis,
|
|
int k,
|
|
idx_t* I,
|
|
float* D,
|
|
MinimaxHeap& candidates,
|
|
VisitedTable& vt,
|
|
HNSWStats& stats,
|
|
int level,
|
|
int nres_in = 0,
|
|
const SearchParametersHNSW* params = nullptr) {
|
|
int nres = nres_in;
|
|
int ndis = 0;
|
|
|
|
// can be overridden by search params
|
|
bool do_dis_check = params ? params->check_relative_distance
|
|
: hnsw.check_relative_distance;
|
|
int efSearch = params ? params->efSearch : hnsw.efSearch;
|
|
const IDSelector* sel = params ? params->sel : nullptr;
|
|
|
|
for (int i = 0; i < candidates.size(); i++) {
|
|
idx_t v1 = candidates.ids[i];
|
|
float d = candidates.dis[i];
|
|
FAISS_ASSERT(v1 >= 0);
|
|
if (!sel || sel->is_member(v1)) {
|
|
if (nres < k) {
|
|
faiss::maxheap_push(++nres, D, I, d, v1);
|
|
} else if (d < D[0]) {
|
|
faiss::maxheap_replace_top(nres, D, I, d, v1);
|
|
}
|
|
}
|
|
vt.set(v1);
|
|
}
|
|
|
|
int nstep = 0;
|
|
|
|
while (candidates.size() > 0) {
|
|
float d0 = 0;
|
|
int v0 = candidates.pop_min(&d0);
|
|
|
|
if (do_dis_check) {
|
|
// tricky stopping condition: there are more that ef
|
|
// distances that are processed already that are smaller
|
|
// than d0
|
|
|
|
int n_dis_below = candidates.count_below(d0);
|
|
if (n_dis_below >= efSearch) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
size_t begin, end;
|
|
hnsw.neighbor_range(v0, level, &begin, &end);
|
|
|
|
// // baseline version
|
|
// for (size_t j = begin; j < end; j++) {
|
|
// int v1 = hnsw.neighbors[j];
|
|
// if (v1 < 0)
|
|
// break;
|
|
// if (vt.get(v1)) {
|
|
// continue;
|
|
// }
|
|
// vt.set(v1);
|
|
// ndis++;
|
|
// float d = qdis(v1);
|
|
// if (!sel || sel->is_member(v1)) {
|
|
// if (nres < k) {
|
|
// faiss::maxheap_push(++nres, D, I, d, v1);
|
|
// } else if (d < D[0]) {
|
|
// faiss::maxheap_replace_top(nres, D, I, d, v1);
|
|
// }
|
|
// }
|
|
// candidates.push(v1, d);
|
|
// }
|
|
|
|
// the following version processes 4 neighbors at a time
|
|
size_t jmax = begin;
|
|
for (size_t j = begin; j < end; j++) {
|
|
int v1 = hnsw.neighbors[j];
|
|
if (v1 < 0)
|
|
break;
|
|
|
|
prefetch_L2(vt.visited.data() + v1);
|
|
jmax += 1;
|
|
}
|
|
|
|
int counter = 0;
|
|
size_t saved_j[4];
|
|
|
|
ndis += jmax - begin;
|
|
|
|
auto add_to_heap = [&](const size_t idx, const float dis) {
|
|
if (!sel || sel->is_member(idx)) {
|
|
if (nres < k) {
|
|
faiss::maxheap_push(++nres, D, I, dis, idx);
|
|
} else if (dis < D[0]) {
|
|
faiss::maxheap_replace_top(nres, D, I, dis, idx);
|
|
}
|
|
}
|
|
candidates.push(idx, dis);
|
|
};
|
|
|
|
for (size_t j = begin; j < jmax; j++) {
|
|
int v1 = hnsw.neighbors[j];
|
|
|
|
bool vget = vt.get(v1);
|
|
vt.set(v1);
|
|
saved_j[counter] = v1;
|
|
counter += vget ? 0 : 1;
|
|
|
|
if (counter == 4) {
|
|
float dis[4];
|
|
qdis.distances_batch_4(
|
|
saved_j[0],
|
|
saved_j[1],
|
|
saved_j[2],
|
|
saved_j[3],
|
|
dis[0],
|
|
dis[1],
|
|
dis[2],
|
|
dis[3]);
|
|
|
|
for (size_t id4 = 0; id4 < 4; id4++) {
|
|
add_to_heap(saved_j[id4], dis[id4]);
|
|
}
|
|
|
|
counter = 0;
|
|
}
|
|
}
|
|
|
|
for (size_t icnt = 0; icnt < counter; icnt++) {
|
|
float dis = qdis(saved_j[icnt]);
|
|
add_to_heap(saved_j[icnt], dis);
|
|
}
|
|
|
|
nstep++;
|
|
if (!do_dis_check && nstep > efSearch) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (level == 0) {
|
|
stats.n1++;
|
|
if (candidates.size() == 0) {
|
|
stats.n2++;
|
|
}
|
|
stats.n3 += ndis;
|
|
}
|
|
|
|
return nres;
|
|
}
|
|
|
|
std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
const HNSW& hnsw,
|
|
const Node& node,
|
|
DistanceComputer& qdis,
|
|
int ef,
|
|
VisitedTable* vt,
|
|
HNSWStats& stats) {
|
|
int ndis = 0;
|
|
std::priority_queue<Node> top_candidates;
|
|
std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
|
|
|
|
top_candidates.push(node);
|
|
candidates.push(node);
|
|
|
|
vt->set(node.second);
|
|
|
|
while (!candidates.empty()) {
|
|
float d0;
|
|
storage_idx_t v0;
|
|
std::tie(d0, v0) = candidates.top();
|
|
|
|
if (d0 > top_candidates.top().first) {
|
|
break;
|
|
}
|
|
|
|
candidates.pop();
|
|
|
|
size_t begin, end;
|
|
hnsw.neighbor_range(v0, 0, &begin, &end);
|
|
|
|
// // baseline version
|
|
// for (size_t j = begin; j < end; ++j) {
|
|
// int v1 = hnsw.neighbors[j];
|
|
//
|
|
// if (v1 < 0) {
|
|
// break;
|
|
// }
|
|
// if (vt->get(v1)) {
|
|
// continue;
|
|
// }
|
|
//
|
|
// vt->set(v1);
|
|
//
|
|
// float d1 = qdis(v1);
|
|
// ++ndis;
|
|
//
|
|
// if (top_candidates.top().first > d1 ||
|
|
// top_candidates.size() < ef) {
|
|
// candidates.emplace(d1, v1);
|
|
// top_candidates.emplace(d1, v1);
|
|
//
|
|
// if (top_candidates.size() > ef) {
|
|
// top_candidates.pop();
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// the following version processes 4 neighbors at a time
|
|
size_t jmax = begin;
|
|
for (size_t j = begin; j < end; j++) {
|
|
int v1 = hnsw.neighbors[j];
|
|
if (v1 < 0)
|
|
break;
|
|
|
|
prefetch_L2(vt->visited.data() + v1);
|
|
jmax += 1;
|
|
}
|
|
|
|
int counter = 0;
|
|
size_t saved_j[4];
|
|
|
|
ndis += jmax - begin;
|
|
|
|
auto add_to_heap = [&](const size_t idx, const float dis) {
|
|
if (top_candidates.top().first > dis ||
|
|
top_candidates.size() < ef) {
|
|
candidates.emplace(dis, idx);
|
|
top_candidates.emplace(dis, idx);
|
|
|
|
if (top_candidates.size() > ef) {
|
|
top_candidates.pop();
|
|
}
|
|
}
|
|
};
|
|
|
|
for (size_t j = begin; j < jmax; j++) {
|
|
int v1 = hnsw.neighbors[j];
|
|
|
|
bool vget = vt->get(v1);
|
|
vt->set(v1);
|
|
saved_j[counter] = v1;
|
|
counter += vget ? 0 : 1;
|
|
|
|
if (counter == 4) {
|
|
float dis[4];
|
|
qdis.distances_batch_4(
|
|
saved_j[0],
|
|
saved_j[1],
|
|
saved_j[2],
|
|
saved_j[3],
|
|
dis[0],
|
|
dis[1],
|
|
dis[2],
|
|
dis[3]);
|
|
|
|
for (size_t id4 = 0; id4 < 4; id4++) {
|
|
add_to_heap(saved_j[id4], dis[id4]);
|
|
}
|
|
|
|
counter = 0;
|
|
}
|
|
}
|
|
|
|
for (size_t icnt = 0; icnt < counter; icnt++) {
|
|
float dis = qdis(saved_j[icnt]);
|
|
add_to_heap(saved_j[icnt], dis);
|
|
}
|
|
}
|
|
|
|
++stats.n1;
|
|
if (candidates.size() == 0) {
|
|
++stats.n2;
|
|
}
|
|
stats.n3 += ndis;
|
|
|
|
return top_candidates;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
HNSWStats HNSW::search(
|
|
DistanceComputer& qdis,
|
|
int k,
|
|
idx_t* I,
|
|
float* D,
|
|
VisitedTable& vt,
|
|
const SearchParametersHNSW* params) const {
|
|
HNSWStats stats;
|
|
if (entry_point == -1) {
|
|
return stats;
|
|
}
|
|
if (upper_beam == 1) {
|
|
// greedy search on upper levels
|
|
storage_idx_t nearest = entry_point;
|
|
float d_nearest = qdis(nearest);
|
|
|
|
for (int level = max_level; level >= 1; level--) {
|
|
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
|
|
}
|
|
|
|
int ef = std::max(efSearch, k);
|
|
if (search_bounded_queue) { // this is the most common branch
|
|
MinimaxHeap candidates(ef);
|
|
|
|
candidates.push(nearest, d_nearest);
|
|
|
|
search_from_candidates(
|
|
*this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
|
|
} else {
|
|
std::priority_queue<Node> top_candidates =
|
|
search_from_candidate_unbounded(
|
|
*this,
|
|
Node(d_nearest, nearest),
|
|
qdis,
|
|
ef,
|
|
&vt,
|
|
stats);
|
|
|
|
while (top_candidates.size() > k) {
|
|
top_candidates.pop();
|
|
}
|
|
|
|
int nres = 0;
|
|
while (!top_candidates.empty()) {
|
|
float d;
|
|
storage_idx_t label;
|
|
std::tie(d, label) = top_candidates.top();
|
|
faiss::maxheap_push(++nres, D, I, d, label);
|
|
top_candidates.pop();
|
|
}
|
|
}
|
|
|
|
vt.advance();
|
|
|
|
} else {
|
|
int candidates_size = upper_beam;
|
|
MinimaxHeap candidates(candidates_size);
|
|
|
|
std::vector<idx_t> I_to_next(candidates_size);
|
|
std::vector<float> D_to_next(candidates_size);
|
|
|
|
int nres = 1;
|
|
I_to_next[0] = entry_point;
|
|
D_to_next[0] = qdis(entry_point);
|
|
|
|
for (int level = max_level; level >= 0; level--) {
|
|
// copy I, D -> candidates
|
|
|
|
candidates.clear();
|
|
|
|
for (int i = 0; i < nres; i++) {
|
|
candidates.push(I_to_next[i], D_to_next[i]);
|
|
}
|
|
|
|
if (level == 0) {
|
|
nres = search_from_candidates(
|
|
*this, qdis, k, I, D, candidates, vt, stats, 0);
|
|
} else {
|
|
nres = search_from_candidates(
|
|
*this,
|
|
qdis,
|
|
candidates_size,
|
|
I_to_next.data(),
|
|
D_to_next.data(),
|
|
candidates,
|
|
vt,
|
|
stats,
|
|
level);
|
|
}
|
|
vt.advance();
|
|
}
|
|
}
|
|
|
|
return stats;
|
|
}
|
|
|
|
void HNSW::search_level_0(
|
|
DistanceComputer& qdis,
|
|
int k,
|
|
idx_t* idxi,
|
|
float* simi,
|
|
idx_t nprobe,
|
|
const storage_idx_t* nearest_i,
|
|
const float* nearest_d,
|
|
int search_type,
|
|
HNSWStats& search_stats,
|
|
VisitedTable& vt) const {
|
|
const HNSW& hnsw = *this;
|
|
|
|
if (search_type == 1) {
|
|
int nres = 0;
|
|
|
|
for (int j = 0; j < nprobe; j++) {
|
|
storage_idx_t cj = nearest_i[j];
|
|
|
|
if (cj < 0)
|
|
break;
|
|
|
|
if (vt.get(cj))
|
|
continue;
|
|
|
|
int candidates_size = std::max(hnsw.efSearch, int(k));
|
|
MinimaxHeap candidates(candidates_size);
|
|
|
|
candidates.push(cj, nearest_d[j]);
|
|
|
|
nres = search_from_candidates(
|
|
hnsw,
|
|
qdis,
|
|
k,
|
|
idxi,
|
|
simi,
|
|
candidates,
|
|
vt,
|
|
search_stats,
|
|
0,
|
|
nres);
|
|
}
|
|
} else if (search_type == 2) {
|
|
int candidates_size = std::max(hnsw.efSearch, int(k));
|
|
candidates_size = std::max(candidates_size, int(nprobe));
|
|
|
|
MinimaxHeap candidates(candidates_size);
|
|
for (int j = 0; j < nprobe; j++) {
|
|
storage_idx_t cj = nearest_i[j];
|
|
|
|
if (cj < 0)
|
|
break;
|
|
candidates.push(cj, nearest_d[j]);
|
|
}
|
|
|
|
search_from_candidates(
|
|
hnsw, qdis, k, idxi, simi, candidates, vt, search_stats, 0);
|
|
}
|
|
}
|
|
|
|
void HNSW::permute_entries(const idx_t* map) {
|
|
// remap levels
|
|
storage_idx_t ntotal = levels.size();
|
|
std::vector<storage_idx_t> imap(ntotal); // inverse mapping
|
|
// map: new index -> old index
|
|
// imap: old index -> new index
|
|
for (int i = 0; i < ntotal; i++) {
|
|
assert(map[i] >= 0 && map[i] < ntotal);
|
|
imap[map[i]] = i;
|
|
}
|
|
if (entry_point != -1) {
|
|
entry_point = imap[entry_point];
|
|
}
|
|
std::vector<int> new_levels(ntotal);
|
|
std::vector<size_t> new_offsets(ntotal + 1);
|
|
std::vector<storage_idx_t> new_neighbors(neighbors.size());
|
|
size_t no = 0;
|
|
for (int i = 0; i < ntotal; i++) {
|
|
storage_idx_t o = map[i]; // corresponding "old" index
|
|
new_levels[i] = levels[o];
|
|
for (size_t j = offsets[o]; j < offsets[o + 1]; j++) {
|
|
storage_idx_t neigh = neighbors[j];
|
|
new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh;
|
|
}
|
|
new_offsets[i + 1] = no;
|
|
}
|
|
assert(new_offsets[ntotal] == offsets[ntotal]);
|
|
// swap everyone
|
|
std::swap(levels, new_levels);
|
|
std::swap(offsets, new_offsets);
|
|
std::swap(neighbors, new_neighbors);
|
|
}
|
|
|
|
/**************************************************************
|
|
* MinimaxHeap
|
|
**************************************************************/
|
|
|
|
void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
|
|
if (k == n) {
|
|
if (v >= dis[0])
|
|
return;
|
|
if (ids[0] != -1) {
|
|
--nvalid;
|
|
}
|
|
faiss::heap_pop<HC>(k--, dis.data(), ids.data());
|
|
}
|
|
faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
|
|
++nvalid;
|
|
}
|
|
|
|
float HNSW::MinimaxHeap::max() const {
|
|
return dis[0];
|
|
}
|
|
|
|
int HNSW::MinimaxHeap::size() const {
|
|
return nvalid;
|
|
}
|
|
|
|
void HNSW::MinimaxHeap::clear() {
|
|
nvalid = k = 0;
|
|
}
|
|
|
|
#ifdef __AVX2__
|
|
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|
assert(k > 0);
|
|
static_assert(
|
|
std::is_same<storage_idx_t, int32_t>::value,
|
|
"This code expects storage_idx_t to be int32_t");
|
|
|
|
int32_t min_idx = -1;
|
|
float min_dis = std::numeric_limits<float>::infinity();
|
|
|
|
size_t iii = 0;
|
|
|
|
__m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
|
|
__m256 min_distances =
|
|
_mm256_set1_ps(std::numeric_limits<float>::infinity());
|
|
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
__m256i offset = _mm256_set1_epi32(8);
|
|
|
|
// The baseline version is available in non-AVX2 branch.
|
|
|
|
// The following loop tracks the rightmost index with the min distance.
|
|
// -1 index values are ignored.
|
|
const int k8 = (k / 8) * 8;
|
|
for (; iii < k8; iii += 8) {
|
|
__m256i indices =
|
|
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
|
|
__m256 distances = _mm256_loadu_ps(dis.data() + iii);
|
|
|
|
// This mask filters out -1 values among indices.
|
|
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
|
|
|
|
__m256i dmask = _mm256_castps_si256(
|
|
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
|
|
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
|
|
|
|
const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
|
|
_mm256_castsi256_ps(current_indices),
|
|
_mm256_castsi256_ps(min_indices),
|
|
finalmask));
|
|
|
|
const __m256 min_distances_new =
|
|
_mm256_blendv_ps(distances, min_distances, finalmask);
|
|
|
|
min_indices = min_indices_new;
|
|
min_distances = min_distances_new;
|
|
|
|
current_indices = _mm256_add_epi32(current_indices, offset);
|
|
}
|
|
|
|
// Vectorizing is doable, but is not practical
|
|
int32_t vidx8[8];
|
|
float vdis8[8];
|
|
_mm256_storeu_ps(vdis8, min_distances);
|
|
_mm256_storeu_si256((__m256i*)vidx8, min_indices);
|
|
|
|
for (size_t j = 0; j < 8; j++) {
|
|
if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
|
|
min_idx = vidx8[j];
|
|
min_dis = vdis8[j];
|
|
}
|
|
}
|
|
|
|
// process last values. Vectorizing is doable, but is not practical
|
|
for (; iii < k; iii++) {
|
|
if (ids[iii] != -1 && dis[iii] <= min_dis) {
|
|
min_dis = dis[iii];
|
|
min_idx = iii;
|
|
}
|
|
}
|
|
|
|
if (min_idx == -1) {
|
|
return -1;
|
|
}
|
|
|
|
if (vmin_out) {
|
|
*vmin_out = min_dis;
|
|
}
|
|
int ret = ids[min_idx];
|
|
ids[min_idx] = -1;
|
|
--nvalid;
|
|
return ret;
|
|
}
|
|
|
|
#else
|
|
|
|
// baseline non-vectorized version
|
|
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|
assert(k > 0);
|
|
// returns min. This is an O(n) operation
|
|
int i = k - 1;
|
|
while (i >= 0) {
|
|
if (ids[i] != -1) {
|
|
break;
|
|
}
|
|
i--;
|
|
}
|
|
if (i == -1) {
|
|
return -1;
|
|
}
|
|
int imin = i;
|
|
float vmin = dis[i];
|
|
i--;
|
|
while (i >= 0) {
|
|
if (ids[i] != -1 && dis[i] < vmin) {
|
|
vmin = dis[i];
|
|
imin = i;
|
|
}
|
|
i--;
|
|
}
|
|
if (vmin_out) {
|
|
*vmin_out = vmin;
|
|
}
|
|
int ret = ids[imin];
|
|
ids[imin] = -1;
|
|
--nvalid;
|
|
|
|
return ret;
|
|
}
|
|
#endif
|
|
|
|
int HNSW::MinimaxHeap::count_below(float thresh) {
|
|
int n_below = 0;
|
|
for (int i = 0; i < k; i++) {
|
|
if (dis[i] < thresh) {
|
|
n_below++;
|
|
}
|
|
}
|
|
|
|
return n_below;
|
|
}
|
|
|
|
} // namespace faiss
|