Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
IndexFlat.cpp
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved
11 
12 #include "IndexFlat.h"
13 
14 #include <cstring>
15 #include "utils.h"
16 #include "Heap.h"
17 
18 #include "FaissAssert.h"
19 
20 namespace faiss {
21 
22 IndexFlat::IndexFlat (idx_t d, MetricType metric):
23  Index(d, metric)
24 {
25  set_typename();
26 }
27 
28 
29 void IndexFlat::set_typename()
30 {
31  std::stringstream s;
32  if (metric_type == METRIC_INNER_PRODUCT)
33  s << "IP";
34  else if (metric_type == METRIC_L2)
35  s << "L2";
36  else s << "??";
37  index_typename = s.str();
38 }
39 
40 
41 void IndexFlat::add (idx_t n, const float *x) {
42  for (idx_t i = 0; i < n * d; i++)
43  xb.push_back (x[i]);
44  ntotal += n;
45 }
46 
47 
49  xb.clear();
50  ntotal = 0;
51 }
52 
53 
54 void IndexFlat::search (idx_t n, const float *x, idx_t k,
55  float *distances, idx_t *labels) const
56 {
57  // we see the distances and labels as heaps
58 
59  if (metric_type == METRIC_INNER_PRODUCT) {
60  float_minheap_array_t res = {
61  size_t(n), size_t(k), labels, distances};
62  knn_inner_product (x, xb.data(), d, n, ntotal, &res);
63  } else if (metric_type == METRIC_L2) {
64  float_maxheap_array_t res = {
65  size_t(n), size_t(k), labels, distances};
66  knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
67  }
68 }
69 
70 void IndexFlat::range_search (idx_t n, const float *x, float radius,
71  RangeSearchResult *result) const
72 {
73  switch (metric_type) {
74  case METRIC_INNER_PRODUCT:
75  range_search_inner_product (x, xb.data(), d, n, ntotal,
76  radius, result);
77  break;
78  case METRIC_L2:
79  range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
80  break;
81  }
82 }
83 
84 
86  idx_t n,
87  const float *x,
88  idx_t k,
89  float *distances,
90  const idx_t *labels) const
91 {
92  switch (metric_type) {
93  case METRIC_INNER_PRODUCT:
94  fvec_inner_products_by_idx (
95  distances,
96  x, xb.data(), labels, d, n, k);
97  break;
98  case METRIC_L2:
99  fvec_L2sqr_by_idx (
100  distances,
101  x, xb.data(), labels, d, n, k);
102  break;
103  }
104 
105 }
106 
107 
108 
109 void IndexFlat::reconstruct (idx_t key, float * recons) const
110 {
111  memcpy (recons, &(xb[key * d]), sizeof(*recons) * d);
112 }
113 
114 /***************************************************
115  * IndexFlatL2BaseShift
116  ***************************************************/
117 
118 IndexFlatL2BaseShift::IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift):
119  IndexFlatL2 (d), shift (nshift)
120 {
121  memcpy (this->shift.data(), shift, sizeof(float) * nshift);
122 }
123 
125  idx_t n,
126  const float *x,
127  idx_t k,
128  float *distances,
129  idx_t *labels) const
130 {
131  FAISS_ASSERT(shift.size() == ntotal);
132 
133  float_maxheap_array_t res = {
134  size_t(n), size_t(k), labels, distances};
135  knn_L2sqr_base_shift (x, xb.data(), d, n, ntotal, &res, shift.data());
136 }
137 
138 
139 
140 /***************************************************
141  * IndexRefineFlat
142  ***************************************************/
143 
144 IndexRefineFlat::IndexRefineFlat (Index *base_index):
145  Index (base_index->d, base_index->metric_type),
146  refine_index (base_index->d, base_index->metric_type),
147  base_index (base_index), own_fields (false),
148  k_factor (1)
149 {
150  is_trained = base_index->is_trained;
151  assert (base_index->ntotal == 0 ||
152  !"base_index should be empty in the beginning");
153  set_typename ();
154 }
155 
156 IndexRefineFlat::IndexRefineFlat () {
157  base_index = nullptr;
158  own_fields = false;
159  k_factor = 1;
160 }
161 
162 void IndexRefineFlat::set_typename ()
163 {
164  std::stringstream s;
165  s << "Refine" << '[' << base_index->get_typename()
166  << ',' << refine_index.get_typename() << ']';
167  index_typename = s.str();
168 }
169 
170 
171 void IndexRefineFlat::train (idx_t n, const float *x)
172 {
173  base_index->train (n, x);
174  is_trained = true;
175 }
176 
177 void IndexRefineFlat::add (idx_t n, const float *x) {
178  FAISS_ASSERT (is_trained);
179  base_index->add (n, x);
180  refine_index.add (n, x);
182 }
183 
185 {
186  base_index->reset ();
187  refine_index.reset ();
188  ntotal = 0;
189 }
190 
191 namespace {
192 typedef faiss::Index::idx_t idx_t;
193 
194 template<class C>
195 static void reorder_2_heaps (
196  idx_t n,
197  idx_t k, idx_t *labels, float *distances,
198  idx_t k_base, const idx_t *base_labels, const float *base_distances)
199 {
200 #pragma omp parallel for
201  for (idx_t i = 0; i < n; i++) {
202  idx_t *idxo = labels + i * k;
203  float *diso = distances + i * k;
204  const idx_t *idxi = base_labels + i * k_base;
205  const float *disi = base_distances + i * k_base;
206 
207  heap_heapify<C> (k, diso, idxo, disi, idxi, k);
208  if (k_base != k) { // add remaining elements
209  heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
210  }
211  heap_reorder<C> (k, diso, idxo);
212  }
213 }
214 
215 
216 }
217 
218 
220  idx_t n, const float *x, idx_t k,
221  float *distances, idx_t *labels) const
222 {
223  FAISS_ASSERT (is_trained);
224  idx_t k_base = idx_t (k * k_factor);
225  idx_t * base_labels = labels;
226  float * base_distances = distances;
227 
228  if (k != k_base) {
229  base_labels = new idx_t [n * k_base];
230  base_distances = new float [n * k_base];
231  }
232 
233  base_index->search (n, x, k_base, base_distances, base_labels);
234 
235  for (int i = 0; i < n * k_base; i++)
236  assert (base_labels[i] >= -1 &&
237  base_labels[i] < ntotal);
238 
239  // compute refined distances
241  n, x, k_base, base_distances, base_labels);
242 
243  // sort and store result
244  if (metric_type == METRIC_L2) {
245  typedef CMax <float, idx_t> C;
246  reorder_2_heaps<C> (
247  n, k, labels, distances,
248  k_base, base_labels, base_distances);
249 
250  } else if (metric_type == METRIC_INNER_PRODUCT) {
251  typedef CMin <float, idx_t> C;
252  reorder_2_heaps<C> (
253  n, k, labels, distances,
254  k_base, base_labels, base_distances);
255  }
256 
257  if (k != k_base) {
258  delete [] base_labels;
259  delete [] base_distances;
260  }
261 }
262 
263 
264 
265 IndexRefineFlat::~IndexRefineFlat ()
266 {
267  if (own_fields) delete base_index;
268 }
269 
270 /***************************************************
271  * IndexFlat1D
272  ***************************************************/
273 
274 
275 IndexFlat1D::IndexFlat1D (bool continuous_update):
276  IndexFlatL2 (1),
277  continuous_update (continuous_update)
278 {
279 }
280 
281 /// if not continuous_update, call this between the last add and
282 /// the first search
284 {
285  perm.resize (ntotal);
286  if (ntotal < 1000000) {
287  fvec_argsort (ntotal, xb.data(), (size_t*)perm.data());
288  } else {
289  fvec_argsort_parallel (ntotal, xb.data(), (size_t*)perm.data());
290  }
291 }
292 
293 void IndexFlat1D::add (idx_t n, const float *x)
294 {
295  IndexFlatL2::add (n, x);
296  if (continuous_update)
298 }
299 
301 {
303  perm.clear();
304 }
305 
307  idx_t n,
308  const float *x,
309  idx_t k,
310  float *distances,
311  idx_t *labels) const
312 {
313  FAISS_ASSERT (perm.size() == ntotal ||
314  !"Call update_permutation before search");
315 
316 #pragma omp parallel for
317  for (idx_t i = 0; i < n; i++) {
318 
319  float q = x[i]; // query
320  float *D = distances + i * k;
321  idx_t *I = labels + i * k;
322 
323  // binary search
324  idx_t i0 = 0, i1 = ntotal;
325  idx_t wp = 0;
326 
327  if (xb[perm[i0]] > q) {
328  i1 = 0;
329  goto finish_right;
330  }
331 
332  if (xb[perm[i1 - 1]] <= q) {
333  i0 = i1 - 1;
334  goto finish_left;
335  }
336 
337  while (i0 + 1 < i1) {
338  idx_t imed = (i0 + i1) / 2;
339  if (xb[perm[imed]] <= q) i0 = imed;
340  else i1 = imed;
341  }
342 
343  // query is between xb[perm[i0]] and xb[perm[i1]]
344  // expand to nearest neighs
345 
346  while (wp < k) {
347  float xleft = xb[perm[i0]];
348  float xright = xb[perm[i1]];
349 
350  if (q - xleft < xright - q) {
351  D[wp] = q - xleft;
352  I[wp] = perm[i0];
353  i0--; wp++;
354  if (i0 < 0) { goto finish_right; }
355  } else {
356  D[wp] = xright - q;
357  I[wp] = perm[i1];
358  i1++; wp++;
359  if (i1 >= ntotal) { goto finish_left; }
360  }
361  }
362  goto done;
363 
364  finish_right:
365  // grow to the right from i1
366  while (wp < k) {
367  if (i1 < ntotal) {
368  D[wp] = xb[perm[i1]] - q;
369  I[wp] = perm[i1];
370  i1++;
371  } else {
372  D[wp] = 1.0 / 0.0;
373  I[wp] = -1;
374  }
375  wp++;
376  }
377  goto done;
378 
379  finish_left:
380  // grow to the left from i0
381  while (wp < k) {
382  if (i0 >= 0) {
383  D[wp] = q - xb[perm[i0]];
384  I[wp] = perm[i0];
385  i0--;
386  } else {
387  D[wp] = 1.0 / 0.0;
388  I[wp] = -1;
389  }
390  wp++;
391  }
392  done: ;
393  }
394 
395 }
396 
397 
398 
399 } // namespace faiss
void knn_L2sqr_base_shift(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_maxheap_array_t *res, const float *base_shift)
Definition: utils.cpp:870
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexFlat.cpp:124
virtual void reset() override
removes all elements from the database.
Definition: IndexFlat.cpp:184
virtual void reset()=0
removes all elements from the database.
bool continuous_update
is the permutation updated continuously?
Definition: IndexFlat.h:143
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexFlat.cpp:54
virtual void reset() override
removes all elements from the database.
Definition: IndexFlat.cpp:48
void update_permutation()
Definition: IndexFlat.cpp:283
virtual void reconstruct(idx_t key, float *recons) const override
Definition: IndexFlat.cpp:109
virtual void add(idx_t n, const float *x) override
Definition: IndexFlat.cpp:293
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexFlat.cpp:219
Index * base_index
faster index to pre-select the vectors that should be filtered
Definition: IndexFlat.h:111
IndexFlat refine_index
storage for full vectors
Definition: IndexFlat.h:108
bool own_fields
should the base index be deallocated?
Definition: IndexFlat.h:112
int d
vector dimension
Definition: Index.h:66
virtual void range_search(idx_t n, const float *x, float radius, RangeSearchResult *result) const override
Definition: IndexFlat.cpp:70
virtual void train(idx_t n, const float *x) override
Definition: IndexFlat.cpp:171
virtual void add(idx_t n, const float *x)=0
long idx_t
all indices are this type
Definition: Index.h:64
void range_search_inner_product(const float *x, const float *y, size_t d, size_t nx, size_t ny, float radius, RangeSearchResult *res)
same as range_search_L2sqr for the inner product similarity
Definition: utils.cpp:1166
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
virtual std::string get_typename() const
Definition: Index.h:188
void knn_inner_product(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_minheap_array_t *res)
Definition: utils.cpp:830
virtual void add(idx_t n, const float *x) override
Definition: IndexFlat.cpp:41
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Warn: the distances returned are L1 not L2.
Definition: IndexFlat.cpp:306
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
void range_search_L2sqr(const float *x, const float *y, size_t d, size_t nx, size_t ny, float radius, RangeSearchResult *res)
Definition: utils.cpp:1151
void compute_distance_subset(idx_t n, const float *x, idx_t k, float *distances, const idx_t *labels) const
Definition: IndexFlat.cpp:85
MetricType metric_type
type of metric this index uses for search
Definition: Index.h:74
void knn_L2sqr(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_maxheap_array_t *res)
Definition: utils.cpp:850
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:71
virtual void train(idx_t n, const float *x)
Definition: Index.h:92
std::vector< float > xb
database vectors, size ntotal * d
Definition: IndexFlat.h:26
virtual void reset() override
removes all elements from the database.
Definition: IndexFlat.cpp:300
std::vector< idx_t > perm
sorted database indices
Definition: IndexFlat.h:145
MetricType
Some algorithms support both an inner product vetsion and a L2 search version.
Definition: Index.h:44
virtual void add(idx_t n, const float *x) override
Definition: IndexFlat.cpp:177