Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/matthijs/github_faiss/faiss/IndexFlat.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved
10 
11 #include "IndexFlat.h"
12 
13 #include <cstring>
14 #include "utils.h"
15 #include "Heap.h"
16 
17 #include "FaissAssert.h"
18 
19 #include "AuxIndexStructures.h"
20 
21 namespace faiss {
22 
23 IndexFlat::IndexFlat (idx_t d, MetricType metric):
24  Index(d, metric)
25 {
26 }
27 
28 
29 
30 void IndexFlat::add (idx_t n, const float *x) {
31  xb.insert(xb.end(), x, x + n * d);
32  ntotal += n;
33 }
34 
35 
37  xb.clear();
38  ntotal = 0;
39 }
40 
41 
42 void IndexFlat::search (idx_t n, const float *x, idx_t k,
43  float *distances, idx_t *labels) const
44 {
45  // we see the distances and labels as heaps
46 
47  if (metric_type == METRIC_INNER_PRODUCT) {
48  float_minheap_array_t res = {
49  size_t(n), size_t(k), labels, distances};
50  knn_inner_product (x, xb.data(), d, n, ntotal, &res);
51  } else if (metric_type == METRIC_L2) {
52  float_maxheap_array_t res = {
53  size_t(n), size_t(k), labels, distances};
54  knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
55  }
56 }
57 
58 void IndexFlat::range_search (idx_t n, const float *x, float radius,
59  RangeSearchResult *result) const
60 {
61  switch (metric_type) {
62  case METRIC_INNER_PRODUCT:
63  range_search_inner_product (x, xb.data(), d, n, ntotal,
64  radius, result);
65  break;
66  case METRIC_L2:
67  range_search_L2sqr (x, xb.data(), d, n, ntotal, radius, result);
68  break;
69  }
70 }
71 
72 
74  idx_t n,
75  const float *x,
76  idx_t k,
77  float *distances,
78  const idx_t *labels) const
79 {
80  switch (metric_type) {
81  case METRIC_INNER_PRODUCT:
82  fvec_inner_products_by_idx (
83  distances,
84  x, xb.data(), labels, d, n, k);
85  break;
86  case METRIC_L2:
87  fvec_L2sqr_by_idx (
88  distances,
89  x, xb.data(), labels, d, n, k);
90  break;
91  }
92 
93 }
94 
96 {
97  idx_t j = 0;
98  for (idx_t i = 0; i < ntotal; i++) {
99  if (sel.is_member (i)) {
100  // should be removed
101  } else {
102  if (i > j) {
103  memmove (&xb[d * j], &xb[d * i], sizeof(xb[0]) * d);
104  }
105  j++;
106  }
107  }
108  long nremove = ntotal - j;
109  if (nremove > 0) {
110  ntotal = j;
111  xb.resize (ntotal * d);
112  }
113  return nremove;
114 }
115 
116 
117 
118 void IndexFlat::reconstruct (idx_t key, float * recons) const
119 {
120  memcpy (recons, &(xb[key * d]), sizeof(*recons) * d);
121 }
122 
123 /***************************************************
124  * IndexFlatL2BaseShift
125  ***************************************************/
126 
127 IndexFlatL2BaseShift::IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift):
128  IndexFlatL2 (d), shift (nshift)
129 {
130  memcpy (this->shift.data(), shift, sizeof(float) * nshift);
131 }
132 
134  idx_t n,
135  const float *x,
136  idx_t k,
137  float *distances,
138  idx_t *labels) const
139 {
140  FAISS_THROW_IF_NOT (shift.size() == ntotal);
141 
142  float_maxheap_array_t res = {
143  size_t(n), size_t(k), labels, distances};
144  knn_L2sqr_base_shift (x, xb.data(), d, n, ntotal, &res, shift.data());
145 }
146 
147 
148 
149 /***************************************************
150  * IndexRefineFlat
151  ***************************************************/
152 
153 IndexRefineFlat::IndexRefineFlat (Index *base_index):
154  Index (base_index->d, base_index->metric_type),
155  refine_index (base_index->d, base_index->metric_type),
156  base_index (base_index), own_fields (false),
157  k_factor (1)
158 {
159  is_trained = base_index->is_trained;
160  FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
161  "base_index should be empty in the beginning");
162 }
163 
164 IndexRefineFlat::IndexRefineFlat () {
165  base_index = nullptr;
166  own_fields = false;
167  k_factor = 1;
168 }
169 
170 
171 void IndexRefineFlat::train (idx_t n, const float *x)
172 {
173  base_index->train (n, x);
174  is_trained = true;
175 }
176 
177 void IndexRefineFlat::add (idx_t n, const float *x) {
178  FAISS_THROW_IF_NOT (is_trained);
179  base_index->add (n, x);
180  refine_index.add (n, x);
182 }
183 
185 {
186  base_index->reset ();
187  refine_index.reset ();
188  ntotal = 0;
189 }
190 
191 namespace {
192 typedef faiss::Index::idx_t idx_t;
193 
194 template<class C>
195 static void reorder_2_heaps (
196  idx_t n,
197  idx_t k, idx_t *labels, float *distances,
198  idx_t k_base, const idx_t *base_labels, const float *base_distances)
199 {
200 #pragma omp parallel for
201  for (idx_t i = 0; i < n; i++) {
202  idx_t *idxo = labels + i * k;
203  float *diso = distances + i * k;
204  const idx_t *idxi = base_labels + i * k_base;
205  const float *disi = base_distances + i * k_base;
206 
207  heap_heapify<C> (k, diso, idxo, disi, idxi, k);
208  if (k_base != k) { // add remaining elements
209  heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
210  }
211  heap_reorder<C> (k, diso, idxo);
212  }
213 }
214 
215 
216 }
217 
218 
220  idx_t n, const float *x, idx_t k,
221  float *distances, idx_t *labels) const
222 {
223  FAISS_THROW_IF_NOT (is_trained);
224  idx_t k_base = idx_t (k * k_factor);
225  idx_t * base_labels = labels;
226  float * base_distances = distances;
227  ScopeDeleter<idx_t> del1;
228  ScopeDeleter<float> del2;
229 
230 
231  if (k != k_base) {
232  base_labels = new idx_t [n * k_base];
233  del1.set (base_labels);
234  base_distances = new float [n * k_base];
235  del2.set (base_distances);
236  }
237 
238  base_index->search (n, x, k_base, base_distances, base_labels);
239 
240  for (int i = 0; i < n * k_base; i++)
241  assert (base_labels[i] >= -1 &&
242  base_labels[i] < ntotal);
243 
244  // compute refined distances
246  n, x, k_base, base_distances, base_labels);
247 
248  // sort and store result
249  if (metric_type == METRIC_L2) {
250  typedef CMax <float, idx_t> C;
251  reorder_2_heaps<C> (
252  n, k, labels, distances,
253  k_base, base_labels, base_distances);
254 
255  } else if (metric_type == METRIC_INNER_PRODUCT) {
256  typedef CMin <float, idx_t> C;
257  reorder_2_heaps<C> (
258  n, k, labels, distances,
259  k_base, base_labels, base_distances);
260  }
261 
262 }
263 
264 
265 
266 IndexRefineFlat::~IndexRefineFlat ()
267 {
268  if (own_fields) delete base_index;
269 }
270 
271 /***************************************************
272  * IndexFlat1D
273  ***************************************************/
274 
275 
276 IndexFlat1D::IndexFlat1D (bool continuous_update):
277  IndexFlatL2 (1),
278  continuous_update (continuous_update)
279 {
280 }
281 
282 /// if not continuous_update, call this between the last add and
283 /// the first search
285 {
286  perm.resize (ntotal);
287  if (ntotal < 1000000) {
288  fvec_argsort (ntotal, xb.data(), (size_t*)perm.data());
289  } else {
290  fvec_argsort_parallel (ntotal, xb.data(), (size_t*)perm.data());
291  }
292 }
293 
294 void IndexFlat1D::add (idx_t n, const float *x)
295 {
296  IndexFlatL2::add (n, x);
297  if (continuous_update)
299 }
300 
302 {
304  perm.clear();
305 }
306 
308  idx_t n,
309  const float *x,
310  idx_t k,
311  float *distances,
312  idx_t *labels) const
313 {
314  FAISS_THROW_IF_NOT_MSG (perm.size() == ntotal,
315  "Call update_permutation before search");
316 
317 #pragma omp parallel for
318  for (idx_t i = 0; i < n; i++) {
319 
320  float q = x[i]; // query
321  float *D = distances + i * k;
322  idx_t *I = labels + i * k;
323 
324  // binary search
325  idx_t i0 = 0, i1 = ntotal;
326  idx_t wp = 0;
327 
328  if (xb[perm[i0]] > q) {
329  i1 = 0;
330  goto finish_right;
331  }
332 
333  if (xb[perm[i1 - 1]] <= q) {
334  i0 = i1 - 1;
335  goto finish_left;
336  }
337 
338  while (i0 + 1 < i1) {
339  idx_t imed = (i0 + i1) / 2;
340  if (xb[perm[imed]] <= q) i0 = imed;
341  else i1 = imed;
342  }
343 
344  // query is between xb[perm[i0]] and xb[perm[i1]]
345  // expand to nearest neighs
346 
347  while (wp < k) {
348  float xleft = xb[perm[i0]];
349  float xright = xb[perm[i1]];
350 
351  if (q - xleft < xright - q) {
352  D[wp] = q - xleft;
353  I[wp] = perm[i0];
354  i0--; wp++;
355  if (i0 < 0) { goto finish_right; }
356  } else {
357  D[wp] = xright - q;
358  I[wp] = perm[i1];
359  i1++; wp++;
360  if (i1 >= ntotal) { goto finish_left; }
361  }
362  }
363  goto done;
364 
365  finish_right:
366  // grow to the right from i1
367  while (wp < k) {
368  if (i1 < ntotal) {
369  D[wp] = xb[perm[i1]] - q;
370  I[wp] = perm[i1];
371  i1++;
372  } else {
373  D[wp] = 1.0 / 0.0;
374  I[wp] = -1;
375  }
376  wp++;
377  }
378  goto done;
379 
380  finish_left:
381  // grow to the left from i0
382  while (wp < k) {
383  if (i0 >= 0) {
384  D[wp] = q - xb[perm[i0]];
385  I[wp] = perm[i0];
386  i0--;
387  } else {
388  D[wp] = 1.0 / 0.0;
389  I[wp] = -1;
390  }
391  wp++;
392  }
393  done: ;
394  }
395 
396 }
397 
398 
399 
400 } // namespace faiss
void knn_L2sqr_base_shift(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_maxheap_array_t *res, const float *base_shift)
Definition: utils.cpp:955
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexFlat.cpp:133
void reset() override
removes all elements from the database.
Definition: IndexFlat.cpp:184
virtual void reset()=0
removes all elements from the database.
bool continuous_update
is the permutation updated continuously?
Definition: IndexFlat.h:139
virtual void train(idx_t, const float *)
Definition: Index.h:89
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexFlat.cpp:42
void reset() override
removes all elements from the database.
Definition: IndexFlat.cpp:36
void update_permutation()
Definition: IndexFlat.cpp:284
void reconstruct(idx_t key, float *recons) const override
Definition: IndexFlat.cpp:118
void add(idx_t n, const float *x) override
Definition: IndexFlat.cpp:294
long remove_ids(const IDSelector &sel) override
Definition: IndexFlat.cpp:95
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexFlat.cpp:219
Index * base_index
faster index to pre-select the vectors that should be filtered
Definition: IndexFlat.h:109
IndexFlat refine_index
storage for full vectors
Definition: IndexFlat.h:106
bool own_fields
should the base index be deallocated?
Definition: IndexFlat.h:110
int d
vector dimension
Definition: Index.h:64
void range_search(idx_t n, const float *x, float radius, RangeSearchResult *result) const override
Definition: IndexFlat.cpp:58
void train(idx_t n, const float *x) override
Definition: IndexFlat.cpp:171
virtual void add(idx_t n, const float *x)=0
long idx_t
all indices are this type
Definition: Index.h:62
void range_search_inner_product(const float *x, const float *y, size_t d, size_t nx, size_t ny, float radius, RangeSearchResult *res)
same as range_search_L2sqr for the inner product similarity
Definition: utils.cpp:1249
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:65
void knn_inner_product(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_minheap_array_t *res)
Definition: utils.cpp:915
void add(idx_t n, const float *x) override
Definition: IndexFlat.cpp:30
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Warn: the distances returned are L1 not L2.
Definition: IndexFlat.cpp:307
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
void range_search_L2sqr(const float *x, const float *y, size_t d, size_t nx, size_t ny, float radius, RangeSearchResult *res)
Definition: utils.cpp:1234
void compute_distance_subset(idx_t n, const float *x, idx_t k, float *distances, const idx_t *labels) const
Definition: IndexFlat.cpp:73
MetricType metric_type
type of metric this index uses for search
Definition: Index.h:72
void knn_L2sqr(const float *x, const float *y, size_t d, size_t nx, size_t ny, float_maxheap_array_t *res)
Definition: utils.cpp:935
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:69
std::vector< float > xb
database vectors, size ntotal * d
Definition: IndexFlat.h:25
void reset() override
removes all elements from the database.
Definition: IndexFlat.cpp:301
std::vector< idx_t > perm
sorted database indices
Definition: IndexFlat.h:141
MetricType
Some algorithms support both an inner product vetsion and a L2 search version.
Definition: Index.h:43
void add(idx_t n, const float *x) override
Definition: IndexFlat.cpp:177