Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/matthijs/github_faiss/faiss/ProductQuantizer.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 /* Copyright 2004-present Facebook. All Rights Reserved.
10  Index based on product quantiztion.
11 */
12 
13 #include "ProductQuantizer.h"
14 
15 
16 #include <cstddef>
17 #include <cstring>
18 #include <cstdio>
19 
20 #include <algorithm>
21 
22 #include "FaissAssert.h"
23 #include "VectorTransform.h"
24 #include "IndexFlat.h"
25 #include "utils.h"
26 
27 
28 extern "C" {
29 
30 /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
31 
32 int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
33  n, FINTEGER *k, const float *alpha, const float *a,
34  FINTEGER *lda, const float *b, FINTEGER *
35  ldb, float *beta, float *c, FINTEGER *ldc);
36 
37 }
38 
39 
40 namespace faiss {
41 
42 
43 
44 
45 /* compute an estimator using look-up tables for typical values of M */
46 template <typename CT, class C>
47 void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
48  size_t ncodes,
49  const float * __restrict dis_table,
50  size_t ksub,
51  size_t k,
52  float * heap_dis,
53  long * heap_ids)
54 {
55 
56  for (size_t j = 0; j < ncodes; j++) {
57  float dis = 0;
58  const float *dt = dis_table;
59 
60  for (size_t m = 0; m < M; m+=4) {
61  float dism = 0;
62  dism = dt[*codes++]; dt += ksub;
63  dism += dt[*codes++]; dt += ksub;
64  dism += dt[*codes++]; dt += ksub;
65  dism += dt[*codes++]; dt += ksub;
66  dis += dism;
67  }
68 
69  if (C::cmp (heap_dis[0], dis)) {
70  heap_pop<C> (k, heap_dis, heap_ids);
71  heap_push<C> (k, heap_dis, heap_ids, dis, j);
72  }
73  }
74 }
75 
76 
77 template <typename CT, class C>
78 void pq_estimators_from_tables_M4 (const CT * codes,
79  size_t ncodes,
80  const float * __restrict dis_table,
81  size_t ksub,
82  size_t k,
83  float * heap_dis,
84  long * heap_ids)
85 {
86 
87  for (size_t j = 0; j < ncodes; j++) {
88  float dis = 0;
89  const float *dt = dis_table;
90  dis = dt[*codes++]; dt += ksub;
91  dis += dt[*codes++]; dt += ksub;
92  dis += dt[*codes++]; dt += ksub;
93  dis += dt[*codes++];
94 
95  if (C::cmp (heap_dis[0], dis)) {
96  heap_pop<C> (k, heap_dis, heap_ids);
97  heap_push<C> (k, heap_dis, heap_ids, dis, j);
98  }
99  }
100 }
101 
102 
103 template <typename CT, class C>
104 static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
105  const CT * codes,
106  size_t ncodes,
107  const float * dis_table,
108  size_t k,
109  float * heap_dis,
110  long * heap_ids)
111 {
112 
113  if (pq->M == 4) {
114 
115  pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
116  dis_table, pq->ksub, k,
117  heap_dis, heap_ids);
118  return;
119  }
120 
121  if (pq->M % 4 == 0) {
122  pq_estimators_from_tables_Mmul4<CT, C> (pq->M, codes, ncodes,
123  dis_table, pq->ksub, k,
124  heap_dis, heap_ids);
125  return;
126  }
127 
128  /* Default is relatively slow */
129  const size_t M = pq->M;
130  const size_t ksub = pq->ksub;
131  for (size_t j = 0; j < ncodes; j++) {
132  float dis = 0;
133  const float * __restrict dt = dis_table;
134  for (int m = 0; m < M; m++) {
135  dis += dt[*codes++];
136  dt += ksub;
137  }
138  if (C::cmp (heap_dis[0], dis)) {
139  heap_pop<C> (k, heap_dis, heap_ids);
140  heap_push<C> (k, heap_dis, heap_ids, dis, j);
141  }
142  }
143 }
144 
145 
146 /*********************************************
147  * PQ implementation
148  *********************************************/
149 
150 
151 
152 ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
153  d(d), M(M), nbits(nbits)
154 {
155  set_derived_values ();
156 }
157 
158 ProductQuantizer::ProductQuantizer ():
159  d(0), M(1), nbits(0)
160 {
161  set_derived_values ();
162 }
163 
164 
165 
167  // quite a few derived values
168  FAISS_THROW_IF_NOT (d % M == 0);
169  dsub = d / M;
170  byte_per_idx = (nbits + 7) / 8;
172  ksub = 1 << nbits;
173  centroids.resize (d * ksub);
174  verbose = false;
175  train_type = Train_default;
176 }
177 
178 
179 void ProductQuantizer::set_params (const float * centroids_, int m)
180 {
181  memcpy (get_centroids(m, 0), centroids_,
182  ksub * dsub * sizeof (centroids_[0]));
183 }
184 
185 
186 static void init_hypercube (int d, int nbits,
187  int n, const float * x,
188  float *centroids)
189 {
190 
191  std::vector<float> mean (d);
192  for (int i = 0; i < n; i++)
193  for (int j = 0; j < d; j++)
194  mean [j] += x[i * d + j];
195 
196  float maxm = 0;
197  for (int j = 0; j < d; j++) {
198  mean [j] /= n;
199  if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
200  }
201 
202  for (int i = 0; i < (1 << nbits); i++) {
203  float * cent = centroids + i * d;
204  for (int j = 0; j < nbits; j++)
205  cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
206  for (int j = nbits; j < d; j++)
207  cent[j] = mean [j];
208  }
209 
210 
211 }
212 
213 static void init_hypercube_pca (int d, int nbits,
214  int n, const float * x,
215  float *centroids)
216 {
217  PCAMatrix pca (d, nbits);
218  pca.train (n, x);
219 
220 
221  for (int i = 0; i < (1 << nbits); i++) {
222  float * cent = centroids + i * d;
223  for (int j = 0; j < d; j++) {
224  cent[j] = pca.mean[j];
225  float f = 1.0;
226  for (int k = 0; k < nbits; k++)
227  cent[j] += f *
228  sqrt (pca.eigenvalues [k]) *
229  (((i >> k) & 1) ? 1 : -1) *
230  pca.PCAMat [j + k * d];
231  }
232  }
233 
234 }
235 
236 void ProductQuantizer::train (int n, const float * x)
237 {
238  if (train_type != Train_shared) {
239  train_type_t final_train_type;
240  final_train_type = train_type;
241  if (train_type == Train_hypercube ||
242  train_type == Train_hypercube_pca) {
243  if (dsub < nbits) {
244  final_train_type = Train_default;
245  printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
246  nbits, dsub);
247  }
248  }
249 
250  float * xslice = new float[n * dsub];
251  ScopeDeleter<float> del (xslice);
252  for (int m = 0; m < M; m++) {
253  for (int j = 0; j < n; j++)
254  memcpy (xslice + j * dsub,
255  x + j * d + m * dsub,
256  dsub * sizeof(float));
257 
258  Clustering clus (dsub, ksub, cp);
259 
260  // we have some initialization for the centroids
261  if (final_train_type != Train_default) {
262  clus.centroids.resize (dsub * ksub);
263  }
264 
265  switch (final_train_type) {
266  case Train_hypercube:
267  init_hypercube (dsub, nbits, n, xslice,
268  clus.centroids.data ());
269  break;
270  case Train_hypercube_pca:
271  init_hypercube_pca (dsub, nbits, n, xslice,
272  clus.centroids.data ());
273  break;
274  case Train_hot_start:
275  memcpy (clus.centroids.data(),
276  get_centroids (m, 0),
277  dsub * ksub * sizeof (float));
278  break;
279  default: ;
280  }
281 
282  if(verbose) {
283  clus.verbose = true;
284  printf ("Training PQ slice %d/%zd\n", m, M);
285  }
286  IndexFlatL2 index (dsub);
287  clus.train (n, xslice, index);
288  set_params (clus.centroids.data(), m);
289  }
290 
291 
292  } else {
293 
294  Clustering clus (dsub, ksub, cp);
295 
296  if(verbose) {
297  clus.verbose = true;
298  printf ("Training all PQ slices at once\n");
299  }
300 
301  IndexFlatL2 index (dsub);
302  clus.train (n * M, x, index);
303  for (int m = 0; m < M; m++) {
304  set_params (clus.centroids.data(), m);
305  }
306 
307  }
308 }
309 
310 
311 void ProductQuantizer::compute_code (const float * x, uint8_t * code) const
312 {
313  float distances [ksub];
314  for (size_t m = 0; m < M; m++) {
315  float mindis = 1e20;
316  int idxm = -1;
317  const float * xsub = x + m * dsub;
318 
319  fvec_L2sqr_ny (distances, xsub, get_centroids(m, 0), dsub, ksub);
320 
321  /* Find best centroid */
322  size_t i;
323  for (i = 0; i < ksub; i++) {
324  float dis = distances [i];
325  if (dis < mindis) {
326  mindis = dis;
327  idxm = i;
328  }
329  }
330  switch (byte_per_idx) {
331  case 1: code[m] = (uint8_t) idxm; break;
332  case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
333  }
334  }
335 
336 }
337 
338 void ProductQuantizer::decode (const uint8_t *code, float *x) const
339 {
340  if (byte_per_idx == 1) {
341  for (size_t m = 0; m < M; m++) {
342  memcpy (x + m * dsub, get_centroids(m, code[m]),
343  sizeof(float) * dsub);
344  }
345  } else {
346  const uint16_t *c = (const uint16_t*) code;
347  for (size_t m = 0; m < M; m++) {
348  memcpy (x + m * dsub, get_centroids(m, c[m]),
349  sizeof(float) * dsub);
350  }
351  }
352 }
353 
354 
355 void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
356 {
357  for (size_t i = 0; i < n; i++) {
358  this->decode (code + M * i, x + d * i);
359  }
360 }
361 
362 
364  uint8_t *code) const
365 {
366  for (size_t m = 0; m < M; m++) {
367  float mindis = 1e20;
368  int idxm = -1;
369 
370  /* Find best centroid */
371  for (size_t j = 0; j < ksub; j++) {
372  float dis = *tab++;
373  if (dis < mindis) {
374  mindis = dis;
375  idxm = j;
376  }
377  }
378  switch (byte_per_idx) {
379  case 1: code[m] = (uint8_t) idxm; break;
380  case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
381  }
382  }
383 }
384 
385 void ProductQuantizer::compute_codes (const float * x,
386  uint8_t * codes,
387  size_t n) const
388 {
389  if (dsub < 16) { // simple direct computation
390 
391 #pragma omp parallel for
392  for (size_t i = 0; i < n; i++)
393  compute_code (x + i * d, codes + i * code_size);
394 
395  } else { // worthwile to use BLAS
396  float *dis_tables = new float [n * ksub * M];
397  ScopeDeleter<float> del (dis_tables);
398  compute_distance_tables (n, x, dis_tables);
399 
400 #pragma omp parallel for
401  for (size_t i = 0; i < n; i++) {
402  uint8_t * code = codes + i * code_size;
403  const float * tab = dis_tables + i * ksub * M;
405  }
406  }
407 }
408 
409 
411  float * dis_table) const
412 {
413  size_t m;
414 
415  for (m = 0; m < M; m++) {
416  fvec_L2sqr_ny (dis_table + m * ksub,
417  x + m * dsub,
418  get_centroids(m, 0),
419  dsub,
420  ksub);
421  }
422 }
423 
424 void ProductQuantizer::compute_inner_prod_table (const float * x,
425  float * dis_table) const
426 {
427  size_t m;
428 
429  for (m = 0; m < M; m++) {
430  fvec_inner_products_ny (dis_table + m * ksub,
431  x + m * dsub,
432  get_centroids(m, 0),
433  dsub,
434  ksub);
435  }
436 }
437 
438 
440  size_t nx,
441  const float * x,
442  float * dis_tables) const
443 {
444 
445  if (dsub < 16) {
446 
447 #pragma omp parallel for
448  for (size_t i = 0; i < nx; i++) {
449  compute_distance_table (x + i * d, dis_tables + i * ksub * M);
450  }
451 
452  } else { // use BLAS
453 
454  for (int m = 0; m < M; m++) {
455  pairwise_L2sqr (dsub,
456  nx, x + dsub * m,
457  ksub, centroids.data() + m * dsub * ksub,
458  dis_tables + ksub * m,
459  d, dsub, ksub * M);
460  }
461  }
462 }
463 
464 void ProductQuantizer::compute_inner_prod_tables (
465  size_t nx,
466  const float * x,
467  float * dis_tables) const
468 {
469 
470  if (dsub < 16) {
471 
472 #pragma omp parallel for
473  for (size_t i = 0; i < nx; i++) {
474  compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
475  }
476 
477  } else { // use BLAS
478 
479  // compute distance tables
480  for (int m = 0; m < M; m++) {
481  FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
482  dsubi = dsub, di = d;
483  float one = 1.0, zero = 0;
484 
485  sgemm_ ("Transposed", "Not transposed",
486  &ksubi, &nxi, &dsubi,
487  &one, &centroids [m * dsub * ksub], &dsubi,
488  x + dsub * m, &di,
489  &zero, dis_tables + ksub * m, &ldc);
490  }
491 
492  }
493 }
494 
495 template <typename CT, class C>
496 static void pq_knn_search_with_tables (
497  const ProductQuantizer * pq,
498  const float *dis_tables,
499  const uint8_t * codes,
500  const size_t ncodes,
501  HeapArray<C> * res,
502  bool init_finalize_heap)
503 {
504  size_t k = res->k, nx = res->nh;
505  size_t ksub = pq->ksub, M = pq->M;
506 
507 
508 #pragma omp parallel for
509  for (size_t i = 0; i < nx; i++) {
510  /* query preparation for asymmetric search: compute look-up tables */
511  const float* dis_table = dis_tables + i * ksub * M;
512 
513  /* Compute distances and keep smallest values */
514  long * __restrict heap_ids = res->ids + i * k;
515  float * __restrict heap_dis = res->val + i * k;
516 
517  if (init_finalize_heap) {
518  heap_heapify<C> (k, heap_dis, heap_ids);
519  }
520 
521  pq_estimators_from_tables<CT, C> (pq,
522  (CT*)codes, ncodes,
523  dis_table,
524  k, heap_dis, heap_ids);
525  if (init_finalize_heap) {
526  heap_reorder<C> (k, heap_dis, heap_ids);
527  }
528  }
529 }
530 
531  /*
532 static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
533  const CT * codes,
534  size_t ncodes,
535  const float * dis_table,
536  size_t k,
537  float * heap_dis,
538  long * heap_ids)
539  */
540 void ProductQuantizer::search (const float * __restrict x,
541  size_t nx,
542  const uint8_t * codes,
543  const size_t ncodes,
544  float_maxheap_array_t * res,
545  bool init_finalize_heap) const
546 {
547  FAISS_THROW_IF_NOT (nx == res->nh);
548  float * dis_tables = new float [nx * ksub * M];
549  ScopeDeleter<float> del(dis_tables);
550  compute_distance_tables (nx, x, dis_tables);
551 
552  if (byte_per_idx == 1) {
553 
554  pq_knn_search_with_tables<uint8_t, CMax<float, long> > (
555  this, dis_tables, codes, ncodes, res, init_finalize_heap);
556 
557  } else if (byte_per_idx == 2) {
558  pq_knn_search_with_tables<uint16_t, CMax<float, long> > (
559  this, dis_tables, codes, ncodes, res, init_finalize_heap);
560 
561  }
562 
563 }
564 
565 void ProductQuantizer::search_ip (const float * __restrict x,
566  size_t nx,
567  const uint8_t * codes,
568  const size_t ncodes,
569  float_minheap_array_t * res,
570  bool init_finalize_heap) const
571 {
572  FAISS_THROW_IF_NOT (nx == res->nh);
573  float * dis_tables = new float [nx * ksub * M];
574  ScopeDeleter<float> del(dis_tables);
575  compute_inner_prod_tables (nx, x, dis_tables);
576 
577  if (byte_per_idx == 1) {
578 
579  pq_knn_search_with_tables<uint8_t, CMin<float, long> > (
580  this, dis_tables, codes, ncodes, res, init_finalize_heap);
581 
582  } else if (byte_per_idx == 2) {
583  pq_knn_search_with_tables<uint16_t, CMin<float, long> > (
584  this, dis_tables, codes, ncodes, res, init_finalize_heap);
585  }
586 
587 }
588 
589 
590 
591 static float sqr (float x) {
592  return x * x;
593 }
594 
595 void ProductQuantizer::compute_sdc_table ()
596 {
597  sdc_table.resize (M * ksub * ksub);
598 
599  for (int m = 0; m < M; m++) {
600 
601  const float *cents = centroids.data() + m * ksub * dsub;
602  float * dis_tab = sdc_table.data() + m * ksub * ksub;
603 
604  // TODO optimize with BLAS
605  for (int i = 0; i < ksub; i++) {
606  const float *centi = cents + i * dsub;
607  for (int j = 0; j < ksub; j++) {
608  float accu = 0;
609  const float *centj = cents + j * dsub;
610  for (int k = 0; k < dsub; k++)
611  accu += sqr (centi[k] - centj[k]);
612  dis_tab [i + j * ksub] = accu;
613  }
614  }
615  }
616 }
617 
618 void ProductQuantizer::search_sdc (const uint8_t * qcodes,
619  size_t nq,
620  const uint8_t * bcodes,
621  const size_t nb,
622  float_maxheap_array_t * res,
623  bool init_finalize_heap) const
624 {
625  FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub);
626  FAISS_THROW_IF_NOT (byte_per_idx == 1);
627  size_t k = res->k;
628 
629 
630 #pragma omp parallel for
631  for (size_t i = 0; i < nq; i++) {
632 
633  /* Compute distances and keep smallest values */
634  long * heap_ids = res->ids + i * k;
635  float * heap_dis = res->val + i * k;
636  const uint8_t * qcode = qcodes + i * code_size;
637 
638  if (init_finalize_heap)
639  maxheap_heapify (k, heap_dis, heap_ids);
640 
641  const uint8_t * bcode = bcodes;
642  for (size_t j = 0; j < nb; j++) {
643  float dis = 0;
644  const float * tab = sdc_table.data();
645  for (int m = 0; m < M; m++) {
646  dis += tab[bcode[m] + qcode[m] * ksub];
647  tab += ksub * ksub;
648  }
649  if (dis < heap_dis[0]) {
650  maxheap_pop (k, heap_dis, heap_ids);
651  maxheap_push (k, heap_dis, heap_ids, dis, j);
652  }
653  bcode += code_size;
654  }
655 
656  if (init_finalize_heap)
657  maxheap_reorder (k, heap_dis, heap_ids);
658  }
659 
660 }
661 
662 
663 
664 
665 
666 
667 } // namespace faiss
void set_params(const float *centroids, int m)
Define the centroids for subquantizer m.
intialize centroids with nbits-D hypercube
size_t nbits
number of bits per quantization index
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
size_t byte_per_idx
nb bytes per code component (1 or 2)
intialize centroids with nbits-D hypercube
void set_derived_values()
compute derived values when d, M and nbits have been set
std::vector< float > sdc_table
Symmetric Distance Table.
share dictionary accross PQ segments
size_t dsub
dimensionality of each subvector
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
void compute_distance_table(const float *x, float *dis_table) const
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
size_t code_size
byte per indexed vector
size_t ksub
number of centroids for each subquantizer
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
void pairwise_L2sqr(long d, long nq, const float *xq, long nb, const float *xb, float *dis, long ldq, long ldb, long ldd)
Definition: utils.cpp:1311
void compute_code(const float *x, uint8_t *code) const
Quantize one vector with the product quantizer.
the centroids are already initialized
ClusteringParameters cp
parameters used during clustering
size_t nh
number of heaps
Definition: Heap.h:354
size_t M
number of subquantizers
float * get_centroids(size_t m, size_t i)
return the centroids associated with subvector m
size_t d
size of the input vectors
bool verbose
verbose during training?
std::vector< float > centroids
Centroid table, size M * ksub * dsub.
train_type_t
initialization