Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/matthijs/github_faiss/faiss/IndexPQ.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 /* Copyright 2004-present Facebook. All Rights Reserved.
10  Index based on product quantiztion.
11 */
12 
13 #include "IndexPQ.h"
14 
15 
16 #include <cstddef>
17 #include <cstring>
18 #include <cstdio>
19 #include <cmath>
20 
21 #include <algorithm>
22 
23 #include "FaissAssert.h"
24 #include "hamming.h"
25 
26 namespace faiss {
27 
28 /*********************************************************
29  * IndexPQ implementation
30  ********************************************************/
31 
32 
33 IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
34  Index(d, metric), pq(d, M, nbits)
35 {
36  is_trained = false;
37  do_polysemous_training = false;
38  polysemous_ht = nbits * M + 1;
39  search_type = ST_PQ;
40  encode_signs = false;
41 }
42 
43 IndexPQ::IndexPQ ()
44 {
45  metric_type = METRIC_L2;
46  is_trained = false;
47  do_polysemous_training = false;
48  polysemous_ht = pq.nbits * pq.M + 1;
49  search_type = ST_PQ;
50  encode_signs = false;
51 }
52 
53 
54 void IndexPQ::train (idx_t n, const float *x)
55 {
56  if (!do_polysemous_training) { // standard training
57  pq.train(n, x);
58  } else {
59  idx_t ntrain_perm = polysemous_training.ntrain_permutation;
60 
61  if (ntrain_perm > n / 4)
62  ntrain_perm = n / 4;
63  if (verbose) {
64  printf ("PQ training on %ld points, remains %ld points: "
65  "training polysemous on %s\n",
66  n - ntrain_perm, ntrain_perm,
67  ntrain_perm == 0 ? "centroids" : "these");
68  }
69  pq.train(n - ntrain_perm, x);
70 
72  pq, ntrain_perm, x + (n - ntrain_perm) * d);
73  }
74  is_trained = true;
75 }
76 
77 
78 void IndexPQ::add (idx_t n, const float *x)
79 {
80  FAISS_THROW_IF_NOT (is_trained);
81  codes.resize ((n + ntotal) * pq.code_size);
83  ntotal += n;
84 }
85 
86 
87 
89 {
90  codes.clear();
91  ntotal = 0;
92 }
93 
94 void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
95 {
96  FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
97  for (idx_t i = 0; i < ni; i++) {
98  const uint8_t * code = &codes[(i0 + i) * pq.code_size];
99  pq.decode (code, recons + i * d);
100  }
101 }
102 
103 
104 void IndexPQ::reconstruct (idx_t key, float * recons) const
105 {
106  FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
107  pq.decode (&codes[key * pq.code_size], recons);
108 }
109 
110 
111 
112 
113 
114 
115 
116 /*****************************************
117  * IndexPQ polysemous search routines
118  ******************************************/
119 
120 
121 
122 
123 
124 void IndexPQ::search (idx_t n, const float *x, idx_t k,
125  float *distances, idx_t *labels) const
126 {
127  FAISS_THROW_IF_NOT (is_trained);
128  if (search_type == ST_PQ) { // Simple PQ search
129 
130  if (metric_type == METRIC_L2) {
131  float_maxheap_array_t res = {
132  size_t(n), size_t(k), labels, distances };
133  pq.search (x, n, codes.data(), ntotal, &res, true);
134  } else {
135  float_minheap_array_t res = {
136  size_t(n), size_t(k), labels, distances };
137  pq.search_ip (x, n, codes.data(), ntotal, &res, true);
138  }
139  indexPQ_stats.nq += n;
140  indexPQ_stats.ncode += n * ntotal;
141 
142  } else if (search_type == ST_polysemous ||
143  search_type == ST_polysemous_generalize) {
144 
145  FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
146 
147  search_core_polysemous (n, x, k, distances, labels);
148 
149  } else { // code-to-code distances
150 
151  uint8_t * q_codes = new uint8_t [n * pq.code_size];
152  ScopeDeleter<uint8_t> del (q_codes);
153 
154 
155  if (!encode_signs) {
156  pq.compute_codes (x, q_codes, n);
157  } else {
158  FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
159  memset (q_codes, 0, n * pq.code_size);
160  for (size_t i = 0; i < n; i++) {
161  const float *xi = x + i * d;
162  uint8_t *code = q_codes + i * pq.code_size;
163  for (int j = 0; j < d; j++)
164  if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
165  }
166  }
167 
168  if (search_type == ST_SDC) {
169 
170  float_maxheap_array_t res = {
171  size_t(n), size_t(k), labels, distances};
172 
173  pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
174 
175  } else {
176  int * idistances = new int [n * k];
177  ScopeDeleter<int> del (idistances);
178 
179  int_maxheap_array_t res = {
180  size_t (n), size_t (k), labels, idistances};
181 
182  if (search_type == ST_HE) {
183 
184  hammings_knn (&res, q_codes, codes.data(),
185  ntotal, pq.code_size, true);
186 
187  } else if (search_type == ST_generalized_HE) {
188 
189  generalized_hammings_knn (&res, q_codes, codes.data(),
190  ntotal, pq.code_size, true);
191  }
192 
193  // convert distances to floats
194  for (int i = 0; i < k * n; i++)
195  distances[i] = idistances[i];
196 
197  }
198 
199 
200  indexPQ_stats.nq += n;
201  indexPQ_stats.ncode += n * ntotal;
202  }
203 }
204 
205 
206 
207 
208 
209 void IndexPQStats::reset()
210 {
211  nq = ncode = n_hamming_pass = 0;
212 }
213 
214 IndexPQStats indexPQ_stats;
215 
216 
217 template <class HammingComputer>
218 static size_t polysemous_inner_loop (
219  const IndexPQ & index,
220  const float *dis_table_qi, const uint8_t *q_code,
221  size_t k, float *heap_dis, long *heap_ids)
222 {
223 
224  int M = index.pq.M;
225  int code_size = index.pq.code_size;
226  int ksub = index.pq.ksub;
227  size_t ntotal = index.ntotal;
228  int ht = index.polysemous_ht;
229 
230  const uint8_t *b_code = index.codes.data();
231 
232  size_t n_pass_i = 0;
233 
234  HammingComputer hc (q_code, code_size);
235 
236  for (long bi = 0; bi < ntotal; bi++) {
237  int hd = hc.hamming (b_code);
238 
239  if (hd < ht) {
240  n_pass_i ++;
241 
242  float dis = 0;
243  const float * dis_table = dis_table_qi;
244  for (int m = 0; m < M; m++) {
245  dis += dis_table [b_code[m]];
246  dis_table += ksub;
247  }
248 
249  if (dis < heap_dis[0]) {
250  maxheap_pop (k, heap_dis, heap_ids);
251  maxheap_push (k, heap_dis, heap_ids, dis, bi);
252  }
253  }
254  b_code += code_size;
255  }
256  return n_pass_i;
257 }
258 
259 
260 void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
261  float *distances, idx_t *labels) const
262 {
263  FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
264  FAISS_THROW_IF_NOT (pq.byte_per_idx == 1);
265 
266  // PQ distance tables
267  float * dis_tables = new float [n * pq.ksub * pq.M];
268  ScopeDeleter<float> del (dis_tables);
269  pq.compute_distance_tables (n, x, dis_tables);
270 
271  // Hamming embedding queries
272  uint8_t * q_codes = new uint8_t [n * pq.code_size];
273  ScopeDeleter<uint8_t> del2 (q_codes);
274 
275  if (false) {
276  pq.compute_codes (x, q_codes, n);
277  } else {
278 #pragma omp parallel for
279  for (idx_t qi = 0; qi < n; qi++) {
281  (dis_tables + qi * pq.M * pq.ksub,
282  q_codes + qi * pq.code_size);
283  }
284  }
285 
286  size_t n_pass = 0;
287 
288 #pragma omp parallel for reduction (+: n_pass)
289  for (idx_t qi = 0; qi < n; qi++) {
290  const uint8_t * q_code = q_codes + qi * pq.code_size;
291 
292  const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
293 
294  long * heap_ids = labels + qi * k;
295  float *heap_dis = distances + qi * k;
296  maxheap_heapify (k, heap_dis, heap_ids);
297 
298  if (search_type == ST_polysemous) {
299 
300  switch (pq.code_size) {
301  case 4:
302  n_pass += polysemous_inner_loop<HammingComputer4>
303  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
304  break;
305  case 8:
306  n_pass += polysemous_inner_loop<HammingComputer8>
307  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
308  break;
309  case 16:
310  n_pass += polysemous_inner_loop<HammingComputer16>
311  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
312  break;
313  case 32:
314  n_pass += polysemous_inner_loop<HammingComputer32>
315  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
316  break;
317  case 20:
318  n_pass += polysemous_inner_loop<HammingComputer20>
319  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
320  break;
321  default:
322  if (pq.code_size % 8 == 0)
323  n_pass += polysemous_inner_loop<HammingComputerM8>
324  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
325  else
326  n_pass += polysemous_inner_loop<HammingComputerM4>
327  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
328  break;
329  }
330  } else {
331  switch (pq.code_size) {
332  case 8:
333  n_pass += polysemous_inner_loop<GenHammingComputer8>
334  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
335  break;
336  case 16:
337  n_pass += polysemous_inner_loop<GenHammingComputer16>
338  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
339  break;
340  case 32:
341  n_pass += polysemous_inner_loop<GenHammingComputer32>
342  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
343  break;
344  default:
345  n_pass += polysemous_inner_loop<GenHammingComputerM8>
346  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
347  break;
348  }
349  }
350  maxheap_reorder (k, heap_dis, heap_ids);
351  }
352 
353  indexPQ_stats.nq += n;
354  indexPQ_stats.ncode += n * ntotal;
355  indexPQ_stats.n_hamming_pass += n_pass;
356 
357 
358 }
359 
360 
361 
362 
363 /*****************************************
364  * Stats of IndexPQ codes
365  ******************************************/
366 
367 
368 
369 
370 void IndexPQ::hamming_distance_table (idx_t n, const float *x,
371  int32_t *dis) const
372 {
373  uint8_t * q_codes = new uint8_t [n * pq.code_size];
374  ScopeDeleter<uint8_t> del (q_codes);
375 
376  pq.compute_codes (x, q_codes, n);
377 
378  hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
379 }
380 
381 
383  idx_t nb, const float *xb,
384  long *hist)
385 {
386  FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
387  FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
388  FAISS_THROW_IF_NOT (pq.byte_per_idx == 1);
389 
390  // Hamming embedding queries
391  uint8_t * q_codes = new uint8_t [n * pq.code_size];
392  ScopeDeleter <uint8_t> del (q_codes);
393  pq.compute_codes (x, q_codes, n);
394 
395  uint8_t * b_codes ;
396  ScopeDeleter <uint8_t> del_b_codes;
397 
398  if (xb) {
399  b_codes = new uint8_t [nb * pq.code_size];
400  del_b_codes.set (b_codes);
401  pq.compute_codes (xb, b_codes, nb);
402  } else {
403  nb = ntotal;
404  b_codes = codes.data();
405  }
406  int nbits = pq.M * pq.nbits;
407  memset (hist, 0, sizeof(*hist) * (nbits + 1));
408  size_t bs = 256;
409 
410 #pragma omp parallel
411  {
412  std::vector<long> histi (nbits + 1);
413  hamdis_t *distances = new hamdis_t [nb * bs];
414  ScopeDeleter<hamdis_t> del (distances);
415 #pragma omp for
416  for (size_t q0 = 0; q0 < n; q0 += bs) {
417  // printf ("dis stats: %ld/%ld\n", q0, n);
418  size_t q1 = q0 + bs;
419  if (q1 > n) q1 = n;
420 
421  hammings (q_codes + q0 * pq.code_size, b_codes,
422  q1 - q0, nb,
423  pq.code_size, distances);
424 
425  for (size_t i = 0; i < nb * (q1 - q0); i++)
426  histi [distances [i]]++;
427  }
428 #pragma omp critical
429  {
430  for (int i = 0; i <= nbits; i++)
431  hist[i] += histi[i];
432  }
433  }
434 
435 }
436 
437 
438 
439 
440 
441 
442 
443 
444 
445 
446 
447 
448 
449 
450 
451 
452 
453 
454 
455 
456 /*****************************************
457  * MultiIndexQuantizer
458  ******************************************/
459 
460 
461 
462 template <typename T>
463 struct ArgSort {
464  const T * x;
465  bool operator() (size_t i, size_t j) {
466  return x[i] < x[j];
467  }
468 };
469 
470 
471 /** Array that maintains a permutation of its elements so that the
472  * array's elements are sorted
473  */
474 template <typename T>
475 struct SortedArray {
476  const T * x;
477  int N;
478  std::vector<int> perm;
479 
480  explicit SortedArray (int N) {
481  this->N = N;
482  perm.resize (N);
483  }
484 
485  void init (const T*x) {
486  this->x = x;
487  for (int n = 0; n < N; n++)
488  perm[n] = n;
489  ArgSort<T> cmp = {x };
490  std::sort (perm.begin(), perm.end(), cmp);
491  }
492 
493  // get smallest value
494  T get_0 () {
495  return x[perm[0]];
496  }
497 
498  // get delta between n-smallest and n-1 -smallest
499  T get_diff (int n) {
500  return x[perm[n]] - x[perm[n - 1]];
501  }
502 
503  // remap orders counted from smallest to indices in array
504  int get_ord (int n) {
505  return perm[n];
506  }
507 };
508 
509 
510 
511 /** Array has n values. Sort the k first ones and copy the other ones
512  * into elements k..n-1
513  */
514 template <class C>
515 void partial_sort (int k, int n,
516  const typename C::T * vals, typename C::TI * perm) {
517  // insert first k elts in heap
518  for (int i = 1; i < k; i++) {
519  indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
520  }
521 
522  // insert next n - k elts in heap
523  for (int i = k; i < n; i++) {
524  typename C::TI id = perm[i];
525  typename C::TI top = perm[0];
526 
527  if (C::cmp(vals[top], vals[id])) {
528  indirect_heap_pop<C> (k, vals, perm);
529  indirect_heap_push<C> (k, vals, perm, id);
530  perm[i] = top;
531  } else {
532  // nothing, elt at i is good where it is.
533  }
534  }
535 
536  // order the k first elements in heap
537  for (int i = k - 1; i > 0; i--) {
538  typename C::TI top = perm[0];
539  indirect_heap_pop<C> (i + 1, vals, perm);
540  perm[i] = top;
541  }
542 }
543 
544 /** same as SortedArray, but only the k first elements are sorted */
545 template <typename T>
547  const T * x;
548  int N;
549 
550  // type of the heap: CMax = sort ascending
551  typedef CMax<T, int> HC;
552  std::vector<int> perm;
553 
554  int k; // k elements are sorted
555 
556  int initial_k, k_factor;
557 
558  explicit SemiSortedArray (int N) {
559  this->N = N;
560  perm.resize (N);
561  perm.resize (N);
562  initial_k = 3;
563  k_factor = 4;
564  }
565 
566  void init (const T*x) {
567  this->x = x;
568  for (int n = 0; n < N; n++)
569  perm[n] = n;
570  k = 0;
571  grow (initial_k);
572  }
573 
574  /// grow the sorted part of the array to size next_k
575  void grow (int next_k) {
576  if (next_k < N) {
577  partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
578  k = next_k;
579  } else { // full sort of remainder of array
580  ArgSort<T> cmp = {x };
581  std::sort (perm.begin() + k, perm.end(), cmp);
582  k = N;
583  }
584  }
585 
586  // get smallest value
587  T get_0 () {
588  return x[perm[0]];
589  }
590 
591  // get delta between n-smallest and n-1 -smallest
592  T get_diff (int n) {
593  if (n >= k) {
594  // want to keep powers of 2 - 1
595  int next_k = (k + 1) * k_factor - 1;
596  grow (next_k);
597  }
598  return x[perm[n]] - x[perm[n - 1]];
599  }
600 
601  // remap orders counted from smallest to indices in array
602  int get_ord (int n) {
603  assert (n < k);
604  return perm[n];
605  }
606 };
607 
608 
609 
610 /*****************************************
611  * Find the k smallest sums of M terms, where each term is taken in a
612  * table x of n values.
613  *
614  * A combination of terms is encoded as a scalar 0 <= t < n^M. The
615  * combination t0 ... t(M-1) that correspond to the sum
616  *
617  * sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)]
618  *
619  * is encoded as
620  *
621  * t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1)
622  *
623  * MinSumK is an object rather than a function, so that storage can be
624  * re-used over several computations with the same sizes. use_seen is
625  * good when there may be ties in the x array and it is a concern if
626  * occasionally several t's are returned.
627  *
628  * @param x size M * n, values to add up
629  * @parms k nb of results to retrieve
630  * @param M nb of terms
631  * @param n nb of distinct values
632  * @param sums output, size k, sorted
633  * @prarm terms output, size k, with encoding as above
634  *
635  ******************************************/
636 template <typename T, class SSA, bool use_seen>
637 struct MinSumK {
638  int K; ///< nb of sums to return
639  int M; ///< nb of elements to sum up
640  int N; ///< nb of possible elements for each of the M terms
641 
642  /** the heap.
643  * We use a heap to maintain a queue of sums, with the associated
644  * terms involved in the sum.
645  */
646  typedef CMin<T, long> HC;
647  size_t heap_capacity, heap_size;
648  T *bh_val;
649  long *bh_ids;
650 
651  std::vector <SSA> ssx;
652  std::vector <long> weights;
653 
654  // all results get pushed several times. When there are ties, they
655  // are popped interleaved with others, so it is not easy to
656  // identify them. Therefore, this bit array just marks elements
657  // that were seen before.
658  std::vector <uint8_t> seen;
659 
660  MinSumK (int K, int M, int N): K(K), M(M), N(N) {
661  heap_capacity = K * M;
662  // we'll do k steps, each step pushes at most M vals
663  bh_val = new T[heap_capacity];
664  bh_ids = new long[heap_capacity];
665 
666  weights.push_back (1);
667  for (int m = 1; m < M; m++)
668  weights.push_back(weights[m - 1] * N);
669 
670  if (use_seen) {
671  long n_ids = weights.back() * N;
672  seen.resize ((n_ids + 7) / 8);
673  }
674 
675  for (int m = 0; m < M; m++)
676  ssx.push_back (SSA(N));
677 
678  }
679 
680  bool is_seen (long i) {
681  return (seen[i >> 3] >> (i & 7)) & 1;
682  }
683 
684  void mark_seen (long i) {
685  if (use_seen)
686  seen [i >> 3] |= 1 << (i & 7);
687  }
688 
689  void run (const T *x, T * sums, long * terms) {
690  heap_size = 0;
691 
692  for (int m = 0; m < M; m++)
693  ssx[m].init(x + N * m);
694 
695  { // intial result: take min for all elements
696  T sum = 0;
697  terms[0] = 0;
698  mark_seen (0);
699  for (int m = 0; m < M; m++) {
700  sum += ssx[m].get_0();
701  }
702  sums[0] = sum;
703  for (int m = 0; m < M; m++) {
704  heap_push<HC> (++heap_size, bh_val, bh_ids,
705  sum + ssx[m].get_diff(1),
706  weights[m]);
707  }
708  }
709 
710  for (int k = 1; k < K; k++) {
711  // pop smallest value from heap
712  if (use_seen) {// skip already seen elements
713  while (is_seen (bh_ids[0])) {
714  assert (heap_size > 0);
715  heap_pop<HC> (heap_size--, bh_val, bh_ids);
716  }
717  }
718  assert (heap_size > 0);
719 
720  T sum = sums[k] = bh_val[0];
721  long ti = terms[k] = bh_ids[0];
722 
723  if (use_seen) {
724  mark_seen (ti);
725  heap_pop<HC> (heap_size--, bh_val, bh_ids);
726  } else {
727  do {
728  heap_pop<HC> (heap_size--, bh_val, bh_ids);
729  } while (heap_size > 0 && bh_ids[0] == ti);
730  }
731 
732  // enqueue followers
733  long ii = ti;
734  for (int m = 0; m < M; m++) {
735  long n = ii % N;
736  ii /= N;
737  if (n + 1 >= N) continue;
738 
739  enqueue_follower (ti, m, n, sum);
740  }
741  }
742 
743  /*
744  for (int k = 0; k < K; k++)
745  for (int l = k + 1; l < K; l++)
746  assert (terms[k] != terms[l]);
747  */
748 
749  // convert indices by applying permutation
750  for (int k = 0; k < K; k++) {
751  long ii = terms[k];
752  if (use_seen) {
753  // clear seen for reuse at next loop
754  seen[ii >> 3] = 0;
755  }
756  long ti = 0;
757  for (int m = 0; m < M; m++) {
758  long n = ii % N;
759  ti += weights[m] * ssx[m].get_ord(n);
760  ii /= N;
761  }
762  terms[k] = ti;
763  }
764  }
765 
766 
767  void enqueue_follower (long ti, int m, int n, T sum) {
768  T next_sum = sum + ssx[m].get_diff(n + 1);
769  long next_ti = ti + weights[m];
770  heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
771  }
772 
773 
774  ~MinSumK () {
775  delete [] bh_ids;
776  delete [] bh_val;
777  }
778 };
779 
780 
781 
782 
783 MultiIndexQuantizer::MultiIndexQuantizer (int d,
784  size_t M,
785  size_t nbits):
786  Index(d, METRIC_L2), pq(d, M, nbits)
787 {
788  is_trained = false;
789  pq.verbose = verbose;
790 }
791 
792 
793 
794 void MultiIndexQuantizer::train(idx_t n, const float *x)
795 {
796  pq.train (n, x);
797  is_trained = true;
798  // count virtual elements in index
799  ntotal = 1;
800  for (int m = 0; m < pq.M; m++)
801  ntotal *= pq.ksub;
802 }
803 
804 
805 void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
806  float *distances, idx_t *labels) const {
807  if (n == 0) return;
808 
809  float * dis_tables = new float [n * pq.ksub * pq.M];
810  ScopeDeleter<float> del (dis_tables);
811 
812  pq.compute_distance_tables (n, x, dis_tables);
813 
814  if (k == 1) {
815  // simple version that just finds the min in each table
816 
817 #pragma omp parallel for
818  for (int i = 0; i < n; i++) {
819  const float * dis_table = dis_tables + i * pq.ksub * pq.M;
820  float dis = 0;
821  idx_t label = 0;
822 
823  for (int s = 0; s < pq.M; s++) {
824  float vmin = HUGE_VALF;
825  idx_t lmin = -1;
826 
827  for (idx_t j = 0; j < pq.ksub; j++) {
828  if (dis_table[j] < vmin) {
829  vmin = dis_table[j];
830  lmin = j;
831  }
832  }
833  dis += vmin;
834  label |= lmin << (s * pq.nbits);
835  dis_table += pq.ksub;
836  }
837 
838  distances [i] = dis;
839  labels [i] = label;
840  }
841 
842 
843  } else {
844 
845 #pragma omp parallel if(n > 1)
846  {
848  msk(k, pq.M, pq.ksub);
849 #pragma omp for
850  for (int i = 0; i < n; i++) {
851  msk.run (dis_tables + i * pq.ksub * pq.M,
852  distances + i * k, labels + i * k);
853 
854  }
855  }
856  }
857 
858 }
859 
860 
861 void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
862 {
863  if (pq.byte_per_idx == 1) {
864  uint8_t code[pq.M];
865  long jj = key;
866  for (int m = 0; m < pq.M; m++) {
867  long n = jj % pq.ksub;
868  jj /= pq.ksub;
869  code[m] = n;
870  }
871  pq.decode (code, recons);
872  } else if (pq.byte_per_idx == 2) {
873  uint16_t code[pq.M];
874  long jj = key;
875  for (int m = 0; m < pq.M; m++) {
876  long n = jj % pq.ksub;
877  jj /= pq.ksub;
878  code[m] = n;
879  }
880  pq.decode ((uint8_t*)code, recons);
881  } else FAISS_THROW_MSG( "only 1 or 2 bytes per index supported");
882 }
883 
884 void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
885  FAISS_THROW_MSG(
886  "This index has virtual elements, "
887  "it does not support add");
888 }
889 
891 {
892  FAISS_THROW_MSG ( "This index has virtual elements, "
893  "it does not support reset");
894 }
895 
896 
897 
898 
899 } // END namespace faiss
int M
nb of elements to sum up
Definition: IndexPQ.cpp:639
std::vector< uint8_t > codes
Codes. Size ntotal * pq.code_size.
Definition: IndexPQ.h:34
size_t nbits
number of bits per quantization index
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
Hamming distance on codes.
Definition: IndexPQ.h:77
bool do_polysemous_training
false = standard PQ
Definition: IndexPQ.h:69
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:54
size_t byte_per_idx
nb bytes per code component (1 or 2)
void reset() override
removes all elements from the database.
Definition: IndexPQ.cpp:890
void partial_sort(int k, int n, const typename C::T *vals, typename C::TI *perm)
Definition: IndexPQ.cpp:515
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:794
CMin< T, long > HC
Definition: IndexPQ.cpp:646
void grow(int next_k)
grow the sorted part of the array to size next_k
Definition: IndexPQ.cpp:575
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void generalized_hammings_knn(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t code_size, int ordered)
Definition: hamming.cpp:639
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
int d
vector dimension
Definition: Index.h:64
void hamming_distance_histogram(idx_t n, const float *x, idx_t nb, const float *xb, long *dist_histogram)
Definition: IndexPQ.cpp:382
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
size_t code_size
byte per indexed vector
Filter on generalized Hamming.
Definition: IndexPQ.h:81
size_t ksub
number of centroids for each subquantizer
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
long idx_t
all indices are this type
Definition: Index.h:62
void hammings_knn(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t ncodes, int order)
Definition: hamming.cpp:474
ProductQuantizer pq
The product quantizer used to encode the vectors.
Definition: IndexPQ.h:31
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:65
bool verbose
verbosity level
Definition: Index.h:66
void add(idx_t n, const float *x) override
Definition: IndexPQ.cpp:78
int K
nb of sums to return
Definition: IndexPQ.cpp:638
void hamming_distance_table(idx_t n, const float *x, int32_t *dis) const
Definition: IndexPQ.cpp:370
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:805
void reconstruct(idx_t key, float *recons) const override
Definition: IndexPQ.cpp:104
MetricType metric_type
type of metric this index uses for search
Definition: Index.h:72
size_t M
number of subquantizers
int N
nb of possible elements for each of the M terms
Definition: IndexPQ.cpp:640
void reconstruct_n(idx_t i0, idx_t ni, float *recons) const override
Definition: IndexPQ.cpp:94
asymmetric product quantizer (default)
Definition: IndexPQ.h:76
void reconstruct(idx_t key, float *recons) const override
Definition: IndexPQ.cpp:861
HE filter (using ht) + PQ combination.
Definition: IndexPQ.h:80
void add(idx_t n, const float *x) override
add and reset will crash at runtime
Definition: IndexPQ.cpp:884
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:69
void reset() override
removes all elements from the database.
Definition: IndexPQ.cpp:88
void optimize_pq_for_hamming(ProductQuantizer &pq, size_t n, const float *x) const
bool verbose
verbose during training?
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:124
symmetric product quantizer (SDC)
Definition: IndexPQ.h:79
int polysemous_ht
Hamming threshold used for polysemy.
Definition: IndexPQ.h:91
PolysemousTraining polysemous_training
parameters used for the polysemous training
Definition: IndexPQ.h:72
MetricType
Some algorithms support both an inner product vetsion and a L2 search version.
Definition: Index.h:43