faiss/faiss/IndexHNSW.cpp
divyegala 079fd5559c Fix seg faults in CAGRA C++ unit tests (#3552)
Summary:
The issue was that `uniform_int_distribution` generates numbers in the range `[0, ntotal]` and not `[0, ntotal)`, which was an oversight on my part. This PR also attempts to reduce the tolerance for `copyTo` tests as we have seen those fail intermittently.

cc ramilbakhshyiev mdouze cjnolet

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3552

Reviewed By: junjieqi

Differential Revision: D59097786

Pulled By: ramilbakhshyiev

fbshipit-source-id: 9dac4367e25c6c219b116ed172089a2fa2a39c4f
2024-06-27 00:50:02 -07:00

980 lines
28 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/IndexHNSW.h>
#include <omp.h>
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <limits>
#include <memory>
#include <queue>
#include <random>
#include <unordered_set>
#include <sys/stat.h>
#include <sys/types.h>
#include <cstdint>
#include <faiss/Index2Layer.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/random.h>
#include <faiss/utils/sorting.h>
extern "C" {
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
int sgemm_(
const char* transa,
const char* transb,
FINTEGER* m,
FINTEGER* n,
FINTEGER* k,
const float* alpha,
const float* a,
FINTEGER* lda,
const float* b,
FINTEGER* ldb,
float* beta,
float* c,
FINTEGER* ldc);
}
namespace faiss {
using MinimaxHeap = HNSW::MinimaxHeap;
using storage_idx_t = HNSW::storage_idx_t;
using NodeDistFarther = HNSW::NodeDistFarther;
HNSWStats hnsw_stats;
/**************************************************************
* add / search blocks of descriptors
**************************************************************/
namespace {
DistanceComputer* storage_distance_computer(const Index* storage) {
if (is_similarity_metric(storage->metric_type)) {
return new NegativeDistanceComputer(storage->get_distance_computer());
} else {
return storage->get_distance_computer();
}
}
void hnsw_add_vertices(
IndexHNSW& index_hnsw,
size_t n0,
size_t n,
const float* x,
bool verbose,
bool preset_levels = false) {
size_t d = index_hnsw.d;
HNSW& hnsw = index_hnsw.hnsw;
size_t ntotal = n0 + n;
double t0 = getmillisecs();
if (verbose) {
printf("hnsw_add_vertices: adding %zd elements on top of %zd "
"(preset_levels=%d)\n",
n,
n0,
int(preset_levels));
}
if (n == 0) {
return;
}
int max_level = hnsw.prepare_level_tab(n, preset_levels);
if (verbose) {
printf(" max_level = %d\n", max_level);
}
std::vector<omp_lock_t> locks(ntotal);
for (int i = 0; i < ntotal; i++)
omp_init_lock(&locks[i]);
// add vectors from highest to lowest level
std::vector<int> hist;
std::vector<int> order(n);
{ // make buckets with vectors of the same level
// build histogram
for (int i = 0; i < n; i++) {
storage_idx_t pt_id = i + n0;
int pt_level = hnsw.levels[pt_id] - 1;
while (pt_level >= hist.size())
hist.push_back(0);
hist[pt_level]++;
}
// accumulate
std::vector<int> offsets(hist.size() + 1, 0);
for (int i = 0; i < hist.size() - 1; i++) {
offsets[i + 1] = offsets[i] + hist[i];
}
// bucket sort
for (int i = 0; i < n; i++) {
storage_idx_t pt_id = i + n0;
int pt_level = hnsw.levels[pt_id] - 1;
order[offsets[pt_level]++] = pt_id;
}
}
idx_t check_period = InterruptCallback::get_period_hint(
max_level * index_hnsw.d * hnsw.efConstruction);
{ // perform add
RandomGenerator rng2(789);
int i1 = n;
for (int pt_level = hist.size() - 1;
pt_level >= !index_hnsw.init_level0;
pt_level--) {
int i0 = i1 - hist[pt_level];
if (verbose) {
printf("Adding %d elements at level %d\n", i1 - i0, pt_level);
}
// random permutation to get rid of dataset order bias
for (int j = i0; j < i1; j++)
std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
bool interrupt = false;
#pragma omp parallel if (i1 > i0 + 100)
{
VisitedTable vt(ntotal);
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(index_hnsw.storage));
int prev_display =
verbose && omp_get_thread_num() == 0 ? 0 : -1;
size_t counter = 0;
// here we should do schedule(dynamic) but this segfaults for
// some versions of LLVM. The performance impact should not be
// too large when (i1 - i0) / num_threads >> 1
#pragma omp for schedule(static)
for (int i = i0; i < i1; i++) {
storage_idx_t pt_id = order[i];
dis->set_query(x + (pt_id - n0) * d);
// cannot break
if (interrupt) {
continue;
}
hnsw.add_with_locks(
*dis,
pt_level,
pt_id,
locks,
vt,
index_hnsw.keep_max_size_level0 && (pt_level == 0));
if (prev_display >= 0 && i - i0 > prev_display + 10000) {
prev_display = i - i0;
printf(" %d / %d\r", i - i0, i1 - i0);
fflush(stdout);
}
if (counter % check_period == 0) {
if (InterruptCallback::is_interrupted()) {
interrupt = true;
}
}
counter++;
}
}
if (interrupt) {
FAISS_THROW_MSG("computation interrupted");
}
i1 = i0;
}
if (index_hnsw.init_level0) {
FAISS_ASSERT(i1 == 0);
} else {
FAISS_ASSERT((i1 - hist[0]) == 0);
}
}
if (verbose) {
printf("Done in %.3f ms\n", getmillisecs() - t0);
}
for (int i = 0; i < ntotal; i++) {
omp_destroy_lock(&locks[i]);
}
}
} // namespace
/**************************************************************
* IndexHNSW implementation
**************************************************************/
IndexHNSW::IndexHNSW(int d, int M, MetricType metric)
: Index(d, metric), hnsw(M) {}
IndexHNSW::IndexHNSW(Index* storage, int M)
: Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {}
IndexHNSW::~IndexHNSW() {
if (own_fields) {
delete storage;
}
}
void IndexHNSW::train(idx_t n, const float* x) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
// hnsw structure does not require training
storage->train(n, x);
is_trained = true;
}
namespace {
template <class BlockResultHandler>
void hnsw_search(
const IndexHNSW* index,
idx_t n,
const float* x,
BlockResultHandler& bres,
const SearchParameters* params_in) {
FAISS_THROW_IF_NOT_MSG(
index->storage,
"No storage index, please use IndexHNSWFlat (or variants) "
"instead of IndexHNSW directly");
const SearchParametersHNSW* params = nullptr;
const HNSW& hnsw = index->hnsw;
int efSearch = hnsw.efSearch;
if (params_in) {
params = dynamic_cast<const SearchParametersHNSW*>(params_in);
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
efSearch = params->efSearch;
}
size_t n1 = 0, n2 = 0, ndis = 0;
idx_t check_period = InterruptCallback::get_period_hint(
hnsw.max_level * index->d * efSearch);
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);
#pragma omp parallel
{
VisitedTable vt(index->ntotal);
typename BlockResultHandler::SingleResultHandler res(bres);
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(index->storage));
#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided)
for (idx_t i = i0; i < i1; i++) {
res.begin(i);
dis->set_query(x + i * index->d);
HNSWStats stats = hnsw.search(*dis, res, vt, params);
n1 += stats.n1;
n2 += stats.n2;
ndis += stats.ndis;
res.end();
}
}
InterruptCallback::check();
}
hnsw_stats.combine({n1, n2, ndis});
}
} // anonymous namespace
void IndexHNSW::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params_in) const {
FAISS_THROW_IF_NOT(k > 0);
using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances, labels, k);
hnsw_search(this, n, x, bres, params_in);
if (is_similarity_metric(this->metric_type)) {
// we need to revert the negated distances
for (size_t i = 0; i < k * n; i++) {
distances[i] = -distances[i];
}
}
}
void IndexHNSW::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params) const {
using RH = RangeSearchBlockResultHandler<HNSW::C>;
RH bres(result, radius);
hnsw_search(this, n, x, bres, params);
if (is_similarity_metric(this->metric_type)) {
// we need to revert the negated distances
for (size_t i = 0; i < result->lims[result->nq]; i++) {
result->distances[i] = -result->distances[i];
}
}
}
void IndexHNSW::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT_MSG(
storage,
"Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
FAISS_THROW_IF_NOT(is_trained);
int n0 = ntotal;
storage->add(n, x);
ntotal = storage->ntotal;
hnsw_add_vertices(*this, n0, n, x, verbose, hnsw.levels.size() == ntotal);
}
void IndexHNSW::reset() {
hnsw.reset();
storage->reset();
ntotal = 0;
}
void IndexHNSW::reconstruct(idx_t key, float* recons) const {
storage->reconstruct(key, recons);
}
void IndexHNSW::shrink_level_0_neighbors(int new_size) {
#pragma omp parallel
{
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(storage));
#pragma omp for
for (idx_t i = 0; i < ntotal; i++) {
size_t begin, end;
hnsw.neighbor_range(i, 0, &begin, &end);
std::priority_queue<NodeDistFarther> initial_list;
for (size_t j = begin; j < end; j++) {
int v1 = hnsw.neighbors[j];
if (v1 < 0)
break;
initial_list.emplace(dis->symmetric_dis(i, v1), v1);
// initial_list.emplace(qdis(v1), v1);
}
std::vector<NodeDistFarther> shrunk_list;
HNSW::shrink_neighbor_list(
*dis, initial_list, shrunk_list, new_size);
for (size_t j = begin; j < end; j++) {
if (j - begin < shrunk_list.size())
hnsw.neighbors[j] = shrunk_list[j - begin].id;
else
hnsw.neighbors[j] = -1;
}
}
}
}
void IndexHNSW::search_level_0(
idx_t n,
const float* x,
idx_t k,
const storage_idx_t* nearest,
const float* nearest_d,
float* distances,
idx_t* labels,
int nprobe,
int search_type,
const SearchParameters* params_in) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(nprobe > 0);
const SearchParametersHNSW* params = nullptr;
if (params_in) {
params = dynamic_cast<const SearchParametersHNSW*>(params_in);
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
}
storage_idx_t ntotal = hnsw.levels.size();
using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances, labels, k);
#pragma omp parallel
{
std::unique_ptr<DistanceComputer> qdis(
storage_distance_computer(storage));
HNSWStats search_stats;
VisitedTable vt(ntotal);
RH::SingleResultHandler res(bres);
#pragma omp for
for (idx_t i = 0; i < n; i++) {
res.begin(i);
qdis->set_query(x + i * d);
hnsw.search_level_0(
*qdis.get(),
res,
nprobe,
nearest + i * nprobe,
nearest_d + i * nprobe,
search_type,
search_stats,
vt,
params);
res.end();
vt.advance();
}
#pragma omp critical
{ hnsw_stats.combine(search_stats); }
}
if (is_similarity_metric(this->metric_type)) {
// we need to revert the negated distances
#pragma omp parallel for
for (int64_t i = 0; i < k * n; i++) {
distances[i] = -distances[i];
}
}
}
void IndexHNSW::init_level_0_from_knngraph(
int k,
const float* D,
const idx_t* I) {
int dest_size = hnsw.nb_neighbors(0);
#pragma omp parallel for
for (idx_t i = 0; i < ntotal; i++) {
DistanceComputer* qdis = storage_distance_computer(storage);
std::vector<float> vec(d);
storage->reconstruct(i, vec.data());
qdis->set_query(vec.data());
std::priority_queue<NodeDistFarther> initial_list;
for (size_t j = 0; j < k; j++) {
int v1 = I[i * k + j];
if (v1 == i)
continue;
if (v1 < 0)
break;
initial_list.emplace(D[i * k + j], v1);
}
std::vector<NodeDistFarther> shrunk_list;
HNSW::shrink_neighbor_list(*qdis, initial_list, shrunk_list, dest_size);
size_t begin, end;
hnsw.neighbor_range(i, 0, &begin, &end);
for (size_t j = begin; j < end; j++) {
if (j - begin < shrunk_list.size())
hnsw.neighbors[j] = shrunk_list[j - begin].id;
else
hnsw.neighbors[j] = -1;
}
}
}
void IndexHNSW::init_level_0_from_entry_points(
int n,
const storage_idx_t* points,
const storage_idx_t* nearests) {
std::vector<omp_lock_t> locks(ntotal);
for (int i = 0; i < ntotal; i++)
omp_init_lock(&locks[i]);
#pragma omp parallel
{
VisitedTable vt(ntotal);
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(storage));
std::vector<float> vec(storage->d);
#pragma omp for schedule(dynamic)
for (int i = 0; i < n; i++) {
storage_idx_t pt_id = points[i];
storage_idx_t nearest = nearests[i];
storage->reconstruct(pt_id, vec.data());
dis->set_query(vec.data());
hnsw.add_links_starting_from(
*dis, pt_id, nearest, (*dis)(nearest), 0, locks.data(), vt);
if (verbose && i % 10000 == 0) {
printf(" %d / %d\r", i, n);
fflush(stdout);
}
}
}
if (verbose) {
printf("\n");
}
for (int i = 0; i < ntotal; i++)
omp_destroy_lock(&locks[i]);
}
void IndexHNSW::reorder_links() {
int M = hnsw.nb_neighbors(0);
#pragma omp parallel
{
std::vector<float> distances(M);
std::vector<size_t> order(M);
std::vector<storage_idx_t> tmp(M);
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(storage));
#pragma omp for
for (storage_idx_t i = 0; i < ntotal; i++) {
size_t begin, end;
hnsw.neighbor_range(i, 0, &begin, &end);
for (size_t j = begin; j < end; j++) {
storage_idx_t nj = hnsw.neighbors[j];
if (nj < 0) {
end = j;
break;
}
distances[j - begin] = dis->symmetric_dis(i, nj);
tmp[j - begin] = nj;
}
fvec_argsort(end - begin, distances.data(), order.data());
for (size_t j = begin; j < end; j++) {
hnsw.neighbors[j] = tmp[order[j - begin]];
}
}
}
}
void IndexHNSW::link_singletons() {
printf("search for singletons\n");
std::vector<bool> seen(ntotal);
for (size_t i = 0; i < ntotal; i++) {
size_t begin, end;
hnsw.neighbor_range(i, 0, &begin, &end);
for (size_t j = begin; j < end; j++) {
storage_idx_t ni = hnsw.neighbors[j];
if (ni >= 0)
seen[ni] = true;
}
}
int n_sing = 0, n_sing_l1 = 0;
std::vector<storage_idx_t> singletons;
for (storage_idx_t i = 0; i < ntotal; i++) {
if (!seen[i]) {
singletons.push_back(i);
n_sing++;
if (hnsw.levels[i] > 1)
n_sing_l1++;
}
}
printf(" Found %d / %" PRId64 " singletons (%d appear in a level above)\n",
n_sing,
ntotal,
n_sing_l1);
std::vector<float> recons(singletons.size() * d);
for (int i = 0; i < singletons.size(); i++) {
FAISS_ASSERT(!"not implemented");
}
}
void IndexHNSW::permute_entries(const idx_t* perm) {
auto flat_storage = dynamic_cast<IndexFlatCodes*>(storage);
FAISS_THROW_IF_NOT_MSG(
flat_storage, "don't know how to permute this index");
flat_storage->permute_entries(perm);
hnsw.permute_entries(perm);
}
/**************************************************************
* IndexHNSWFlat implementation
**************************************************************/
IndexHNSWFlat::IndexHNSWFlat() {
is_trained = true;
}
IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
: IndexHNSW(
(metric == METRIC_L2) ? new IndexFlatL2(d)
: new IndexFlat(d, metric),
M) {
own_fields = true;
is_trained = true;
}
/**************************************************************
* IndexHNSWPQ implementation
**************************************************************/
IndexHNSWPQ::IndexHNSWPQ() = default;
IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits)
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
own_fields = true;
is_trained = false;
}
void IndexHNSWPQ::train(idx_t n, const float* x) {
IndexHNSW::train(n, x);
(dynamic_cast<IndexPQ*>(storage))->pq.compute_sdc_table();
}
/**************************************************************
* IndexHNSWSQ implementation
**************************************************************/
IndexHNSWSQ::IndexHNSWSQ(
int d,
ScalarQuantizer::QuantizerType qtype,
int M,
MetricType metric)
: IndexHNSW(new IndexScalarQuantizer(d, qtype, metric), M) {
is_trained = this->storage->is_trained;
own_fields = true;
}
IndexHNSWSQ::IndexHNSWSQ() = default;
/**************************************************************
* IndexHNSW2Level implementation
**************************************************************/
IndexHNSW2Level::IndexHNSW2Level(
Index* quantizer,
size_t nlist,
int m_pq,
int M)
: IndexHNSW(new Index2Layer(quantizer, nlist, m_pq), M) {
own_fields = true;
is_trained = false;
}
IndexHNSW2Level::IndexHNSW2Level() = default;
namespace {
// same as search_from_candidates but uses v
// visno -> is in result list
// visno + 1 -> in result list + in candidates
int search_from_candidates_2(
const HNSW& hnsw,
DistanceComputer& qdis,
int k,
idx_t* I,
float* D,
MinimaxHeap& candidates,
VisitedTable& vt,
HNSWStats& stats,
int level,
int nres_in = 0) {
int nres = nres_in;
for (int i = 0; i < candidates.size(); i++) {
idx_t v1 = candidates.ids[i];
FAISS_ASSERT(v1 >= 0);
vt.visited[v1] = vt.visno + 1;
}
int nstep = 0;
while (candidates.size() > 0) {
float d0 = 0;
int v0 = candidates.pop_min(&d0);
size_t begin, end;
hnsw.neighbor_range(v0, level, &begin, &end);
for (size_t j = begin; j < end; j++) {
int v1 = hnsw.neighbors[j];
if (v1 < 0)
break;
if (vt.visited[v1] == vt.visno + 1) {
// nothing to do
} else {
float d = qdis(v1);
candidates.push(v1, d);
// never seen before --> add to heap
if (vt.visited[v1] < vt.visno) {
if (nres < k) {
faiss::maxheap_push(++nres, D, I, d, v1);
} else if (d < D[0]) {
faiss::maxheap_replace_top(nres, D, I, d, v1);
}
}
vt.visited[v1] = vt.visno + 1;
}
}
nstep++;
if (nstep > hnsw.efSearch) {
break;
}
}
stats.n1++;
if (candidates.size() == 0)
stats.n2++;
return nres;
}
} // namespace
void IndexHNSW2Level::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
if (dynamic_cast<const Index2Layer*>(storage)) {
IndexHNSW::search(n, x, k, distances, labels);
} else { // "mixed" search
size_t n1 = 0, n2 = 0, ndis = 0;
const IndexIVFPQ* index_ivfpq =
dynamic_cast<const IndexIVFPQ*>(storage);
int nprobe = index_ivfpq->nprobe;
std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]);
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
index_ivfpq->quantizer->search(
n, x, nprobe, coarse_dis.get(), coarse_assign.get());
index_ivfpq->search_preassigned(
n,
x,
k,
coarse_assign.get(),
coarse_dis.get(),
distances,
labels,
false);
#pragma omp parallel
{
VisitedTable vt(ntotal);
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(storage));
int candidates_size = hnsw.upper_beam;
MinimaxHeap candidates(candidates_size);
#pragma omp for reduction(+ : n1, n2, ndis)
for (idx_t i = 0; i < n; i++) {
idx_t* idxi = labels + i * k;
float* simi = distances + i * k;
dis->set_query(x + i * d);
// mark all inverted list elements as visited
for (int j = 0; j < nprobe; j++) {
idx_t key = coarse_assign[j + i * nprobe];
if (key < 0)
break;
size_t list_length = index_ivfpq->get_list_size(key);
const idx_t* ids = index_ivfpq->invlists->get_ids(key);
for (int jj = 0; jj < list_length; jj++) {
vt.set(ids[jj]);
}
}
candidates.clear();
for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
if (idxi[j] < 0)
break;
candidates.push(idxi[j], simi[j]);
}
// reorder from sorted to heap
maxheap_heapify(k, simi, idxi, simi, idxi, k);
HNSWStats search_stats;
search_from_candidates_2(
hnsw,
*dis,
k,
idxi,
simi,
candidates,
vt,
search_stats,
0,
k);
n1 += search_stats.n1;
n2 += search_stats.n2;
ndis += search_stats.ndis;
vt.advance();
vt.advance();
maxheap_reorder(k, simi, idxi);
}
}
hnsw_stats.combine({n1, n2, ndis});
}
}
void IndexHNSW2Level::flip_to_ivf() {
Index2Layer* storage2l = dynamic_cast<Index2Layer*>(storage);
FAISS_THROW_IF_NOT(storage2l);
IndexIVFPQ* index_ivfpq = new IndexIVFPQ(
storage2l->q1.quantizer,
d,
storage2l->q1.nlist,
storage2l->pq.M,
8);
index_ivfpq->pq = storage2l->pq;
index_ivfpq->is_trained = storage2l->is_trained;
index_ivfpq->precompute_table();
index_ivfpq->own_fields = storage2l->q1.own_fields;
storage2l->transfer_to_IVFPQ(*index_ivfpq);
index_ivfpq->make_direct_map(true);
storage = index_ivfpq;
delete storage2l;
}
/**************************************************************
* IndexHNSWCagra implementation
**************************************************************/
IndexHNSWCagra::IndexHNSWCagra() {
is_trained = true;
}
IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric)
: IndexHNSW(
(metric == METRIC_L2)
? static_cast<IndexFlat*>(new IndexFlatL2(d))
: static_cast<IndexFlat*>(new IndexFlatIP(d)),
M) {
FAISS_THROW_IF_NOT_MSG(
((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)),
"unsupported metric type for IndexHNSWCagra");
own_fields = true;
is_trained = true;
init_level0 = true;
keep_max_size_level0 = true;
}
void IndexHNSWCagra::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT_MSG(
!base_level_only,
"Cannot add vectors when base_level_only is set to True");
IndexHNSW::add(n, x);
}
void IndexHNSWCagra::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params) const {
if (!base_level_only) {
IndexHNSW::search(n, x, k, distances, labels, params);
} else {
std::vector<storage_idx_t> nearest(n);
std::vector<float> nearest_d(n);
#pragma omp for
for (idx_t i = 0; i < n; i++) {
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(this->storage));
dis->set_query(x + i * d);
nearest[i] = -1;
nearest_d[i] = std::numeric_limits<float>::max();
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<idx_t> distrib(0, this->ntotal - 1);
for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) {
auto idx = distrib(gen);
auto distance = (*dis)(idx);
if (distance < nearest_d[i]) {
nearest[i] = idx;
nearest_d[i] = distance;
}
}
FAISS_THROW_IF_NOT_MSG(
nearest[i] >= 0, "Could not find a valid entrypoint.");
}
search_level_0(
n,
x,
k,
nearest.data(),
nearest_d.data(),
distances,
labels,
1, // n_probes
1, // search_type
params);
}
}
} // namespace faiss