Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/hoss/faiss/IndexBinaryHNSW.cpp
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 #include "IndexBinaryHNSW.h"
11 
12 
13 #include <memory>
14 #include <cstdlib>
15 #include <cassert>
16 #include <cstring>
17 #include <cstdio>
18 #include <cmath>
19 #include <omp.h>
20 
21 #include <unordered_set>
22 #include <queue>
23 
24 #include <sys/types.h>
25 #include <sys/stat.h>
26 #include <unistd.h>
27 #include <stdint.h>
28 
29 #include "utils.h"
30 #include "Heap.h"
31 #include "FaissAssert.h"
32 #include "IndexBinaryFlat.h"
33 #include "hamming.h"
34 #include "AuxIndexStructures.h"
35 
36 namespace faiss {
37 
38 
39 /**************************************************************
40  * add / search blocks of descriptors
41  **************************************************************/
42 
43 namespace {
44 
45 
46 void hnsw_add_vertices(IndexBinaryHNSW& index_hnsw,
47  size_t n0,
48  size_t n, const uint8_t *x,
49  bool verbose,
50  bool preset_levels = false) {
51  HNSW& hnsw = index_hnsw.hnsw;
52  size_t ntotal = n0 + n;
53  double t0 = getmillisecs();
54  if (verbose) {
55  printf("hnsw_add_vertices: adding %ld elements on top of %ld "
56  "(preset_levels=%d)\n",
57  n, n0, int(preset_levels));
58  }
59 
60  int max_level = hnsw.prepare_level_tab(n, preset_levels);
61 
62  if (verbose) {
63  printf(" max_level = %d\n", max_level);
64  }
65 
66  std::vector<omp_lock_t> locks(ntotal);
67  for(int i = 0; i < ntotal; i++) {
68  omp_init_lock(&locks[i]);
69  }
70 
71  // add vectors from highest to lowest level
72  std::vector<int> hist;
73  std::vector<int> order(n);
74 
75  { // make buckets with vectors of the same level
76 
77  // build histogram
78  for (int i = 0; i < n; i++) {
79  HNSW::storage_idx_t pt_id = i + n0;
80  int pt_level = hnsw.levels[pt_id] - 1;
81  while (pt_level >= hist.size()) {
82  hist.push_back(0);
83  }
84  hist[pt_level] ++;
85  }
86 
87  // accumulate
88  std::vector<int> offsets(hist.size() + 1, 0);
89  for (int i = 0; i < hist.size() - 1; i++) {
90  offsets[i + 1] = offsets[i] + hist[i];
91  }
92 
93  // bucket sort
94  for (int i = 0; i < n; i++) {
95  HNSW::storage_idx_t pt_id = i + n0;
96  int pt_level = hnsw.levels[pt_id] - 1;
97  order[offsets[pt_level]++] = pt_id;
98  }
99  }
100 
101  { // perform add
102  RandomGenerator rng2(789);
103 
104  int i1 = n;
105 
106  for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
107  int i0 = i1 - hist[pt_level];
108 
109  if (verbose) {
110  printf("Adding %d elements at level %d\n",
111  i1 - i0, pt_level);
112  }
113 
114  // random permutation to get rid of dataset order bias
115  for (int j = i0; j < i1; j++) {
116  std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
117  }
118 
119 #pragma omp parallel
120  {
121  VisitedTable vt (ntotal);
122 
123  std::unique_ptr<DistanceComputer> dis(
124  index_hnsw.get_distance_computer()
125  );
126  int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
127 
128 #pragma omp for schedule(dynamic)
129  for (int i = i0; i < i1; i++) {
130  HNSW::storage_idx_t pt_id = order[i];
131  dis->set_query((float *)(x + (pt_id - n0) * index_hnsw.code_size));
132 
133  hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
134 
135  if (prev_display >= 0 && i - i0 > prev_display + 10000) {
136  prev_display = i - i0;
137  printf(" %d / %d\r", i - i0, i1 - i0);
138  fflush(stdout);
139  }
140  }
141  }
142  i1 = i0;
143  }
144  FAISS_ASSERT(i1 == 0);
145  }
146  if (verbose) {
147  printf("Done in %.3f ms\n", getmillisecs() - t0);
148  }
149 
150  for(int i = 0; i < ntotal; i++)
151  omp_destroy_lock(&locks[i]);
152 }
153 
154 
155 } // anonymous namespace
156 
157 
158 /**************************************************************
159  * IndexBinaryHNSW implementation
160  **************************************************************/
161 
162 IndexBinaryHNSW::IndexBinaryHNSW()
163 {
164  is_trained = true;
165 }
166 
167 IndexBinaryHNSW::IndexBinaryHNSW(int d, int M)
168  : IndexBinary(d),
169  hnsw(M),
170  own_fields(true),
171  storage(new IndexBinaryFlat(d))
172 {
173  is_trained = true;
174 }
175 
176 IndexBinaryHNSW::IndexBinaryHNSW(IndexBinary *storage, int M)
177  : IndexBinary(storage->d),
178  hnsw(M),
179  own_fields(false),
180  storage(storage)
181 {
182  is_trained = true;
183 }
184 
185 IndexBinaryHNSW::~IndexBinaryHNSW() {
186  if (own_fields) {
187  delete storage;
188  }
189 }
190 
191 void IndexBinaryHNSW::train(idx_t n, const uint8_t *x)
192 {
193  // hnsw structure does not require training
194  storage->train(n, x);
195  is_trained = true;
196 }
197 
198 void IndexBinaryHNSW::search(idx_t n, const uint8_t *x, idx_t k,
199  int32_t *distances, idx_t *labels) const
200 {
201 #pragma omp parallel
202  {
203  VisitedTable vt(ntotal);
204  std::unique_ptr<DistanceComputer> dis(get_distance_computer());
205 
206 #pragma omp for
207  for(idx_t i = 0; i < n; i++) {
208  idx_t *idxi = labels + i * k;
209  float *simi = (float *)(distances + i * k);
210 
211  dis->set_query((float *)(x + i * code_size));
212 
213  maxheap_heapify(k, simi, idxi);
214  hnsw.search(*dis, k, idxi, simi, vt);
215  maxheap_reorder(k, simi, idxi);
216  }
217  }
218 
219 #pragma omp parallel for
220  for (int i = 0; i < n * k; ++i) {
221  distances[i] = std::round(((float *)distances)[i]);
222  }
223 }
224 
225 
226 void IndexBinaryHNSW::add(idx_t n, const uint8_t *x)
227 {
228  FAISS_THROW_IF_NOT(is_trained);
229  int n0 = ntotal;
230  storage->add(n, x);
231  ntotal = storage->ntotal;
232 
233  hnsw_add_vertices(*this, n0, n, x, verbose,
234  hnsw.levels.size() == ntotal);
235 }
236 
238 {
239  hnsw.reset();
240  storage->reset();
241  ntotal = 0;
242 }
243 
244 void IndexBinaryHNSW::reconstruct(idx_t key, uint8_t *recons) const
245 {
246  storage->reconstruct(key, recons);
247 }
248 
249 
250 namespace {
251 
252 
253 template<class HammingComputer>
254 struct FlatHammingDis : DistanceComputer {
255  const int code_size;
256  const uint8_t *b;
257  size_t ndis;
258  HammingComputer hc;
259 
260  float operator () (idx_t i) override {
261  ndis++;
262  return hc.hamming(b + i * code_size);
263  }
264 
265  float symmetric_dis(idx_t i, idx_t j) override {
266  return HammingComputerDefault(b + j * code_size, code_size)
267  .hamming(b + i * code_size);
268  }
269 
270 
271  explicit FlatHammingDis(const IndexBinaryFlat& storage)
272  : code_size(storage.code_size),
273  b(storage.xb.data()),
274  ndis(0),
275  hc() {}
276 
277  // NOTE: Pointers are cast from float in order to reuse the floating-point
278  // DistanceComputer.
279  void set_query(const float *x) override {
280  hc.set((uint8_t *)x, code_size);
281  }
282 
283  ~FlatHammingDis() override {
284 #pragma omp critical
285  {
286  hnsw_stats.ndis += ndis;
287  }
288  }
289 };
290 
291 
292 } // namespace
293 
294 
295 DistanceComputer *IndexBinaryHNSW::get_distance_computer() const {
296  IndexBinaryFlat *flat_storage = dynamic_cast<IndexBinaryFlat *>(storage);
297 
298  FAISS_ASSERT(flat_storage != nullptr);
299 
300  switch(code_size) {
301  case 4:
302  return new FlatHammingDis<HammingComputer4>(*flat_storage);
303  case 8:
304  return new FlatHammingDis<HammingComputer8>(*flat_storage);
305  case 16:
306  return new FlatHammingDis<HammingComputer16>(*flat_storage);
307  case 20:
308  return new FlatHammingDis<HammingComputer20>(*flat_storage);
309  case 32:
310  return new FlatHammingDis<HammingComputer32>(*flat_storage);
311  case 64:
312  return new FlatHammingDis<HammingComputer64>(*flat_storage);
313  default:
314  if (code_size % 8 == 0) {
315  return new FlatHammingDis<HammingComputerM8>(*flat_storage);
316  } else if (code_size % 4 == 0) {
317  return new FlatHammingDis<HammingComputerM4>(*flat_storage);
318  }
319  }
320 
321  return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
322 }
323 
324 
325 } // namespace faiss
virtual void reset()=0
Removes all elements from the database.
bool is_trained
set if the Index does not require training, or if training is done already
Definition: IndexBinary.h:47
virtual void train(idx_t n, const uint8_t *x)
Definition: IndexBinary.cpp:19
int code_size
number of bytes per vector ( = d / 8 )
Definition: IndexBinary.h:42
void add(idx_t n, const uint8_t *x) override
Index::idx_t idx_t
all indices are this type
Definition: IndexBinary.h:37
void reconstruct(idx_t key, uint8_t *recons) const override
set implementation optimized for fast access.
Definition: HNSW.h:223
virtual void reconstruct(idx_t key, uint8_t *recons) const
Definition: IndexBinary.cpp:43
double getmillisecs()
ms elapsed since some arbitrary epoch
Definition: utils.cpp:69
void train(idx_t n, const uint8_t *x) override
Trains the storage if needed.
idx_t ntotal
total nb of indexed vectors
Definition: IndexBinary.h:43
virtual void add(idx_t n, const uint8_t *x)=0
void search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels) const override
entry point for search
int storage_idx_t
internal storage of vectors (32 bits: this is expensive)
Definition: HNSW.h:48
void reset() override
Removes all elements from the database.