Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/hoss/faiss/IndexShards.cpp
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 #include "IndexShards.h"
11 
12 #include <cstdio>
13 #include <functional>
14 
15 #include "FaissAssert.h"
16 #include "Heap.h"
17 #include "WorkerThread.h"
18 
19 namespace faiss {
20 
21 // subroutines
22 namespace {
23 
24 typedef Index::idx_t idx_t;
25 
26 
27 // add translation to all valid labels
28 void translate_labels (long n, idx_t *labels, long translation)
29 {
30  if (translation == 0) return;
31  for (long i = 0; i < n; i++) {
32  if(labels[i] < 0) continue;
33  labels[i] += translation;
34  }
35 }
36 
37 
38 /** merge result tables from several shards.
39  * @param all_distances size nshard * n * k
40  * @param all_labels idem
41  * @param translartions label translations to apply, size nshard
42  */
43 
44 template <class IndexClass, class C>
45 void
46 merge_tables(long n, long k, long nshard,
47  typename IndexClass::distance_t *distances,
48  idx_t *labels,
49  const std::vector<typename IndexClass::distance_t>& all_distances,
50  const std::vector<idx_t>& all_labels,
51  const std::vector<long>& translations) {
52  if (k == 0) {
53  return;
54  }
55  using distance_t = typename IndexClass::distance_t;
56 
57  long stride = n * k;
58 #pragma omp parallel
59  {
60  std::vector<int> buf (2 * nshard);
61  int * pointer = buf.data();
62  int * shard_ids = pointer + nshard;
63  std::vector<distance_t> buf2 (nshard);
64  distance_t * heap_vals = buf2.data();
65 #pragma omp for
66  for (long i = 0; i < n; i++) {
67  // the heap maps values to the shard where they are
68  // produced.
69  const distance_t *D_in = all_distances.data() + i * k;
70  const idx_t *I_in = all_labels.data() + i * k;
71  int heap_size = 0;
72 
73  for (long s = 0; s < nshard; s++) {
74  pointer[s] = 0;
75  if (I_in[stride * s] >= 0) {
76  heap_push<C> (++heap_size, heap_vals, shard_ids,
77  D_in[stride * s], s);
78  }
79  }
80 
81  distance_t *D = distances + i * k;
82  idx_t *I = labels + i * k;
83 
84  for (int j = 0; j < k; j++) {
85  if (heap_size == 0) {
86  I[j] = -1;
87  D[j] = C::neutral();
88  } else {
89  // pop best element
90  int s = shard_ids[0];
91  int & p = pointer[s];
92  D[j] = heap_vals[0];
93  I[j] = I_in[stride * s + p] + translations[s];
94 
95  heap_pop<C> (heap_size--, heap_vals, shard_ids);
96  p++;
97  if (p < k && I_in[stride * s + p] >= 0) {
98  heap_push<C> (++heap_size, heap_vals, shard_ids,
99  D_in[stride * s + p], s);
100  }
101  }
102  }
103  }
104  }
105 }
106 
107 } // anonymous namespace
108 
109 template <typename IndexT>
111  bool threaded,
112  bool successive_ids)
113  : ThreadedIndex<IndexT>(d, threaded),
114  successive_ids(successive_ids) {
115 }
116 
117 template <typename IndexT>
119  bool threaded,
120  bool successive_ids)
121  : ThreadedIndex<IndexT>(d, threaded),
122  successive_ids(successive_ids) {
123 }
124 
125 template <typename IndexT>
127  bool successive_ids)
128  : ThreadedIndex<IndexT>(threaded),
129  successive_ids(successive_ids) {
130 }
131 
132 template <typename IndexT>
133 void
134 IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
135  sync_with_shard_indexes();
136 }
137 
138 template <typename IndexT>
139 void
141  sync_with_shard_indexes();
142 }
143 
144 template <typename IndexT>
145 void
147  if (!this->count()) {
148  this->is_trained = false;
149  this->ntotal = 0;
150 
151  return;
152  }
153 
154  auto firstIndex = this->at(0);
155  this->metric_type = firstIndex->metric_type;
156  this->is_trained = firstIndex->is_trained;
157  this->ntotal = firstIndex->ntotal;
158 
159  for (int i = 1; i < this->count(); ++i) {
160  auto index = this->at(i);
161  FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
162  FAISS_THROW_IF_NOT(this->d == index->d);
163 
164  this->ntotal += index->ntotal;
165  }
166 }
167 
168 template <typename IndexT>
169 void
170 IndexShardsTemplate<IndexT>::train(idx_t n,
171  const component_t *x) {
172  auto fn =
173  [n, x](int no, IndexT *index) {
174  if (index->verbose) {
175  printf("begin train shard %d on %ld points\n", no, n);
176  }
177 
178  index->train(n, x);
179 
180  if (index->verbose) {
181  printf("end train shard %d\n", no);
182  }
183  };
184 
185  this->runOnIndex(fn);
186  sync_with_shard_indexes();
187 }
188 
189 template <typename IndexT>
190 void
192  const component_t *x) {
193  add_with_ids(n, x, nullptr);
194 }
195 
196 template <typename IndexT>
197 void
199  const component_t * x,
200  const idx_t *xids) {
201 
202  FAISS_THROW_IF_NOT_MSG(!(successive_ids && xids),
203  "It makes no sense to pass in ids and "
204  "request them to be shifted");
205 
206  if (successive_ids) {
207  FAISS_THROW_IF_NOT_MSG(!xids,
208  "It makes no sense to pass in ids and "
209  "request them to be shifted");
210  FAISS_THROW_IF_NOT_MSG(this->ntotal == 0,
211  "when adding to IndexShards with sucessive_ids, "
212  "only add() in a single pass is supported");
213  }
214 
215  idx_t nshard = this->count();
216  const idx_t *ids = xids;
217 
218  std::vector<idx_t> aids;
219 
220  if (!ids && !successive_ids) {
221  aids.resize(n);
222 
223  for (idx_t i = 0; i < n; i++) {
224  aids[i] = this->ntotal + i;
225  }
226 
227  ids = aids.data();
228  }
229 
230  size_t components_per_vec =
231  sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
232 
233  auto fn =
234  [n, ids, x, nshard, components_per_vec](int no, IndexT *index) {
235  idx_t i0 = (idx_t) no * n / nshard;
236  idx_t i1 = ((idx_t) no + 1) * n / nshard;
237  auto x0 = x + i0 * components_per_vec;
238 
239  if (index->verbose) {
240  printf ("begin add shard %d on %ld points\n", no, n);
241  }
242 
243  if (ids) {
244  index->add_with_ids (i1 - i0, x0, ids + i0);
245  } else {
246  index->add (i1 - i0, x0);
247  }
248 
249  if (index->verbose) {
250  printf ("end add shard %d on %ld points\n", no, i1 - i0);
251  }
252  };
253 
254  this->runOnIndex(fn);
255 
256  // This is safe to do here because the current thread controls execution in
257  // all threads, and nothing else is happening
258  this->ntotal += n;
259 }
260 
261 template <typename IndexT>
262 void
264  const component_t *x,
265  idx_t k,
266  distance_t *distances,
267  idx_t *labels) const {
268  long nshard = this->count();
269 
270  std::vector<distance_t> all_distances(nshard * k * n);
271  std::vector<idx_t> all_labels(nshard * k * n);
272 
273  auto fn =
274  [n, k, x, &all_distances, &all_labels](int no, const IndexT *index) {
275  if (index->verbose) {
276  printf ("begin query shard %d on %ld points\n", no, n);
277  }
278 
279  index->search (n, x, k,
280  all_distances.data() + no * k * n,
281  all_labels.data() + no * k * n);
282 
283  if (index->verbose) {
284  printf ("end query shard %d\n", no);
285  }
286  };
287 
288  this->runOnIndex(fn);
289 
290  std::vector<long> translations(nshard, 0);
291 
292  // Because we just called runOnIndex above, it is safe to access the sub-index
293  // ntotal here
294  if (successive_ids) {
295  translations[0] = 0;
296 
297  for (int s = 0; s + 1 < nshard; s++) {
298  translations[s + 1] = translations[s] + this->at(s)->ntotal;
299  }
300  }
301 
302  if (this->metric_type == METRIC_L2) {
303  merge_tables<IndexT, CMin<distance_t, int>>(
304  n, k, nshard, distances, labels,
305  all_distances, all_labels, translations);
306  } else {
307  merge_tables<IndexT, CMax<distance_t, int>>(
308  n, k, nshard, distances, labels,
309  all_distances, all_labels, translations);
310  }
311 }
312 
313 // explicit instanciations
314 template struct IndexShardsTemplate<Index>;
315 template struct IndexShardsTemplate<IndexBinary>;
316 
317 } // namespace faiss
int d
vector dimension
Definition: Index.h:66
long idx_t
all indices are this type
Definition: Index.h:62
void train(idx_t n, const float *x) override
Trains the quantizer and calls train_residual to train sub-quantizers.
Definition: IndexIVF.cpp:688
void add_with_ids(idx_t n, const component_t *x, const idx_t *xids) override
void add_with_ids(idx_t n, const float *x, const idx_t *xids) override
default implementation that calls encode_vectors
Definition: IndexIVF.cpp:149
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
bool verbose
verbosity level
Definition: Index.h:68
IndexShardsTemplate(bool threaded=false, bool successive_ids=true)
void add(idx_t n, const component_t *x) override
supported only for sub-indices that implement add_with_ids
void onAfterAddIndex(IndexT *index) override
Called just after an index is added.
MetricType metric_type
type of metric this index uses for search
Definition: Index.h:74
void add(idx_t n, const float *x) override
Calls add_with_ids with NULL ids.
Definition: IndexIVF.cpp:143
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexIVF.cpp:228
void onAfterRemoveIndex(IndexT *index) override
Called just after an index is removed.