Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/tmp/faiss/IndexPQ.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // -*- c++ -*-
10 
11 #include "IndexPQ.h"
12 
13 
14 #include <cstddef>
15 #include <cstring>
16 #include <cstdio>
17 #include <cmath>
18 
19 #include <algorithm>
20 
21 #include "FaissAssert.h"
22 #include "AuxIndexStructures.h"
23 #include "hamming.h"
24 
25 namespace faiss {
26 
27 /*********************************************************
28  * IndexPQ implementation
29  ********************************************************/
30 
31 
32 IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
33  Index(d, metric), pq(d, M, nbits)
34 {
35  is_trained = false;
36  do_polysemous_training = false;
37  polysemous_ht = nbits * M + 1;
38  search_type = ST_PQ;
39  encode_signs = false;
40 }
41 
42 IndexPQ::IndexPQ ()
43 {
44  metric_type = METRIC_L2;
45  is_trained = false;
46  do_polysemous_training = false;
47  polysemous_ht = pq.nbits * pq.M + 1;
48  search_type = ST_PQ;
49  encode_signs = false;
50 }
51 
52 
53 void IndexPQ::train (idx_t n, const float *x)
54 {
55  if (!do_polysemous_training) { // standard training
56  pq.train(n, x);
57  } else {
58  idx_t ntrain_perm = polysemous_training.ntrain_permutation;
59 
60  if (ntrain_perm > n / 4)
61  ntrain_perm = n / 4;
62  if (verbose) {
63  printf ("PQ training on %ld points, remains %ld points: "
64  "training polysemous on %s\n",
65  n - ntrain_perm, ntrain_perm,
66  ntrain_perm == 0 ? "centroids" : "these");
67  }
68  pq.train(n - ntrain_perm, x);
69 
71  pq, ntrain_perm, x + (n - ntrain_perm) * d);
72  }
73  is_trained = true;
74 }
75 
76 
77 void IndexPQ::add (idx_t n, const float *x)
78 {
79  FAISS_THROW_IF_NOT (is_trained);
80  codes.resize ((n + ntotal) * pq.code_size);
82  ntotal += n;
83 }
84 
85 
86 long IndexPQ::remove_ids (const IDSelector & sel)
87 {
88  idx_t j = 0;
89  for (idx_t i = 0; i < ntotal; i++) {
90  if (sel.is_member (i)) {
91  // should be removed
92  } else {
93  if (i > j) {
94  memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
95  }
96  j++;
97  }
98  }
99  long nremove = ntotal - j;
100  if (nremove > 0) {
101  ntotal = j;
102  codes.resize (ntotal * pq.code_size);
103  }
104  return nremove;
105 }
106 
107 
109 {
110  codes.clear();
111  ntotal = 0;
112 }
113 
114 void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
115 {
116  FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
117  for (idx_t i = 0; i < ni; i++) {
118  const uint8_t * code = &codes[(i0 + i) * pq.code_size];
119  pq.decode (code, recons + i * d);
120  }
121 }
122 
123 
124 void IndexPQ::reconstruct (idx_t key, float * recons) const
125 {
126  FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
127  pq.decode (&codes[key * pq.code_size], recons);
128 }
129 
130 
131 
132 
133 
134 
135 
136 /*****************************************
137  * IndexPQ polysemous search routines
138  ******************************************/
139 
140 
141 
142 
143 
144 void IndexPQ::search (idx_t n, const float *x, idx_t k,
145  float *distances, idx_t *labels) const
146 {
147  FAISS_THROW_IF_NOT (is_trained);
148  if (search_type == ST_PQ) { // Simple PQ search
149 
150  if (metric_type == METRIC_L2) {
151  float_maxheap_array_t res = {
152  size_t(n), size_t(k), labels, distances };
153  pq.search (x, n, codes.data(), ntotal, &res, true);
154  } else {
155  float_minheap_array_t res = {
156  size_t(n), size_t(k), labels, distances };
157  pq.search_ip (x, n, codes.data(), ntotal, &res, true);
158  }
159  indexPQ_stats.nq += n;
160  indexPQ_stats.ncode += n * ntotal;
161 
162  } else if (search_type == ST_polysemous ||
163  search_type == ST_polysemous_generalize) {
164 
165  FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
166 
167  search_core_polysemous (n, x, k, distances, labels);
168 
169  } else { // code-to-code distances
170 
171  uint8_t * q_codes = new uint8_t [n * pq.code_size];
172  ScopeDeleter<uint8_t> del (q_codes);
173 
174 
175  if (!encode_signs) {
176  pq.compute_codes (x, q_codes, n);
177  } else {
178  FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
179  memset (q_codes, 0, n * pq.code_size);
180  for (size_t i = 0; i < n; i++) {
181  const float *xi = x + i * d;
182  uint8_t *code = q_codes + i * pq.code_size;
183  for (int j = 0; j < d; j++)
184  if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
185  }
186  }
187 
188  if (search_type == ST_SDC) {
189 
190  float_maxheap_array_t res = {
191  size_t(n), size_t(k), labels, distances};
192 
193  pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
194 
195  } else {
196  int * idistances = new int [n * k];
197  ScopeDeleter<int> del (idistances);
198 
199  int_maxheap_array_t res = {
200  size_t (n), size_t (k), labels, idistances};
201 
202  if (search_type == ST_HE) {
203 
204  hammings_knn_hc (&res, q_codes, codes.data(),
205  ntotal, pq.code_size, true);
206 
207  } else if (search_type == ST_generalized_HE) {
208 
209  generalized_hammings_knn_hc (&res, q_codes, codes.data(),
210  ntotal, pq.code_size, true);
211  }
212 
213  // convert distances to floats
214  for (int i = 0; i < k * n; i++)
215  distances[i] = idistances[i];
216 
217  }
218 
219 
220  indexPQ_stats.nq += n;
221  indexPQ_stats.ncode += n * ntotal;
222  }
223 }
224 
225 
226 
227 
228 
229 void IndexPQStats::reset()
230 {
231  nq = ncode = n_hamming_pass = 0;
232 }
233 
234 IndexPQStats indexPQ_stats;
235 
236 
237 template <class HammingComputer>
238 static size_t polysemous_inner_loop (
239  const IndexPQ & index,
240  const float *dis_table_qi, const uint8_t *q_code,
241  size_t k, float *heap_dis, long *heap_ids)
242 {
243 
244  int M = index.pq.M;
245  int code_size = index.pq.code_size;
246  int ksub = index.pq.ksub;
247  size_t ntotal = index.ntotal;
248  int ht = index.polysemous_ht;
249 
250  const uint8_t *b_code = index.codes.data();
251 
252  size_t n_pass_i = 0;
253 
254  HammingComputer hc (q_code, code_size);
255 
256  for (long bi = 0; bi < ntotal; bi++) {
257  int hd = hc.hamming (b_code);
258 
259  if (hd < ht) {
260  n_pass_i ++;
261 
262  float dis = 0;
263  const float * dis_table = dis_table_qi;
264  for (int m = 0; m < M; m++) {
265  dis += dis_table [b_code[m]];
266  dis_table += ksub;
267  }
268 
269  if (dis < heap_dis[0]) {
270  maxheap_pop (k, heap_dis, heap_ids);
271  maxheap_push (k, heap_dis, heap_ids, dis, bi);
272  }
273  }
274  b_code += code_size;
275  }
276  return n_pass_i;
277 }
278 
279 
280 void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
281  float *distances, idx_t *labels) const
282 {
283  FAISS_THROW_IF_NOT (pq.byte_per_idx == 1);
284 
285  // PQ distance tables
286  float * dis_tables = new float [n * pq.ksub * pq.M];
287  ScopeDeleter<float> del (dis_tables);
288  pq.compute_distance_tables (n, x, dis_tables);
289 
290  // Hamming embedding queries
291  uint8_t * q_codes = new uint8_t [n * pq.code_size];
292  ScopeDeleter<uint8_t> del2 (q_codes);
293 
294  if (false) {
295  pq.compute_codes (x, q_codes, n);
296  } else {
297 #pragma omp parallel for
298  for (idx_t qi = 0; qi < n; qi++) {
300  (dis_tables + qi * pq.M * pq.ksub,
301  q_codes + qi * pq.code_size);
302  }
303  }
304 
305  size_t n_pass = 0;
306 
307 #pragma omp parallel for reduction (+: n_pass)
308  for (idx_t qi = 0; qi < n; qi++) {
309  const uint8_t * q_code = q_codes + qi * pq.code_size;
310 
311  const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
312 
313  long * heap_ids = labels + qi * k;
314  float *heap_dis = distances + qi * k;
315  maxheap_heapify (k, heap_dis, heap_ids);
316 
317  if (search_type == ST_polysemous) {
318 
319  switch (pq.code_size) {
320  case 4:
321  n_pass += polysemous_inner_loop<HammingComputer4>
322  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
323  break;
324  case 8:
325  n_pass += polysemous_inner_loop<HammingComputer8>
326  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
327  break;
328  case 16:
329  n_pass += polysemous_inner_loop<HammingComputer16>
330  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
331  break;
332  case 32:
333  n_pass += polysemous_inner_loop<HammingComputer32>
334  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
335  break;
336  case 20:
337  n_pass += polysemous_inner_loop<HammingComputer20>
338  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
339  break;
340  default:
341  if (pq.code_size % 8 == 0) {
342  n_pass += polysemous_inner_loop<HammingComputerM8>
343  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
344  } else if (pq.code_size % 4 == 0) {
345  n_pass += polysemous_inner_loop<HammingComputerM4>
346  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
347  } else {
348  FAISS_THROW_FMT(
349  "code size %zd not supported for polysemous",
350  pq.code_size);
351  }
352  break;
353  }
354  } else {
355  switch (pq.code_size) {
356  case 8:
357  n_pass += polysemous_inner_loop<GenHammingComputer8>
358  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
359  break;
360  case 16:
361  n_pass += polysemous_inner_loop<GenHammingComputer16>
362  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
363  break;
364  case 32:
365  n_pass += polysemous_inner_loop<GenHammingComputer32>
366  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
367  break;
368  default:
369  if (pq.code_size % 8 == 0) {
370  n_pass += polysemous_inner_loop<GenHammingComputerM8>
371  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
372  } else {
373  FAISS_THROW_FMT(
374  "code size %zd not supported for polysemous",
375  pq.code_size);
376  }
377  break;
378  }
379  }
380  maxheap_reorder (k, heap_dis, heap_ids);
381  }
382 
383  indexPQ_stats.nq += n;
384  indexPQ_stats.ncode += n * ntotal;
385  indexPQ_stats.n_hamming_pass += n_pass;
386 
387 
388 }
389 
390 
391 
392 
393 /*****************************************
394  * Stats of IndexPQ codes
395  ******************************************/
396 
397 
398 
399 
400 void IndexPQ::hamming_distance_table (idx_t n, const float *x,
401  int32_t *dis) const
402 {
403  uint8_t * q_codes = new uint8_t [n * pq.code_size];
404  ScopeDeleter<uint8_t> del (q_codes);
405 
406  pq.compute_codes (x, q_codes, n);
407 
408  hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
409 }
410 
411 
413  idx_t nb, const float *xb,
414  long *hist)
415 {
416  FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
417  FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
418  FAISS_THROW_IF_NOT (pq.byte_per_idx == 1);
419 
420  // Hamming embedding queries
421  uint8_t * q_codes = new uint8_t [n * pq.code_size];
422  ScopeDeleter <uint8_t> del (q_codes);
423  pq.compute_codes (x, q_codes, n);
424 
425  uint8_t * b_codes ;
426  ScopeDeleter <uint8_t> del_b_codes;
427 
428  if (xb) {
429  b_codes = new uint8_t [nb * pq.code_size];
430  del_b_codes.set (b_codes);
431  pq.compute_codes (xb, b_codes, nb);
432  } else {
433  nb = ntotal;
434  b_codes = codes.data();
435  }
436  int nbits = pq.M * pq.nbits;
437  memset (hist, 0, sizeof(*hist) * (nbits + 1));
438  size_t bs = 256;
439 
440 #pragma omp parallel
441  {
442  std::vector<long> histi (nbits + 1);
443  hamdis_t *distances = new hamdis_t [nb * bs];
444  ScopeDeleter<hamdis_t> del (distances);
445 #pragma omp for
446  for (size_t q0 = 0; q0 < n; q0 += bs) {
447  // printf ("dis stats: %ld/%ld\n", q0, n);
448  size_t q1 = q0 + bs;
449  if (q1 > n) q1 = n;
450 
451  hammings (q_codes + q0 * pq.code_size, b_codes,
452  q1 - q0, nb,
453  pq.code_size, distances);
454 
455  for (size_t i = 0; i < nb * (q1 - q0); i++)
456  histi [distances [i]]++;
457  }
458 #pragma omp critical
459  {
460  for (int i = 0; i <= nbits; i++)
461  hist[i] += histi[i];
462  }
463  }
464 
465 }
466 
467 
468 
469 
470 
471 
472 
473 
474 
475 
476 
477 
478 
479 
480 
481 
482 
483 
484 
485 
486 /*****************************************
487  * MultiIndexQuantizer
488  ******************************************/
489 
490 namespace {
491 
492 template <typename T>
493 struct PreSortedArray {
494 
495  const T * x;
496  int N;
497 
498  explicit PreSortedArray (int N): N(N) {
499  }
500  void init (const T*x) {
501  this->x = x;
502  }
503  // get smallest value
504  T get_0 () {
505  return x[0];
506  }
507 
508  // get delta between n-smallest and n-1 -smallest
509  T get_diff (int n) {
510  return x[n] - x[n - 1];
511  }
512 
513  // remap orders counted from smallest to indices in array
514  int get_ord (int n) {
515  return n;
516  }
517 
518 };
519 
520 template <typename T>
521 struct ArgSort {
522  const T * x;
523  bool operator() (size_t i, size_t j) {
524  return x[i] < x[j];
525  }
526 };
527 
528 
529 /** Array that maintains a permutation of its elements so that the
530  * array's elements are sorted
531  */
532 template <typename T>
533 struct SortedArray {
534  const T * x;
535  int N;
536  std::vector<int> perm;
537 
538  explicit SortedArray (int N) {
539  this->N = N;
540  perm.resize (N);
541  }
542 
543  void init (const T*x) {
544  this->x = x;
545  for (int n = 0; n < N; n++)
546  perm[n] = n;
547  ArgSort<T> cmp = {x };
548  std::sort (perm.begin(), perm.end(), cmp);
549  }
550 
551  // get smallest value
552  T get_0 () {
553  return x[perm[0]];
554  }
555 
556  // get delta between n-smallest and n-1 -smallest
557  T get_diff (int n) {
558  return x[perm[n]] - x[perm[n - 1]];
559  }
560 
561  // remap orders counted from smallest to indices in array
562  int get_ord (int n) {
563  return perm[n];
564  }
565 };
566 
567 
568 
569 /** Array has n values. Sort the k first ones and copy the other ones
570  * into elements k..n-1
571  */
572 template <class C>
573 void partial_sort (int k, int n,
574  const typename C::T * vals, typename C::TI * perm) {
575  // insert first k elts in heap
576  for (int i = 1; i < k; i++) {
577  indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
578  }
579 
580  // insert next n - k elts in heap
581  for (int i = k; i < n; i++) {
582  typename C::TI id = perm[i];
583  typename C::TI top = perm[0];
584 
585  if (C::cmp(vals[top], vals[id])) {
586  indirect_heap_pop<C> (k, vals, perm);
587  indirect_heap_push<C> (k, vals, perm, id);
588  perm[i] = top;
589  } else {
590  // nothing, elt at i is good where it is.
591  }
592  }
593 
594  // order the k first elements in heap
595  for (int i = k - 1; i > 0; i--) {
596  typename C::TI top = perm[0];
597  indirect_heap_pop<C> (i + 1, vals, perm);
598  perm[i] = top;
599  }
600 }
601 
602 /** same as SortedArray, but only the k first elements are sorted */
603 template <typename T>
604 struct SemiSortedArray {
605  const T * x;
606  int N;
607 
608  // type of the heap: CMax = sort ascending
609  typedef CMax<T, int> HC;
610  std::vector<int> perm;
611 
612  int k; // k elements are sorted
613 
614  int initial_k, k_factor;
615 
616  explicit SemiSortedArray (int N) {
617  this->N = N;
618  perm.resize (N);
619  perm.resize (N);
620  initial_k = 3;
621  k_factor = 4;
622  }
623 
624  void init (const T*x) {
625  this->x = x;
626  for (int n = 0; n < N; n++)
627  perm[n] = n;
628  k = 0;
629  grow (initial_k);
630  }
631 
632  /// grow the sorted part of the array to size next_k
633  void grow (int next_k) {
634  if (next_k < N) {
635  partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
636  k = next_k;
637  } else { // full sort of remainder of array
638  ArgSort<T> cmp = {x };
639  std::sort (perm.begin() + k, perm.end(), cmp);
640  k = N;
641  }
642  }
643 
644  // get smallest value
645  T get_0 () {
646  return x[perm[0]];
647  }
648 
649  // get delta between n-smallest and n-1 -smallest
650  T get_diff (int n) {
651  if (n >= k) {
652  // want to keep powers of 2 - 1
653  int next_k = (k + 1) * k_factor - 1;
654  grow (next_k);
655  }
656  return x[perm[n]] - x[perm[n - 1]];
657  }
658 
659  // remap orders counted from smallest to indices in array
660  int get_ord (int n) {
661  assert (n < k);
662  return perm[n];
663  }
664 };
665 
666 
667 
668 /*****************************************
669  * Find the k smallest sums of M terms, where each term is taken in a
670  * table x of n values.
671  *
672  * A combination of terms is encoded as a scalar 0 <= t < n^M. The
673  * combination t0 ... t(M-1) that correspond to the sum
674  *
675  * sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)]
676  *
677  * is encoded as
678  *
679  * t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1)
680  *
681  * MinSumK is an object rather than a function, so that storage can be
682  * re-used over several computations with the same sizes. use_seen is
683  * good when there may be ties in the x array and it is a concern if
684  * occasionally several t's are returned.
685  *
686  * @param x size M * n, values to add up
687  * @parms k nb of results to retrieve
688  * @param M nb of terms
689  * @param n nb of distinct values
690  * @param sums output, size k, sorted
691  * @prarm terms output, size k, with encoding as above
692  *
693  ******************************************/
694 template <typename T, class SSA, bool use_seen>
695 struct MinSumK {
696  int K; ///< nb of sums to return
697  int M; ///< nb of elements to sum up
698  int nbit; ///< nb of bits to encode one entry
699  int N; ///< nb of possible elements for each of the M terms
700 
701  /** the heap.
702  * We use a heap to maintain a queue of sums, with the associated
703  * terms involved in the sum.
704  */
705  typedef CMin<T, long> HC;
706  size_t heap_capacity, heap_size;
707  T *bh_val;
708  long *bh_ids;
709 
710  std::vector <SSA> ssx;
711 
712  // all results get pushed several times. When there are ties, they
713  // are popped interleaved with others, so it is not easy to
714  // identify them. Therefore, this bit array just marks elements
715  // that were seen before.
716  std::vector <uint8_t> seen;
717 
718  MinSumK (int K, int M, int nbit, int N):
719  K(K), M(M), nbit(nbit), N(N) {
720  heap_capacity = K * M;
721  assert (N <= (1 << nbit));
722 
723  // we'll do k steps, each step pushes at most M vals
724  bh_val = new T[heap_capacity];
725  bh_ids = new long[heap_capacity];
726 
727  if (use_seen) {
728  long n_ids = weight(M);
729  seen.resize ((n_ids + 7) / 8);
730  }
731 
732  for (int m = 0; m < M; m++)
733  ssx.push_back (SSA(N));
734 
735  }
736 
737  long weight (int i) {
738  return 1 << (i * nbit);
739  }
740 
741  bool is_seen (long i) {
742  return (seen[i >> 3] >> (i & 7)) & 1;
743  }
744 
745  void mark_seen (long i) {
746  if (use_seen)
747  seen [i >> 3] |= 1 << (i & 7);
748  }
749 
750  void run (const T *x, long ldx,
751  T * sums, long * terms) {
752  heap_size = 0;
753 
754  for (int m = 0; m < M; m++) {
755  ssx[m].init(x);
756  x += ldx;
757  }
758 
759  { // intial result: take min for all elements
760  T sum = 0;
761  terms[0] = 0;
762  mark_seen (0);
763  for (int m = 0; m < M; m++) {
764  sum += ssx[m].get_0();
765  }
766  sums[0] = sum;
767  for (int m = 0; m < M; m++) {
768  heap_push<HC> (++heap_size, bh_val, bh_ids,
769  sum + ssx[m].get_diff(1),
770  weight(m));
771  }
772  }
773 
774  for (int k = 1; k < K; k++) {
775  // pop smallest value from heap
776  if (use_seen) {// skip already seen elements
777  while (is_seen (bh_ids[0])) {
778  assert (heap_size > 0);
779  heap_pop<HC> (heap_size--, bh_val, bh_ids);
780  }
781  }
782  assert (heap_size > 0);
783 
784  T sum = sums[k] = bh_val[0];
785  long ti = terms[k] = bh_ids[0];
786 
787  if (use_seen) {
788  mark_seen (ti);
789  heap_pop<HC> (heap_size--, bh_val, bh_ids);
790  } else {
791  do {
792  heap_pop<HC> (heap_size--, bh_val, bh_ids);
793  } while (heap_size > 0 && bh_ids[0] == ti);
794  }
795 
796  // enqueue followers
797  long ii = ti;
798  for (int m = 0; m < M; m++) {
799  long n = ii & ((1 << nbit) - 1);
800  ii >>= nbit;
801  if (n + 1 >= N) continue;
802 
803  enqueue_follower (ti, m, n, sum);
804  }
805  }
806 
807  /*
808  for (int k = 0; k < K; k++)
809  for (int l = k + 1; l < K; l++)
810  assert (terms[k] != terms[l]);
811  */
812 
813  // convert indices by applying permutation
814  for (int k = 0; k < K; k++) {
815  long ii = terms[k];
816  if (use_seen) {
817  // clear seen for reuse at next loop
818  seen[ii >> 3] = 0;
819  }
820  long ti = 0;
821  for (int m = 0; m < M; m++) {
822  long n = ii & ((1 << nbit) - 1);
823  ti += ssx[m].get_ord(n) << (nbit * m);
824  ii >>= nbit;
825  }
826  terms[k] = ti;
827  }
828  }
829 
830 
831  void enqueue_follower (long ti, int m, int n, T sum) {
832  T next_sum = sum + ssx[m].get_diff(n + 1);
833  long next_ti = ti + weight(m);
834  heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
835  }
836 
837  ~MinSumK () {
838  delete [] bh_ids;
839  delete [] bh_val;
840  }
841 };
842 
843 } // anonymous namespace
844 
845 
846 MultiIndexQuantizer::MultiIndexQuantizer (int d,
847  size_t M,
848  size_t nbits):
849  Index(d, METRIC_L2), pq(d, M, nbits)
850 {
851  is_trained = false;
852  pq.verbose = verbose;
853 }
854 
855 
856 
857 void MultiIndexQuantizer::train(idx_t n, const float *x)
858 {
859  pq.verbose = verbose;
860  pq.train (n, x);
861  is_trained = true;
862  // count virtual elements in index
863  ntotal = 1;
864  for (int m = 0; m < pq.M; m++)
865  ntotal *= pq.ksub;
866 }
867 
868 
869 void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
870  float *distances, idx_t *labels) const {
871  if (n == 0) return;
872 
873  // the allocation just below can be severe...
874  idx_t bs = 32768;
875  if (n > bs) {
876  for (idx_t i0 = 0; i0 < n; i0 += bs) {
877  idx_t i1 = std::min(i0 + bs, n);
878  if (verbose) {
879  printf("MultiIndexQuantizer::search: %ld:%ld / %ld\n",
880  i0, i1, n);
881  }
882  search (i1 - i0, x + i0 * d, k,
883  distances + i0 * k,
884  labels + i0 * k);
885  }
886  return;
887  }
888 
889  float * dis_tables = new float [n * pq.ksub * pq.M];
890  ScopeDeleter<float> del (dis_tables);
891 
892  pq.compute_distance_tables (n, x, dis_tables);
893 
894  if (k == 1) {
895  // simple version that just finds the min in each table
896 
897 #pragma omp parallel for
898  for (int i = 0; i < n; i++) {
899  const float * dis_table = dis_tables + i * pq.ksub * pq.M;
900  float dis = 0;
901  idx_t label = 0;
902 
903  for (int s = 0; s < pq.M; s++) {
904  float vmin = HUGE_VALF;
905  idx_t lmin = -1;
906 
907  for (idx_t j = 0; j < pq.ksub; j++) {
908  if (dis_table[j] < vmin) {
909  vmin = dis_table[j];
910  lmin = j;
911  }
912  }
913  dis += vmin;
914  label |= lmin << (s * pq.nbits);
915  dis_table += pq.ksub;
916  }
917 
918  distances [i] = dis;
919  labels [i] = label;
920  }
921 
922 
923  } else {
924 
925 #pragma omp parallel if(n > 1)
926  {
927  MinSumK <float, SemiSortedArray<float>, false>
928  msk(k, pq.M, pq.nbits, pq.ksub);
929 #pragma omp for
930  for (int i = 0; i < n; i++) {
931  msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
932  distances + i * k, labels + i * k);
933 
934  }
935  }
936  }
937 
938 }
939 
940 
941 void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
942 {
943 
944  long jj = key;
945  for (int m = 0; m < pq.M; m++) {
946  long n = jj & ((1L << pq.nbits) - 1);
947  jj >>= pq.nbits;
948  memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
949  recons += pq.dsub;
950  }
951 }
952 
953 void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
954  FAISS_THROW_MSG(
955  "This index has virtual elements, "
956  "it does not support add");
957 }
958 
960 {
961  FAISS_THROW_MSG ( "This index has virtual elements, "
962  "it does not support reset");
963 }
964 
965 
966 
967 
968 
969 
970 
971 
972 
973 
974 /*****************************************
975  * MultiIndexQuantizer2
976  ******************************************/
977 
978 
979 
980 MultiIndexQuantizer2::MultiIndexQuantizer2 (
981  int d, size_t M, size_t nbits,
982  Index **indexes):
983  MultiIndexQuantizer (d, M, nbits)
984 {
985  assign_indexes.resize (M);
986  for (int i = 0; i < M; i++) {
987  FAISS_THROW_IF_NOT_MSG(
988  indexes[i]->d == pq.dsub,
989  "Provided sub-index has incorrect size");
990  assign_indexes[i] = indexes[i];
991  }
992  own_fields = false;
993 }
994 
995 MultiIndexQuantizer2::MultiIndexQuantizer2 (
996  int d, size_t nbits,
997  Index *assign_index_0,
998  Index *assign_index_1):
999  MultiIndexQuantizer (d, 2, nbits)
1000 {
1001  FAISS_THROW_IF_NOT_MSG(
1002  assign_index_0->d == pq.dsub &&
1003  assign_index_1->d == pq.dsub,
1004  "Provided sub-index has incorrect size");
1005  assign_indexes.resize (2);
1006  assign_indexes [0] = assign_index_0;
1007  assign_indexes [1] = assign_index_1;
1008  own_fields = false;
1009 }
1010 
1011 void MultiIndexQuantizer2::train(idx_t n, const float* x)
1012 {
1014  // add centroids to sub-indexes
1015  for (int i = 0; i < pq.M; i++) {
1016  assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0));
1017  }
1018 }
1019 
1020 
1022  idx_t n, const float* x, idx_t K,
1023  float* distances, idx_t* labels) const
1024 {
1025 
1026  if (n == 0) return;
1027 
1028  int k2 = std::min(K, long(pq.ksub));
1029 
1030  long M = pq.M;
1031  long dsub = pq.dsub, ksub = pq.ksub;
1032 
1033  // size (M, n, k2)
1034  std::vector<idx_t> sub_ids(n * M * k2);
1035  std::vector<float> sub_dis(n * M * k2);
1036  std::vector<float> xsub(n * dsub);
1037 
1038  for (int m = 0; m < M; m++) {
1039  float *xdest = xsub.data();
1040  const float *xsrc = x + m * dsub;
1041  for (int j = 0; j < n; j++) {
1042  memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
1043  xsrc += d;
1044  xdest += dsub;
1045  }
1046 
1047  assign_indexes[m]->search(
1048  n, xsub.data(), k2,
1049  &sub_dis[k2 * n * m],
1050  &sub_ids[k2 * n * m]);
1051  }
1052 
1053  if (K == 1) {
1054  // simple version that just finds the min in each table
1055  assert (k2 == 1);
1056 
1057  for (int i = 0; i < n; i++) {
1058  float dis = 0;
1059  idx_t label = 0;
1060 
1061  for (int m = 0; m < M; m++) {
1062  float vmin = sub_dis[i + m * n];
1063  idx_t lmin = sub_ids[i + m * n];
1064  dis += vmin;
1065  label |= lmin << (m * pq.nbits);
1066  }
1067  distances [i] = dis;
1068  labels [i] = label;
1069  }
1070 
1071  } else {
1072 
1073 #pragma omp parallel if(n > 1)
1074  {
1075  MinSumK <float, PreSortedArray<float>, false>
1076  msk(K, pq.M, pq.nbits, k2);
1077 #pragma omp for
1078  for (int i = 0; i < n; i++) {
1079  idx_t *li = labels + i * K;
1080  msk.run (&sub_dis[i * k2], k2 * n,
1081  distances + i * K, li);
1082 
1083  // remap ids
1084 
1085  const idx_t *idmap0 = sub_ids.data() + i * k2;
1086  long ld_idmap = k2 * n;
1087  long mask1 = ksub - 1L;
1088 
1089  for (int k = 0; k < K; k++) {
1090  const idx_t *idmap = idmap0;
1091  long vin = li[k];
1092  long vout = 0;
1093  int bs = 0;
1094  for (int m = 0; m < M; m++) {
1095  long s = vin & mask1;
1096  vin >>= pq.nbits;
1097  vout |= idmap[s] << bs;
1098  bs += pq.nbits;
1099  idmap += ld_idmap;
1100  }
1101  li[k] = vout;
1102  }
1103  }
1104  }
1105  }
1106 }
1107 
1108 
1109 } // namespace faiss
void hammings_knn_hc(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t ncodes, int order)
Definition: hamming.cpp:518
std::vector< uint8_t > codes
Codes. Size ntotal * pq.code_size.
Definition: IndexPQ.h:33
size_t nbits
number of bits per quantization index
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
Hamming distance on codes.
Definition: IndexPQ.h:78
bool do_polysemous_training
false = standard PQ
Definition: IndexPQ.h:70
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:53
size_t byte_per_idx
nb bytes per code component (1 or 2)
void reset() override
removes all elements from the database.
Definition: IndexPQ.cpp:959
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:857
size_t dsub
dimensionality of each subvector
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
int d
vector dimension
Definition: Index.h:66
void hamming_distance_histogram(idx_t n, const float *x, idx_t nb, const float *xb, long *dist_histogram)
Definition: IndexPQ.cpp:412
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:1011
size_t code_size
byte per indexed vector
long remove_ids(const IDSelector &sel) override
Definition: IndexPQ.cpp:86
Filter on generalized Hamming.
Definition: IndexPQ.h:82
size_t ksub
number of centroids for each subquantizer
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
long idx_t
all indices are this type
Definition: Index.h:64
ProductQuantizer pq
The product quantizer used to encode the vectors.
Definition: IndexPQ.h:30
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
bool verbose
verbosity level
Definition: Index.h:68
void add(idx_t n, const float *x) override
Definition: IndexPQ.cpp:77
void hamming_distance_table(idx_t n, const float *x, int32_t *dis) const
Definition: IndexPQ.cpp:400
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:869
void reconstruct(idx_t key, float *recons) const override
Definition: IndexPQ.cpp:124
MetricType metric_type
type of metric this index uses for search
Definition: Index.h:74
size_t M
number of subquantizers
void reconstruct_n(idx_t i0, idx_t ni, float *recons) const override
Definition: IndexPQ.cpp:114
asymmetric product quantizer (default)
Definition: IndexPQ.h:77
void reconstruct(idx_t key, float *recons) const override
Definition: IndexPQ.cpp:941
HE filter (using ht) + PQ combination.
Definition: IndexPQ.h:81
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:1021
void add(idx_t n, const float *x) override
add and reset will crash at runtime
Definition: IndexPQ.cpp:953
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:71
void reset() override
removes all elements from the database.
Definition: IndexPQ.cpp:108
float * get_centroids(size_t m, size_t i)
return the centroids associated with subvector m
void optimize_pq_for_hamming(ProductQuantizer &pq, size_t n, const float *x) const
bool verbose
verbose during training?
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:144
symmetric product quantizer (SDC)
Definition: IndexPQ.h:80
int polysemous_ht
Hamming threshold used for polysemy.
Definition: IndexPQ.h:92
PolysemousTraining polysemous_training
parameters used for the polysemous training
Definition: IndexPQ.h:73
std::vector< Index * > assign_indexes
M Indexes on d / M dimensions.
Definition: IndexPQ.h:164
MetricType
Some algorithms support both an inner product version and a L2 search version.
Definition: Index.h:45
void generalized_hammings_knn_hc(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t code_size, int ordered)
Definition: hamming.cpp:729