Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/hoss/faiss/AuxIndexStructures.cpp
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 #include <cstring>
11 
12 #include "AuxIndexStructures.h"
13 
14 #include "FaissAssert.h"
15 
16 
17 namespace faiss {
18 
19 
20 /***********************************************************************
21  * RangeSearchResult
22  ***********************************************************************/
23 
24 RangeSearchResult::RangeSearchResult (idx_t nq, bool alloc_lims): nq (nq) {
25  if (alloc_lims) {
26  lims = new size_t [nq + 1];
27  memset (lims, 0, sizeof(*lims) * (nq + 1));
28  } else {
29  lims = nullptr;
30  }
31  labels = nullptr;
32  distances = nullptr;
33  buffer_size = 1024 * 256;
34 }
35 
36 /// called when lims contains the nb of elements result entries
37 /// for each query
39  size_t ofs = 0;
40  for (int i = 0; i < nq; i++) {
41  size_t n = lims[i];
42  lims [i] = ofs;
43  ofs += n;
44  }
45  lims [nq] = ofs;
46  labels = new idx_t [ofs];
47  distances = new float [ofs];
48 }
49 
50 RangeSearchResult::~RangeSearchResult () {
51  delete [] labels;
52  delete [] distances;
53  delete [] lims;
54 }
55 
56 
57 
58 
59 
60 /***********************************************************************
61  * BufferList
62  ***********************************************************************/
63 
64 
65 BufferList::BufferList (size_t buffer_size):
66  buffer_size (buffer_size)
67 {
68  wp = buffer_size;
69 }
70 
71 BufferList::~BufferList ()
72 {
73  for (int i = 0; i < buffers.size(); i++) {
74  delete [] buffers[i].ids;
75  delete [] buffers[i].dis;
76  }
77 }
78 
79 void BufferList::add (idx_t id, float dis) {
80  if (wp == buffer_size) { // need new buffer
81  append_buffer();
82  }
83  Buffer & buf = buffers.back();
84  buf.ids [wp] = id;
85  buf.dis [wp] = dis;
86  wp++;
87 }
88 
89 
91 {
92  Buffer buf = {new idx_t [buffer_size], new float [buffer_size]};
93  buffers.push_back (buf);
94  wp = 0;
95 }
96 
97 /// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to
98 /// tables dest_ids, dest_dis
99 void BufferList::copy_range (size_t ofs, size_t n,
100  idx_t * dest_ids, float *dest_dis)
101 {
102  size_t bno = ofs / buffer_size;
103  ofs -= bno * buffer_size;
104  while (n > 0) {
105  size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs;
106  Buffer buf = buffers [bno];
107  memcpy (dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids));
108  memcpy (dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis));
109  dest_ids += ncopy;
110  dest_dis += ncopy;
111  ofs = 0;
112  bno ++;
113  n -= ncopy;
114  }
115 }
116 
117 
118 /***********************************************************************
119  * RangeSearchPartialResult
120  ***********************************************************************/
121 
122 void RangeQueryResult::add (float dis, idx_t id) {
123  nres++;
124  pres->add (id, dis);
125 }
126 
127 
128 
130  BufferList(res_in->buffer_size),
131  res(res_in)
132 {}
133 
134 
135 /// begin a new result
138 {
139  RangeQueryResult qres = {qno, 0, this};
140  queries.push_back (qres);
141  return queries.back();
142 }
143 
144 
145 void RangeSearchPartialResult::finalize ()
146 {
147  set_lims ();
148 #pragma omp barrier
149 
150 #pragma omp single
151  res->do_allocation ();
152 
153 #pragma omp barrier
154  copy_result ();
155 }
156 
157 
158 /// called by range_search before do_allocation
160 {
161  for (int i = 0; i < queries.size(); i++) {
162  RangeQueryResult & qres = queries[i];
163  res->lims[qres.qno] = qres.nres;
164  }
165 }
166 
167 /// called by range_search after do_allocation
169 {
170  size_t ofs = 0;
171  for (int i = 0; i < queries.size(); i++) {
172  RangeQueryResult & qres = queries[i];
173 
174  copy_range (ofs, qres.nres,
175  res->labels + res->lims[qres.qno],
176  res->distances + res->lims[qres.qno]);
177  if (incremental) {
178  res->lims[qres.qno] += qres.nres;
179  }
180  ofs += qres.nres;
181  }
182 }
183 
184 void RangeSearchPartialResult::merge (std::vector <RangeSearchPartialResult *> &
185  partial_results, bool do_delete)
186 {
187 
188  int npres = partial_results.size();
189  if (npres == 0) return;
190  RangeSearchResult *result = partial_results[0]->res;
191  size_t nx = result->nq;
192 
193  // count
194  for (size_t i = 0; i < nx; i++) {
195  for (int j = 0; j < npres; j++) {
196  if (!partial_results[j]) continue;
197  result->lims[i] += partial_results[j]->queries[i].nres;
198  }
199  }
200  result->do_allocation ();
201  for (int j = 0; j < npres; j++) {
202  if (!partial_results[j]) continue;
203  partial_results[j]->copy_result (true);
204  if (do_delete) {
205  delete partial_results[j];
206  partial_results[j] = nullptr;
207  }
208  }
209 
210  // reset the limits
211  for (size_t i = nx; i > 0; i--) {
212  result->lims [i] = result->lims [i - 1];
213  }
214  result->lims [0] = 0;
215 }
216 
217 /***********************************************************************
218  * IDSelectorRange
219  ***********************************************************************/
220 
221 IDSelectorRange::IDSelectorRange (idx_t imin, idx_t imax):
222  imin (imin), imax (imax)
223 {
224 }
225 
226 bool IDSelectorRange::is_member (idx_t id) const
227 {
228  return id >= imin && id < imax;
229 }
230 
231 
232 /***********************************************************************
233  * IDSelectorBatch
234  ***********************************************************************/
235 
236 IDSelectorBatch::IDSelectorBatch (long n, const idx_t *indices)
237 {
238  nbits = 0;
239  while (n > (1L << nbits)) nbits++;
240  nbits += 5;
241  // for n = 1M, nbits = 25 is optimal, see P56659518
242 
243  mask = (1L << nbits) - 1;
244  bloom.resize (1UL << (nbits - 3), 0);
245  for (long i = 0; i < n; i++) {
246  long id = indices[i];
247  set.insert(id);
248  id &= mask;
249  bloom[id >> 3] |= 1 << (id & 7);
250  }
251 }
252 
253 bool IDSelectorBatch::is_member (idx_t i) const
254 {
255  long im = i & mask;
256  if(!(bloom[im>>3] & (1 << (im & 7)))) {
257  return 0;
258  }
259  return set.count(i);
260 }
261 
262 
263 /***********************************************************************
264  * IO functions
265  ***********************************************************************/
266 
267 
268 int IOReader::fileno ()
269 {
270  FAISS_THROW_MSG ("IOReader does not support memory mapping");
271 }
272 
273 int IOWriter::fileno ()
274 {
275  FAISS_THROW_MSG ("IOWriter does not support memory mapping");
276 }
277 
278 
279 size_t VectorIOWriter::operator()(
280  const void *ptr, size_t size, size_t nitems)
281 {
282  size_t o = data.size();
283  data.resize(o + size * nitems);
284  memcpy (&data[o], ptr, size * nitems);
285  return nitems;
286 }
287 
288 size_t VectorIOReader::operator()(
289  void *ptr, size_t size, size_t nitems)
290 {
291  if (rp >= data.size()) return 0;
292  size_t nremain = (data.size() - rp) / size;
293  if (nremain < nitems) nitems = nremain;
294  memcpy (ptr, &data[rp], size * nitems);
295  rp += size * nitems;
296  return nitems;
297 }
298 
299 
300 /***********************************************************
301  * Interrupt callback
302  ***********************************************************/
303 
304 
305 std::unique_ptr<InterruptCallback> InterruptCallback::instance;
306 
308  if (!instance.get()) {
309  return;
310  }
311  if (instance->want_interrupt ()) {
312  FAISS_THROW_MSG ("computation interrupted");
313  }
314 }
315 
317  if (!instance.get()) {
318  return false;
319  }
320  return instance->want_interrupt();
321 }
322 
323 
324 size_t InterruptCallback::get_period_hint (size_t flops) {
325  if (!instance.get()) {
326  return 1L << 30; // never check
327  }
328  // for 10M flops, it is reasonable to check once every 10 iterations
329  return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
330 }
331 
332 
333 
334 
335 } // namespace faiss
std::vector< RangeQueryResult > queries
query ids + nb of results per query.
size_t nq
nb of queries
result structure for a single query
void append_buffer()
create a new buffer
void copy_range(size_t ofs, size_t n, idx_t *dest_ids, float *dest_dis)
void set_lims()
called by range_search before do_allocation
RangeSearchResult(idx_t nq, bool alloc_lims=true)
lims must be allocated on input to range_search.
size_t wp
write pointer in the last buffer.
float * distances
corresponding distances (not sorted)
void add(idx_t id, float dis)
add one result, possibly appending a new buffer if needed
void copy_result(bool incremental=false)
called by range_search after do_allocation
static void merge(std::vector< RangeSearchPartialResult * > &partial_results, bool do_delete=true)
RangeQueryResult & new_result(idx_t qno)
begin a new result
size_t buffer_size
size of the result buffers used
static size_t get_period_hint(size_t flops)
size_t * lims
size (nq + 1)
void add(float dis, idx_t id)
called by search function to report a new result
idx_t * labels
result for query i is labels[lims[i]:lims[i+1]]
RangeSearchPartialResult(RangeSearchResult *res_in)
eventually the result will be stored in res_in