Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/tmp/faiss/IndexFlat.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // -*- c++ -*-
10 
11 #include "IndexFlat.h"
12 
13 #include <cstring>
14 #include "utils.h"
15 #include "Heap.h"
16 
17 #include "FaissAssert.h"
18 
19 #include "AuxIndexStructures.h"
20 
21 
22 namespace faiss {
23 
24 IndexFlat::IndexFlat (idx_t d, MetricType metric):
25  Index(d, metric)
26 {
27 }
28 
29 
30 
31 void IndexFlat::add (idx_t n, const float *x) {
32  xb.insert(xb.end(), x, x + n * d);
33  ntotal += n;
34 }
35 
36 
38  xb.clear();
39  ntotal = 0;
40 }
41 
42 
43 void IndexFlat::search (idx_t n, const float *x, idx_t k,
44  float *distances, idx_t *labels) const
45 {
46  // we see the distances and labels as heaps
47 
48  if (metric_type == METRIC_INNER_PRODUCT) {
49  float_minheap_array_t res = {
50  size_t(n), size_t(k), labels, distances};
51  knn_inner_product (x, xb.data(), d, n, ntotal, &res);
52  } else if (metric_type == METRIC_L2) {
53  float_maxheap_array_t res = {
54  size_t(n), size_t(k), labels, distances};
55  knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
56  }
57 }
58 
59 void IndexFlat::range_search (idx_t n, const float *x, float radius,
60  RangeSearchResult *result) const
61 {
62  switch (metric_type) {
63  case METRIC_INNER_PRODUCT:
64  range_search_inner_product (x, xb.data(), d, n, ntotal,
65  radius, result);
66  break;
67  case METRIC_L2:
68  range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
69  break;
70  }
71 }
72 
73 
75  idx_t n,
76  const float *x,
77  idx_t k,
78  float *distances,
79  const idx_t *labels) const
80 {
81  switch (metric_type) {
82  case METRIC_INNER_PRODUCT:
83  fvec_inner_products_by_idx (
84  distances,
85  x, xb.data(), labels, d, n, k);
86  break;
87  case METRIC_L2:
88  fvec_L2sqr_by_idx (
89  distances,
90  x, xb.data(), labels, d, n, k);
91  break;
92  }
93 
94 }
95 
97 {
98  idx_t j = 0;
99  for (idx_t i = 0; i < ntotal; i++) {
100  if (sel.is_member (i)) {
101  // should be removed
102  } else {
103  if (i > j) {
104  memmove (&xb[d * j], &xb[d * i], sizeof(xb[0]) * d);
105  }
106  j++;
107  }
108  }
109  long nremove = ntotal - j;
110  if (nremove > 0) {
111  ntotal = j;
112  xb.resize (ntotal * d);
113  }
114  return nremove;
115 }
116 
117 
118 
119 void IndexFlat::reconstruct (idx_t key, float * recons) const
120 {
121  memcpy (recons, &(xb[key * d]), sizeof(*recons) * d);
122 }
123 
124 /***************************************************
125  * IndexFlatL2BaseShift
126  ***************************************************/
127 
128 IndexFlatL2BaseShift::IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift):
129  IndexFlatL2 (d), shift (nshift)
130 {
131  memcpy (this->shift.data(), shift, sizeof(float) * nshift);
132 }
133 
135  idx_t n,
136  const float *x,
137  idx_t k,
138  float *distances,
139  idx_t *labels) const
140 {
141  FAISS_THROW_IF_NOT (shift.size() == ntotal);
142 
143  float_maxheap_array_t res = {
144  size_t(n), size_t(k), labels, distances};
145  knn_L2sqr_base_shift (x, xb.data(), d, n, ntotal, &res, shift.data());
146 }
147 
148 
149 
150 /***************************************************
151  * IndexRefineFlat
152  ***************************************************/
153 
154 IndexRefineFlat::IndexRefineFlat (Index *base_index):
155  Index (base_index->d, base_index->metric_type),
156  refine_index (base_index->d, base_index->metric_type),
157  base_index (base_index), own_fields (false),
158  k_factor (1)
159 {
160  is_trained = base_index->is_trained;
161  FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
162  "base_index should be empty in the beginning");
163 }
164 
165 IndexRefineFlat::IndexRefineFlat () {
166  base_index = nullptr;
167  own_fields = false;
168  k_factor = 1;
169 }
170 
171 
172 void IndexRefineFlat::train (idx_t n, const float *x)
173 {
174  base_index->train (n, x);
175  is_trained = true;
176 }
177 
178 void IndexRefineFlat::add (idx_t n, const float *x) {
179  FAISS_THROW_IF_NOT (is_trained);
180  base_index->add (n, x);
181  refine_index.add (n, x);
183 }
184 
186 {
187  base_index->reset ();
188  refine_index.reset ();
189  ntotal = 0;
190 }
191 
192 namespace {
193 typedef faiss::Index::idx_t idx_t;
194 
195 template<class C>
196 static void reorder_2_heaps (
197  idx_t n,
198  idx_t k, idx_t *labels, float *distances,
199  idx_t k_base, const idx_t *base_labels, const float *base_distances)
200 {
201 #pragma omp parallel for
202  for (idx_t i = 0; i < n; i++) {
203  idx_t *idxo = labels + i * k;
204  float *diso = distances + i * k;
205  const idx_t *idxi = base_labels + i * k_base;
206  const float *disi = base_distances + i * k_base;
207 
208  heap_heapify<C> (k, diso, idxo, disi, idxi, k);
209  if (k_base != k) { // add remaining elements
210  heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
211  }
212  heap_reorder<C> (k, diso, idxo);
213  }
214 }
215 
216 
217 }
218 
219 
221  idx_t n, const float *x, idx_t k,
222  float *distances, idx_t *labels) const
223 {
224  FAISS_THROW_IF_NOT (is_trained);
225  idx_t k_base = idx_t (k * k_factor);
226  idx_t * base_labels = labels;
227  float * base_distances = distances;
228  ScopeDeleter<idx_t> del1;
229  ScopeDeleter<float> del2;
230 
231 
232  if (k != k_base) {
233  base_labels = new idx_t [n * k_base];
234  del1.set (base_labels);
235  base_distances = new float [n * k_base];
236  del2.set (base_distances);
237  }
238 
239  base_index->search (n, x, k_base, base_distances, base_labels);
240 
241  for (int i = 0; i < n * k_base; i++)
242  assert (base_labels[i] >= -1 &&
243  base_labels[i] < ntotal);
244 
245  // compute refined distances
247  n, x, k_base, base_distances, base_labels);
248 
249  // sort and store result
250  if (metric_type == METRIC_L2) {
251  typedef CMax <float, idx_t> C;
252  reorder_2_heaps<C> (
253  n, k, labels, distances,
254  k_base, base_labels, base_distances);
255 
256  } else if (metric_type == METRIC_INNER_PRODUCT) {
257  typedef CMin <float, idx_t> C;
258  reorder_2_heaps<C> (
259  n, k, labels, distances,
260  k_base, base_labels, base_distances);
261  }
262 
263 }
264 
265 
266 
267 IndexRefineFlat::~IndexRefineFlat ()
268 {
269  if (own_fields) delete base_index;
270 }
271 
272 /***************************************************
273  * IndexFlat1D
274  ***************************************************/
275 
276 
277 IndexFlat1D::IndexFlat1D (bool continuous_update):
278  IndexFlatL2 (1),
279  continuous_update (continuous_update)
280 {
281 }
282 
283 /// if not continuous_update, call this between the last add and
284 /// the first search
286 {
287  perm.resize (ntotal);
288  if (ntotal < 1000000) {
289  fvec_argsort (ntotal, xb.data(), (size_t*)perm.data());
290  } else {
291  fvec_argsort_parallel (ntotal, xb.data(), (size_t*)perm.data());
292  }
293 }
294 
295 void IndexFlat1D::add (idx_t n, const float *x)
296 {
297  IndexFlatL2::add (n, x);
298  if (continuous_update)
300 }
301 
303 {
305  perm.clear();
306 }
307 
309  idx_t n,
310  const float *x,
311  idx_t k,
312  float *distances,
313  idx_t *labels) const
314 {
315  FAISS_THROW_IF_NOT_MSG (perm.size() == ntotal,
316  "Call update_permutation before search");
317 
318 #pragma omp parallel for
319  for (idx_t i = 0; i < n; i++) {
320 
321  float q = x[i]; // query
322  float *D = distances + i * k;
323  idx_t *I = labels + i * k;
324 
325  // binary search
326  idx_t i0 = 0, i1 = ntotal;
327  idx_t wp = 0;
328 
329  if (xb[perm[i0]] > q) {
330  i1 = 0;
331  goto finish_right;
332  }
333 
334  if (xb[perm[i1 - 1]] <= q) {
335  i0 = i1 - 1;
336  goto finish_left;
337  }
338 
339  while (i0 + 1 < i1) {
340  idx_t imed = (i0 + i1) / 2;
341  if (xb[perm[imed]] <= q) i0 = imed;
342  else i1 = imed;
343  }
344 
345  // query is between xb[perm[i0]] and xb[perm[i1]]
346  // expand to nearest neighs
347 
348  while (wp < k) {
349  float xleft = xb[perm[i0]];
350  float xright = xb[perm[i1]];
351 
352  if (q - xleft < xright - q) {
353  D[wp] = q - xleft;
354  I[wp] = perm[i0];
355  i0--; wp++;
356  if (i0 < 0) { goto finish_right; }
357  } else {
358  D[wp] = xright - q;
359  I[wp] = perm[i1];
360  i1++; wp++;
361  if (i1 >= ntotal) { goto finish_left; }
362  }
363  }
364  goto done;
365 
366  finish_right:
367  // grow to the right from i1
368  while (wp < k) {
369  if (i1 < ntotal) {
370  D[wp] = xb[perm[i1]] - q;
371  I[wp] = perm[i1];
372  i1++;
373  } else {
374  D[wp] = 1.0 / 0.0;
375  I[wp] = -1;
376  }
377  wp++;
378  }
379  goto done;
380 
381  finish_left:
382  // grow to the left from i0
383  while (wp < k) {
384  if (i0 >= 0) {
385  D[wp] = q - xb[perm[i0]];
386  I[wp] = perm[i0];
387  i0--;
388  } else {
389  D[wp] = 1.0 / 0.0;
390  I[wp] = -1;
391  }
392  wp++;
393  }
394  done: ;
395  }
396 
397 }
398 
399 
400 
401 } // namespace faiss
void knn_L2sqr_base_shift(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_maxheap_array_t *res, const float *base_shift)
Definition: utils.cpp:653
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexFlat.cpp:134
void reset() override
removes all elements from the database.
Definition: IndexFlat.cpp:185
virtual void reset()=0
removes all elements from the database.
bool continuous_update
is the permutation updated continuously?
Definition: IndexFlat.h:138
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexFlat.cpp:43
void reset() override
removes all elements from the database.
Definition: IndexFlat.cpp:37
void update_permutation()
Definition: IndexFlat.cpp:285
virtual void train(idx_t n, const float *x)
Definition: Index.cpp:24
void reconstruct(idx_t key, float *recons) const override
Definition: IndexFlat.cpp:119
void add(idx_t n, const float *x) override
Definition: IndexFlat.cpp:295
long remove_ids(const IDSelector &sel) override
Definition: IndexFlat.cpp:96
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexFlat.cpp:220
Index * base_index
faster index to pre-select the vectors that should be filtered
Definition: IndexFlat.h:108
IndexFlat refine_index
storage for full vectors
Definition: IndexFlat.h:105
bool own_fields
should the base index be deallocated?
Definition: IndexFlat.h:109
int d
vector dimension
Definition: Index.h:66
void range_search(idx_t n, const float *x, float radius, RangeSearchResult *result) const override
Definition: IndexFlat.cpp:59
void train(idx_t n, const float *x) override
Definition: IndexFlat.cpp:172
virtual void add(idx_t n, const float *x)=0
long idx_t
all indices are this type
Definition: Index.h:64
void range_search_inner_product(const float *x, const float *y, size_t d, size_t nx, size_t ny, float radius, RangeSearchResult *res)
same as range_search_L2sqr for the inner product similarity
Definition: utils.cpp:947
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
void knn_inner_product(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_minheap_array_t *res)
Definition: utils.cpp:613
void add(idx_t n, const float *x) override
Definition: IndexFlat.cpp:31
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Warn: the distances returned are L1 not L2.
Definition: IndexFlat.cpp:308
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
void range_search_L2sqr(const float *x, const float *y, size_t d, size_t nx, size_t ny, float radius, RangeSearchResult *res)
Definition: utils.cpp:932
void compute_distance_subset(idx_t n, const float *x, idx_t k, float *distances, const idx_t *labels) const
Definition: IndexFlat.cpp:74
MetricType metric_type
type of metric this index uses for search
Definition: Index.h:74
void knn_L2sqr(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_maxheap_array_t *res)
Definition: utils.cpp:633
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:71
std::vector< float > xb
database vectors, size ntotal * d
Definition: IndexFlat.h:24
void reset() override
removes all elements from the database.
Definition: IndexFlat.cpp:302
std::vector< idx_t > perm
sorted database indices
Definition: IndexFlat.h:140
MetricType
Some algorithms support both an inner product version and a L2 search version.
Definition: Index.h:45
void add(idx_t n, const float *x) override
Definition: IndexFlat.cpp:178