faiss/faiss/IndexBinaryIVF.cpp
Andres Suarez ab2b7f5093 Apply clang-format 18
Summary: Previously this code conformed from clang-format 12.

Reviewed By: igorsugak

Differential Revision: D56065247

fbshipit-source-id: f5a985dd8f8b84f2f9e1818b3719b43c5a1b05b3
2024-04-14 11:28:32 -07:00

905 lines
28 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexBinaryIVF.h>
#include <omp.h>
#include <cinttypes>
#include <cstdio>
#include <algorithm>
#include <memory>
#include <faiss/IndexFlat.h>
#include <faiss/IndexLSH.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/sorting.h>
#include <faiss/utils/utils.h>
namespace faiss {
IndexBinaryIVF::IndexBinaryIVF(IndexBinary* quantizer, size_t d, size_t nlist)
: IndexBinary(d),
invlists(new ArrayInvertedLists(nlist, code_size)),
quantizer(quantizer),
nlist(nlist) {
FAISS_THROW_IF_NOT(d == quantizer->d);
is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
cp.niter = 10;
}
IndexBinaryIVF::IndexBinaryIVF() = default;
void IndexBinaryIVF::add(idx_t n, const uint8_t* x) {
add_with_ids(n, x, nullptr);
}
void IndexBinaryIVF::add_with_ids(
idx_t n,
const uint8_t* x,
const idx_t* xids) {
add_core(n, x, xids, nullptr);
}
void IndexBinaryIVF::add_core(
idx_t n,
const uint8_t* x,
const idx_t* xids,
const idx_t* precomputed_idx) {
FAISS_THROW_IF_NOT(is_trained);
assert(invlists);
direct_map.check_can_add(xids);
const idx_t* idx;
std::unique_ptr<idx_t[]> scoped_idx;
if (precomputed_idx) {
idx = precomputed_idx;
} else {
scoped_idx.reset(new idx_t[n]);
quantizer->assign(n, x, scoped_idx.get());
idx = scoped_idx.get();
}
idx_t n_add = 0;
for (size_t i = 0; i < n; i++) {
idx_t id = xids ? xids[i] : ntotal + i;
idx_t list_no = idx[i];
if (list_no < 0) {
direct_map.add_single_id(id, -1, 0);
} else {
const uint8_t* xi = x + i * code_size;
size_t offset = invlists->add_entry(list_no, id, xi);
direct_map.add_single_id(id, list_no, offset);
}
n_add++;
}
if (verbose) {
printf("IndexBinaryIVF::add_with_ids: added "
"%" PRId64 " / %" PRId64 " vectors\n",
n_add,
n);
}
ntotal += n_add;
}
void IndexBinaryIVF::make_direct_map(bool b) {
if (b) {
direct_map.set_type(DirectMap::Array, invlists, ntotal);
} else {
direct_map.set_type(DirectMap::NoMap, invlists, ntotal);
}
}
void IndexBinaryIVF::set_direct_map_type(DirectMap::Type type) {
direct_map.set_type(type, invlists, ntotal);
}
void IndexBinaryIVF::search(
idx_t n,
const uint8_t* x,
idx_t k,
int32_t* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(nprobe > 0);
const size_t nprobe_2 = std::min(nlist, this->nprobe);
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
double t0 = getmillisecs();
quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
indexIVF_stats.quantization_time += getmillisecs() - t0;
t0 = getmillisecs();
invlists->prefetch_lists(idx.get(), n * nprobe_2);
search_preassigned(
n, x, k, idx.get(), coarse_dis.get(), distances, labels, false);
indexIVF_stats.search_time += getmillisecs() - t0;
}
void IndexBinaryIVF::reconstruct(idx_t key, uint8_t* recons) const {
idx_t lo = direct_map.get(key);
reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
}
void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
for (idx_t list_no = 0; list_no < nlist; list_no++) {
size_t list_size = invlists->list_size(list_no);
const idx_t* idlist = invlists->get_ids(list_no);
for (idx_t offset = 0; offset < list_size; offset++) {
idx_t id = idlist[offset];
if (!(id >= i0 && id < i0 + ni)) {
continue;
}
uint8_t* reconstructed = recons + (id - i0) * d;
reconstruct_from_offset(list_no, offset, reconstructed);
}
}
}
void IndexBinaryIVF::search_and_reconstruct(
idx_t n,
const uint8_t* __restrict x,
idx_t k,
int32_t* __restrict distances,
idx_t* __restrict labels,
uint8_t* __restrict recons,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
const size_t nprobe_2 = std::min(nlist, this->nprobe);
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(nprobe_2 > 0);
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
invlists->prefetch_lists(idx.get(), n * nprobe_2);
// search_preassigned() with `store_pairs` enabled to obtain the list_no
// and offset into `codes` for reconstruction
search_preassigned(
n,
x,
k,
idx.get(),
coarse_dis.get(),
distances,
labels,
/* store_pairs */ true);
for (idx_t i = 0; i < n; ++i) {
for (idx_t j = 0; j < k; ++j) {
idx_t ij = i * k + j;
idx_t key = labels[ij];
uint8_t* reconstructed = recons + ij * d;
if (key < 0) {
// Fill with NaNs
memset(reconstructed, -1, sizeof(*reconstructed) * d);
} else {
int list_no = key >> 32;
int offset = key & 0xffffffff;
// Update label to the actual id
labels[ij] = invlists->get_single_id(list_no, offset);
reconstruct_from_offset(list_no, offset, reconstructed);
}
}
}
}
void IndexBinaryIVF::reconstruct_from_offset(
idx_t list_no,
idx_t offset,
uint8_t* recons) const {
memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
}
void IndexBinaryIVF::reset() {
direct_map.clear();
invlists->reset();
ntotal = 0;
}
size_t IndexBinaryIVF::remove_ids(const IDSelector& sel) {
size_t nremove = direct_map.remove_ids(sel, invlists);
ntotal -= nremove;
return nremove;
}
void IndexBinaryIVF::train(idx_t n, const uint8_t* x) {
if (verbose) {
printf("Training quantizer\n");
}
if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
if (verbose) {
printf("IVF quantizer does not need training.\n");
}
} else {
if (verbose) {
printf("Training quantizer on %" PRId64 " vectors in %dD\n", n, d);
}
Clustering clus(d, nlist, cp);
quantizer->reset();
IndexFlatL2 index_tmp(d);
if (clustering_index && verbose) {
printf("using clustering_index of dimension %d to do the clustering\n",
clustering_index->d);
}
// LSH codec that is able to convert the binary vectors to floats.
IndexLSH codec(d, d, false, false);
clus.train_encoded(
n, x, &codec, clustering_index ? *clustering_index : index_tmp);
// convert clusters to binary
std::unique_ptr<uint8_t[]> x_b(new uint8_t[clus.k * code_size]);
real_to_binary(d * clus.k, clus.centroids.data(), x_b.get());
quantizer->add(clus.k, x_b.get());
quantizer->is_trained = true;
}
is_trained = true;
}
void IndexBinaryIVF::check_compatible_for_merge(
const IndexBinary& otherIndex) const {
auto other = dynamic_cast<const IndexBinaryIVF*>(&otherIndex);
FAISS_THROW_IF_NOT(other);
FAISS_THROW_IF_NOT(other->d == d);
FAISS_THROW_IF_NOT(other->nlist == nlist);
FAISS_THROW_IF_NOT(other->code_size == code_size);
FAISS_THROW_IF_NOT_MSG(
direct_map.no() && other->direct_map.no(),
"direct map copy not implemented");
FAISS_THROW_IF_NOT_MSG(
typeid(*this) == typeid(other),
"can only merge indexes of the same type");
}
void IndexBinaryIVF::merge_from(IndexBinary& otherIndex, idx_t add_id) {
// minimal sanity checks
check_compatible_for_merge(otherIndex);
auto other = static_cast<IndexBinaryIVF*>(&otherIndex);
invlists->merge_from(other->invlists, add_id);
ntotal += other->ntotal;
other->ntotal = 0;
}
void IndexBinaryIVF::replace_invlists(InvertedLists* il, bool own) {
FAISS_THROW_IF_NOT(il->nlist == nlist && il->code_size == code_size);
if (own_invlists) {
delete invlists;
}
invlists = il;
own_invlists = own;
}
namespace {
template <class HammingComputer>
struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
HammingComputer hc;
size_t code_size;
bool store_pairs;
IVFBinaryScannerL2(size_t code_size, bool store_pairs)
: code_size(code_size), store_pairs(store_pairs) {}
void set_query(const uint8_t* query_vector) override {
hc.set(query_vector, code_size);
}
idx_t list_no;
void set_list(idx_t list_no_2, uint8_t /* coarse_dis */) override {
this->list_no = list_no_2;
}
uint32_t distance_to_code(const uint8_t* code) const override {
return hc.hamming(code);
}
size_t scan_codes(
size_t n,
const uint8_t* __restrict codes,
const idx_t* __restrict ids,
int32_t* __restrict simi,
idx_t* __restrict idxi,
size_t k) const override {
using C = CMax<int32_t, idx_t>;
size_t nup = 0;
for (size_t j = 0; j < n; j++) {
uint32_t dis = hc.hamming(codes);
if (dis < simi[0]) {
idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
heap_replace_top<C>(k, simi, idxi, dis, id);
nup++;
}
codes += code_size;
}
return nup;
}
void scan_codes_range(
size_t n,
const uint8_t* __restrict codes,
const idx_t* __restrict ids,
int radius,
RangeQueryResult& result) const override {
for (size_t j = 0; j < n; j++) {
uint32_t dis = hc.hamming(codes);
if (dis < radius) {
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
result.add(dis, id);
}
codes += code_size;
}
}
};
void search_knn_hamming_heap(
const IndexBinaryIVF* ivf,
size_t n,
const uint8_t* __restrict x,
idx_t k,
const idx_t* __restrict keys,
const int32_t* __restrict coarse_dis,
int32_t* __restrict distances,
idx_t* __restrict labels,
bool store_pairs,
const IVFSearchParameters* params) {
idx_t nprobe = params ? params->nprobe : ivf->nprobe;
nprobe = std::min((idx_t)ivf->nlist, nprobe);
idx_t max_codes = params ? params->max_codes : ivf->max_codes;
MetricType metric_type = ivf->metric_type;
// almost verbatim copy from IndexIVF::search_preassigned
size_t nlistv = 0, ndis = 0, nheap = 0;
using HeapForIP = CMin<int32_t, idx_t>;
using HeapForL2 = CMax<int32_t, idx_t>;
#pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap)
{
std::unique_ptr<BinaryInvertedListScanner> scanner(
ivf->get_InvertedListScanner(store_pairs));
#pragma omp for
for (idx_t i = 0; i < n; i++) {
const uint8_t* xi = x + i * ivf->code_size;
scanner->set_query(xi);
const idx_t* keysi = keys + i * nprobe;
int32_t* simi = distances + k * i;
idx_t* idxi = labels + k * i;
if (metric_type == METRIC_INNER_PRODUCT) {
heap_heapify<HeapForIP>(k, simi, idxi);
} else {
heap_heapify<HeapForL2>(k, simi, idxi);
}
size_t nscan = 0;
for (size_t ik = 0; ik < nprobe; ik++) {
idx_t key = keysi[ik]; /* select the list */
if (key < 0) {
// not enough centroids for multiprobe
continue;
}
FAISS_THROW_IF_NOT_FMT(
key < (idx_t)ivf->nlist,
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
key,
ik,
ivf->nlist);
scanner->set_list(key, coarse_dis[i * nprobe + ik]);
nlistv++;
size_t list_size = ivf->invlists->list_size(key);
InvertedLists::ScopedCodes scodes(ivf->invlists, key);
std::unique_ptr<InvertedLists::ScopedIds> sids;
const idx_t* ids = nullptr;
if (!store_pairs) {
sids = std::make_unique<InvertedLists::ScopedIds>(
ivf->invlists, key);
ids = sids->get();
}
nheap += scanner->scan_codes(
list_size, scodes.get(), ids, simi, idxi, k);
nscan += list_size;
if (max_codes && nscan >= max_codes)
break;
}
ndis += nscan;
if (metric_type == METRIC_INNER_PRODUCT) {
heap_reorder<HeapForIP>(k, simi, idxi);
} else {
heap_reorder<HeapForL2>(k, simi, idxi);
}
} // parallel for
} // parallel
indexIVF_stats.nq += n;
indexIVF_stats.nlist += nlistv;
indexIVF_stats.ndis += ndis;
indexIVF_stats.nheap_updates += nheap;
}
template <class HammingComputer, bool store_pairs>
void search_knn_hamming_count(
const IndexBinaryIVF* ivf,
size_t nx,
const uint8_t* __restrict x,
const idx_t* __restrict keys,
int k,
int32_t* __restrict distances,
idx_t* __restrict labels,
const IVFSearchParameters* params) {
const int nBuckets = ivf->d + 1;
std::vector<int> all_counters(nx * nBuckets, 0);
std::unique_ptr<idx_t[]> all_ids_per_dis(new idx_t[nx * nBuckets * k]);
idx_t nprobe = params ? params->nprobe : ivf->nprobe;
nprobe = std::min((idx_t)ivf->nlist, nprobe);
idx_t max_codes = params ? params->max_codes : ivf->max_codes;
std::vector<HCounterState<HammingComputer>> cs;
for (size_t i = 0; i < nx; ++i) {
cs.push_back(HCounterState<HammingComputer>(
all_counters.data() + i * nBuckets,
all_ids_per_dis.get() + i * nBuckets * k,
x + i * ivf->code_size,
ivf->d,
k));
}
size_t nlistv = 0, ndis = 0;
#pragma omp parallel for reduction(+ : nlistv, ndis)
for (int64_t i = 0; i < nx; i++) {
const idx_t* keysi = keys + i * nprobe;
HCounterState<HammingComputer>& csi = cs[i];
size_t nscan = 0;
for (size_t ik = 0; ik < nprobe; ik++) {
idx_t key = keysi[ik]; /* select the list */
if (key < 0) {
// not enough centroids for multiprobe
continue;
}
FAISS_THROW_IF_NOT_FMT(
key < (idx_t)ivf->nlist,
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
key,
ik,
ivf->nlist);
nlistv++;
size_t list_size = ivf->invlists->list_size(key);
InvertedLists::ScopedCodes scodes(ivf->invlists, key);
const uint8_t* list_vecs = scodes.get();
const idx_t* ids =
store_pairs ? nullptr : ivf->invlists->get_ids(key);
for (size_t j = 0; j < list_size; j++) {
const uint8_t* yj = list_vecs + ivf->code_size * j;
idx_t id = store_pairs ? (key << 32 | j) : ids[j];
csi.update_counter(yj, id);
}
if (ids) {
ivf->invlists->release_ids(key, ids);
}
nscan += list_size;
if (max_codes && nscan >= max_codes)
break;
}
ndis += nscan;
int nres = 0;
for (int b = 0; b < nBuckets && nres < k; b++) {
for (int l = 0; l < csi.counters[b] && nres < k; l++) {
labels[i * k + nres] = csi.ids_per_dis[b * k + l];
distances[i * k + nres] = b;
nres++;
}
}
while (nres < k) {
labels[i * k + nres] = -1;
distances[i * k + nres] = std::numeric_limits<int32_t>::max();
++nres;
}
}
indexIVF_stats.nq += nx;
indexIVF_stats.nlist += nlistv;
indexIVF_stats.ndis += ndis;
}
/* Manages NQ queries at a time, stores results */
template <class HammingComputer, int NQ, int K>
struct BlockSearch {
HammingComputer hcs[NQ];
// heaps to update for each query
int32_t* distances[NQ];
idx_t* labels[NQ];
// curent top of heap
int32_t heap_tops[NQ];
BlockSearch(
size_t code_size,
const uint8_t* __restrict x,
const int32_t* __restrict keys,
int32_t* __restrict all_distances,
idx_t* __restrict all_labels) {
for (idx_t q = 0; q < NQ; q++) {
idx_t qno = keys[q];
hcs[q] = HammingComputer(x + qno * code_size, code_size);
distances[q] = all_distances + qno * K;
labels[q] = all_labels + qno * K;
heap_tops[q] = distances[q][0];
}
}
void add_bcode(const uint8_t* bcode, idx_t id) {
using C = CMax<int32_t, idx_t>;
for (int q = 0; q < NQ; q++) {
int dis = hcs[q].hamming(bcode);
if (dis < heap_tops[q]) {
heap_replace_top<C>(K, distances[q], labels[q], dis, id);
heap_tops[q] = distances[q][0];
}
}
}
};
template <class HammingComputer, int NQ>
struct BlockSearchVariableK {
int k;
HammingComputer hcs[NQ];
// heaps to update for each query
int32_t* distances[NQ];
idx_t* labels[NQ];
// curent top of heap
int32_t heap_tops[NQ];
BlockSearchVariableK(
size_t code_size,
int k,
const uint8_t* __restrict x,
const int32_t* __restrict keys,
int32_t* __restrict all_distances,
idx_t* __restrict all_labels)
: k(k) {
for (idx_t q = 0; q < NQ; q++) {
idx_t qno = keys[q];
hcs[q] = HammingComputer(x + qno * code_size, code_size);
distances[q] = all_distances + qno * k;
labels[q] = all_labels + qno * k;
heap_tops[q] = distances[q][0];
}
}
void add_bcode(const uint8_t* bcode, idx_t id) {
using C = CMax<int32_t, idx_t>;
for (int q = 0; q < NQ; q++) {
int dis = hcs[q].hamming(bcode);
if (dis < heap_tops[q]) {
heap_replace_top<C>(k, distances[q], labels[q], dis, id);
heap_tops[q] = distances[q][0];
}
}
}
};
template <class HammingComputer>
void search_knn_hamming_per_invlist(
const IndexBinaryIVF* ivf,
size_t n,
const uint8_t* __restrict x,
idx_t k,
const idx_t* __restrict keys_in,
const int32_t* __restrict coarse_dis,
int32_t* __restrict distances,
idx_t* __restrict labels,
bool store_pairs,
const IVFSearchParameters* params) {
idx_t nprobe = params ? params->nprobe : ivf->nprobe;
nprobe = std::min((idx_t)ivf->nlist, nprobe);
idx_t max_codes = params ? params->max_codes : ivf->max_codes;
FAISS_THROW_IF_NOT(max_codes == 0);
FAISS_THROW_IF_NOT(!store_pairs);
// reorder buckets
std::vector<int64_t> lims(n + 1);
int32_t* keys = new int32_t[n * nprobe];
std::unique_ptr<int32_t[]> delete_keys(keys);
for (idx_t i = 0; i < n * nprobe; i++) {
keys[i] = keys_in[i];
}
matrix_bucket_sort_inplace(n, nprobe, keys, ivf->nlist, lims.data(), 0);
using C = CMax<int32_t, idx_t>;
heap_heapify<C>(n * k, distances, labels);
const size_t code_size = ivf->code_size;
for (idx_t l = 0; l < ivf->nlist; l++) {
idx_t l0 = lims[l], nq = lims[l + 1] - l0;
InvertedLists::ScopedCodes scodes(ivf->invlists, l);
InvertedLists::ScopedIds sidx(ivf->invlists, l);
idx_t nb = ivf->invlists->list_size(l);
const uint8_t* bcodes = scodes.get();
const idx_t* ids = sidx.get();
idx_t i = 0;
// process as much as possible by blocks
constexpr int BS = 4;
if (k == 1) {
for (; i + BS <= nq; i += BS) {
BlockSearch<HammingComputer, BS, 1> bc(
code_size, x, keys + l0 + i, distances, labels);
for (idx_t j = 0; j < nb; j++) {
bc.add_bcode(bcodes + j * code_size, ids[j]);
}
}
} else if (k == 2) {
for (; i + BS <= nq; i += BS) {
BlockSearch<HammingComputer, BS, 2> bc(
code_size, x, keys + l0 + i, distances, labels);
for (idx_t j = 0; j < nb; j++) {
bc.add_bcode(bcodes + j * code_size, ids[j]);
}
}
} else if (k == 4) {
for (; i + BS <= nq; i += BS) {
BlockSearch<HammingComputer, BS, 4> bc(
code_size, x, keys + l0 + i, distances, labels);
for (idx_t j = 0; j < nb; j++) {
bc.add_bcode(bcodes + j * code_size, ids[j]);
}
}
} else {
for (; i + BS <= nq; i += BS) {
BlockSearchVariableK<HammingComputer, BS> bc(
code_size, k, x, keys + l0 + i, distances, labels);
for (idx_t j = 0; j < nb; j++) {
bc.add_bcode(bcodes + j * code_size, ids[j]);
}
}
}
// leftovers
for (; i < nq; i++) {
idx_t qno = keys[l0 + i];
HammingComputer hc(x + qno * code_size, code_size);
idx_t* __restrict idxi = labels + qno * k;
int32_t* __restrict simi = distances + qno * k;
int32_t simi0 = simi[0];
for (idx_t j = 0; j < nb; j++) {
int dis = hc.hamming(bcodes + j * code_size);
if (dis < simi0) {
idx_t id = store_pairs ? lo_build(l, j) : ids[j];
heap_replace_top<C>(k, simi, idxi, dis, id);
simi0 = simi[0];
}
}
}
}
for (idx_t i = 0; i < n; i++) {
heap_reorder<C>(k, distances + i * k, labels + i * k);
}
}
struct Run_search_knn_hamming_per_invlist {
using T = void;
template <class HammingComputer, class... Types>
void f(Types... args) {
search_knn_hamming_per_invlist<HammingComputer>(args...);
}
};
template <bool store_pairs>
struct Run_search_knn_hamming_count {
using T = void;
template <class HammingComputer, class... Types>
void f(Types... args) {
search_knn_hamming_count<HammingComputer, store_pairs>(args...);
}
};
struct BuildScanner {
using T = BinaryInvertedListScanner*;
template <class HammingComputer>
T f(size_t code_size, bool store_pairs) {
return new IVFBinaryScannerL2<HammingComputer>(code_size, store_pairs);
}
};
} // anonymous namespace
BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
bool store_pairs) const {
BuildScanner bs;
return dispatch_HammingComputer(code_size, bs, code_size, store_pairs);
}
void IndexBinaryIVF::search_preassigned(
idx_t n,
const uint8_t* x,
idx_t k,
const idx_t* cidx,
const int32_t* cdis,
int32_t* dis,
idx_t* idx,
bool store_pairs,
const IVFSearchParameters* params) const {
if (per_invlist_search) {
Run_search_knn_hamming_per_invlist r;
// clang-format off
dispatch_HammingComputer(
code_size, r, this, n, x, k,
cidx, cdis, dis, idx, store_pairs, params);
// clang-format on
} else if (use_heap) {
search_knn_hamming_heap(
this, n, x, k, cidx, cdis, dis, idx, store_pairs, params);
} else if (store_pairs) { // !use_heap && store_pairs
Run_search_knn_hamming_count<true> r;
dispatch_HammingComputer(
code_size, r, this, n, x, cidx, k, dis, idx, params);
} else { // !use_heap && !store_pairs
Run_search_knn_hamming_count<false> r;
dispatch_HammingComputer(
code_size, r, this, n, x, cidx, k, dis, idx, params);
}
}
void IndexBinaryIVF::range_search(
idx_t n,
const uint8_t* __restrict x,
int radius,
RangeSearchResult* __restrict res,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
const size_t nprobe_2 = std::min(nlist, this->nprobe);
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
double t0 = getmillisecs();
quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
indexIVF_stats.quantization_time += getmillisecs() - t0;
t0 = getmillisecs();
invlists->prefetch_lists(idx.get(), n * nprobe_2);
range_search_preassigned(n, x, radius, idx.get(), coarse_dis.get(), res);
indexIVF_stats.search_time += getmillisecs() - t0;
}
void IndexBinaryIVF::range_search_preassigned(
idx_t n,
const uint8_t* __restrict x,
int radius,
const idx_t* __restrict assign,
const int32_t* __restrict centroid_dis,
RangeSearchResult* __restrict res) const {
const size_t nprobe_2 = std::min(nlist, this->nprobe);
bool store_pairs = false;
size_t nlistv = 0, ndis = 0;
std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
#pragma omp parallel reduction(+ : nlistv, ndis)
{
RangeSearchPartialResult pres(res);
std::unique_ptr<BinaryInvertedListScanner> scanner(
get_InvertedListScanner(store_pairs));
FAISS_THROW_IF_NOT(scanner.get());
all_pres[omp_get_thread_num()] = &pres;
auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
idx_t key = assign[i * nprobe_2 + ik]; /* select the list */
if (key < 0)
return;
FAISS_THROW_IF_NOT_FMT(
key < (idx_t)nlist,
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
key,
ik,
nlist);
const size_t list_size = invlists->list_size(key);
if (list_size == 0)
return;
InvertedLists::ScopedCodes scodes(invlists, key);
InvertedLists::ScopedIds ids(invlists, key);
scanner->set_list(key, assign[i * nprobe_2 + ik]);
nlistv++;
ndis += list_size;
scanner->scan_codes_range(
list_size, scodes.get(), ids.get(), radius, qres);
};
#pragma omp for
for (idx_t i = 0; i < n; i++) {
scanner->set_query(x + i * code_size);
RangeQueryResult& qres = pres.new_result(i);
for (size_t ik = 0; ik < nprobe_2; ik++) {
scan_list_func(i, ik, qres);
}
}
pres.finalize();
}
indexIVF_stats.nq += n;
indexIVF_stats.nlist += nlistv;
indexIVF_stats.ndis += ndis;
}
IndexBinaryIVF::~IndexBinaryIVF() {
if (own_invlists) {
delete invlists;
}
if (own_fields) {
delete quantizer;
}
}
} // namespace faiss