10 #include "ProductQuantizer.h"
20 #include "FaissAssert.h"
21 #include "VectorTransform.h"
22 #include "IndexFlat.h"
30 int sgemm_ (
const char *transa,
const char *transb, FINTEGER *m, FINTEGER *
31 n, FINTEGER *k,
const float *alpha,
const float *a,
32 FINTEGER *lda,
const float *b, FINTEGER *
33 ldb,
float *beta,
float *c, FINTEGER *ldc);
42 template <
typename CT,
class C>
43 void pq_estimators_from_tables_Mmul4 (
int M,
const CT * codes,
45 const float * __restrict dis_table,
52 for (
size_t j = 0; j < ncodes; j++) {
54 const float *dt = dis_table;
56 for (
size_t m = 0; m < M; m+=4) {
58 dism = dt[*codes++]; dt += ksub;
59 dism += dt[*codes++]; dt += ksub;
60 dism += dt[*codes++]; dt += ksub;
61 dism += dt[*codes++]; dt += ksub;
65 if (C::cmp (heap_dis[0], dis)) {
66 heap_pop<C> (k, heap_dis, heap_ids);
67 heap_push<C> (k, heap_dis, heap_ids, dis, j);
73 template <
typename CT,
class C>
74 void pq_estimators_from_tables_M4 (
const CT * codes,
76 const float * __restrict dis_table,
83 for (
size_t j = 0; j < ncodes; j++) {
85 const float *dt = dis_table;
86 dis = dt[*codes++]; dt += ksub;
87 dis += dt[*codes++]; dt += ksub;
88 dis += dt[*codes++]; dt += ksub;
91 if (C::cmp (heap_dis[0], dis)) {
92 heap_pop<C> (k, heap_dis, heap_ids);
93 heap_push<C> (k, heap_dis, heap_ids, dis, j);
99 template <
typename CT,
class C>
100 static inline void pq_estimators_from_tables (
const ProductQuantizer& pq,
103 const float * dis_table,
111 pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
112 dis_table, pq.ksub, k,
118 pq_estimators_from_tables_Mmul4<CT, C> (pq.M, codes, ncodes,
119 dis_table, pq.ksub, k,
125 const size_t M = pq.M;
126 const size_t ksub = pq.ksub;
127 for (
size_t j = 0; j < ncodes; j++) {
129 const float * __restrict dt = dis_table;
130 for (
int m = 0; m < M; m++) {
134 if (C::cmp (heap_dis[0], dis)) {
135 heap_pop<C> (k, heap_dis, heap_ids);
136 heap_push<C> (k, heap_dis, heap_ids, dis, j);
142 static inline void pq_estimators_from_tables_generic(
const ProductQuantizer& pq,
144 const uint8_t *codes,
146 const float *dis_table,
151 const size_t M = pq.M;
152 const size_t ksub = pq.ksub;
153 for (
size_t j = 0; j < ncodes; ++j) {
155 codes + j * pq.code_size, nbits
158 const float * __restrict dt = dis_table;
159 for (
size_t m = 0; m < M; m++) {
160 uint64_t c = decoder.decode();
165 if (C::cmp(heap_dis[0], dis)) {
166 heap_pop<C>(k, heap_dis, heap_ids);
167 heap_push<C>(k, heap_dis, heap_ids, dis, j);
178 ProductQuantizer::ProductQuantizer (
size_t d,
size_t M,
size_t nbits):
179 d(d), M(M), nbits(nbits), assign_index(nullptr)
181 set_derived_values ();
184 ProductQuantizer::ProductQuantizer ()
185 : ProductQuantizer(0, 1, 0) {}
189 FAISS_THROW_IF_NOT (
d % M == 0);
195 train_type = Train_default;
201 ksub *
dsub *
sizeof (centroids_[0]));
205 static void init_hypercube (
int d,
int nbits,
206 int n,
const float * x,
210 std::vector<float> mean (d);
211 for (
int i = 0; i < n; i++)
212 for (
int j = 0; j < d; j++)
213 mean [j] += x[i * d + j];
216 for (
int j = 0; j < d; j++) {
218 if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
221 for (
int i = 0; i < (1 << nbits); i++) {
222 float * cent = centroids + i * d;
223 for (
int j = 0; j < nbits; j++)
224 cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
225 for (
int j = nbits; j < d; j++)
232 static void init_hypercube_pca (
int d,
int nbits,
233 int n,
const float * x,
236 PCAMatrix pca (d, nbits);
240 for (
int i = 0; i < (1 << nbits); i++) {
241 float * cent = centroids + i * d;
242 for (
int j = 0; j < d; j++) {
243 cent[j] = pca.mean[j];
245 for (
int k = 0; k < nbits; k++)
247 sqrt (pca.eigenvalues [k]) *
248 (((i >> k) & 1) ? 1 : -1) *
249 pca.PCAMat [j + k * d];
255 void ProductQuantizer::train (
int n,
const float * x)
259 final_train_type = train_type;
263 final_train_type = Train_default;
264 printf (
"cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
269 float * xslice =
new float[n *
dsub];
270 ScopeDeleter<float> del (xslice);
271 for (
int m = 0; m <
M; m++) {
272 for (
int j = 0; j < n; j++)
273 memcpy (xslice + j *
dsub,
274 x + j * d + m * dsub,
275 dsub *
sizeof(
float));
277 Clustering clus (dsub, ksub,
cp);
280 if (final_train_type != Train_default) {
281 clus.centroids.resize (dsub * ksub);
284 switch (final_train_type) {
286 init_hypercube (dsub, nbits, n, xslice,
287 clus.centroids.data ());
290 init_hypercube_pca (dsub, nbits, n, xslice,
291 clus.centroids.data ());
294 memcpy (clus.centroids.data(),
296 dsub * ksub *
sizeof (float));
303 printf (
"Training PQ slice %d/%zd\n", m, M);
305 IndexFlatL2 index (dsub);
313 Clustering clus (dsub, ksub,
cp);
317 printf (
"Training all PQ slices at once\n");
320 IndexFlatL2 index (dsub);
323 for (
int m = 0; m <
M; m++) {
330 template<
class PQEncoder>
331 void compute_code(
const ProductQuantizer& pq,
const float *x, uint8_t *code) {
332 float distances [pq.ksub];
333 PQEncoder encoder(code, pq.nbits);
334 for (
size_t m = 0; m < pq.M; m++) {
337 const float * xsub = x + m * pq.dsub;
339 fvec_L2sqr_ny(distances, xsub, pq.get_centroids(m, 0), pq.dsub, pq.ksub);
342 for (
size_t i = 0; i < pq.ksub; i++) {
343 float dis = distances[i];
350 encoder.encode(idxm);
357 faiss::compute_code<PQEncoder8>(*
this, x, code);
361 faiss::compute_code<PQEncoder16>(*
this, x, code);
365 faiss::compute_code<PQEncoderGeneric>(*
this, x, code);
370 template<
class PQDecoder>
373 PQDecoder decoder(code, pq.
nbits);
374 for (
size_t m = 0; m < pq.
M; m++) {
375 uint64_t c = decoder.decode();
384 faiss::decode<PQDecoder8>(*
this, code, x);
388 faiss::decode<PQDecoder16>(*
this, code, x);
392 faiss::decode<PQDecoderGeneric>(*
this, code, x);
400 for (
size_t i = 0; i < n; i++) {
410 for (
size_t m = 0; m <
M; m++) {
415 for (
size_t j = 0; j <
ksub; j++) {
423 encoder.encode(idxm);
434 for (
size_t m = 0; m <
M; m++) {
438 float * xslice =
new float[bs *
dsub];
440 idx_t *assign =
new idx_t[bs];
443 for (
size_t i0 = 0; i0 < n; i0 += bs) {
444 size_t i1 = std::min(i0 + bs, n);
446 for (
size_t i = i0; i < i1; i++) {
447 memcpy (xslice + (i - i0) * dsub,
448 x + i * d + m * dsub,
449 dsub *
sizeof(
float));
456 for (
size_t i = i0; i < i1; i++) {
460 }
else if (nbits == 16) {
461 uint16_t *c = (uint16_t*)(codes +
code_size * i0 + m * 2);
462 for (
size_t i = i0; i < i1; i++) {
467 for (
size_t i = i0; i < i1; ++i) {
469 uint8_t offset = (m *
nbits) % 8;
470 uint64_t ass = assign[i - i0];
487 size_t bs = 256 * 1024;
489 for (
size_t i0 = 0; i0 < n; i0 += bs) {
490 size_t i1 = std::min(i0 + bs, n);
498 #pragma omp parallel for
499 for (
size_t i = 0; i < n; i++)
503 float *dis_tables =
new float [n * ksub *
M];
507 #pragma omp parallel for
508 for (
size_t i = 0; i < n; i++) {
510 const float * tab = dis_tables + i * ksub *
M;
518 float * dis_table)
const
522 for (m = 0; m <
M; m++) {
523 fvec_L2sqr_ny (dis_table + m * ksub,
531 void ProductQuantizer::compute_inner_prod_table (
const float * x,
532 float * dis_table)
const
536 for (m = 0; m <
M; m++) {
537 fvec_inner_products_ny (dis_table + m * ksub,
549 float * dis_tables)
const
554 #pragma omp parallel for
555 for (
size_t i = 0; i < nx; i++) {
561 for (
int m = 0; m <
M; m++) {
564 ksub, centroids.data() + m * dsub *
ksub,
565 dis_tables + ksub * m,
571 void ProductQuantizer::compute_inner_prod_tables (
574 float * dis_tables)
const
579 #pragma omp parallel for
580 for (
size_t i = 0; i < nx; i++) {
581 compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
587 for (
int m = 0; m <
M; m++) {
588 FINTEGER ldc = ksub *
M, nxi = nx, ksubi =
ksub,
589 dsubi =
dsub, di =
d;
590 float one = 1.0, zero = 0;
592 sgemm_ (
"Transposed",
"Not transposed",
593 &ksubi, &nxi, &dsubi,
594 &one, ¢roids [m * dsub * ksub], &dsubi,
596 &zero, dis_tables + ksub * m, &ldc);
603 static void pq_knn_search_with_tables (
604 const ProductQuantizer& pq,
606 const float *dis_tables,
607 const uint8_t * codes,
610 bool init_finalize_heap)
612 size_t k = res->k, nx = res->nh;
613 size_t ksub = pq.ksub, M = pq.M;
616 #pragma omp parallel for
617 for (
size_t i = 0; i < nx; i++) {
619 const float* dis_table = dis_tables + i * ksub * M;
622 long * __restrict heap_ids = res->ids + i * k;
623 float * __restrict heap_dis = res->val + i * k;
625 if (init_finalize_heap) {
626 heap_heapify<C> (k, heap_dis, heap_ids);
631 pq_estimators_from_tables<uint8_t, C> (pq,
634 k, heap_dis, heap_ids);
638 pq_estimators_from_tables<uint16_t, C> (pq,
639 (uint16_t*)codes, ncodes,
641 k, heap_dis, heap_ids);
645 pq_estimators_from_tables_generic<C> (pq,
649 k, heap_dis, heap_ids);
653 if (init_finalize_heap) {
654 heap_reorder<C> (k, heap_dis, heap_ids);
661 const uint8_t * codes,
664 bool init_finalize_heap)
const
666 FAISS_THROW_IF_NOT (nx == res->
nh);
667 std::unique_ptr<float[]> dis_tables(
new float [nx * ksub * M]);
670 pq_knn_search_with_tables<CMax<float, long>> (
671 *
this,
nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
676 const uint8_t * codes,
679 bool init_finalize_heap)
const
681 FAISS_THROW_IF_NOT (nx == res->
nh);
682 std::unique_ptr<float[]> dis_tables(
new float [nx * ksub * M]);
683 compute_inner_prod_tables (nx, x, dis_tables.get());
685 pq_knn_search_with_tables<CMin<float, long> > (
686 *
this,
nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
691 static float sqr (
float x) {
695 void ProductQuantizer::compute_sdc_table ()
699 for (
int m = 0; m <
M; m++) {
701 const float *cents = centroids.data() + m * ksub *
dsub;
705 for (
int i = 0; i <
ksub; i++) {
706 const float *centi = cents + i *
dsub;
707 for (
int j = 0; j <
ksub; j++) {
709 const float *centj = cents + j *
dsub;
710 for (
int k = 0; k <
dsub; k++)
711 accu += sqr (centi[k] - centj[k]);
712 dis_tab [i + j *
ksub] = accu;
718 void ProductQuantizer::search_sdc (
const uint8_t * qcodes,
720 const uint8_t * bcodes,
722 float_maxheap_array_t * res,
723 bool init_finalize_heap)
const
726 FAISS_THROW_IF_NOT (nbits == 8);
730 #pragma omp parallel for
731 for (
size_t i = 0; i < nq; i++) {
734 long * heap_ids = res->ids + i * k;
735 float * heap_dis = res->val + i * k;
736 const uint8_t * qcode = qcodes + i *
code_size;
738 if (init_finalize_heap)
739 maxheap_heapify (k, heap_dis, heap_ids);
741 const uint8_t * bcode = bcodes;
742 for (
size_t j = 0; j < nb; j++) {
745 for (
int m = 0; m <
M; m++) {
746 dis += tab[bcode[m] + qcode[m] *
ksub];
749 if (dis < heap_dis[0]) {
750 maxheap_pop (k, heap_dis, heap_ids);
751 maxheap_push (k, heap_dis, heap_ids, dis, j);
756 if (init_finalize_heap)
757 maxheap_reorder (k, heap_dis, heap_ids);
763 ProductQuantizer::PQEncoderGeneric::PQEncoderGeneric(uint8_t *code,
int nbits,
765 : code(code), offset(offset), nbits(nbits), reg(0) {
768 reg = (*code & ((1 << offset) - 1));
772 void ProductQuantizer::PQEncoderGeneric::encode(uint64_t x) {
773 reg |= (uint8_t)(x << offset);
775 if (offset + nbits >= 8) {
778 for (
int i = 0; i < (nbits - (8 - offset)) / 8; ++i) {
779 *code++ = (uint8_t)x;
791 ProductQuantizer::PQEncoderGeneric::~PQEncoderGeneric() {
798 ProductQuantizer::PQEncoder8::PQEncoder8(uint8_t *code,
int nbits)
803 void ProductQuantizer::PQEncoder8::encode(uint64_t x) {
804 *code++ = (uint8_t)x;
808 ProductQuantizer::PQEncoder16::PQEncoder16(uint8_t *code,
int nbits)
809 : code((uint16_t *)code) {
813 void ProductQuantizer::PQEncoder16::encode(uint64_t x) {
814 *code++ = (uint16_t)x;
818 ProductQuantizer::PQDecoderGeneric::PQDecoderGeneric(
const uint8_t *code,
823 mask((1ull << nbits) - 1),
828 uint64_t ProductQuantizer::PQDecoderGeneric::decode() {
832 uint64_t c = (reg >> offset);
834 if (offset + nbits >= 8) {
835 uint64_t e = 8 - offset;
837 for (
int i = 0; i < (nbits - (8 - offset)) / 8; ++i) {
838 c |= ((uint64_t)(*code++) << e);
846 c |= ((uint64_t)reg << e);
856 ProductQuantizer::PQDecoder8::PQDecoder8(
const uint8_t *code,
int nbits)
861 uint64_t ProductQuantizer::PQDecoder8::decode() {
862 return (uint64_t)(*code++);
866 ProductQuantizer::PQDecoder16::PQDecoder16(
const uint8_t *code,
int nbits)
867 : code((uint16_t *)code) {
871 uint64_t ProductQuantizer::PQDecoder16::decode() {
872 return (uint64_t)(*code++);
void set_params(const float *centroids, int m)
Define the centroids for subquantizer m.
intialize centroids with nbits-D hypercube
size_t nbits
number of bits per quantization index
void decode(const uint8_t *code, float *x) const
decode a vector from a given code (or n vectors if third argument)
virtual void reset()=0
removes all elements from the database.
intialize centroids with nbits-D hypercube
void assign(idx_t n, const float *x, idx_t *labels, idx_t k=1)
void set_derived_values()
compute derived values when d, M and nbits have been set
std::vector< float > sdc_table
Symmetric Distance Table.
share dictionary accross PQ segments
size_t dsub
dimensionality of each subvector
void compute_distance_tables(size_t nx, const float *x, float *dis_tables) const
void compute_code_from_distance_table(const float *tab, uint8_t *code) const
void compute_codes(const float *x, uint8_t *codes, size_t n) const
same as compute_code for several vectors
void compute_distance_table(const float *x, float *dis_table) const
void search(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_maxheap_array_t *res, bool init_finalize_heap=true) const
size_t code_size
byte per indexed vector
virtual void add(idx_t n, const float *x)=0
const int nbits
number of bits per subquantizer index
size_t ksub
number of centroids for each subquantizer
void search_ip(const float *x, size_t nx, const uint8_t *codes, const size_t ncodes, float_minheap_array_t *res, bool init_finalize_heap=true) const
void pairwise_L2sqr(long d, long nq, const float *xq, long nb, const float *xb, float *dis, long ldq, long ldb, long ldd)
void compute_code(const float *x, uint8_t *code) const
Quantize one vector with the product quantizer.
the centroids are already initialized
ClusteringParameters cp
parameters used during clustering
size_t M
number of subquantizers
void compute_codes_with_assign_index(const float *x, uint8_t *codes, size_t n)
float * get_centroids(size_t m, size_t i)
return the centroids associated with subvector m
size_t d
size of the input vectors
bool verbose
verbose during training?
std::vector< float > centroids
Centroid table, size M * ksub * dsub.
train_type_t
initialization