Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/matthijs/github_faiss/faiss/IndexPQ.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 /* Copyright 2004-present Facebook. All Rights Reserved.
10  Index based on product quantiztion.
11 */
12 
13 #include "IndexPQ.h"
14 
15 
16 #include <cstddef>
17 #include <cstring>
18 #include <cstdio>
19 #include <cmath>
20 
21 #include <algorithm>
22 
23 #include "FaissAssert.h"
24 #include "hamming.h"
25 
26 namespace faiss {
27 
28 /*********************************************************
29  * IndexPQ implementation
30  ********************************************************/
31 
32 
33 IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
34  Index(d, metric), pq(d, M, nbits)
35 {
36  is_trained = false;
37  do_polysemous_training = false;
38  polysemous_ht = nbits * M + 1;
39  search_type = ST_PQ;
40  encode_signs = false;
41 }
42 
43 IndexPQ::IndexPQ ()
44 {
45  metric_type = METRIC_L2;
46  is_trained = false;
47  do_polysemous_training = false;
48  polysemous_ht = pq.nbits * pq.M + 1;
49  search_type = ST_PQ;
50  encode_signs = false;
51 }
52 
53 
54 void IndexPQ::train (idx_t n, const float *x)
55 {
56  if (!do_polysemous_training) { // standard training
57  pq.train(n, x);
58  } else {
59  idx_t ntrain_perm = polysemous_training.ntrain_permutation;
60 
61  if (ntrain_perm > n / 4)
62  ntrain_perm = n / 4;
63  if (verbose) {
64  printf ("PQ training on %ld points, remains %ld points: "
65  "training polysemous on %s\n",
66  n - ntrain_perm, ntrain_perm,
67  ntrain_perm == 0 ? "centroids" : "these");
68  }
69  pq.train(n - ntrain_perm, x);
70 
72  pq, ntrain_perm, x + (n - ntrain_perm) * d);
73  }
74  is_trained = true;
75 }
76 
77 
78 void IndexPQ::add (idx_t n, const float *x)
79 {
80  FAISS_THROW_IF_NOT (is_trained);
81  codes.resize ((n + ntotal) * pq.code_size);
83  ntotal += n;
84 }
85 
86 
87 
89 {
90  codes.clear();
91  ntotal = 0;
92 }
93 
94 void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
95 {
96  FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
97  for (idx_t i = 0; i < ni; i++) {
98  const uint8_t * code = &codes[(i0 + i) * pq.code_size];
99  pq.decode (code, recons + i * d);
100  }
101 }
102 
103 
104 void IndexPQ::reconstruct (idx_t key, float * recons) const
105 {
106  FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
107  pq.decode (&codes[key * pq.code_size], recons);
108 }
109 
110 
111 
112 
113 
114 
115 
116 /*****************************************
117  * IndexPQ polysemous search routines
118  ******************************************/
119 
120 
121 
122 
123 
124 void IndexPQ::search (idx_t n, const float *x, idx_t k,
125  float *distances, idx_t *labels) const
126 {
127  FAISS_THROW_IF_NOT (is_trained);
128  if (search_type == ST_PQ) { // Simple PQ search
129 
130  if (metric_type == METRIC_L2) {
131  float_maxheap_array_t res = {
132  size_t(n), size_t(k), labels, distances };
133  pq.search (x, n, codes.data(), ntotal, &res, true);
134  } else {
135  float_minheap_array_t res = {
136  size_t(n), size_t(k), labels, distances };
137  pq.search_ip (x, n, codes.data(), ntotal, &res, true);
138  }
139  indexPQ_stats.nq += n;
140  indexPQ_stats.ncode += n * ntotal;
141 
142  } else if (search_type == ST_polysemous ||
143  search_type == ST_polysemous_generalize) {
144 
145  FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
146 
147  search_core_polysemous (n, x, k, distances, labels);
148 
149  } else { // code-to-code distances
150 
151  uint8_t * q_codes = new uint8_t [n * pq.code_size];
152  ScopeDeleter<uint8_t> del (q_codes);
153 
154 
155  if (!encode_signs) {
156  pq.compute_codes (x, q_codes, n);
157  } else {
158  FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
159  memset (q_codes, 0, n * pq.code_size);
160  for (size_t i = 0; i < n; i++) {
161  const float *xi = x + i * d;
162  uint8_t *code = q_codes + i * pq.code_size;
163  for (int j = 0; j < d; j++)
164  if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
165  }
166  }
167 
168  if (search_type == ST_SDC) {
169 
170  float_maxheap_array_t res = {
171  size_t(n), size_t(k), labels, distances};
172 
173  pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
174 
175  } else {
176  int * idistances = new int [n * k];
177  ScopeDeleter<int> del (idistances);
178 
179  int_maxheap_array_t res = {
180  size_t (n), size_t (k), labels, idistances};
181 
182  if (search_type == ST_HE) {
183 
184  hammings_knn (&res, q_codes, codes.data(),
185  ntotal, pq.code_size, true);
186 
187  } else if (search_type == ST_generalized_HE) {
188 
189  generalized_hammings_knn (&res, q_codes, codes.data(),
190  ntotal, pq.code_size, true);
191  }
192 
193  // convert distances to floats
194  for (int i = 0; i < k * n; i++)
195  distances[i] = idistances[i];
196 
197  }
198 
199 
200  indexPQ_stats.nq += n;
201  indexPQ_stats.ncode += n * ntotal;
202  }
203 }
204 
205 
206 
207 
208 
209 void IndexPQStats::reset()
210 {
211  nq = ncode = n_hamming_pass = 0;
212 }
213 
214 IndexPQStats indexPQ_stats;
215 
216 
217 template <class HammingComputer>
218 static size_t polysemous_inner_loop (
219  const IndexPQ & index,
220  const float *dis_table_qi, const uint8_t *q_code,
221  size_t k, float *heap_dis, long *heap_ids)
222 {
223 
224  int M = index.pq.M;
225  int code_size = index.pq.code_size;
226  int ksub = index.pq.ksub;
227  size_t ntotal = index.ntotal;
228  int ht = index.polysemous_ht;
229 
230  const uint8_t *b_code = index.codes.data();
231 
232  size_t n_pass_i = 0;
233 
234  HammingComputer hc (q_code, code_size);
235 
236  for (long bi = 0; bi < ntotal; bi++) {
237  int hd = hc.hamming (b_code);
238 
239  if (hd < ht) {
240  n_pass_i ++;
241 
242  float dis = 0;
243  const float * dis_table = dis_table_qi;
244  for (int m = 0; m < M; m++) {
245  dis += dis_table [b_code[m]];
246  dis_table += ksub;
247  }
248 
249  if (dis < heap_dis[0]) {
250  maxheap_pop (k, heap_dis, heap_ids);
251  maxheap_push (k, heap_dis, heap_ids, dis, bi);
252  }
253  }
254  b_code += code_size;
255  }
256  return n_pass_i;
257 }
258 
259 
260 void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
261  float *distances, idx_t *labels) const
262 {
263  FAISS_THROW_IF_NOT (pq.byte_per_idx == 1);
264 
265  // PQ distance tables
266  float * dis_tables = new float [n * pq.ksub * pq.M];
267  ScopeDeleter<float> del (dis_tables);
268  pq.compute_distance_tables (n, x, dis_tables);
269 
270  // Hamming embedding queries
271  uint8_t * q_codes = new uint8_t [n * pq.code_size];
272  ScopeDeleter<uint8_t> del2 (q_codes);
273 
274  if (false) {
275  pq.compute_codes (x, q_codes, n);
276  } else {
277 #pragma omp parallel for
278  for (idx_t qi = 0; qi < n; qi++) {
280  (dis_tables + qi * pq.M * pq.ksub,
281  q_codes + qi * pq.code_size);
282  }
283  }
284 
285  size_t n_pass = 0;
286 
287 #pragma omp parallel for reduction (+: n_pass)
288  for (idx_t qi = 0; qi < n; qi++) {
289  const uint8_t * q_code = q_codes + qi * pq.code_size;
290 
291  const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
292 
293  long * heap_ids = labels + qi * k;
294  float *heap_dis = distances + qi * k;
295  maxheap_heapify (k, heap_dis, heap_ids);
296 
297  if (search_type == ST_polysemous) {
298 
299  switch (pq.code_size) {
300  case 4:
301  n_pass += polysemous_inner_loop<HammingComputer4>
302  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
303  break;
304  case 8:
305  n_pass += polysemous_inner_loop<HammingComputer8>
306  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
307  break;
308  case 16:
309  n_pass += polysemous_inner_loop<HammingComputer16>
310  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
311  break;
312  case 32:
313  n_pass += polysemous_inner_loop<HammingComputer32>
314  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
315  break;
316  case 20:
317  n_pass += polysemous_inner_loop<HammingComputer20>
318  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
319  break;
320  default:
321  if (pq.code_size % 8 == 0) {
322  n_pass += polysemous_inner_loop<HammingComputerM8>
323  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
324  } else if (pq.code_size % 4 == 0) {
325  n_pass += polysemous_inner_loop<HammingComputerM4>
326  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
327  } else {
328  FAISS_THROW_FMT(
329  "code size %zd not supported for polysemous",
330  pq.code_size);
331  }
332  break;
333  }
334  } else {
335  switch (pq.code_size) {
336  case 8:
337  n_pass += polysemous_inner_loop<GenHammingComputer8>
338  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
339  break;
340  case 16:
341  n_pass += polysemous_inner_loop<GenHammingComputer16>
342  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
343  break;
344  case 32:
345  n_pass += polysemous_inner_loop<GenHammingComputer32>
346  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
347  break;
348  default:
349  if (pq.code_size % 8 == 0) {
350  n_pass += polysemous_inner_loop<GenHammingComputerM8>
351  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
352  } else {
353  FAISS_THROW_FMT(
354  "code size %zd not supported for polysemous",
355  pq.code_size);
356  }
357  break;
358  }
359  }
360  maxheap_reorder (k, heap_dis, heap_ids);
361  }
362 
363  indexPQ_stats.nq += n;
364  indexPQ_stats.ncode += n * ntotal;
365  indexPQ_stats.n_hamming_pass += n_pass;
366 
367 
368 }
369 
370 
371 
372 
373 /*****************************************
374  * Stats of IndexPQ codes
375  ******************************************/
376 
377 
378 
379 
380 void IndexPQ::hamming_distance_table (idx_t n, const float *x,
381  int32_t *dis) const
382 {
383  uint8_t * q_codes = new uint8_t [n * pq.code_size];
384  ScopeDeleter<uint8_t> del (q_codes);
385 
386  pq.compute_codes (x, q_codes, n);
387 
388  hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
389 }
390 
391 
393  idx_t nb, const float *xb,
394  long *hist)
395 {
396  FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
397  FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
398  FAISS_THROW_IF_NOT (pq.byte_per_idx == 1);
399 
400  // Hamming embedding queries
401  uint8_t * q_codes = new uint8_t [n * pq.code_size];
402  ScopeDeleter <uint8_t> del (q_codes);
403  pq.compute_codes (x, q_codes, n);
404 
405  uint8_t * b_codes ;
406  ScopeDeleter <uint8_t> del_b_codes;
407 
408  if (xb) {
409  b_codes = new uint8_t [nb * pq.code_size];
410  del_b_codes.set (b_codes);
411  pq.compute_codes (xb, b_codes, nb);
412  } else {
413  nb = ntotal;
414  b_codes = codes.data();
415  }
416  int nbits = pq.M * pq.nbits;
417  memset (hist, 0, sizeof(*hist) * (nbits + 1));
418  size_t bs = 256;
419 
420 #pragma omp parallel
421  {
422  std::vector<long> histi (nbits + 1);
423  hamdis_t *distances = new hamdis_t [nb * bs];
424  ScopeDeleter<hamdis_t> del (distances);
425 #pragma omp for
426  for (size_t q0 = 0; q0 < n; q0 += bs) {
427  // printf ("dis stats: %ld/%ld\n", q0, n);
428  size_t q1 = q0 + bs;
429  if (q1 > n) q1 = n;
430 
431  hammings (q_codes + q0 * pq.code_size, b_codes,
432  q1 - q0, nb,
433  pq.code_size, distances);
434 
435  for (size_t i = 0; i < nb * (q1 - q0); i++)
436  histi [distances [i]]++;
437  }
438 #pragma omp critical
439  {
440  for (int i = 0; i <= nbits; i++)
441  hist[i] += histi[i];
442  }
443  }
444 
445 }
446 
447 
448 
449 
450 
451 
452 
453 
454 
455 
456 
457 
458 
459 
460 
461 
462 
463 
464 
465 
466 /*****************************************
467  * MultiIndexQuantizer
468  ******************************************/
469 
470 namespace {
471 
472 template <typename T>
473 struct PreSortedArray {
474 
475  const T * x;
476  int N;
477 
478  explicit PreSortedArray (int N): N(N) {
479  }
480  void init (const T*x) {
481  this->x = x;
482  }
483  // get smallest value
484  T get_0 () {
485  return x[0];
486  }
487 
488  // get delta between n-smallest and n-1 -smallest
489  T get_diff (int n) {
490  return x[n] - x[n - 1];
491  }
492 
493  // remap orders counted from smallest to indices in array
494  int get_ord (int n) {
495  return n;
496  }
497 
498 };
499 
500 template <typename T>
501 struct ArgSort {
502  const T * x;
503  bool operator() (size_t i, size_t j) {
504  return x[i] < x[j];
505  }
506 };
507 
508 
509 /** Array that maintains a permutation of its elements so that the
510  * array's elements are sorted
511  */
512 template <typename T>
513 struct SortedArray {
514  const T * x;
515  int N;
516  std::vector<int> perm;
517 
518  explicit SortedArray (int N) {
519  this->N = N;
520  perm.resize (N);
521  }
522 
523  void init (const T*x) {
524  this->x = x;
525  for (int n = 0; n < N; n++)
526  perm[n] = n;
527  ArgSort<T> cmp = {x };
528  std::sort (perm.begin(), perm.end(), cmp);
529  }
530 
531  // get smallest value
532  T get_0 () {
533  return x[perm[0]];
534  }
535 
536  // get delta between n-smallest and n-1 -smallest
537  T get_diff (int n) {
538  return x[perm[n]] - x[perm[n - 1]];
539  }
540 
541  // remap orders counted from smallest to indices in array
542  int get_ord (int n) {
543  return perm[n];
544  }
545 };
546 
547 
548 
549 /** Array has n values. Sort the k first ones and copy the other ones
550  * into elements k..n-1
551  */
552 template <class C>
553 void partial_sort (int k, int n,
554  const typename C::T * vals, typename C::TI * perm) {
555  // insert first k elts in heap
556  for (int i = 1; i < k; i++) {
557  indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
558  }
559 
560  // insert next n - k elts in heap
561  for (int i = k; i < n; i++) {
562  typename C::TI id = perm[i];
563  typename C::TI top = perm[0];
564 
565  if (C::cmp(vals[top], vals[id])) {
566  indirect_heap_pop<C> (k, vals, perm);
567  indirect_heap_push<C> (k, vals, perm, id);
568  perm[i] = top;
569  } else {
570  // nothing, elt at i is good where it is.
571  }
572  }
573 
574  // order the k first elements in heap
575  for (int i = k - 1; i > 0; i--) {
576  typename C::TI top = perm[0];
577  indirect_heap_pop<C> (i + 1, vals, perm);
578  perm[i] = top;
579  }
580 }
581 
582 /** same as SortedArray, but only the k first elements are sorted */
583 template <typename T>
584 struct SemiSortedArray {
585  const T * x;
586  int N;
587 
588  // type of the heap: CMax = sort ascending
589  typedef CMax<T, int> HC;
590  std::vector<int> perm;
591 
592  int k; // k elements are sorted
593 
594  int initial_k, k_factor;
595 
596  explicit SemiSortedArray (int N) {
597  this->N = N;
598  perm.resize (N);
599  perm.resize (N);
600  initial_k = 3;
601  k_factor = 4;
602  }
603 
604  void init (const T*x) {
605  this->x = x;
606  for (int n = 0; n < N; n++)
607  perm[n] = n;
608  k = 0;
609  grow (initial_k);
610  }
611 
612  /// grow the sorted part of the array to size next_k
613  void grow (int next_k) {
614  if (next_k < N) {
615  partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
616  k = next_k;
617  } else { // full sort of remainder of array
618  ArgSort<T> cmp = {x };
619  std::sort (perm.begin() + k, perm.end(), cmp);
620  k = N;
621  }
622  }
623 
624  // get smallest value
625  T get_0 () {
626  return x[perm[0]];
627  }
628 
629  // get delta between n-smallest and n-1 -smallest
630  T get_diff (int n) {
631  if (n >= k) {
632  // want to keep powers of 2 - 1
633  int next_k = (k + 1) * k_factor - 1;
634  grow (next_k);
635  }
636  return x[perm[n]] - x[perm[n - 1]];
637  }
638 
639  // remap orders counted from smallest to indices in array
640  int get_ord (int n) {
641  assert (n < k);
642  return perm[n];
643  }
644 };
645 
646 
647 
648 /*****************************************
649  * Find the k smallest sums of M terms, where each term is taken in a
650  * table x of n values.
651  *
652  * A combination of terms is encoded as a scalar 0 <= t < n^M. The
653  * combination t0 ... t(M-1) that correspond to the sum
654  *
655  * sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)]
656  *
657  * is encoded as
658  *
659  * t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1)
660  *
661  * MinSumK is an object rather than a function, so that storage can be
662  * re-used over several computations with the same sizes. use_seen is
663  * good when there may be ties in the x array and it is a concern if
664  * occasionally several t's are returned.
665  *
666  * @param x size M * n, values to add up
667  * @parms k nb of results to retrieve
668  * @param M nb of terms
669  * @param n nb of distinct values
670  * @param sums output, size k, sorted
671  * @prarm terms output, size k, with encoding as above
672  *
673  ******************************************/
674 template <typename T, class SSA, bool use_seen>
675 struct MinSumK {
676  int K; ///< nb of sums to return
677  int M; ///< nb of elements to sum up
678  int nbit; ///< nb of bits to encode one entry
679  int N; ///< nb of possible elements for each of the M terms
680 
681  /** the heap.
682  * We use a heap to maintain a queue of sums, with the associated
683  * terms involved in the sum.
684  */
685  typedef CMin<T, long> HC;
686  size_t heap_capacity, heap_size;
687  T *bh_val;
688  long *bh_ids;
689 
690  std::vector <SSA> ssx;
691 
692  // all results get pushed several times. When there are ties, they
693  // are popped interleaved with others, so it is not easy to
694  // identify them. Therefore, this bit array just marks elements
695  // that were seen before.
696  std::vector <uint8_t> seen;
697 
698  MinSumK (int K, int M, int nbit, int N):
699  K(K), M(M), nbit(nbit), N(N) {
700  heap_capacity = K * M;
701  assert (N <= (1 << nbit));
702 
703  // we'll do k steps, each step pushes at most M vals
704  bh_val = new T[heap_capacity];
705  bh_ids = new long[heap_capacity];
706 
707  if (use_seen) {
708  long n_ids = weight(M);
709  seen.resize ((n_ids + 7) / 8);
710  }
711 
712  for (int m = 0; m < M; m++)
713  ssx.push_back (SSA(N));
714 
715  }
716 
717  long weight (int i) {
718  return 1 << (i * nbit);
719  }
720 
721  bool is_seen (long i) {
722  return (seen[i >> 3] >> (i & 7)) & 1;
723  }
724 
725  void mark_seen (long i) {
726  if (use_seen)
727  seen [i >> 3] |= 1 << (i & 7);
728  }
729 
730  void run (const T *x, long ldx,
731  T * sums, long * terms) {
732  heap_size = 0;
733 
734  for (int m = 0; m < M; m++) {
735  ssx[m].init(x);
736  x += ldx;
737  }
738 
739  { // intial result: take min for all elements
740  T sum = 0;
741  terms[0] = 0;
742  mark_seen (0);
743  for (int m = 0; m < M; m++) {
744  sum += ssx[m].get_0();
745  }
746  sums[0] = sum;
747  for (int m = 0; m < M; m++) {
748  heap_push<HC> (++heap_size, bh_val, bh_ids,
749  sum + ssx[m].get_diff(1),
750  weight(m));
751  }
752  }
753 
754  for (int k = 1; k < K; k++) {
755  // pop smallest value from heap
756  if (use_seen) {// skip already seen elements
757  while (is_seen (bh_ids[0])) {
758  assert (heap_size > 0);
759  heap_pop<HC> (heap_size--, bh_val, bh_ids);
760  }
761  }
762  assert (heap_size > 0);
763 
764  T sum = sums[k] = bh_val[0];
765  long ti = terms[k] = bh_ids[0];
766 
767  if (use_seen) {
768  mark_seen (ti);
769  heap_pop<HC> (heap_size--, bh_val, bh_ids);
770  } else {
771  do {
772  heap_pop<HC> (heap_size--, bh_val, bh_ids);
773  } while (heap_size > 0 && bh_ids[0] == ti);
774  }
775 
776  // enqueue followers
777  long ii = ti;
778  for (int m = 0; m < M; m++) {
779  long n = ii & ((1 << nbit) - 1);
780  ii >>= nbit;
781  if (n + 1 >= N) continue;
782 
783  enqueue_follower (ti, m, n, sum);
784  }
785  }
786 
787  /*
788  for (int k = 0; k < K; k++)
789  for (int l = k + 1; l < K; l++)
790  assert (terms[k] != terms[l]);
791  */
792 
793  // convert indices by applying permutation
794  for (int k = 0; k < K; k++) {
795  long ii = terms[k];
796  if (use_seen) {
797  // clear seen for reuse at next loop
798  seen[ii >> 3] = 0;
799  }
800  long ti = 0;
801  for (int m = 0; m < M; m++) {
802  long n = ii & ((1 << nbit) - 1);
803  ti += ssx[m].get_ord(n) << (nbit * m);
804  ii >>= nbit;
805  }
806  terms[k] = ti;
807  }
808  }
809 
810 
811  void enqueue_follower (long ti, int m, int n, T sum) {
812  T next_sum = sum + ssx[m].get_diff(n + 1);
813  long next_ti = ti + weight(m);
814  heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
815  }
816 
817  ~MinSumK () {
818  delete [] bh_ids;
819  delete [] bh_val;
820  }
821 };
822 
823 } // anonymous namespace
824 
825 
826 MultiIndexQuantizer::MultiIndexQuantizer (int d,
827  size_t M,
828  size_t nbits):
829  Index(d, METRIC_L2), pq(d, M, nbits)
830 {
831  is_trained = false;
832  pq.verbose = verbose;
833 }
834 
835 
836 
837 void MultiIndexQuantizer::train(idx_t n, const float *x)
838 {
839  pq.verbose = verbose;
840  pq.train (n, x);
841  is_trained = true;
842  // count virtual elements in index
843  ntotal = 1;
844  for (int m = 0; m < pq.M; m++)
845  ntotal *= pq.ksub;
846 }
847 
848 
849 void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
850  float *distances, idx_t *labels) const {
851  if (n == 0) return;
852 
853  float * dis_tables = new float [n * pq.ksub * pq.M];
854  ScopeDeleter<float> del (dis_tables);
855 
856  pq.compute_distance_tables (n, x, dis_tables);
857 
858  if (k == 1) {
859  // simple version that just finds the min in each table
860 
861 #pragma omp parallel for
862  for (int i = 0; i < n; i++) {
863  const float * dis_table = dis_tables + i * pq.ksub * pq.M;
864  float dis = 0;
865  idx_t label = 0;
866 
867  for (int s = 0; s < pq.M; s++) {
868  float vmin = HUGE_VALF;
869  idx_t lmin = -1;
870 
871  for (idx_t j = 0; j < pq.ksub; j++) {
872  if (dis_table[j] < vmin) {
873  vmin = dis_table[j];
874  lmin = j;
875  }
876  }
877  dis += vmin;
878  label |= lmin << (s * pq.nbits);
879  dis_table += pq.ksub;
880  }
881 
882  distances [i] = dis;
883  labels [i] = label;
884  }
885 
886 
887  } else {
888 
889 #pragma omp parallel if(n > 1)
890  {
891  MinSumK <float, SemiSortedArray<float>, false>
892  msk(k, pq.M, pq.nbits, pq.ksub);
893 #pragma omp for
894  for (int i = 0; i < n; i++) {
895  msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
896  distances + i * k, labels + i * k);
897 
898  }
899  }
900  }
901 
902 }
903 
904 
905 void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
906 {
907 
908  long jj = key;
909  for (int m = 0; m < pq.M; m++) {
910  long n = jj & ((1L << pq.nbits) - 1);
911  jj >>= pq.nbits;
912  memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
913  recons += pq.dsub;
914  }
915 }
916 
917 void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
918  FAISS_THROW_MSG(
919  "This index has virtual elements, "
920  "it does not support add");
921 }
922 
924 {
925  FAISS_THROW_MSG ( "This index has virtual elements, "
926  "it does not support reset");
927 }
928 
929 
930 
931 
932 
933 
934 
935 
936 
937 
938 /*****************************************
939  * MultiIndexQuantizer2
940  ******************************************/
941 
942 
943 
944 MultiIndexQuantizer2::MultiIndexQuantizer2 (
945  int d, size_t M, size_t nbits,
946  Index **indexes):
947  MultiIndexQuantizer (d, M, nbits)
948 {
949  assign_indexes.resize (M);
950  for (int i = 0; i < M; i++) {
951  FAISS_THROW_IF_NOT_MSG(
952  indexes[i]->d == pq.dsub,
953  "Provided sub-index has incorrect size");
954  assign_indexes[i] = indexes[i];
955  }
956  own_fields = false;
957 }
958 
959 MultiIndexQuantizer2::MultiIndexQuantizer2 (
960  int d, size_t nbits,
961  Index *assign_index_0,
962  Index *assign_index_1):
963  MultiIndexQuantizer (d, 2, nbits)
964 {
965  FAISS_THROW_IF_NOT_MSG(
966  assign_index_0->d == pq.dsub &&
967  assign_index_1->d == pq.dsub,
968  "Provided sub-index has incorrect size");
969  assign_indexes.resize (2);
970  assign_indexes [0] = assign_index_0;
971  assign_indexes [1] = assign_index_1;
972  own_fields = false;
973 }
974 
975 void MultiIndexQuantizer2::train(idx_t n, const float* x)
976 {
978  // add centroids to sub-indexes
979  for (int i = 0; i < pq.M; i++) {
980  assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0));
981  }
982 }
983 
984 
986  idx_t n, const float* x, idx_t K,
987  float* distances, idx_t* labels) const
988 {
989 
990  if (n == 0) return;
991 
992  int k2 = std::min(K, long(pq.ksub));
993 
994  long M = pq.M;
995  long dsub = pq.dsub, ksub = pq.ksub;
996 
997  // size (M, n, k2)
998  std::vector<idx_t> sub_ids(n * M * k2);
999  std::vector<float> sub_dis(n * M * k2);
1000  std::vector<float> xsub(n * dsub);
1001 
1002  for (int m = 0; m < M; m++) {
1003  float *xdest = xsub.data();
1004  const float *xsrc = x + m * dsub;
1005  for (int j = 0; j < n; j++) {
1006  memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
1007  xsrc += d;
1008  xdest += dsub;
1009  }
1010 
1011  assign_indexes[m]->search(
1012  n, xsub.data(), k2,
1013  &sub_dis[k2 * n * m],
1014  &sub_ids[k2 * n * m]);
1015  }
1016 
1017  if (K == 1) {
1018  // simple version that just finds the min in each table
1019  assert (k2 == 1);
1020 
1021  for (int i = 0; i < n; i++) {
1022  float dis = 0;
1023  idx_t label = 0;
1024 
1025  for (int m = 0; m < M; m++) {
1026  float vmin = sub_dis[i + m * n];
1027  idx_t lmin = sub_ids[i + m * n];
1028  dis += vmin;
1029  label |= lmin << (m * pq.nbits);
1030  }
1031  distances [i] = dis;
1032  labels [i] = label;
1033  }
1034 
1035  } else {
1036 
1037 #pragma omp parallel if(n > 1)
1038  {
1039  MinSumK <float, PreSortedArray<float>, false>
1040  msk(K, pq.M, pq.nbits, k2);
1041 #pragma omp for
1042  for (int i = 0; i < n; i++) {
1043  idx_t *li = labels + i * K;
1044  msk.run (&sub_dis[i * k2], k2 * n,
1045  distances + i * K, li);
1046 
1047  // remap ids
1048 
1049  const idx_t *idmap0 = sub_ids.data() + i * k2;
1050  long ld_idmap = k2 * n;
1051  long mask1 = ksub - 1L;
1052 
1053  for (int k = 0; k < K; k++) {
1054  const idx_t *idmap = idmap0;
1055  long vin = li[k];
1056  long vout = 0;
1057  int bs = 0;
1058  for (int m = 0; m < M; m++) {
1059  long s = vin & mask1;
1060  vin >>= pq.nbits;
1061  vout |= idmap[s] << bs;
1062  bs += pq.nbits;
1063  idmap += ld_idmap;
1064  }
1065  li[k] = vout;
1066  }
1067  }
1068  }
1069  }
1070 }
1071 
1072 
1073 
1074 } // END namespace faiss
std::vector< uint8_t > codes
Codes. Size ntotal * pq.code_size.
Definition: IndexPQ.h:34
size_t nbits
number of bits per quantization index
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
Hamming distance on codes.
Definition: IndexPQ.h:77
bool do_polysemous_training
false = standard PQ
Definition: IndexPQ.h:69
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:54
size_t byte_per_idx
nb bytes per code component (1 or 2)
void reset() override
removes all elements from the database.
Definition: IndexPQ.cpp:923
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:837
size_t dsub
dimensionality of each subvector
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void generalized_hammings_knn(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t code_size, int ordered)
Definition: hamming.cpp:639
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
int d
vector dimension
Definition: Index.h:64
void hamming_distance_histogram(idx_t n, const float *x, idx_t nb, const float *xb, long *dist_histogram)
Definition: IndexPQ.cpp:392
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:975
size_t code_size
byte per indexed vector
Filter on generalized Hamming.
Definition: IndexPQ.h:81
size_t ksub
number of centroids for each subquantizer
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
long idx_t
all indices are this type
Definition: Index.h:62
void hammings_knn(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t ncodes, int order)
Definition: hamming.cpp:474
ProductQuantizer pq
The product quantizer used to encode the vectors.
Definition: IndexPQ.h:31
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:65
bool verbose
verbosity level
Definition: Index.h:66
void add(idx_t n, const float *x) override
Definition: IndexPQ.cpp:78
void hamming_distance_table(idx_t n, const float *x, int32_t *dis) const
Definition: IndexPQ.cpp:380
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:849
void reconstruct(idx_t key, float *recons) const override
Definition: IndexPQ.cpp:104
MetricType metric_type
type of metric this index uses for search
Definition: Index.h:72
size_t M
number of subquantizers
void reconstruct_n(idx_t i0, idx_t ni, float *recons) const override
Definition: IndexPQ.cpp:94
asymmetric product quantizer (default)
Definition: IndexPQ.h:76
void reconstruct(idx_t key, float *recons) const override
Definition: IndexPQ.cpp:905
HE filter (using ht) + PQ combination.
Definition: IndexPQ.h:80
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:985
void add(idx_t n, const float *x) override
add and reset will crash at runtime
Definition: IndexPQ.cpp:917
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:69
void reset() override
removes all elements from the database.
Definition: IndexPQ.cpp:88
float * get_centroids(size_t m, size_t i)
return the centroids associated with subvector m
void optimize_pq_for_hamming(ProductQuantizer &pq, size_t n, const float *x) const
bool verbose
verbose during training?
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:124
symmetric product quantizer (SDC)
Definition: IndexPQ.h:79
int polysemous_ht
Hamming threshold used for polysemy.
Definition: IndexPQ.h:91
PolysemousTraining polysemous_training
parameters used for the polysemous training
Definition: IndexPQ.h:72
std::vector< Index * > assign_indexes
M Indexes on d / M dimensions.
Definition: IndexPQ.h:163
MetricType
Some algorithms support both an inner product version and a L2 search version.
Definition: Index.h:43