Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/tmp/faiss/IndexBinaryHNSW.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // -*- c++ -*-
10 
11 #include "IndexBinaryHNSW.h"
12 
13 
14 #include <memory>
15 #include <cstdlib>
16 #include <cassert>
17 #include <cstring>
18 #include <cstdio>
19 #include <cmath>
20 #include <omp.h>
21 
22 #include <unordered_set>
23 #include <queue>
24 
25 #include <sys/types.h>
26 #include <sys/stat.h>
27 #include <unistd.h>
28 #include <stdint.h>
29 
30 #include "utils.h"
31 #include "Heap.h"
32 #include "FaissAssert.h"
33 #include "IndexBinaryFlat.h"
34 #include "hamming.h"
35 
36 
37 namespace faiss {
38 
39 
40 /**************************************************************
41  * add / search blocks of descriptors
42  **************************************************************/
43 
44 namespace {
45 
46 
47 void hnsw_add_vertices(IndexBinaryHNSW& index_hnsw,
48  size_t n0,
49  size_t n, const uint8_t *x,
50  bool verbose,
51  bool preset_levels = false) {
52  HNSW& hnsw = index_hnsw.hnsw;
53  size_t ntotal = n0 + n;
54  double t0 = getmillisecs();
55  if (verbose) {
56  printf("hnsw_add_vertices: adding %ld elements on top of %ld "
57  "(preset_levels=%d)\n",
58  n, n0, int(preset_levels));
59  }
60 
61  int max_level = hnsw.prepare_level_tab(n, preset_levels);
62 
63  if (verbose) {
64  printf(" max_level = %d\n", max_level);
65  }
66 
67  std::vector<omp_lock_t> locks(ntotal);
68  for(int i = 0; i < ntotal; i++) {
69  omp_init_lock(&locks[i]);
70  }
71 
72  // add vectors from highest to lowest level
73  std::vector<int> hist;
74  std::vector<int> order(n);
75 
76  { // make buckets with vectors of the same level
77 
78  // build histogram
79  for (int i = 0; i < n; i++) {
80  HNSW::storage_idx_t pt_id = i + n0;
81  int pt_level = hnsw.levels[pt_id] - 1;
82  while (pt_level >= hist.size()) {
83  hist.push_back(0);
84  }
85  hist[pt_level] ++;
86  }
87 
88  // accumulate
89  std::vector<int> offsets(hist.size() + 1, 0);
90  for (int i = 0; i < hist.size() - 1; i++) {
91  offsets[i + 1] = offsets[i] + hist[i];
92  }
93 
94  // bucket sort
95  for (int i = 0; i < n; i++) {
96  HNSW::storage_idx_t pt_id = i + n0;
97  int pt_level = hnsw.levels[pt_id] - 1;
98  order[offsets[pt_level]++] = pt_id;
99  }
100  }
101 
102  { // perform add
103  RandomGenerator rng2(789);
104 
105  int i1 = n;
106 
107  for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
108  int i0 = i1 - hist[pt_level];
109 
110  if (verbose) {
111  printf("Adding %d elements at level %d\n",
112  i1 - i0, pt_level);
113  }
114 
115  // random permutation to get rid of dataset order bias
116  for (int j = i0; j < i1; j++) {
117  std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
118  }
119 
120 #pragma omp parallel
121  {
122  VisitedTable vt (ntotal);
123 
124  std::unique_ptr<HNSW::DistanceComputer> dis(
125  index_hnsw.get_distance_computer()
126  );
127  int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
128 
129 #pragma omp for schedule(dynamic)
130  for (int i = i0; i < i1; i++) {
131  HNSW::storage_idx_t pt_id = order[i];
132  dis->set_query((float *)(x + (pt_id - n0) * index_hnsw.code_size));
133 
134  hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
135 
136  if (prev_display >= 0 && i - i0 > prev_display + 10000) {
137  prev_display = i - i0;
138  printf(" %d / %d\r", i - i0, i1 - i0);
139  fflush(stdout);
140  }
141  }
142  }
143  i1 = i0;
144  }
145  FAISS_ASSERT(i1 == 0);
146  }
147  if (verbose) {
148  printf("Done in %.3f ms\n", getmillisecs() - t0);
149  }
150 
151  for(int i = 0; i < ntotal; i++)
152  omp_destroy_lock(&locks[i]);
153 }
154 
155 
156 } // anonymous namespace
157 
158 
159 /**************************************************************
160  * IndexBinaryHNSW implementation
161  **************************************************************/
162 
163 IndexBinaryHNSW::IndexBinaryHNSW()
164 {
165  is_trained = true;
166 }
167 
168 IndexBinaryHNSW::IndexBinaryHNSW(int d, int M)
169  : IndexBinary(d),
170  hnsw(M),
171  own_fields(true),
172  storage(new IndexBinaryFlat(d))
173 {
174  is_trained = true;
175 }
176 
177 IndexBinaryHNSW::IndexBinaryHNSW(IndexBinary *storage, int M)
178  : IndexBinary(storage->d),
179  hnsw(M),
180  own_fields(false),
181  storage(storage)
182 {
183  is_trained = true;
184 }
185 
186 IndexBinaryHNSW::~IndexBinaryHNSW() {
187  if (own_fields) {
188  delete storage;
189  }
190 }
191 
192 void IndexBinaryHNSW::train(idx_t n, const uint8_t *x)
193 {
194  // hnsw structure does not require training
195  storage->train(n, x);
196  is_trained = true;
197 }
198 
199 void IndexBinaryHNSW::search(idx_t n, const uint8_t *x, idx_t k,
200  int32_t *distances, idx_t *labels) const
201 {
202 #pragma omp parallel
203  {
204  VisitedTable vt(ntotal);
205  std::unique_ptr<HNSW::DistanceComputer> dis(get_distance_computer());
206 
207 #pragma omp for
208  for(idx_t i = 0; i < n; i++) {
209  idx_t *idxi = labels + i * k;
210  float *simi = (float *)(distances + i * k);
211 
212  dis->set_query((float *)(x + i * code_size));
213 
214  maxheap_heapify(k, simi, idxi);
215  hnsw.search(*dis, k, idxi, simi, vt);
216  maxheap_reorder(k, simi, idxi);
217  }
218  }
219 
220 #pragma omp parallel for
221  for (int i = 0; i < n * k; ++i) {
222  distances[i] = std::round(((float *)distances)[i]);
223  }
224 }
225 
226 
227 void IndexBinaryHNSW::add(idx_t n, const uint8_t *x)
228 {
229  FAISS_THROW_IF_NOT(is_trained);
230  int n0 = ntotal;
231  storage->add(n, x);
232  ntotal = storage->ntotal;
233 
234  hnsw_add_vertices(*this, n0, n, x, verbose,
235  hnsw.levels.size() == ntotal);
236 }
237 
239 {
240  hnsw.reset();
241  storage->reset();
242  ntotal = 0;
243 }
244 
245 void IndexBinaryHNSW::reconstruct(idx_t key, uint8_t *recons) const
246 {
247  storage->reconstruct(key, recons);
248 }
249 
250 
251 namespace {
252 
253 
254 template<class HammingComputer>
255 struct FlatHammingDis : HNSW::DistanceComputer {
256  const int code_size;
257  const uint8_t *b;
258  size_t ndis;
259  HammingComputer hc;
260 
261  float operator () (HNSW::storage_idx_t i) override {
262  ndis++;
263  return hc.hamming(b + i * code_size);
264  }
265 
266  float symmetric_dis(HNSW::storage_idx_t i, HNSW::storage_idx_t j) override {
267  return HammingComputerDefault(b + j * code_size, code_size)
268  .hamming(b + i * code_size);
269  }
270 
271 
272  explicit FlatHammingDis(const IndexBinaryFlat& storage)
273  : code_size(storage.code_size),
274  b(storage.xb.data()),
275  ndis(0),
276  hc() {}
277 
278  // NOTE: Pointers are cast from float in order to reuse the floating-point
279  // DistanceComputer.
280  void set_query(const float *x) override {
281  hc.set((uint8_t *)x, code_size);
282  }
283 
284  virtual ~FlatHammingDis() {
285 #pragma omp critical
286  {
287  hnsw_stats.ndis += ndis;
288  }
289  }
290 };
291 
292 
293 } // namespace
294 
295 
296 HNSW::DistanceComputer *IndexBinaryHNSW::get_distance_computer() const {
297  IndexBinaryFlat *flat_storage = dynamic_cast<IndexBinaryFlat *>(storage);
298 
299  FAISS_ASSERT(flat_storage != nullptr);
300 
301  switch(code_size) {
302  case 4:
303  return new FlatHammingDis<HammingComputer4>(*flat_storage);
304  case 8:
305  return new FlatHammingDis<HammingComputer8>(*flat_storage);
306  case 16:
307  return new FlatHammingDis<HammingComputer16>(*flat_storage);
308  case 20:
309  return new FlatHammingDis<HammingComputer20>(*flat_storage);
310  case 32:
311  return new FlatHammingDis<HammingComputer32>(*flat_storage);
312  case 64:
313  return new FlatHammingDis<HammingComputer64>(*flat_storage);
314  default:
315  if (code_size % 8 == 0) {
316  return new FlatHammingDis<HammingComputerM8>(*flat_storage);
317  } else if (code_size % 4 == 0) {
318  return new FlatHammingDis<HammingComputerM4>(*flat_storage);
319  }
320  }
321 
322  return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
323 }
324 
325 
326 } // namespace faiss
virtual void reset()=0
Removes all elements from the database.
bool is_trained
set if the Index does not require training, or if training is done already
Definition: IndexBinary.h:46
virtual void train(idx_t n, const uint8_t *x)
Definition: IndexBinary.cpp:20
int code_size
number of bytes per vector ( = d / 8 )
Definition: IndexBinary.h:41
void add(idx_t n, const uint8_t *x) override
void reconstruct(idx_t key, uint8_t *recons) const override
set implementation optimized for fast access.
Definition: HNSW.h:235
virtual void reconstruct(idx_t key, uint8_t *recons) const
Definition: IndexBinary.cpp:44
double getmillisecs()
ms elapsed since some arbitrary epoch
Definition: utils.cpp:70
void train(idx_t n, const uint8_t *x) override
Trains the storage if needed.
idx_t ntotal
total nb of indexed vectors
Definition: IndexBinary.h:42
long idx_t
all indices are this type
Definition: IndexBinary.h:38
virtual void add(idx_t n, const uint8_t *x)=0
void search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels) const override
entry point for search
int storage_idx_t
internal storage of vectors (32 bits: this is expensive)
Definition: HNSW.h:49
void reset() override
Removes all elements from the database.