Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/hoss/faiss/IndexPQ.cpp
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 #include "IndexPQ.h"
11 
12 
13 #include <cstddef>
14 #include <cstring>
15 #include <cstdio>
16 #include <cmath>
17 
18 #include <algorithm>
19 
20 #include "FaissAssert.h"
21 #include "AuxIndexStructures.h"
22 #include "hamming.h"
23 
24 namespace faiss {
25 
26 /*********************************************************
27  * IndexPQ implementation
28  ********************************************************/
29 
30 
31 IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
32  Index(d, metric), pq(d, M, nbits)
33 {
34  is_trained = false;
35  do_polysemous_training = false;
36  polysemous_ht = nbits * M + 1;
37  search_type = ST_PQ;
38  encode_signs = false;
39 }
40 
41 IndexPQ::IndexPQ ()
42 {
43  metric_type = METRIC_L2;
44  is_trained = false;
45  do_polysemous_training = false;
46  polysemous_ht = pq.nbits * pq.M + 1;
47  search_type = ST_PQ;
48  encode_signs = false;
49 }
50 
51 
52 void IndexPQ::train (idx_t n, const float *x)
53 {
54  if (!do_polysemous_training) { // standard training
55  pq.train(n, x);
56  } else {
57  idx_t ntrain_perm = polysemous_training.ntrain_permutation;
58 
59  if (ntrain_perm > n / 4)
60  ntrain_perm = n / 4;
61  if (verbose) {
62  printf ("PQ training on %ld points, remains %ld points: "
63  "training polysemous on %s\n",
64  n - ntrain_perm, ntrain_perm,
65  ntrain_perm == 0 ? "centroids" : "these");
66  }
67  pq.train(n - ntrain_perm, x);
68 
70  pq, ntrain_perm, x + (n - ntrain_perm) * d);
71  }
72  is_trained = true;
73 }
74 
75 
76 void IndexPQ::add (idx_t n, const float *x)
77 {
78  FAISS_THROW_IF_NOT (is_trained);
79  codes.resize ((n + ntotal) * pq.code_size);
81  ntotal += n;
82 }
83 
84 
85 long IndexPQ::remove_ids (const IDSelector & sel)
86 {
87  idx_t j = 0;
88  for (idx_t i = 0; i < ntotal; i++) {
89  if (sel.is_member (i)) {
90  // should be removed
91  } else {
92  if (i > j) {
93  memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
94  }
95  j++;
96  }
97  }
98  long nremove = ntotal - j;
99  if (nremove > 0) {
100  ntotal = j;
101  codes.resize (ntotal * pq.code_size);
102  }
103  return nremove;
104 }
105 
106 
108 {
109  codes.clear();
110  ntotal = 0;
111 }
112 
113 void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
114 {
115  FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
116  for (idx_t i = 0; i < ni; i++) {
117  const uint8_t * code = &codes[(i0 + i) * pq.code_size];
118  pq.decode (code, recons + i * d);
119  }
120 }
121 
122 
123 void IndexPQ::reconstruct (idx_t key, float * recons) const
124 {
125  FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
126  pq.decode (&codes[key * pq.code_size], recons);
127 }
128 
129 
130 
131 
132 
133 
134 
135 /*****************************************
136  * IndexPQ polysemous search routines
137  ******************************************/
138 
139 
140 
141 
142 
143 void IndexPQ::search (idx_t n, const float *x, idx_t k,
144  float *distances, idx_t *labels) const
145 {
146  FAISS_THROW_IF_NOT (is_trained);
147  if (search_type == ST_PQ) { // Simple PQ search
148 
149  if (metric_type == METRIC_L2) {
150  float_maxheap_array_t res = {
151  size_t(n), size_t(k), labels, distances };
152  pq.search (x, n, codes.data(), ntotal, &res, true);
153  } else {
154  float_minheap_array_t res = {
155  size_t(n), size_t(k), labels, distances };
156  pq.search_ip (x, n, codes.data(), ntotal, &res, true);
157  }
158  indexPQ_stats.nq += n;
159  indexPQ_stats.ncode += n * ntotal;
160 
161  } else if (search_type == ST_polysemous ||
162  search_type == ST_polysemous_generalize) {
163 
164  FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
165 
166  search_core_polysemous (n, x, k, distances, labels);
167 
168  } else { // code-to-code distances
169 
170  uint8_t * q_codes = new uint8_t [n * pq.code_size];
171  ScopeDeleter<uint8_t> del (q_codes);
172 
173 
174  if (!encode_signs) {
175  pq.compute_codes (x, q_codes, n);
176  } else {
177  FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
178  memset (q_codes, 0, n * pq.code_size);
179  for (size_t i = 0; i < n; i++) {
180  const float *xi = x + i * d;
181  uint8_t *code = q_codes + i * pq.code_size;
182  for (int j = 0; j < d; j++)
183  if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
184  }
185  }
186 
187  if (search_type == ST_SDC) {
188 
189  float_maxheap_array_t res = {
190  size_t(n), size_t(k), labels, distances};
191 
192  pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
193 
194  } else {
195  int * idistances = new int [n * k];
196  ScopeDeleter<int> del (idistances);
197 
198  int_maxheap_array_t res = {
199  size_t (n), size_t (k), labels, idistances};
200 
201  if (search_type == ST_HE) {
202 
203  hammings_knn_hc (&res, q_codes, codes.data(),
204  ntotal, pq.code_size, true);
205 
206  } else if (search_type == ST_generalized_HE) {
207 
208  generalized_hammings_knn_hc (&res, q_codes, codes.data(),
209  ntotal, pq.code_size, true);
210  }
211 
212  // convert distances to floats
213  for (int i = 0; i < k * n; i++)
214  distances[i] = idistances[i];
215 
216  }
217 
218 
219  indexPQ_stats.nq += n;
220  indexPQ_stats.ncode += n * ntotal;
221  }
222 }
223 
224 
225 
226 
227 
228 void IndexPQStats::reset()
229 {
230  nq = ncode = n_hamming_pass = 0;
231 }
232 
233 IndexPQStats indexPQ_stats;
234 
235 
236 template <class HammingComputer>
237 static size_t polysemous_inner_loop (
238  const IndexPQ & index,
239  const float *dis_table_qi, const uint8_t *q_code,
240  size_t k, float *heap_dis, long *heap_ids)
241 {
242 
243  int M = index.pq.M;
244  int code_size = index.pq.code_size;
245  int ksub = index.pq.ksub;
246  size_t ntotal = index.ntotal;
247  int ht = index.polysemous_ht;
248 
249  const uint8_t *b_code = index.codes.data();
250 
251  size_t n_pass_i = 0;
252 
253  HammingComputer hc (q_code, code_size);
254 
255  for (long bi = 0; bi < ntotal; bi++) {
256  int hd = hc.hamming (b_code);
257 
258  if (hd < ht) {
259  n_pass_i ++;
260 
261  float dis = 0;
262  const float * dis_table = dis_table_qi;
263  for (int m = 0; m < M; m++) {
264  dis += dis_table [b_code[m]];
265  dis_table += ksub;
266  }
267 
268  if (dis < heap_dis[0]) {
269  maxheap_pop (k, heap_dis, heap_ids);
270  maxheap_push (k, heap_dis, heap_ids, dis, bi);
271  }
272  }
273  b_code += code_size;
274  }
275  return n_pass_i;
276 }
277 
278 
279 void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
280  float *distances, idx_t *labels) const
281 {
282  FAISS_THROW_IF_NOT (pq.nbits == 8);
283 
284  // PQ distance tables
285  float * dis_tables = new float [n * pq.ksub * pq.M];
286  ScopeDeleter<float> del (dis_tables);
287  pq.compute_distance_tables (n, x, dis_tables);
288 
289  // Hamming embedding queries
290  uint8_t * q_codes = new uint8_t [n * pq.code_size];
291  ScopeDeleter<uint8_t> del2 (q_codes);
292 
293  if (false) {
294  pq.compute_codes (x, q_codes, n);
295  } else {
296 #pragma omp parallel for
297  for (idx_t qi = 0; qi < n; qi++) {
299  (dis_tables + qi * pq.M * pq.ksub,
300  q_codes + qi * pq.code_size);
301  }
302  }
303 
304  size_t n_pass = 0;
305 
306 #pragma omp parallel for reduction (+: n_pass)
307  for (idx_t qi = 0; qi < n; qi++) {
308  const uint8_t * q_code = q_codes + qi * pq.code_size;
309 
310  const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
311 
312  long * heap_ids = labels + qi * k;
313  float *heap_dis = distances + qi * k;
314  maxheap_heapify (k, heap_dis, heap_ids);
315 
316  if (search_type == ST_polysemous) {
317 
318  switch (pq.code_size) {
319  case 4:
320  n_pass += polysemous_inner_loop<HammingComputer4>
321  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
322  break;
323  case 8:
324  n_pass += polysemous_inner_loop<HammingComputer8>
325  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
326  break;
327  case 16:
328  n_pass += polysemous_inner_loop<HammingComputer16>
329  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
330  break;
331  case 32:
332  n_pass += polysemous_inner_loop<HammingComputer32>
333  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
334  break;
335  case 20:
336  n_pass += polysemous_inner_loop<HammingComputer20>
337  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
338  break;
339  default:
340  if (pq.code_size % 8 == 0) {
341  n_pass += polysemous_inner_loop<HammingComputerM8>
342  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
343  } else if (pq.code_size % 4 == 0) {
344  n_pass += polysemous_inner_loop<HammingComputerM4>
345  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
346  } else {
347  FAISS_THROW_FMT(
348  "code size %zd not supported for polysemous",
349  pq.code_size);
350  }
351  break;
352  }
353  } else {
354  switch (pq.code_size) {
355  case 8:
356  n_pass += polysemous_inner_loop<GenHammingComputer8>
357  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
358  break;
359  case 16:
360  n_pass += polysemous_inner_loop<GenHammingComputer16>
361  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
362  break;
363  case 32:
364  n_pass += polysemous_inner_loop<GenHammingComputer32>
365  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
366  break;
367  default:
368  if (pq.code_size % 8 == 0) {
369  n_pass += polysemous_inner_loop<GenHammingComputerM8>
370  (*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
371  } else {
372  FAISS_THROW_FMT(
373  "code size %zd not supported for polysemous",
374  pq.code_size);
375  }
376  break;
377  }
378  }
379  maxheap_reorder (k, heap_dis, heap_ids);
380  }
381 
382  indexPQ_stats.nq += n;
383  indexPQ_stats.ncode += n * ntotal;
384  indexPQ_stats.n_hamming_pass += n_pass;
385 
386 
387 }
388 
389 
390 
391 
392 /*****************************************
393  * Stats of IndexPQ codes
394  ******************************************/
395 
396 
397 
398 
399 void IndexPQ::hamming_distance_table (idx_t n, const float *x,
400  int32_t *dis) const
401 {
402  uint8_t * q_codes = new uint8_t [n * pq.code_size];
403  ScopeDeleter<uint8_t> del (q_codes);
404 
405  pq.compute_codes (x, q_codes, n);
406 
407  hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
408 }
409 
410 
412  idx_t nb, const float *xb,
413  long *hist)
414 {
415  FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
416  FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
417  FAISS_THROW_IF_NOT (pq.nbits == 8);
418 
419  // Hamming embedding queries
420  uint8_t * q_codes = new uint8_t [n * pq.code_size];
421  ScopeDeleter <uint8_t> del (q_codes);
422  pq.compute_codes (x, q_codes, n);
423 
424  uint8_t * b_codes ;
425  ScopeDeleter <uint8_t> del_b_codes;
426 
427  if (xb) {
428  b_codes = new uint8_t [nb * pq.code_size];
429  del_b_codes.set (b_codes);
430  pq.compute_codes (xb, b_codes, nb);
431  } else {
432  nb = ntotal;
433  b_codes = codes.data();
434  }
435  int nbits = pq.M * pq.nbits;
436  memset (hist, 0, sizeof(*hist) * (nbits + 1));
437  size_t bs = 256;
438 
439 #pragma omp parallel
440  {
441  std::vector<long> histi (nbits + 1);
442  hamdis_t *distances = new hamdis_t [nb * bs];
443  ScopeDeleter<hamdis_t> del (distances);
444 #pragma omp for
445  for (size_t q0 = 0; q0 < n; q0 += bs) {
446  // printf ("dis stats: %ld/%ld\n", q0, n);
447  size_t q1 = q0 + bs;
448  if (q1 > n) q1 = n;
449 
450  hammings (q_codes + q0 * pq.code_size, b_codes,
451  q1 - q0, nb,
452  pq.code_size, distances);
453 
454  for (size_t i = 0; i < nb * (q1 - q0); i++)
455  histi [distances [i]]++;
456  }
457 #pragma omp critical
458  {
459  for (int i = 0; i <= nbits; i++)
460  hist[i] += histi[i];
461  }
462  }
463 
464 }
465 
466 
467 
468 
469 
470 
471 
472 
473 
474 
475 
476 
477 
478 
479 
480 
481 
482 
483 
484 
485 /*****************************************
486  * MultiIndexQuantizer
487  ******************************************/
488 
489 namespace {
490 
491 template <typename T>
492 struct PreSortedArray {
493 
494  const T * x;
495  int N;
496 
497  explicit PreSortedArray (int N): N(N) {
498  }
499  void init (const T*x) {
500  this->x = x;
501  }
502  // get smallest value
503  T get_0 () {
504  return x[0];
505  }
506 
507  // get delta between n-smallest and n-1 -smallest
508  T get_diff (int n) {
509  return x[n] - x[n - 1];
510  }
511 
512  // remap orders counted from smallest to indices in array
513  int get_ord (int n) {
514  return n;
515  }
516 
517 };
518 
519 template <typename T>
520 struct ArgSort {
521  const T * x;
522  bool operator() (size_t i, size_t j) {
523  return x[i] < x[j];
524  }
525 };
526 
527 
528 /** Array that maintains a permutation of its elements so that the
529  * array's elements are sorted
530  */
531 template <typename T>
532 struct SortedArray {
533  const T * x;
534  int N;
535  std::vector<int> perm;
536 
537  explicit SortedArray (int N) {
538  this->N = N;
539  perm.resize (N);
540  }
541 
542  void init (const T*x) {
543  this->x = x;
544  for (int n = 0; n < N; n++)
545  perm[n] = n;
546  ArgSort<T> cmp = {x };
547  std::sort (perm.begin(), perm.end(), cmp);
548  }
549 
550  // get smallest value
551  T get_0 () {
552  return x[perm[0]];
553  }
554 
555  // get delta between n-smallest and n-1 -smallest
556  T get_diff (int n) {
557  return x[perm[n]] - x[perm[n - 1]];
558  }
559 
560  // remap orders counted from smallest to indices in array
561  int get_ord (int n) {
562  return perm[n];
563  }
564 };
565 
566 
567 
568 /** Array has n values. Sort the k first ones and copy the other ones
569  * into elements k..n-1
570  */
571 template <class C>
572 void partial_sort (int k, int n,
573  const typename C::T * vals, typename C::TI * perm) {
574  // insert first k elts in heap
575  for (int i = 1; i < k; i++) {
576  indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
577  }
578 
579  // insert next n - k elts in heap
580  for (int i = k; i < n; i++) {
581  typename C::TI id = perm[i];
582  typename C::TI top = perm[0];
583 
584  if (C::cmp(vals[top], vals[id])) {
585  indirect_heap_pop<C> (k, vals, perm);
586  indirect_heap_push<C> (k, vals, perm, id);
587  perm[i] = top;
588  } else {
589  // nothing, elt at i is good where it is.
590  }
591  }
592 
593  // order the k first elements in heap
594  for (int i = k - 1; i > 0; i--) {
595  typename C::TI top = perm[0];
596  indirect_heap_pop<C> (i + 1, vals, perm);
597  perm[i] = top;
598  }
599 }
600 
601 /** same as SortedArray, but only the k first elements are sorted */
602 template <typename T>
603 struct SemiSortedArray {
604  const T * x;
605  int N;
606 
607  // type of the heap: CMax = sort ascending
608  typedef CMax<T, int> HC;
609  std::vector<int> perm;
610 
611  int k; // k elements are sorted
612 
613  int initial_k, k_factor;
614 
615  explicit SemiSortedArray (int N) {
616  this->N = N;
617  perm.resize (N);
618  perm.resize (N);
619  initial_k = 3;
620  k_factor = 4;
621  }
622 
623  void init (const T*x) {
624  this->x = x;
625  for (int n = 0; n < N; n++)
626  perm[n] = n;
627  k = 0;
628  grow (initial_k);
629  }
630 
631  /// grow the sorted part of the array to size next_k
632  void grow (int next_k) {
633  if (next_k < N) {
634  partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
635  k = next_k;
636  } else { // full sort of remainder of array
637  ArgSort<T> cmp = {x };
638  std::sort (perm.begin() + k, perm.end(), cmp);
639  k = N;
640  }
641  }
642 
643  // get smallest value
644  T get_0 () {
645  return x[perm[0]];
646  }
647 
648  // get delta between n-smallest and n-1 -smallest
649  T get_diff (int n) {
650  if (n >= k) {
651  // want to keep powers of 2 - 1
652  int next_k = (k + 1) * k_factor - 1;
653  grow (next_k);
654  }
655  return x[perm[n]] - x[perm[n - 1]];
656  }
657 
658  // remap orders counted from smallest to indices in array
659  int get_ord (int n) {
660  assert (n < k);
661  return perm[n];
662  }
663 };
664 
665 
666 
667 /*****************************************
668  * Find the k smallest sums of M terms, where each term is taken in a
669  * table x of n values.
670  *
671  * A combination of terms is encoded as a scalar 0 <= t < n^M. The
672  * combination t0 ... t(M-1) that correspond to the sum
673  *
674  * sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)]
675  *
676  * is encoded as
677  *
678  * t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1)
679  *
680  * MinSumK is an object rather than a function, so that storage can be
681  * re-used over several computations with the same sizes. use_seen is
682  * good when there may be ties in the x array and it is a concern if
683  * occasionally several t's are returned.
684  *
685  * @param x size M * n, values to add up
686  * @parms k nb of results to retrieve
687  * @param M nb of terms
688  * @param n nb of distinct values
689  * @param sums output, size k, sorted
690  * @prarm terms output, size k, with encoding as above
691  *
692  ******************************************/
693 template <typename T, class SSA, bool use_seen>
694 struct MinSumK {
695  int K; ///< nb of sums to return
696  int M; ///< nb of elements to sum up
697  int nbit; ///< nb of bits to encode one entry
698  int N; ///< nb of possible elements for each of the M terms
699 
700  /** the heap.
701  * We use a heap to maintain a queue of sums, with the associated
702  * terms involved in the sum.
703  */
704  typedef CMin<T, long> HC;
705  size_t heap_capacity, heap_size;
706  T *bh_val;
707  long *bh_ids;
708 
709  std::vector <SSA> ssx;
710 
711  // all results get pushed several times. When there are ties, they
712  // are popped interleaved with others, so it is not easy to
713  // identify them. Therefore, this bit array just marks elements
714  // that were seen before.
715  std::vector <uint8_t> seen;
716 
717  MinSumK (int K, int M, int nbit, int N):
718  K(K), M(M), nbit(nbit), N(N) {
719  heap_capacity = K * M;
720  assert (N <= (1 << nbit));
721 
722  // we'll do k steps, each step pushes at most M vals
723  bh_val = new T[heap_capacity];
724  bh_ids = new long[heap_capacity];
725 
726  if (use_seen) {
727  long n_ids = weight(M);
728  seen.resize ((n_ids + 7) / 8);
729  }
730 
731  for (int m = 0; m < M; m++)
732  ssx.push_back (SSA(N));
733 
734  }
735 
736  long weight (int i) {
737  return 1 << (i * nbit);
738  }
739 
740  bool is_seen (long i) {
741  return (seen[i >> 3] >> (i & 7)) & 1;
742  }
743 
744  void mark_seen (long i) {
745  if (use_seen)
746  seen [i >> 3] |= 1 << (i & 7);
747  }
748 
749  void run (const T *x, long ldx,
750  T * sums, long * terms) {
751  heap_size = 0;
752 
753  for (int m = 0; m < M; m++) {
754  ssx[m].init(x);
755  x += ldx;
756  }
757 
758  { // intial result: take min for all elements
759  T sum = 0;
760  terms[0] = 0;
761  mark_seen (0);
762  for (int m = 0; m < M; m++) {
763  sum += ssx[m].get_0();
764  }
765  sums[0] = sum;
766  for (int m = 0; m < M; m++) {
767  heap_push<HC> (++heap_size, bh_val, bh_ids,
768  sum + ssx[m].get_diff(1),
769  weight(m));
770  }
771  }
772 
773  for (int k = 1; k < K; k++) {
774  // pop smallest value from heap
775  if (use_seen) {// skip already seen elements
776  while (is_seen (bh_ids[0])) {
777  assert (heap_size > 0);
778  heap_pop<HC> (heap_size--, bh_val, bh_ids);
779  }
780  }
781  assert (heap_size > 0);
782 
783  T sum = sums[k] = bh_val[0];
784  long ti = terms[k] = bh_ids[0];
785 
786  if (use_seen) {
787  mark_seen (ti);
788  heap_pop<HC> (heap_size--, bh_val, bh_ids);
789  } else {
790  do {
791  heap_pop<HC> (heap_size--, bh_val, bh_ids);
792  } while (heap_size > 0 && bh_ids[0] == ti);
793  }
794 
795  // enqueue followers
796  long ii = ti;
797  for (int m = 0; m < M; m++) {
798  long n = ii & ((1L << nbit) - 1);
799  ii >>= nbit;
800  if (n + 1 >= N) continue;
801 
802  enqueue_follower (ti, m, n, sum);
803  }
804  }
805 
806  /*
807  for (int k = 0; k < K; k++)
808  for (int l = k + 1; l < K; l++)
809  assert (terms[k] != terms[l]);
810  */
811 
812  // convert indices by applying permutation
813  for (int k = 0; k < K; k++) {
814  long ii = terms[k];
815  if (use_seen) {
816  // clear seen for reuse at next loop
817  seen[ii >> 3] = 0;
818  }
819  long ti = 0;
820  for (int m = 0; m < M; m++) {
821  long n = ii & ((1L << nbit) - 1);
822  ti += long(ssx[m].get_ord(n)) << (nbit * m);
823  ii >>= nbit;
824  }
825  terms[k] = ti;
826  }
827  }
828 
829 
830  void enqueue_follower (long ti, int m, int n, T sum) {
831  T next_sum = sum + ssx[m].get_diff(n + 1);
832  long next_ti = ti + weight(m);
833  heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
834  }
835 
836  ~MinSumK () {
837  delete [] bh_ids;
838  delete [] bh_val;
839  }
840 };
841 
842 } // anonymous namespace
843 
844 
845 MultiIndexQuantizer::MultiIndexQuantizer (int d,
846  size_t M,
847  size_t nbits):
848  Index(d, METRIC_L2), pq(d, M, nbits)
849 {
850  is_trained = false;
851  pq.verbose = verbose;
852 }
853 
854 
855 
856 void MultiIndexQuantizer::train(idx_t n, const float *x)
857 {
858  pq.verbose = verbose;
859  pq.train (n, x);
860  is_trained = true;
861  // count virtual elements in index
862  ntotal = 1;
863  for (int m = 0; m < pq.M; m++)
864  ntotal *= pq.ksub;
865 }
866 
867 
868 void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
869  float *distances, idx_t *labels) const {
870  if (n == 0) return;
871 
872  // the allocation just below can be severe...
873  idx_t bs = 32768;
874  if (n > bs) {
875  for (idx_t i0 = 0; i0 < n; i0 += bs) {
876  idx_t i1 = std::min(i0 + bs, n);
877  if (verbose) {
878  printf("MultiIndexQuantizer::search: %ld:%ld / %ld\n",
879  i0, i1, n);
880  }
881  search (i1 - i0, x + i0 * d, k,
882  distances + i0 * k,
883  labels + i0 * k);
884  }
885  return;
886  }
887 
888  float * dis_tables = new float [n * pq.ksub * pq.M];
889  ScopeDeleter<float> del (dis_tables);
890 
891  pq.compute_distance_tables (n, x, dis_tables);
892 
893  if (k == 1) {
894  // simple version that just finds the min in each table
895 
896 #pragma omp parallel for
897  for (int i = 0; i < n; i++) {
898  const float * dis_table = dis_tables + i * pq.ksub * pq.M;
899  float dis = 0;
900  idx_t label = 0;
901 
902  for (int s = 0; s < pq.M; s++) {
903  float vmin = HUGE_VALF;
904  idx_t lmin = -1;
905 
906  for (idx_t j = 0; j < pq.ksub; j++) {
907  if (dis_table[j] < vmin) {
908  vmin = dis_table[j];
909  lmin = j;
910  }
911  }
912  dis += vmin;
913  label |= lmin << (s * pq.nbits);
914  dis_table += pq.ksub;
915  }
916 
917  distances [i] = dis;
918  labels [i] = label;
919  }
920 
921 
922  } else {
923 
924 #pragma omp parallel if(n > 1)
925  {
926  MinSumK <float, SemiSortedArray<float>, false>
927  msk(k, pq.M, pq.nbits, pq.ksub);
928 #pragma omp for
929  for (int i = 0; i < n; i++) {
930  msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
931  distances + i * k, labels + i * k);
932 
933  }
934  }
935  }
936 
937 }
938 
939 
940 void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
941 {
942 
943  long jj = key;
944  for (int m = 0; m < pq.M; m++) {
945  long n = jj & ((1L << pq.nbits) - 1);
946  jj >>= pq.nbits;
947  memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
948  recons += pq.dsub;
949  }
950 }
951 
952 void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
953  FAISS_THROW_MSG(
954  "This index has virtual elements, "
955  "it does not support add");
956 }
957 
959 {
960  FAISS_THROW_MSG ( "This index has virtual elements, "
961  "it does not support reset");
962 }
963 
964 
965 
966 
967 
968 
969 
970 
971 
972 
973 /*****************************************
974  * MultiIndexQuantizer2
975  ******************************************/
976 
977 
978 
979 MultiIndexQuantizer2::MultiIndexQuantizer2 (
980  int d, size_t M, size_t nbits,
981  Index **indexes):
982  MultiIndexQuantizer (d, M, nbits)
983 {
984  assign_indexes.resize (M);
985  for (int i = 0; i < M; i++) {
986  FAISS_THROW_IF_NOT_MSG(
987  indexes[i]->d == pq.dsub,
988  "Provided sub-index has incorrect size");
989  assign_indexes[i] = indexes[i];
990  }
991  own_fields = false;
992 }
993 
994 MultiIndexQuantizer2::MultiIndexQuantizer2 (
995  int d, size_t nbits,
996  Index *assign_index_0,
997  Index *assign_index_1):
998  MultiIndexQuantizer (d, 2, nbits)
999 {
1000  FAISS_THROW_IF_NOT_MSG(
1001  assign_index_0->d == pq.dsub &&
1002  assign_index_1->d == pq.dsub,
1003  "Provided sub-index has incorrect size");
1004  assign_indexes.resize (2);
1005  assign_indexes [0] = assign_index_0;
1006  assign_indexes [1] = assign_index_1;
1007  own_fields = false;
1008 }
1009 
1010 void MultiIndexQuantizer2::train(idx_t n, const float* x)
1011 {
1013  // add centroids to sub-indexes
1014  for (int i = 0; i < pq.M; i++) {
1015  assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0));
1016  }
1017 }
1018 
1019 
1021  idx_t n, const float* x, idx_t K,
1022  float* distances, idx_t* labels) const
1023 {
1024 
1025  if (n == 0) return;
1026 
1027  int k2 = std::min(K, long(pq.ksub));
1028 
1029  long M = pq.M;
1030  long dsub = pq.dsub, ksub = pq.ksub;
1031 
1032  // size (M, n, k2)
1033  std::vector<idx_t> sub_ids(n * M * k2);
1034  std::vector<float> sub_dis(n * M * k2);
1035  std::vector<float> xsub(n * dsub);
1036 
1037  for (int m = 0; m < M; m++) {
1038  float *xdest = xsub.data();
1039  const float *xsrc = x + m * dsub;
1040  for (int j = 0; j < n; j++) {
1041  memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
1042  xsrc += d;
1043  xdest += dsub;
1044  }
1045 
1046  assign_indexes[m]->search(
1047  n, xsub.data(), k2,
1048  &sub_dis[k2 * n * m],
1049  &sub_ids[k2 * n * m]);
1050  }
1051 
1052  if (K == 1) {
1053  // simple version that just finds the min in each table
1054  assert (k2 == 1);
1055 
1056  for (int i = 0; i < n; i++) {
1057  float dis = 0;
1058  idx_t label = 0;
1059 
1060  for (int m = 0; m < M; m++) {
1061  float vmin = sub_dis[i + m * n];
1062  idx_t lmin = sub_ids[i + m * n];
1063  dis += vmin;
1064  label |= lmin << (m * pq.nbits);
1065  }
1066  distances [i] = dis;
1067  labels [i] = label;
1068  }
1069 
1070  } else {
1071 
1072 #pragma omp parallel if(n > 1)
1073  {
1074  MinSumK <float, PreSortedArray<float>, false>
1075  msk(K, pq.M, pq.nbits, k2);
1076 #pragma omp for
1077  for (int i = 0; i < n; i++) {
1078  idx_t *li = labels + i * K;
1079  msk.run (&sub_dis[i * k2], k2 * n,
1080  distances + i * K, li);
1081 
1082  // remap ids
1083 
1084  const idx_t *idmap0 = sub_ids.data() + i * k2;
1085  long ld_idmap = k2 * n;
1086  long mask1 = ksub - 1L;
1087 
1088  for (int k = 0; k < K; k++) {
1089  const idx_t *idmap = idmap0;
1090  long vin = li[k];
1091  long vout = 0;
1092  int bs = 0;
1093  for (int m = 0; m < M; m++) {
1094  long s = vin & mask1;
1095  vin >>= pq.nbits;
1096  vout |= idmap[s] << bs;
1097  bs += pq.nbits;
1098  idmap += ld_idmap;
1099  }
1100  li[k] = vout;
1101  }
1102  }
1103  }
1104  }
1105 }
1106 
1107 
1108 } // namespace faiss
void hammings_knn_hc(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t ncodes, int order)
Definition: hamming.cpp:517
std::vector< uint8_t > codes
Codes. Size ntotal * pq.code_size.
Definition: IndexPQ.h:32
size_t nbits
number of bits per quantization index
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
Hamming distance on codes.
Definition: IndexPQ.h:77
bool do_polysemous_training
false = standard PQ
Definition: IndexPQ.h:69
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:52
void reset() override
removes all elements from the database.
Definition: IndexPQ.cpp:958
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:856
size_t dsub
dimensionality of each subvector
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
int d
vector dimension
Definition: Index.h:66
long idx_t
all indices are this type
Definition: Index.h:62
void hamming_distance_histogram(idx_t n, const float *x, idx_t nb, const float *xb, long *dist_histogram)
Definition: IndexPQ.cpp:411
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
void train(idx_t n, const float *x) override
Definition: IndexPQ.cpp:1010
size_t code_size
byte per indexed vector
long remove_ids(const IDSelector &sel) override
Definition: IndexPQ.cpp:85
Filter on generalized Hamming.
Definition: IndexPQ.h:81
size_t ksub
number of centroids for each subquantizer
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
ProductQuantizer pq
The product quantizer used to encode the vectors.
Definition: IndexPQ.h:29
idx_t ntotal
total nb of indexed vectors
Definition: Index.h:67
bool verbose
verbosity level
Definition: Index.h:68
void add(idx_t n, const float *x) override
Definition: IndexPQ.cpp:76
void hamming_distance_table(idx_t n, const float *x, int32_t *dis) const
Definition: IndexPQ.cpp:399
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:868
void reconstruct(idx_t key, float *recons) const override
Definition: IndexPQ.cpp:123
MetricType metric_type
type of metric this index uses for search
Definition: Index.h:74
size_t M
number of subquantizers
void reconstruct_n(idx_t i0, idx_t ni, float *recons) const override
Definition: IndexPQ.cpp:113
asymmetric product quantizer (default)
Definition: IndexPQ.h:76
void reconstruct(idx_t key, float *recons) const override
Definition: IndexPQ.cpp:940
HE filter (using ht) + PQ combination.
Definition: IndexPQ.h:80
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:1020
void add(idx_t n, const float *x) override
add and reset will crash at runtime
Definition: IndexPQ.cpp:952
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:71
void reset() override
removes all elements from the database.
Definition: IndexPQ.cpp:107
float * get_centroids(size_t m, size_t i)
return the centroids associated with subvector m
void optimize_pq_for_hamming(ProductQuantizer &pq, size_t n, const float *x) const
bool verbose
verbose during training?
void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const override
Definition: IndexPQ.cpp:143
symmetric product quantizer (SDC)
Definition: IndexPQ.h:79
int polysemous_ht
Hamming threshold used for polysemy.
Definition: IndexPQ.h:91
PolysemousTraining polysemous_training
parameters used for the polysemous training
Definition: IndexPQ.h:72
std::vector< Index * > assign_indexes
M Indexes on d / M dimensions.
Definition: IndexPQ.h:163
MetricType
Some algorithms support both an inner product version and a L2 search version.
Definition: Index.h:44
void generalized_hammings_knn_hc(int_maxheap_array_t *ha, const uint8_t *a, const uint8_t *b, size_t nb, size_t code_size, int ordered)
Definition: hamming.cpp:728