14 #include "ProductQuantizer.h"
23 #include "FaissAssert.h"
24 #include "VectorTransform.h"
25 #include "IndexFlat.h"
33 int sgemm_ (
const char *transa,
const char *transb, FINTEGER *m, FINTEGER *
34 n, FINTEGER *k,
const float *alpha,
const float *a,
35 FINTEGER *lda,
const float *b, FINTEGER *
36 ldb,
float *beta,
float *c, FINTEGER *ldc);
47 template <
typename CT,
class C>
48 void pq_estimators_from_tables_Mmul4 (
int M,
const CT * codes,
50 const float * __restrict dis_table,
57 for (
size_t j = 0; j < ncodes; j++) {
59 const float *dt = dis_table;
61 for (
size_t m = 0; m < M; m+=4) {
63 dism = dt[*codes++]; dt += ksub;
64 dism += dt[*codes++]; dt += ksub;
65 dism += dt[*codes++]; dt += ksub;
66 dism += dt[*codes++]; dt += ksub;
70 if (C::cmp (heap_dis[0], dis)) {
71 heap_pop<C> (k, heap_dis, heap_ids);
72 heap_push<C> (k, heap_dis, heap_ids, dis, j);
78 template <
typename CT,
class C>
79 void pq_estimators_from_tables_M4 (
const CT * codes,
81 const float * __restrict dis_table,
88 for (
size_t j = 0; j < ncodes; j++) {
90 const float *dt = dis_table;
91 dis = dt[*codes++]; dt += ksub;
92 dis += dt[*codes++]; dt += ksub;
93 dis += dt[*codes++]; dt += ksub;
96 if (C::cmp (heap_dis[0], dis)) {
97 heap_pop<C> (k, heap_dis, heap_ids);
98 heap_push<C> (k, heap_dis, heap_ids, dis, j);
104 template <
typename CT,
class C>
105 static inline void pq_estimators_from_tables (
const ProductQuantizer * pq,
108 const float * dis_table,
116 pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
117 dis_table, pq->ksub, k,
122 if (pq->M % 4 == 0) {
123 pq_estimators_from_tables_Mmul4<CT, C> (pq->M, codes, ncodes,
124 dis_table, pq->ksub, k,
130 const size_t M = pq->M;
131 const size_t ksub = pq->ksub;
132 for (
size_t j = 0; j < ncodes; j++) {
134 const float * __restrict dt = dis_table;
135 for (
int m = 0; m < M; m++) {
139 if (C::cmp (heap_dis[0], dis)) {
140 heap_pop<C> (k, heap_dis, heap_ids);
141 heap_push<C> (k, heap_dis, heap_ids, dis, j);
153 ProductQuantizer::ProductQuantizer (
size_t d,
size_t M,
size_t nbits):
154 d(d), M(M), nbits(nbits)
156 set_derived_values ();
159 ProductQuantizer::ProductQuantizer ():
162 set_derived_values ();
169 FAISS_ASSERT (d % M == 0);
171 byte_per_idx = (nbits + 7) / 8;
172 code_size = byte_per_idx * M;
174 centroids.resize (d * ksub);
176 train_type = Train_default;
182 memcpy (get_centroids(m, 0), centroids_,
183 ksub * dsub *
sizeof (centroids_[0]));
187 static void init_hypercube (
int d,
int nbits,
188 int n,
const float * x,
192 std::vector<float> mean (d);
193 for (
int i = 0; i < n; i++)
194 for (
int j = 0; j < d; j++)
195 mean [j] += x[i * d + j];
198 for (
int j = 0; j < d; j++) {
200 if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
203 for (
int i = 0; i < (1 << nbits); i++) {
204 float * cent = centroids + i * d;
205 for (
int j = 0; j < nbits; j++)
206 cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
207 for (
int j = nbits; j < d; j++)
214 static void init_hypercube_pca (
int d,
int nbits,
215 int n,
const float * x,
218 PCAMatrix pca (d, nbits);
222 for (
int i = 0; i < (1 << nbits); i++) {
223 float * cent = centroids + i * d;
224 for (
int j = 0; j < d; j++) {
225 cent[j] = pca.mean[j];
227 for (
int k = 0; k < nbits; k++)
229 sqrt (pca.eigenvalues [k]) *
230 (((i >> k) & 1) ? 1 : -1) *
231 pca.PCAMat [j + k * d];
237 void ProductQuantizer::train (
int n,
const float * x)
239 if (train_type != Train_shared) {
240 train_type_t final_train_type;
241 final_train_type = train_type;
242 if (train_type == Train_hypercube ||
243 train_type == Train_hypercube_pca) {
245 final_train_type = Train_default;
246 printf (
"cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
251 float * xslice =
new float[n * dsub];
252 for (
int m = 0; m < M; m++) {
253 for (
int j = 0; j < n; j++)
254 memcpy (xslice + j * dsub,
255 x + j * d + m * dsub,
256 dsub *
sizeof(
float));
258 Clustering clus (dsub, ksub, cp);
261 if (final_train_type != Train_default) {
262 clus.centroids.resize (dsub * ksub);
265 switch (final_train_type) {
266 case Train_hypercube:
267 init_hypercube (dsub, nbits, n, xslice,
268 clus.centroids.data ());
270 case Train_hypercube_pca:
271 init_hypercube_pca (dsub, nbits, n, xslice,
272 clus.centroids.data ());
274 case Train_hot_start:
275 memcpy (clus.centroids.data(),
276 get_centroids (m, 0),
277 dsub * ksub *
sizeof (float));
284 printf (
"Training PQ slice %d/%zd\n", m, M);
286 IndexFlatL2 index (dsub);
287 clus.train (n, xslice, index);
288 set_params (clus.centroids.data(), m);
294 Clustering clus (dsub, ksub, cp);
298 printf (
"Training all PQ slices at once\n");
301 IndexFlatL2 index (dsub);
302 clus.train (n * M, x, index);
303 for (
int m = 0; m < M; m++) {
304 set_params (clus.centroids.data(), m);
313 float distances [ksub];
314 for (
size_t m = 0; m < M; m++) {
317 const float * xsub = x + m * dsub;
319 fvec_L2sqr_ny (distances, xsub, get_centroids(m, 0), dsub, ksub);
323 for (i = 0; i < ksub; i++) {
324 float dis = distances [i];
330 switch (byte_per_idx) {
331 case 1: code[m] = (uint8_t) idxm;
break;
332 case 2: ((uint16_t *) code)[m] = (uint16_t) idxm;
break;
340 if (byte_per_idx == 1) {
341 for (
size_t m = 0; m < M; m++) {
342 memcpy (x + m * dsub, get_centroids(m, code[m]),
343 sizeof(
float) * dsub);
346 const uint16_t *c = (
const uint16_t*) code;
347 for (
size_t m = 0; m < M; m++) {
348 memcpy (x + m * dsub, get_centroids(m, c[m]),
349 sizeof(
float) * dsub);
357 for (
size_t i = 0; i < n; i++) {
358 this->decode (code + M * i, x + d * i);
366 for (
size_t m = 0; m < M; m++) {
371 for (
size_t j = 0; j < ksub; j++) {
378 switch (byte_per_idx) {
379 case 1: code[m] = (uint8_t) idxm;
break;
380 case 2: ((uint16_t *) code)[m] = (uint16_t) idxm;
break;
391 #pragma omp parallel for
392 for (
size_t i = 0; i < n; i++)
393 compute_code (x + i * d, codes + i * code_size);
396 float *dis_tables =
new float [n * ksub * M];
397 compute_distance_tables (n, x, dis_tables);
399 #pragma omp parallel for
400 for (
size_t i = 0; i < n; i++) {
401 uint8_t * code = codes + i * code_size;
402 const float * tab = dis_tables + i * ksub * M;
403 compute_code_from_distance_table (tab, code);
405 delete [] dis_tables;
411 float * dis_table)
const
415 for (m = 0; m < M; m++) {
416 fvec_L2sqr_ny (dis_table + m * ksub,
424 void ProductQuantizer::compute_inner_prod_table (
const float * x,
425 float * dis_table)
const
429 for (m = 0; m < M; m++) {
430 fvec_inner_products_ny (dis_table + m * ksub,
442 float * dis_tables)
const
447 #pragma omp parallel for
448 for (
size_t i = 0; i < nx; i++) {
449 compute_distance_table (x + i * d, dis_tables + i * ksub * M);
454 for (
int m = 0; m < M; m++) {
457 ksub, centroids.data() + m * dsub * ksub,
458 dis_tables + ksub * m,
464 void ProductQuantizer::compute_inner_prod_tables (
467 float * dis_tables)
const
472 #pragma omp parallel for
473 for (
size_t i = 0; i < nx; i++) {
474 compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
480 for (
int m = 0; m < M; m++) {
481 FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
482 dsubi = dsub, di = d;
483 float one = 1.0, zero = 0;
485 sgemm_ (
"Transposed",
"Not transposed",
486 &ksubi, &nxi, &dsubi,
487 &one, ¢roids [m * dsub * ksub], &dsubi,
489 &zero, dis_tables + ksub * m, &ldc);
495 template <
typename CT,
class C>
496 static void pq_knn_search_with_tables (
497 const ProductQuantizer * pq,
498 const float *dis_tables,
499 const uint8_t * codes,
502 bool init_finalize_heap)
504 size_t k = res->k, nx = res->nh;
505 size_t ksub = pq->ksub, M = pq->M;
508 #pragma omp parallel for
509 for (
size_t i = 0; i < nx; i++) {
511 const float* dis_table = dis_tables + i * ksub * M;
514 long * __restrict heap_ids = res->ids + i * k;
515 float * __restrict heap_dis = res->val + i * k;
517 if (init_finalize_heap) {
518 heap_heapify<C> (k, heap_dis, heap_ids);
521 pq_estimators_from_tables<CT, C> (pq,
524 k, heap_dis, heap_ids);
525 if (init_finalize_heap) {
526 heap_reorder<C> (k, heap_dis, heap_ids);
542 const uint8_t * codes,
545 bool init_finalize_heap)
const
547 float * dis_tables =
new float [nx * ksub * M];
548 compute_distance_tables (nx, x, dis_tables);
549 FAISS_ASSERT(nx == res->
nh);
551 if (byte_per_idx == 1) {
553 pq_knn_search_with_tables<uint8_t, CMax<float, long> > (
554 this, dis_tables, codes, ncodes, res, init_finalize_heap);
556 }
else if (byte_per_idx == 2) {
557 pq_knn_search_with_tables<uint16_t, CMax<float, long> > (
558 this, dis_tables, codes, ncodes, res, init_finalize_heap);
561 delete [] dis_tables;
566 const uint8_t * codes,
569 bool init_finalize_heap)
const
571 float * dis_tables =
new float [nx * ksub * M];
572 compute_inner_prod_tables (nx, x, dis_tables);
573 FAISS_ASSERT(nx == res->
nh);
575 if (byte_per_idx == 1) {
577 pq_knn_search_with_tables<uint8_t, CMin<float, long> > (
578 this, dis_tables, codes, ncodes, res, init_finalize_heap);
580 }
else if (byte_per_idx == 2) {
581 pq_knn_search_with_tables<uint16_t, CMin<float, long> > (
582 this, dis_tables, codes, ncodes, res, init_finalize_heap);
584 delete [] dis_tables;
589 static float sqr (
float x) {
593 void ProductQuantizer::compute_sdc_table ()
595 sdc_table.resize (M * ksub * ksub);
597 for (
int m = 0; m < M; m++) {
599 const float *cents = centroids.data() + m * ksub * dsub;
600 float * dis_tab = sdc_table.data() + m * ksub * ksub;
603 for (
int i = 0; i < ksub; i++) {
604 const float *centi = cents + i * dsub;
605 for (
int j = 0; j < ksub; j++) {
607 const float *centj = cents + j * dsub;
608 for (
int k = 0; k < dsub; k++)
609 accu += sqr (centi[k] - centj[k]);
610 dis_tab [i + j * ksub] = accu;
616 void ProductQuantizer::search_sdc (
const uint8_t * qcodes,
618 const uint8_t * bcodes,
620 float_maxheap_array_t * res,
621 bool init_finalize_heap)
const
623 FAISS_ASSERT (sdc_table.size() == M * ksub * ksub);
626 FAISS_ASSERT (byte_per_idx == 1);
628 #pragma omp parallel for
629 for (
size_t i = 0; i < nq; i++) {
632 long * heap_ids = res->ids + i * k;
633 float * heap_dis = res->val + i * k;
634 const uint8_t * qcode = qcodes + i * code_size;
636 if (init_finalize_heap)
637 maxheap_heapify (k, heap_dis, heap_ids);
639 const uint8_t * bcode = bcodes;
640 for (
size_t j = 0; j < nb; j++) {
642 const float * tab = sdc_table.data();
643 for (
int m = 0; m < M; m++) {
644 dis += tab[bcode[m] + qcode[m] * ksub];
647 if (dis < heap_dis[0]) {
648 maxheap_pop (k, heap_dis, heap_ids);
649 maxheap_push (k, heap_dis, heap_ids, dis, j);
654 if (init_finalize_heap)
655 maxheap_reorder (k, heap_dis, heap_ids);
void set_params(const float *centroids, int m)
Define the centroids for subquantizer m.
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
void set_derived_values()
compute derived values when d, M and nbits have been set
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
void compute_distance_table(const float *x, float *dis_table) const
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
void pairwise_L2sqr(long d, long nq, const float *xq, long nb, const float *xb, float *dis, long ldq, long ldb, long ldd)
void compute_code(const float *x, uint8_t *code) const
Quantize one vector with the product quantizer.