Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
ProductQuantizer.cpp
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 /* Copyright 2004-present Facebook. All Rights Reserved.
11  Index based on product quantiztion.
12 */
13 
14 #include "ProductQuantizer.h"
15 
16 
17 #include <cstddef>
18 #include <cstring>
19 #include <cstdio>
20 
21 #include <algorithm>
22 
23 #include "FaissAssert.h"
24 #include "VectorTransform.h"
25 #include "IndexFlat.h"
26 #include "utils.h"
27 
28 
29 extern "C" {
30 
31 /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
32 
33 int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
34  n, FINTEGER *k, const float *alpha, const float *a,
35  FINTEGER *lda, const float *b, FINTEGER *
36  ldb, float *beta, float *c, FINTEGER *ldc);
37 
38 }
39 
40 
41 namespace faiss {
42 
43 
44 
45 
46 /* compute an estimator using look-up tables for typical values of M */
47 template <typename CT, class C>
48 void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
49  size_t ncodes,
50  const float * __restrict dis_table,
51  size_t ksub,
52  size_t k,
53  float * heap_dis,
54  long * heap_ids)
55 {
56 
57  for (size_t j = 0; j < ncodes; j++) {
58  float dis = 0;
59  const float *dt = dis_table;
60 
61  for (size_t m = 0; m < M; m+=4) {
62  float dism = 0;
63  dism = dt[*codes++]; dt += ksub;
64  dism += dt[*codes++]; dt += ksub;
65  dism += dt[*codes++]; dt += ksub;
66  dism += dt[*codes++]; dt += ksub;
67  dis += dism;
68  }
69 
70  if (C::cmp (heap_dis[0], dis)) {
71  heap_pop<C> (k, heap_dis, heap_ids);
72  heap_push<C> (k, heap_dis, heap_ids, dis, j);
73  }
74  }
75 }
76 
77 
78 template <typename CT, class C>
79 void pq_estimators_from_tables_M4 (const CT * codes,
80  size_t ncodes,
81  const float * __restrict dis_table,
82  size_t ksub,
83  size_t k,
84  float * heap_dis,
85  long * heap_ids)
86 {
87 
88  for (size_t j = 0; j < ncodes; j++) {
89  float dis = 0;
90  const float *dt = dis_table;
91  dis = dt[*codes++]; dt += ksub;
92  dis += dt[*codes++]; dt += ksub;
93  dis += dt[*codes++]; dt += ksub;
94  dis += dt[*codes++];
95 
96  if (C::cmp (heap_dis[0], dis)) {
97  heap_pop<C> (k, heap_dis, heap_ids);
98  heap_push<C> (k, heap_dis, heap_ids, dis, j);
99  }
100  }
101 }
102 
103 
104 template <typename CT, class C>
105 static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
106  const CT * codes,
107  size_t ncodes,
108  const float * dis_table,
109  size_t k,
110  float * heap_dis,
111  long * heap_ids)
112 {
113 
114  if (pq->M == 4) {
115 
116  pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
117  dis_table, pq->ksub, k,
118  heap_dis, heap_ids);
119  return;
120  }
121 
122  if (pq->M % 4 == 0) {
123  pq_estimators_from_tables_Mmul4<CT, C> (pq->M, codes, ncodes,
124  dis_table, pq->ksub, k,
125  heap_dis, heap_ids);
126  return;
127  }
128 
129  /* Default is relatively slow */
130  const size_t M = pq->M;
131  const size_t ksub = pq->ksub;
132  for (size_t j = 0; j < ncodes; j++) {
133  float dis = 0;
134  const float * __restrict dt = dis_table;
135  for (int m = 0; m < M; m++) {
136  dis += dt[*codes++];
137  dt += ksub;
138  }
139  if (C::cmp (heap_dis[0], dis)) {
140  heap_pop<C> (k, heap_dis, heap_ids);
141  heap_push<C> (k, heap_dis, heap_ids, dis, j);
142  }
143  }
144 }
145 
146 
147 /*********************************************
148  * PQ implementation
149  *********************************************/
150 
151 
152 
153 ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
154  d(d), M(M), nbits(nbits)
155 {
156  set_derived_values ();
157 }
158 
159 ProductQuantizer::ProductQuantizer ():
160  d(0), M(1), nbits(0)
161 {
162  set_derived_values ();
163 }
164 
165 
166 
168  // quite a few derived values
169  FAISS_ASSERT (d % M == 0);
170  dsub = d / M;
171  byte_per_idx = (nbits + 7) / 8;
172  code_size = byte_per_idx * M;
173  ksub = 1 << nbits;
174  centroids.resize (d * ksub);
175  verbose = false;
176  train_type = Train_default;
177 }
178 
179 
180 void ProductQuantizer::set_params (const float * centroids_, int m)
181 {
182  memcpy (get_centroids(m, 0), centroids_,
183  ksub * dsub * sizeof (centroids_[0]));
184 }
185 
186 
187 static void init_hypercube (int d, int nbits,
188  int n, const float * x,
189  float *centroids)
190 {
191 
192  std::vector<float> mean (d);
193  for (int i = 0; i < n; i++)
194  for (int j = 0; j < d; j++)
195  mean [j] += x[i * d + j];
196 
197  float maxm = 0;
198  for (int j = 0; j < d; j++) {
199  mean [j] /= n;
200  if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
201  }
202 
203  for (int i = 0; i < (1 << nbits); i++) {
204  float * cent = centroids + i * d;
205  for (int j = 0; j < nbits; j++)
206  cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
207  for (int j = nbits; j < d; j++)
208  cent[j] = mean [j];
209  }
210 
211 
212 }
213 
214 static void init_hypercube_pca (int d, int nbits,
215  int n, const float * x,
216  float *centroids)
217 {
218  PCAMatrix pca (d, nbits);
219  pca.train (n, x);
220 
221 
222  for (int i = 0; i < (1 << nbits); i++) {
223  float * cent = centroids + i * d;
224  for (int j = 0; j < d; j++) {
225  cent[j] = pca.mean[j];
226  float f = 1.0;
227  for (int k = 0; k < nbits; k++)
228  cent[j] += f *
229  sqrt (pca.eigenvalues [k]) *
230  (((i >> k) & 1) ? 1 : -1) *
231  pca.PCAMat [j + k * d];
232  }
233  }
234 
235 }
236 
237 void ProductQuantizer::train (int n, const float * x)
238 {
239  if (train_type != Train_shared) {
240  train_type_t final_train_type;
241  final_train_type = train_type;
242  if (train_type == Train_hypercube ||
243  train_type == Train_hypercube_pca) {
244  if (dsub < nbits) {
245  final_train_type = Train_default;
246  printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
247  nbits, dsub);
248  }
249  }
250 
251  float * xslice = new float[n * dsub];
252  for (int m = 0; m < M; m++) {
253  for (int j = 0; j < n; j++)
254  memcpy (xslice + j * dsub,
255  x + j * d + m * dsub,
256  dsub * sizeof(float));
257 
258  Clustering clus (dsub, ksub, cp);
259 
260  // we have some initialization for the centroids
261  if (final_train_type != Train_default) {
262  clus.centroids.resize (dsub * ksub);
263  }
264 
265  switch (final_train_type) {
266  case Train_hypercube:
267  init_hypercube (dsub, nbits, n, xslice,
268  clus.centroids.data ());
269  break;
270  case Train_hypercube_pca:
271  init_hypercube_pca (dsub, nbits, n, xslice,
272  clus.centroids.data ());
273  break;
274  case Train_hot_start:
275  memcpy (clus.centroids.data(),
276  get_centroids (m, 0),
277  dsub * ksub * sizeof (float));
278  break;
279  default: ;
280  }
281 
282  if(verbose) {
283  clus.verbose = true;
284  printf ("Training PQ slice %d/%zd\n", m, M);
285  }
286  IndexFlatL2 index (dsub);
287  clus.train (n, xslice, index);
288  set_params (clus.centroids.data(), m);
289  }
290  delete [] xslice;
291 
292  } else {
293 
294  Clustering clus (dsub, ksub, cp);
295 
296  if(verbose) {
297  clus.verbose = true;
298  printf ("Training all PQ slices at once\n");
299  }
300 
301  IndexFlatL2 index (dsub);
302  clus.train (n * M, x, index);
303  for (int m = 0; m < M; m++) {
304  set_params (clus.centroids.data(), m);
305  }
306 
307  }
308 }
309 
310 
311 void ProductQuantizer::compute_code (const float * x, uint8_t * code) const
312 {
313  float distances [ksub];
314  for (size_t m = 0; m < M; m++) {
315  float mindis = 1e20;
316  int idxm = -1;
317  const float * xsub = x + m * dsub;
318 
319  fvec_L2sqr_ny (distances, xsub, get_centroids(m, 0), dsub, ksub);
320 
321  /* Find best centroid */
322  size_t i;
323  for (i = 0; i < ksub; i++) {
324  float dis = distances [i];
325  if (dis < mindis) {
326  mindis = dis;
327  idxm = i;
328  }
329  }
330  switch (byte_per_idx) {
331  case 1: code[m] = (uint8_t) idxm; break;
332  case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
333  }
334  }
335 
336 }
337 
338 void ProductQuantizer::decode (const uint8_t *code, float *x) const
339 {
340  if (byte_per_idx == 1) {
341  for (size_t m = 0; m < M; m++) {
342  memcpy (x + m * dsub, get_centroids(m, code[m]),
343  sizeof(float) * dsub);
344  }
345  } else {
346  const uint16_t *c = (const uint16_t*) code;
347  for (size_t m = 0; m < M; m++) {
348  memcpy (x + m * dsub, get_centroids(m, c[m]),
349  sizeof(float) * dsub);
350  }
351  }
352 }
353 
354 
355 void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
356 {
357  for (size_t i = 0; i < n; i++) {
358  this->decode (code + M * i, x + d * i);
359  }
360 }
361 
362 
364  uint8_t *code) const
365 {
366  for (size_t m = 0; m < M; m++) {
367  float mindis = 1e20;
368  int idxm = -1;
369 
370  /* Find best centroid */
371  for (size_t j = 0; j < ksub; j++) {
372  float dis = *tab++;
373  if (dis < mindis) {
374  mindis = dis;
375  idxm = j;
376  }
377  }
378  switch (byte_per_idx) {
379  case 1: code[m] = (uint8_t) idxm; break;
380  case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
381  }
382  }
383 }
384 
385 void ProductQuantizer::compute_codes (const float * x,
386  uint8_t * codes,
387  size_t n) const
388 {
389  if (dsub < 16) { // simple direct computation
390 
391 #pragma omp parallel for
392  for (size_t i = 0; i < n; i++)
393  compute_code (x + i * d, codes + i * code_size);
394 
395  } else { // worthwile to use BLAS
396  float *dis_tables = new float [n * ksub * M];
397  compute_distance_tables (n, x, dis_tables);
398 
399 #pragma omp parallel for
400  for (size_t i = 0; i < n; i++) {
401  uint8_t * code = codes + i * code_size;
402  const float * tab = dis_tables + i * ksub * M;
403  compute_code_from_distance_table (tab, code);
404  }
405  delete [] dis_tables;
406  }
407 }
408 
409 
411  float * dis_table) const
412 {
413  size_t m;
414 
415  for (m = 0; m < M; m++) {
416  fvec_L2sqr_ny (dis_table + m * ksub,
417  x + m * dsub,
418  get_centroids(m, 0),
419  dsub,
420  ksub);
421  }
422 }
423 
424 void ProductQuantizer::compute_inner_prod_table (const float * x,
425  float * dis_table) const
426 {
427  size_t m;
428 
429  for (m = 0; m < M; m++) {
430  fvec_inner_products_ny (dis_table + m * ksub,
431  x + m * dsub,
432  get_centroids(m, 0),
433  dsub,
434  ksub);
435  }
436 }
437 
438 
440  size_t nx,
441  const float * x,
442  float * dis_tables) const
443 {
444 
445  if (dsub < 16) {
446 
447 #pragma omp parallel for
448  for (size_t i = 0; i < nx; i++) {
449  compute_distance_table (x + i * d, dis_tables + i * ksub * M);
450  }
451 
452  } else { // use BLAS
453 
454  for (int m = 0; m < M; m++) {
455  pairwise_L2sqr (dsub,
456  nx, x + dsub * m,
457  ksub, centroids.data() + m * dsub * ksub,
458  dis_tables + ksub * m,
459  d, dsub, ksub * M);
460  }
461  }
462 }
463 
464 void ProductQuantizer::compute_inner_prod_tables (
465  size_t nx,
466  const float * x,
467  float * dis_tables) const
468 {
469 
470  if (dsub < 16) {
471 
472 #pragma omp parallel for
473  for (size_t i = 0; i < nx; i++) {
474  compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
475  }
476 
477  } else { // use BLAS
478 
479  // compute distance tables
480  for (int m = 0; m < M; m++) {
481  FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
482  dsubi = dsub, di = d;
483  float one = 1.0, zero = 0;
484 
485  sgemm_ ("Transposed", "Not transposed",
486  &ksubi, &nxi, &dsubi,
487  &one, &centroids [m * dsub * ksub], &dsubi,
488  x + dsub * m, &di,
489  &zero, dis_tables + ksub * m, &ldc);
490  }
491 
492  }
493 }
494 
495 template <typename CT, class C>
496 static void pq_knn_search_with_tables (
497  const ProductQuantizer * pq,
498  const float *dis_tables,
499  const uint8_t * codes,
500  const size_t ncodes,
501  HeapArray<C> * res,
502  bool init_finalize_heap)
503 {
504  size_t k = res->k, nx = res->nh;
505  size_t ksub = pq->ksub, M = pq->M;
506 
507 
508 #pragma omp parallel for
509  for (size_t i = 0; i < nx; i++) {
510  /* query preparation for asymmetric search: compute look-up tables */
511  const float* dis_table = dis_tables + i * ksub * M;
512 
513  /* Compute distances and keep smallest values */
514  long * __restrict heap_ids = res->ids + i * k;
515  float * __restrict heap_dis = res->val + i * k;
516 
517  if (init_finalize_heap) {
518  heap_heapify<C> (k, heap_dis, heap_ids);
519  }
520 
521  pq_estimators_from_tables<CT, C> (pq,
522  (CT*)codes, ncodes,
523  dis_table,
524  k, heap_dis, heap_ids);
525  if (init_finalize_heap) {
526  heap_reorder<C> (k, heap_dis, heap_ids);
527  }
528  }
529 }
530 
531  /*
532 static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
533  const CT * codes,
534  size_t ncodes,
535  const float * dis_table,
536  size_t k,
537  float * heap_dis,
538  long * heap_ids)
539  */
540 void ProductQuantizer::search (const float * __restrict x,
541  size_t nx,
542  const uint8_t * codes,
543  const size_t ncodes,
544  float_maxheap_array_t * res,
545  bool init_finalize_heap) const
546 {
547  float * dis_tables = new float [nx * ksub * M];
548  compute_distance_tables (nx, x, dis_tables);
549  FAISS_ASSERT(nx == res->nh);
550 
551  if (byte_per_idx == 1) {
552 
553  pq_knn_search_with_tables<uint8_t, CMax<float, long> > (
554  this, dis_tables, codes, ncodes, res, init_finalize_heap);
555 
556  } else if (byte_per_idx == 2) {
557  pq_knn_search_with_tables<uint16_t, CMax<float, long> > (
558  this, dis_tables, codes, ncodes, res, init_finalize_heap);
559 
560  }
561  delete [] dis_tables;
562 }
563 
564 void ProductQuantizer::search_ip (const float * __restrict x,
565  size_t nx,
566  const uint8_t * codes,
567  const size_t ncodes,
568  float_minheap_array_t * res,
569  bool init_finalize_heap) const
570 {
571  float * dis_tables = new float [nx * ksub * M];
572  compute_inner_prod_tables (nx, x, dis_tables);
573  FAISS_ASSERT(nx == res->nh);
574 
575  if (byte_per_idx == 1) {
576 
577  pq_knn_search_with_tables<uint8_t, CMin<float, long> > (
578  this, dis_tables, codes, ncodes, res, init_finalize_heap);
579 
580  } else if (byte_per_idx == 2) {
581  pq_knn_search_with_tables<uint16_t, CMin<float, long> > (
582  this, dis_tables, codes, ncodes, res, init_finalize_heap);
583  }
584  delete [] dis_tables;
585 }
586 
587 
588 
589 static float sqr (float x) {
590  return x * x;
591 }
592 
593 void ProductQuantizer::compute_sdc_table ()
594 {
595  sdc_table.resize (M * ksub * ksub);
596 
597  for (int m = 0; m < M; m++) {
598 
599  const float *cents = centroids.data() + m * ksub * dsub;
600  float * dis_tab = sdc_table.data() + m * ksub * ksub;
601 
602  // TODO optimize with BLAS
603  for (int i = 0; i < ksub; i++) {
604  const float *centi = cents + i * dsub;
605  for (int j = 0; j < ksub; j++) {
606  float accu = 0;
607  const float *centj = cents + j * dsub;
608  for (int k = 0; k < dsub; k++)
609  accu += sqr (centi[k] - centj[k]);
610  dis_tab [i + j * ksub] = accu;
611  }
612  }
613  }
614 }
615 
616 void ProductQuantizer::search_sdc (const uint8_t * qcodes,
617  size_t nq,
618  const uint8_t * bcodes,
619  const size_t nb,
620  float_maxheap_array_t * res,
621  bool init_finalize_heap) const
622 {
623  FAISS_ASSERT (sdc_table.size() == M * ksub * ksub);
624  size_t k = res->k;
625 
626  FAISS_ASSERT (byte_per_idx == 1);
627 
628 #pragma omp parallel for
629  for (size_t i = 0; i < nq; i++) {
630 
631  /* Compute distances and keep smallest values */
632  long * heap_ids = res->ids + i * k;
633  float * heap_dis = res->val + i * k;
634  const uint8_t * qcode = qcodes + i * code_size;
635 
636  if (init_finalize_heap)
637  maxheap_heapify (k, heap_dis, heap_ids);
638 
639  const uint8_t * bcode = bcodes;
640  for (size_t j = 0; j < nb; j++) {
641  float dis = 0;
642  const float * tab = sdc_table.data();
643  for (int m = 0; m < M; m++) {
644  dis += tab[bcode[m] + qcode[m] * ksub];
645  tab += ksub * ksub;
646  }
647  if (dis < heap_dis[0]) {
648  maxheap_pop (k, heap_dis, heap_ids);
649  maxheap_push (k, heap_dis, heap_ids, dis, j);
650  }
651  bcode += code_size;
652  }
653 
654  if (init_finalize_heap)
655  maxheap_reorder (k, heap_dis, heap_ids);
656  }
657 
658 }
659 
660 
661 
662 
663 
664 
665 } // namespace faiss
void set_params(const float *centroids, int m)
Define the centroids for subquantizer m.
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
void set_derived_values()
compute derived values when d, M and nbits have been set
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
void compute_distance_table(const float *x, float *dis_table) const
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
void pairwise_L2sqr(long d, long nq, const float *xq, long nb, const float *xb, float *dis, long ldq, long ldb, long ldd)
Definition: utils.cpp:1228
void compute_code(const float *x, uint8_t *code) const
Quantize one vector with the product quantizer.
size_t nh
number of heaps
Definition: Heap.h:355