Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/tmp/faiss/ProductQuantizer.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // -*- c++ -*-
10 
11 #include "ProductQuantizer.h"
12 
13 
14 #include <cstddef>
15 #include <cstring>
16 #include <cstdio>
17 
18 #include <algorithm>
19 
20 #include "FaissAssert.h"
21 #include "VectorTransform.h"
22 #include "IndexFlat.h"
23 #include "utils.h"
24 
25 
26 extern "C" {
27 
28 /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
29 
30 int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
31  n, FINTEGER *k, const float *alpha, const float *a,
32  FINTEGER *lda, const float *b, FINTEGER *
33  ldb, float *beta, float *c, FINTEGER *ldc);
34 
35 }
36 
37 
38 namespace faiss {
39 
40 
41 /* compute an estimator using look-up tables for typical values of M */
42 template <typename CT, class C>
43 void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
44  size_t ncodes,
45  const float * __restrict dis_table,
46  size_t ksub,
47  size_t k,
48  float * heap_dis,
49  long * heap_ids)
50 {
51 
52  for (size_t j = 0; j < ncodes; j++) {
53  float dis = 0;
54  const float *dt = dis_table;
55 
56  for (size_t m = 0; m < M; m+=4) {
57  float dism = 0;
58  dism = dt[*codes++]; dt += ksub;
59  dism += dt[*codes++]; dt += ksub;
60  dism += dt[*codes++]; dt += ksub;
61  dism += dt[*codes++]; dt += ksub;
62  dis += dism;
63  }
64 
65  if (C::cmp (heap_dis[0], dis)) {
66  heap_pop<C> (k, heap_dis, heap_ids);
67  heap_push<C> (k, heap_dis, heap_ids, dis, j);
68  }
69  }
70 }
71 
72 
73 template <typename CT, class C>
74 void pq_estimators_from_tables_M4 (const CT * codes,
75  size_t ncodes,
76  const float * __restrict dis_table,
77  size_t ksub,
78  size_t k,
79  float * heap_dis,
80  long * heap_ids)
81 {
82 
83  for (size_t j = 0; j < ncodes; j++) {
84  float dis = 0;
85  const float *dt = dis_table;
86  dis = dt[*codes++]; dt += ksub;
87  dis += dt[*codes++]; dt += ksub;
88  dis += dt[*codes++]; dt += ksub;
89  dis += dt[*codes++];
90 
91  if (C::cmp (heap_dis[0], dis)) {
92  heap_pop<C> (k, heap_dis, heap_ids);
93  heap_push<C> (k, heap_dis, heap_ids, dis, j);
94  }
95  }
96 }
97 
98 
99 template <typename CT, class C>
100 static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
101  const CT * codes,
102  size_t ncodes,
103  const float * dis_table,
104  size_t k,
105  float * heap_dis,
106  long * heap_ids)
107 {
108 
109  if (pq->M == 4) {
110 
111  pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
112  dis_table, pq->ksub, k,
113  heap_dis, heap_ids);
114  return;
115  }
116 
117  if (pq->M % 4 == 0) {
118  pq_estimators_from_tables_Mmul4<CT, C> (pq->M, codes, ncodes,
119  dis_table, pq->ksub, k,
120  heap_dis, heap_ids);
121  return;
122  }
123 
124  /* Default is relatively slow */
125  const size_t M = pq->M;
126  const size_t ksub = pq->ksub;
127  for (size_t j = 0; j < ncodes; j++) {
128  float dis = 0;
129  const float * __restrict dt = dis_table;
130  for (int m = 0; m < M; m++) {
131  dis += dt[*codes++];
132  dt += ksub;
133  }
134  if (C::cmp (heap_dis[0], dis)) {
135  heap_pop<C> (k, heap_dis, heap_ids);
136  heap_push<C> (k, heap_dis, heap_ids, dis, j);
137  }
138  }
139 }
140 
141 
142 /*********************************************
143  * PQ implementation
144  *********************************************/
145 
146 
147 
148 ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
149  d(d), M(M), nbits(nbits), assign_index(nullptr)
150 {
151  set_derived_values ();
152 }
153 
154 ProductQuantizer::ProductQuantizer ():
155  d(0), M(1), nbits(0), assign_index(nullptr)
156 {
157  set_derived_values ();
158 }
159 
160 
161 
163  // quite a few derived values
164  FAISS_THROW_IF_NOT (d % M == 0);
165  dsub = d / M;
166  byte_per_idx = (nbits + 7) / 8;
168  ksub = 1 << nbits;
169  centroids.resize (d * ksub);
170  verbose = false;
171  train_type = Train_default;
172 }
173 
174 
175 void ProductQuantizer::set_params (const float * centroids_, int m)
176 {
177  memcpy (get_centroids(m, 0), centroids_,
178  ksub * dsub * sizeof (centroids_[0]));
179 }
180 
181 
182 static void init_hypercube (int d, int nbits,
183  int n, const float * x,
184  float *centroids)
185 {
186 
187  std::vector<float> mean (d);
188  for (int i = 0; i < n; i++)
189  for (int j = 0; j < d; j++)
190  mean [j] += x[i * d + j];
191 
192  float maxm = 0;
193  for (int j = 0; j < d; j++) {
194  mean [j] /= n;
195  if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
196  }
197 
198  for (int i = 0; i < (1 << nbits); i++) {
199  float * cent = centroids + i * d;
200  for (int j = 0; j < nbits; j++)
201  cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
202  for (int j = nbits; j < d; j++)
203  cent[j] = mean [j];
204  }
205 
206 
207 }
208 
209 static void init_hypercube_pca (int d, int nbits,
210  int n, const float * x,
211  float *centroids)
212 {
213  PCAMatrix pca (d, nbits);
214  pca.train (n, x);
215 
216 
217  for (int i = 0; i < (1 << nbits); i++) {
218  float * cent = centroids + i * d;
219  for (int j = 0; j < d; j++) {
220  cent[j] = pca.mean[j];
221  float f = 1.0;
222  for (int k = 0; k < nbits; k++)
223  cent[j] += f *
224  sqrt (pca.eigenvalues [k]) *
225  (((i >> k) & 1) ? 1 : -1) *
226  pca.PCAMat [j + k * d];
227  }
228  }
229 
230 }
231 
232 void ProductQuantizer::train (int n, const float * x)
233 {
234  if (train_type != Train_shared) {
235  train_type_t final_train_type;
236  final_train_type = train_type;
237  if (train_type == Train_hypercube ||
238  train_type == Train_hypercube_pca) {
239  if (dsub < nbits) {
240  final_train_type = Train_default;
241  printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
242  nbits, dsub);
243  }
244  }
245 
246  float * xslice = new float[n * dsub];
247  ScopeDeleter<float> del (xslice);
248  for (int m = 0; m < M; m++) {
249  for (int j = 0; j < n; j++)
250  memcpy (xslice + j * dsub,
251  x + j * d + m * dsub,
252  dsub * sizeof(float));
253 
254  Clustering clus (dsub, ksub, cp);
255 
256  // we have some initialization for the centroids
257  if (final_train_type != Train_default) {
258  clus.centroids.resize (dsub * ksub);
259  }
260 
261  switch (final_train_type) {
262  case Train_hypercube:
263  init_hypercube (dsub, nbits, n, xslice,
264  clus.centroids.data ());
265  break;
266  case Train_hypercube_pca:
267  init_hypercube_pca (dsub, nbits, n, xslice,
268  clus.centroids.data ());
269  break;
270  case Train_hot_start:
271  memcpy (clus.centroids.data(),
272  get_centroids (m, 0),
273  dsub * ksub * sizeof (float));
274  break;
275  default: ;
276  }
277 
278  if(verbose) {
279  clus.verbose = true;
280  printf ("Training PQ slice %d/%zd\n", m, M);
281  }
282  IndexFlatL2 index (dsub);
283  clus.train (n, xslice, assign_index ? *assign_index : index);
284  set_params (clus.centroids.data(), m);
285  }
286 
287 
288  } else {
289 
290  Clustering clus (dsub, ksub, cp);
291 
292  if(verbose) {
293  clus.verbose = true;
294  printf ("Training all PQ slices at once\n");
295  }
296 
297  IndexFlatL2 index (dsub);
298 
299  clus.train (n * M, x, assign_index ? *assign_index : index);
300  for (int m = 0; m < M; m++) {
301  set_params (clus.centroids.data(), m);
302  }
303 
304  }
305 }
306 
307 
308 void ProductQuantizer::compute_code (const float * x, uint8_t * code) const
309 {
310  float distances [ksub];
311  for (size_t m = 0; m < M; m++) {
312  float mindis = 1e20;
313  int idxm = -1;
314  const float * xsub = x + m * dsub;
315 
316  fvec_L2sqr_ny (distances, xsub, get_centroids(m, 0), dsub, ksub);
317 
318  /* Find best centroid */
319  size_t i;
320  for (i = 0; i < ksub; i++) {
321  float dis = distances [i];
322  if (dis < mindis) {
323  mindis = dis;
324  idxm = i;
325  }
326  }
327  switch (byte_per_idx) {
328  case 1: code[m] = (uint8_t) idxm; break;
329  case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
330  }
331  }
332 
333 }
334 
335 void ProductQuantizer::decode (const uint8_t *code, float *x) const
336 {
337  if (byte_per_idx == 1) {
338  for (size_t m = 0; m < M; m++) {
339  memcpy (x + m * dsub, get_centroids(m, code[m]),
340  sizeof(float) * dsub);
341  }
342  } else {
343  const uint16_t *c = (const uint16_t*) code;
344  for (size_t m = 0; m < M; m++) {
345  memcpy (x + m * dsub, get_centroids(m, c[m]),
346  sizeof(float) * dsub);
347  }
348  }
349 }
350 
351 
352 void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
353 {
354  for (size_t i = 0; i < n; i++) {
355  this->decode (code + code_size * i, x + d * i);
356  }
357 }
358 
359 
361  uint8_t *code) const
362 {
363  for (size_t m = 0; m < M; m++) {
364  float mindis = 1e20;
365  int idxm = -1;
366 
367  /* Find best centroid */
368  for (size_t j = 0; j < ksub; j++) {
369  float dis = *tab++;
370  if (dis < mindis) {
371  mindis = dis;
372  idxm = j;
373  }
374  }
375  switch (byte_per_idx) {
376  case 1: code[m] = (uint8_t) idxm; break;
377  case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
378  }
379  }
380 }
381 
382 void ProductQuantizer::compute_codes (const float * x,
383  uint8_t * codes,
384  size_t n) const
385 {
386  if (dsub < 16) { // simple direct computation
387 
388 #pragma omp parallel for
389  for (size_t i = 0; i < n; i++)
390  compute_code (x + i * d, codes + i * code_size);
391 
392  } else { // worthwile to use BLAS
393  float *dis_tables = new float [n * ksub * M];
394  ScopeDeleter<float> del (dis_tables);
395  compute_distance_tables (n, x, dis_tables);
396 
397 #pragma omp parallel for
398  for (size_t i = 0; i < n; i++) {
399  uint8_t * code = codes + i * code_size;
400  const float * tab = dis_tables + i * ksub * M;
402  }
403  }
404 }
405 
406 
408  float * dis_table) const
409 {
410  size_t m;
411 
412  for (m = 0; m < M; m++) {
413  fvec_L2sqr_ny (dis_table + m * ksub,
414  x + m * dsub,
415  get_centroids(m, 0),
416  dsub,
417  ksub);
418  }
419 }
420 
421 void ProductQuantizer::compute_inner_prod_table (const float * x,
422  float * dis_table) const
423 {
424  size_t m;
425 
426  for (m = 0; m < M; m++) {
427  fvec_inner_products_ny (dis_table + m * ksub,
428  x + m * dsub,
429  get_centroids(m, 0),
430  dsub,
431  ksub);
432  }
433 }
434 
435 
437  size_t nx,
438  const float * x,
439  float * dis_tables) const
440 {
441 
442  if (dsub < 16) {
443 
444 #pragma omp parallel for
445  for (size_t i = 0; i < nx; i++) {
446  compute_distance_table (x + i * d, dis_tables + i * ksub * M);
447  }
448 
449  } else { // use BLAS
450 
451  for (int m = 0; m < M; m++) {
452  pairwise_L2sqr (dsub,
453  nx, x + dsub * m,
454  ksub, centroids.data() + m * dsub * ksub,
455  dis_tables + ksub * m,
456  d, dsub, ksub * M);
457  }
458  }
459 }
460 
461 void ProductQuantizer::compute_inner_prod_tables (
462  size_t nx,
463  const float * x,
464  float * dis_tables) const
465 {
466 
467  if (dsub < 16) {
468 
469 #pragma omp parallel for
470  for (size_t i = 0; i < nx; i++) {
471  compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
472  }
473 
474  } else { // use BLAS
475 
476  // compute distance tables
477  for (int m = 0; m < M; m++) {
478  FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
479  dsubi = dsub, di = d;
480  float one = 1.0, zero = 0;
481 
482  sgemm_ ("Transposed", "Not transposed",
483  &ksubi, &nxi, &dsubi,
484  &one, &centroids [m * dsub * ksub], &dsubi,
485  x + dsub * m, &di,
486  &zero, dis_tables + ksub * m, &ldc);
487  }
488 
489  }
490 }
491 
492 template <typename CT, class C>
493 static void pq_knn_search_with_tables (
494  const ProductQuantizer * pq,
495  const float *dis_tables,
496  const uint8_t * codes,
497  const size_t ncodes,
498  HeapArray<C> * res,
499  bool init_finalize_heap)
500 {
501  size_t k = res->k, nx = res->nh;
502  size_t ksub = pq->ksub, M = pq->M;
503 
504 
505 #pragma omp parallel for
506  for (size_t i = 0; i < nx; i++) {
507  /* query preparation for asymmetric search: compute look-up tables */
508  const float* dis_table = dis_tables + i * ksub * M;
509 
510  /* Compute distances and keep smallest values */
511  long * __restrict heap_ids = res->ids + i * k;
512  float * __restrict heap_dis = res->val + i * k;
513 
514  if (init_finalize_heap) {
515  heap_heapify<C> (k, heap_dis, heap_ids);
516  }
517 
518  pq_estimators_from_tables<CT, C> (pq,
519  (CT*)codes, ncodes,
520  dis_table,
521  k, heap_dis, heap_ids);
522  if (init_finalize_heap) {
523  heap_reorder<C> (k, heap_dis, heap_ids);
524  }
525  }
526 }
527 
528  /*
529 static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
530  const CT * codes,
531  size_t ncodes,
532  const float * dis_table,
533  size_t k,
534  float * heap_dis,
535  long * heap_ids)
536  */
537 void ProductQuantizer::search (const float * __restrict x,
538  size_t nx,
539  const uint8_t * codes,
540  const size_t ncodes,
541  float_maxheap_array_t * res,
542  bool init_finalize_heap) const
543 {
544  FAISS_THROW_IF_NOT (nx == res->nh);
545  float * dis_tables = new float [nx * ksub * M];
546  ScopeDeleter<float> del(dis_tables);
547  compute_distance_tables (nx, x, dis_tables);
548 
549  if (byte_per_idx == 1) {
550 
551  pq_knn_search_with_tables<uint8_t, CMax<float, long> > (
552  this, dis_tables, codes, ncodes, res, init_finalize_heap);
553 
554  } else if (byte_per_idx == 2) {
555  pq_knn_search_with_tables<uint16_t, CMax<float, long> > (
556  this, dis_tables, codes, ncodes, res, init_finalize_heap);
557 
558  }
559 
560 }
561 
562 void ProductQuantizer::search_ip (const float * __restrict x,
563  size_t nx,
564  const uint8_t * codes,
565  const size_t ncodes,
566  float_minheap_array_t * res,
567  bool init_finalize_heap) const
568 {
569  FAISS_THROW_IF_NOT (nx == res->nh);
570  float * dis_tables = new float [nx * ksub * M];
571  ScopeDeleter<float> del(dis_tables);
572  compute_inner_prod_tables (nx, x, dis_tables);
573 
574  if (byte_per_idx == 1) {
575 
576  pq_knn_search_with_tables<uint8_t, CMin<float, long> > (
577  this, dis_tables, codes, ncodes, res, init_finalize_heap);
578 
579  } else if (byte_per_idx == 2) {
580  pq_knn_search_with_tables<uint16_t, CMin<float, long> > (
581  this, dis_tables, codes, ncodes, res, init_finalize_heap);
582  }
583 
584 }
585 
586 
587 
588 static float sqr (float x) {
589  return x * x;
590 }
591 
592 void ProductQuantizer::compute_sdc_table ()
593 {
594  sdc_table.resize (M * ksub * ksub);
595 
596  for (int m = 0; m < M; m++) {
597 
598  const float *cents = centroids.data() + m * ksub * dsub;
599  float * dis_tab = sdc_table.data() + m * ksub * ksub;
600 
601  // TODO optimize with BLAS
602  for (int i = 0; i < ksub; i++) {
603  const float *centi = cents + i * dsub;
604  for (int j = 0; j < ksub; j++) {
605  float accu = 0;
606  const float *centj = cents + j * dsub;
607  for (int k = 0; k < dsub; k++)
608  accu += sqr (centi[k] - centj[k]);
609  dis_tab [i + j * ksub] = accu;
610  }
611  }
612  }
613 }
614 
615 void ProductQuantizer::search_sdc (const uint8_t * qcodes,
616  size_t nq,
617  const uint8_t * bcodes,
618  const size_t nb,
619  float_maxheap_array_t * res,
620  bool init_finalize_heap) const
621 {
622  FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub);
623  FAISS_THROW_IF_NOT (byte_per_idx == 1);
624  size_t k = res->k;
625 
626 
627 #pragma omp parallel for
628  for (size_t i = 0; i < nq; i++) {
629 
630  /* Compute distances and keep smallest values */
631  long * heap_ids = res->ids + i * k;
632  float * heap_dis = res->val + i * k;
633  const uint8_t * qcode = qcodes + i * code_size;
634 
635  if (init_finalize_heap)
636  maxheap_heapify (k, heap_dis, heap_ids);
637 
638  const uint8_t * bcode = bcodes;
639  for (size_t j = 0; j < nb; j++) {
640  float dis = 0;
641  const float * tab = sdc_table.data();
642  for (int m = 0; m < M; m++) {
643  dis += tab[bcode[m] + qcode[m] * ksub];
644  tab += ksub * ksub;
645  }
646  if (dis < heap_dis[0]) {
647  maxheap_pop (k, heap_dis, heap_ids);
648  maxheap_push (k, heap_dis, heap_ids, dis, j);
649  }
650  bcode += code_size;
651  }
652 
653  if (init_finalize_heap)
654  maxheap_reorder (k, heap_dis, heap_ids);
655  }
656 
657 }
658 
659 
660 } // namespace faiss
void set_params(const float *centroids, int m)
Define the centroids for subquantizer m.
intialize centroids with nbits-D hypercube
size_t nbits
number of bits per quantization index
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
size_t byte_per_idx
nb bytes per code component (1 or 2)
intialize centroids with nbits-D hypercube
void set_derived_values()
compute derived values when d, M and nbits have been set
std::vector< float > sdc_table
Symmetric Distance Table.
share dictionary accross PQ segments
size_t dsub
dimensionality of each subvector
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
void compute_distance_table(const float *x, float *dis_table) const
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
size_t code_size
byte per indexed vector
size_t ksub
number of centroids for each subquantizer
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
void pairwise_L2sqr(long d, long nq, const float *xq, long nb, const float *xb, float *dis, long ldq, long ldb, long ldd)
Definition: utils.cpp:1009
void compute_code(const float *x, uint8_t *code) const
Quantize one vector with the product quantizer.
the centroids are already initialized
ClusteringParameters cp
parameters used during clustering
size_t nh
number of heaps
Definition: Heap.h:354
size_t M
number of subquantizers
float * get_centroids(size_t m, size_t i)
return the centroids associated with subvector m
size_t d
size of the input vectors
bool verbose
verbose during training?
std::vector< float > centroids
Centroid table, size M * ksub * dsub.
train_type_t
initialization