11 #include "IndexFlat.h"
17 #include "FaissAssert.h"
19 #include "AuxIndexStructures.h"
23 IndexFlat::IndexFlat (idx_t d,
MetricType metric):
31 xb.insert(
xb.end(), x, x + n *
d);
43 float *distances,
idx_t *labels)
const
49 size_t(n), size_t(k), labels, distances};
53 size_t(n), size_t(k), labels, distances};
62 case METRIC_INNER_PRODUCT:
78 const idx_t *labels)
const
81 case METRIC_INNER_PRODUCT:
82 fvec_inner_products_by_idx (
84 x,
xb.data(), labels,
d, n, k);
89 x,
xb.data(), labels,
d, n, k);
99 if (sel.is_member (i)) {
103 memmove (&
xb[
d * j], &
xb[
d * i],
sizeof(
xb[0]) *
d);
108 long nremove = ntotal - j;
111 xb.resize (ntotal *
d);
120 memcpy (recons, &(
xb[key *
d]),
sizeof(*recons) * d);
127 IndexFlatL2BaseShift::IndexFlatL2BaseShift (idx_t d,
size_t nshift,
const float *shift):
130 memcpy (this->shift.data(), shift,
sizeof(float) * nshift);
140 FAISS_THROW_IF_NOT (shift.size() ==
ntotal);
143 size_t(n), size_t(k), labels, distances};
153 IndexRefineFlat::IndexRefineFlat (
Index *base_index):
154 Index (base_index->d, base_index->metric_type),
155 refine_index (base_index->d, base_index->metric_type),
156 base_index (base_index), own_fields (false),
160 FAISS_THROW_IF_NOT_MSG (base_index->
ntotal == 0,
161 "base_index should be empty in the beginning");
164 IndexRefineFlat::IndexRefineFlat () {
195 static void reorder_2_heaps (
197 idx_t k, idx_t *labels,
float *distances,
198 idx_t k_base,
const idx_t *base_labels,
const float *base_distances)
200 #pragma omp parallel for
201 for (idx_t i = 0; i < n; i++) {
202 idx_t *idxo = labels + i * k;
203 float *diso = distances + i * k;
204 const idx_t *idxi = base_labels + i * k_base;
205 const float *disi = base_distances + i * k_base;
207 heap_heapify<C> (k, diso, idxo, disi, idxi, k);
209 heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
211 heap_reorder<C> (k, diso, idxo);
220 idx_t n,
const float *x, idx_t k,
221 float *distances, idx_t *labels)
const
225 idx_t * base_labels = labels;
226 float * base_distances = distances;
232 base_labels =
new idx_t [n * k_base];
233 del1.set (base_labels);
234 base_distances =
new float [n * k_base];
235 del2.set (base_distances);
240 for (
int i = 0; i < n * k_base; i++)
241 assert (base_labels[i] >= -1 &&
246 n, x, k_base, base_distances, base_labels);
252 n, k, labels, distances,
253 k_base, base_labels, base_distances);
258 n, k, labels, distances,
259 k_base, base_labels, base_distances);
266 IndexRefineFlat::~IndexRefineFlat ()
276 IndexFlat1D::IndexFlat1D (
bool continuous_update):
278 continuous_update (continuous_update)
290 fvec_argsort_parallel (
ntotal,
xb.data(), (
size_t*)
perm.data());
314 FAISS_THROW_IF_NOT_MSG (
perm.size() ==
ntotal,
315 "Call update_permutation before search");
317 #pragma omp parallel for
318 for (idx_t i = 0; i < n; i++) {
321 float *D = distances + i * k;
322 idx_t *I = labels + i * k;
325 idx_t i0 = 0, i1 =
ntotal;
333 if (
xb[
perm[i1 - 1]] <= q) {
338 while (i0 + 1 < i1) {
339 idx_t imed = (i0 + i1) / 2;
340 if (
xb[
perm[imed]] <= q) i0 = imed;
348 float xleft =
xb[
perm[i0]];
349 float xright =
xb[perm[i1]];
351 if (q - xleft < xright - q) {
355 if (i0 < 0) {
goto finish_right; }
360 if (i1 >=
ntotal) {
goto finish_left; }
void knn_L2sqr_base_shift(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_maxheap_array_t *res, const float *base_shift)
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
void reset() override
removes all elements from the database.
virtual void reset()=0
removes all elements from the database.
bool continuous_update
is the permutation updated continuously?
virtual void train(idx_t, const float *)
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
void reset() override
removes all elements from the database.
void update_permutation()
void reconstruct(idx_t key, float *recons) const override
void add(idx_t n, const float *x) override
long remove_ids(const IDSelector &sel) override
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Index * base_index
faster index to pre-select the vectors that should be filtered
IndexFlat refine_index
storage for full vectors
bool own_fields
should the base index be deallocated?
void range_search(idx_t n, const float *x, float radius, RangeSearchResult *result) const override
void train(idx_t n, const float *x) override
virtual void add(idx_t n, const float *x)=0
long idx_t
all indices are this type
void range_search_inner_product(const float *x, const float *y, size_t d, size_t nx, size_t ny, float radius, RangeSearchResult *res)
same as range_search_L2sqr for the inner product similarity
idx_t ntotal
total nb of indexed vectors
void knn_inner_product(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_minheap_array_t *res)
void add(idx_t n, const float *x) override
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Warn: the distances returned are L1 not L2.
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
void range_search_L2sqr(const float *x, const float *y, size_t d, size_t nx, size_t ny, float radius, RangeSearchResult *res)
void compute_distance_subset(idx_t n, const float *x, idx_t k, float *distances, const idx_t *labels) const
MetricType metric_type
type of metric this index uses for search
void knn_L2sqr(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_maxheap_array_t *res)
bool is_trained
set if the Index does not require training, or if training is done already
std::vector< float > xb
database vectors, size ntotal * d
void reset() override
removes all elements from the database.
std::vector< idx_t > perm
sorted database indices
MetricType
Some algorithms support both an inner product vetsion and a L2 search version.
void add(idx_t n, const float *x) override