Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/matthijs/github_faiss/faiss/ProductQuantizer.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 /* Copyright 2004-present Facebook. All Rights Reserved.
10  Index based on product quantiztion.
11 */
12 
13 #include "ProductQuantizer.h"
14 
15 
16 #include <cstddef>
17 #include <cstring>
18 #include <cstdio>
19 
20 #include <algorithm>
21 
22 #include "FaissAssert.h"
23 #include "VectorTransform.h"
24 #include "IndexFlat.h"
25 #include "utils.h"
26 
27 
28 extern "C" {
29 
30 /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
31 
32 int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
33  n, FINTEGER *k, const float *alpha, const float *a,
34  FINTEGER *lda, const float *b, FINTEGER *
35  ldb, float *beta, float *c, FINTEGER *ldc);
36 
37 }
38 
39 
40 namespace faiss {
41 
42 
43 
44 
45 /* compute an estimator using look-up tables for typical values of M */
46 template <typename CT, class C>
47 void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
48  size_t ncodes,
49  const float * __restrict dis_table,
50  size_t ksub,
51  size_t k,
52  float * heap_dis,
53  long * heap_ids)
54 {
55 
56  for (size_t j = 0; j < ncodes; j++) {
57  float dis = 0;
58  const float *dt = dis_table;
59 
60  for (size_t m = 0; m < M; m+=4) {
61  float dism = 0;
62  dism = dt[*codes++]; dt += ksub;
63  dism += dt[*codes++]; dt += ksub;
64  dism += dt[*codes++]; dt += ksub;
65  dism += dt[*codes++]; dt += ksub;
66  dis += dism;
67  }
68 
69  if (C::cmp (heap_dis[0], dis)) {
70  heap_pop<C> (k, heap_dis, heap_ids);
71  heap_push<C> (k, heap_dis, heap_ids, dis, j);
72  }
73  }
74 }
75 
76 
77 template <typename CT, class C>
78 void pq_estimators_from_tables_M4 (const CT * codes,
79  size_t ncodes,
80  const float * __restrict dis_table,
81  size_t ksub,
82  size_t k,
83  float * heap_dis,
84  long * heap_ids)
85 {
86 
87  for (size_t j = 0; j < ncodes; j++) {
88  float dis = 0;
89  const float *dt = dis_table;
90  dis = dt[*codes++]; dt += ksub;
91  dis += dt[*codes++]; dt += ksub;
92  dis += dt[*codes++]; dt += ksub;
93  dis += dt[*codes++];
94 
95  if (C::cmp (heap_dis[0], dis)) {
96  heap_pop<C> (k, heap_dis, heap_ids);
97  heap_push<C> (k, heap_dis, heap_ids, dis, j);
98  }
99  }
100 }
101 
102 
103 template <typename CT, class C>
104 static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
105  const CT * codes,
106  size_t ncodes,
107  const float * dis_table,
108  size_t k,
109  float * heap_dis,
110  long * heap_ids)
111 {
112 
113  if (pq->M == 4) {
114 
115  pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
116  dis_table, pq->ksub, k,
117  heap_dis, heap_ids);
118  return;
119  }
120 
121  if (pq->M % 4 == 0) {
122  pq_estimators_from_tables_Mmul4<CT, C> (pq->M, codes, ncodes,
123  dis_table, pq->ksub, k,
124  heap_dis, heap_ids);
125  return;
126  }
127 
128  /* Default is relatively slow */
129  const size_t M = pq->M;
130  const size_t ksub = pq->ksub;
131  for (size_t j = 0; j < ncodes; j++) {
132  float dis = 0;
133  const float * __restrict dt = dis_table;
134  for (int m = 0; m < M; m++) {
135  dis += dt[*codes++];
136  dt += ksub;
137  }
138  if (C::cmp (heap_dis[0], dis)) {
139  heap_pop<C> (k, heap_dis, heap_ids);
140  heap_push<C> (k, heap_dis, heap_ids, dis, j);
141  }
142  }
143 }
144 
145 
146 /*********************************************
147  * PQ implementation
148  *********************************************/
149 
150 
151 
152 ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
153  d(d), M(M), nbits(nbits), assign_index(nullptr)
154 {
155  set_derived_values ();
156 }
157 
158 ProductQuantizer::ProductQuantizer ():
159  d(0), M(1), nbits(0), assign_index(nullptr)
160 {
161  set_derived_values ();
162 }
163 
164 
165 
167  // quite a few derived values
168  FAISS_THROW_IF_NOT (d % M == 0);
169  dsub = d / M;
170  byte_per_idx = (nbits + 7) / 8;
172  ksub = 1 << nbits;
173  centroids.resize (d * ksub);
174  verbose = false;
175  train_type = Train_default;
176 }
177 
178 
179 void ProductQuantizer::set_params (const float * centroids_, int m)
180 {
181  memcpy (get_centroids(m, 0), centroids_,
182  ksub * dsub * sizeof (centroids_[0]));
183 }
184 
185 
186 static void init_hypercube (int d, int nbits,
187  int n, const float * x,
188  float *centroids)
189 {
190 
191  std::vector<float> mean (d);
192  for (int i = 0; i < n; i++)
193  for (int j = 0; j < d; j++)
194  mean [j] += x[i * d + j];
195 
196  float maxm = 0;
197  for (int j = 0; j < d; j++) {
198  mean [j] /= n;
199  if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
200  }
201 
202  for (int i = 0; i < (1 << nbits); i++) {
203  float * cent = centroids + i * d;
204  for (int j = 0; j < nbits; j++)
205  cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
206  for (int j = nbits; j < d; j++)
207  cent[j] = mean [j];
208  }
209 
210 
211 }
212 
213 static void init_hypercube_pca (int d, int nbits,
214  int n, const float * x,
215  float *centroids)
216 {
217  PCAMatrix pca (d, nbits);
218  pca.train (n, x);
219 
220 
221  for (int i = 0; i < (1 << nbits); i++) {
222  float * cent = centroids + i * d;
223  for (int j = 0; j < d; j++) {
224  cent[j] = pca.mean[j];
225  float f = 1.0;
226  for (int k = 0; k < nbits; k++)
227  cent[j] += f *
228  sqrt (pca.eigenvalues [k]) *
229  (((i >> k) & 1) ? 1 : -1) *
230  pca.PCAMat [j + k * d];
231  }
232  }
233 
234 }
235 
236 void ProductQuantizer::train (int n, const float * x)
237 {
238  if (train_type != Train_shared) {
239  train_type_t final_train_type;
240  final_train_type = train_type;
241  if (train_type == Train_hypercube ||
242  train_type == Train_hypercube_pca) {
243  if (dsub < nbits) {
244  final_train_type = Train_default;
245  printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
246  nbits, dsub);
247  }
248  }
249 
250  float * xslice = new float[n * dsub];
251  ScopeDeleter<float> del (xslice);
252  for (int m = 0; m < M; m++) {
253  for (int j = 0; j < n; j++)
254  memcpy (xslice + j * dsub,
255  x + j * d + m * dsub,
256  dsub * sizeof(float));
257 
258  Clustering clus (dsub, ksub, cp);
259 
260  // we have some initialization for the centroids
261  if (final_train_type != Train_default) {
262  clus.centroids.resize (dsub * ksub);
263  }
264 
265  switch (final_train_type) {
266  case Train_hypercube:
267  init_hypercube (dsub, nbits, n, xslice,
268  clus.centroids.data ());
269  break;
270  case Train_hypercube_pca:
271  init_hypercube_pca (dsub, nbits, n, xslice,
272  clus.centroids.data ());
273  break;
274  case Train_hot_start:
275  memcpy (clus.centroids.data(),
276  get_centroids (m, 0),
277  dsub * ksub * sizeof (float));
278  break;
279  default: ;
280  }
281 
282  if(verbose) {
283  clus.verbose = true;
284  printf ("Training PQ slice %d/%zd\n", m, M);
285  }
286  IndexFlatL2 index (dsub);
287  clus.train (n, xslice, assign_index ? *assign_index : index);
288  set_params (clus.centroids.data(), m);
289  }
290 
291 
292  } else {
293 
294  Clustering clus (dsub, ksub, cp);
295 
296  if(verbose) {
297  clus.verbose = true;
298  printf ("Training all PQ slices at once\n");
299  }
300 
301  IndexFlatL2 index (dsub);
302 
303  clus.train (n * M, x, assign_index ? *assign_index : index);
304  for (int m = 0; m < M; m++) {
305  set_params (clus.centroids.data(), m);
306  }
307 
308  }
309 }
310 
311 
312 void ProductQuantizer::compute_code (const float * x, uint8_t * code) const
313 {
314  float distances [ksub];
315  for (size_t m = 0; m < M; m++) {
316  float mindis = 1e20;
317  int idxm = -1;
318  const float * xsub = x + m * dsub;
319 
320  fvec_L2sqr_ny (distances, xsub, get_centroids(m, 0), dsub, ksub);
321 
322  /* Find best centroid */
323  size_t i;
324  for (i = 0; i < ksub; i++) {
325  float dis = distances [i];
326  if (dis < mindis) {
327  mindis = dis;
328  idxm = i;
329  }
330  }
331  switch (byte_per_idx) {
332  case 1: code[m] = (uint8_t) idxm; break;
333  case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
334  }
335  }
336 
337 }
338 
339 void ProductQuantizer::decode (const uint8_t *code, float *x) const
340 {
341  if (byte_per_idx == 1) {
342  for (size_t m = 0; m < M; m++) {
343  memcpy (x + m * dsub, get_centroids(m, code[m]),
344  sizeof(float) * dsub);
345  }
346  } else {
347  const uint16_t *c = (const uint16_t*) code;
348  for (size_t m = 0; m < M; m++) {
349  memcpy (x + m * dsub, get_centroids(m, c[m]),
350  sizeof(float) * dsub);
351  }
352  }
353 }
354 
355 
356 void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
357 {
358  for (size_t i = 0; i < n; i++) {
359  this->decode (code + code_size * i, x + d * i);
360  }
361 }
362 
363 
365  uint8_t *code) const
366 {
367  for (size_t m = 0; m < M; m++) {
368  float mindis = 1e20;
369  int idxm = -1;
370 
371  /* Find best centroid */
372  for (size_t j = 0; j < ksub; j++) {
373  float dis = *tab++;
374  if (dis < mindis) {
375  mindis = dis;
376  idxm = j;
377  }
378  }
379  switch (byte_per_idx) {
380  case 1: code[m] = (uint8_t) idxm; break;
381  case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
382  }
383  }
384 }
385 
386 void ProductQuantizer::compute_codes (const float * x,
387  uint8_t * codes,
388  size_t n) const
389 {
390  if (dsub < 16) { // simple direct computation
391 
392 #pragma omp parallel for
393  for (size_t i = 0; i < n; i++)
394  compute_code (x + i * d, codes + i * code_size);
395 
396  } else { // worthwile to use BLAS
397  float *dis_tables = new float [n * ksub * M];
398  ScopeDeleter<float> del (dis_tables);
399  compute_distance_tables (n, x, dis_tables);
400 
401 #pragma omp parallel for
402  for (size_t i = 0; i < n; i++) {
403  uint8_t * code = codes + i * code_size;
404  const float * tab = dis_tables + i * ksub * M;
406  }
407  }
408 }
409 
410 
412  float * dis_table) const
413 {
414  size_t m;
415 
416  for (m = 0; m < M; m++) {
417  fvec_L2sqr_ny (dis_table + m * ksub,
418  x + m * dsub,
419  get_centroids(m, 0),
420  dsub,
421  ksub);
422  }
423 }
424 
425 void ProductQuantizer::compute_inner_prod_table (const float * x,
426  float * dis_table) const
427 {
428  size_t m;
429 
430  for (m = 0; m < M; m++) {
431  fvec_inner_products_ny (dis_table + m * ksub,
432  x + m * dsub,
433  get_centroids(m, 0),
434  dsub,
435  ksub);
436  }
437 }
438 
439 
441  size_t nx,
442  const float * x,
443  float * dis_tables) const
444 {
445 
446  if (dsub < 16) {
447 
448 #pragma omp parallel for
449  for (size_t i = 0; i < nx; i++) {
450  compute_distance_table (x + i * d, dis_tables + i * ksub * M);
451  }
452 
453  } else { // use BLAS
454 
455  for (int m = 0; m < M; m++) {
456  pairwise_L2sqr (dsub,
457  nx, x + dsub * m,
458  ksub, centroids.data() + m * dsub * ksub,
459  dis_tables + ksub * m,
460  d, dsub, ksub * M);
461  }
462  }
463 }
464 
465 void ProductQuantizer::compute_inner_prod_tables (
466  size_t nx,
467  const float * x,
468  float * dis_tables) const
469 {
470 
471  if (dsub < 16) {
472 
473 #pragma omp parallel for
474  for (size_t i = 0; i < nx; i++) {
475  compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
476  }
477 
478  } else { // use BLAS
479 
480  // compute distance tables
481  for (int m = 0; m < M; m++) {
482  FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
483  dsubi = dsub, di = d;
484  float one = 1.0, zero = 0;
485 
486  sgemm_ ("Transposed", "Not transposed",
487  &ksubi, &nxi, &dsubi,
488  &one, &centroids [m * dsub * ksub], &dsubi,
489  x + dsub * m, &di,
490  &zero, dis_tables + ksub * m, &ldc);
491  }
492 
493  }
494 }
495 
496 template <typename CT, class C>
497 static void pq_knn_search_with_tables (
498  const ProductQuantizer * pq,
499  const float *dis_tables,
500  const uint8_t * codes,
501  const size_t ncodes,
502  HeapArray<C> * res,
503  bool init_finalize_heap)
504 {
505  size_t k = res->k, nx = res->nh;
506  size_t ksub = pq->ksub, M = pq->M;
507 
508 
509 #pragma omp parallel for
510  for (size_t i = 0; i < nx; i++) {
511  /* query preparation for asymmetric search: compute look-up tables */
512  const float* dis_table = dis_tables + i * ksub * M;
513 
514  /* Compute distances and keep smallest values */
515  long * __restrict heap_ids = res->ids + i * k;
516  float * __restrict heap_dis = res->val + i * k;
517 
518  if (init_finalize_heap) {
519  heap_heapify<C> (k, heap_dis, heap_ids);
520  }
521 
522  pq_estimators_from_tables<CT, C> (pq,
523  (CT*)codes, ncodes,
524  dis_table,
525  k, heap_dis, heap_ids);
526  if (init_finalize_heap) {
527  heap_reorder<C> (k, heap_dis, heap_ids);
528  }
529  }
530 }
531 
532  /*
533 static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
534  const CT * codes,
535  size_t ncodes,
536  const float * dis_table,
537  size_t k,
538  float * heap_dis,
539  long * heap_ids)
540  */
541 void ProductQuantizer::search (const float * __restrict x,
542  size_t nx,
543  const uint8_t * codes,
544  const size_t ncodes,
545  float_maxheap_array_t * res,
546  bool init_finalize_heap) const
547 {
548  FAISS_THROW_IF_NOT (nx == res->nh);
549  float * dis_tables = new float [nx * ksub * M];
550  ScopeDeleter<float> del(dis_tables);
551  compute_distance_tables (nx, x, dis_tables);
552 
553  if (byte_per_idx == 1) {
554 
555  pq_knn_search_with_tables<uint8_t, CMax<float, long> > (
556  this, dis_tables, codes, ncodes, res, init_finalize_heap);
557 
558  } else if (byte_per_idx == 2) {
559  pq_knn_search_with_tables<uint16_t, CMax<float, long> > (
560  this, dis_tables, codes, ncodes, res, init_finalize_heap);
561 
562  }
563 
564 }
565 
566 void ProductQuantizer::search_ip (const float * __restrict x,
567  size_t nx,
568  const uint8_t * codes,
569  const size_t ncodes,
570  float_minheap_array_t * res,
571  bool init_finalize_heap) const
572 {
573  FAISS_THROW_IF_NOT (nx == res->nh);
574  float * dis_tables = new float [nx * ksub * M];
575  ScopeDeleter<float> del(dis_tables);
576  compute_inner_prod_tables (nx, x, dis_tables);
577 
578  if (byte_per_idx == 1) {
579 
580  pq_knn_search_with_tables<uint8_t, CMin<float, long> > (
581  this, dis_tables, codes, ncodes, res, init_finalize_heap);
582 
583  } else if (byte_per_idx == 2) {
584  pq_knn_search_with_tables<uint16_t, CMin<float, long> > (
585  this, dis_tables, codes, ncodes, res, init_finalize_heap);
586  }
587 
588 }
589 
590 
591 
592 static float sqr (float x) {
593  return x * x;
594 }
595 
596 void ProductQuantizer::compute_sdc_table ()
597 {
598  sdc_table.resize (M * ksub * ksub);
599 
600  for (int m = 0; m < M; m++) {
601 
602  const float *cents = centroids.data() + m * ksub * dsub;
603  float * dis_tab = sdc_table.data() + m * ksub * ksub;
604 
605  // TODO optimize with BLAS
606  for (int i = 0; i < ksub; i++) {
607  const float *centi = cents + i * dsub;
608  for (int j = 0; j < ksub; j++) {
609  float accu = 0;
610  const float *centj = cents + j * dsub;
611  for (int k = 0; k < dsub; k++)
612  accu += sqr (centi[k] - centj[k]);
613  dis_tab [i + j * ksub] = accu;
614  }
615  }
616  }
617 }
618 
619 void ProductQuantizer::search_sdc (const uint8_t * qcodes,
620  size_t nq,
621  const uint8_t * bcodes,
622  const size_t nb,
623  float_maxheap_array_t * res,
624  bool init_finalize_heap) const
625 {
626  FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub);
627  FAISS_THROW_IF_NOT (byte_per_idx == 1);
628  size_t k = res->k;
629 
630 
631 #pragma omp parallel for
632  for (size_t i = 0; i < nq; i++) {
633 
634  /* Compute distances and keep smallest values */
635  long * heap_ids = res->ids + i * k;
636  float * heap_dis = res->val + i * k;
637  const uint8_t * qcode = qcodes + i * code_size;
638 
639  if (init_finalize_heap)
640  maxheap_heapify (k, heap_dis, heap_ids);
641 
642  const uint8_t * bcode = bcodes;
643  for (size_t j = 0; j < nb; j++) {
644  float dis = 0;
645  const float * tab = sdc_table.data();
646  for (int m = 0; m < M; m++) {
647  dis += tab[bcode[m] + qcode[m] * ksub];
648  tab += ksub * ksub;
649  }
650  if (dis < heap_dis[0]) {
651  maxheap_pop (k, heap_dis, heap_ids);
652  maxheap_push (k, heap_dis, heap_ids, dis, j);
653  }
654  bcode += code_size;
655  }
656 
657  if (init_finalize_heap)
658  maxheap_reorder (k, heap_dis, heap_ids);
659  }
660 
661 }
662 
663 
664 
665 
666 
667 
668 } // namespace faiss
void set_params(const float *centroids, int m)
Define the centroids for subquantizer m.
intialize centroids with nbits-D hypercube
size_t nbits
number of bits per quantization index
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
size_t byte_per_idx
nb bytes per code component (1 or 2)
intialize centroids with nbits-D hypercube
void set_derived_values()
compute derived values when d, M and nbits have been set
std::vector< float > sdc_table
Symmetric Distance Table.
share dictionary accross PQ segments
size_t dsub
dimensionality of each subvector
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
void compute_distance_table(const float *x, float *dis_table) const
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
size_t code_size
byte per indexed vector
size_t ksub
number of centroids for each subquantizer
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
void pairwise_L2sqr(long d, long nq, const float *xq, long nb, const float *xb, float *dis, long ldq, long ldb, long ldd)
Definition: utils.cpp:1344
void compute_code(const float *x, uint8_t *code) const
Quantize one vector with the product quantizer.
the centroids are already initialized
ClusteringParameters cp
parameters used during clustering
size_t nh
number of heaps
Definition: Heap.h:354
size_t M
number of subquantizers
float * get_centroids(size_t m, size_t i)
return the centroids associated with subvector m
size_t d
size of the input vectors
bool verbose
verbose during training?
std::vector< float > centroids
Centroid table, size M * ksub * dsub.
train_type_t
initialization