faiss/ProductQuantizer.cpp

666 lines
19 KiB
C++

/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the CC-by-NC license found in the
* LICENSE file in the root directory of this source tree.
*/
/* Copyright 2004-present Facebook. All Rights Reserved.
Index based on product quantiztion.
*/
#include "ProductQuantizer.h"
#include <cstddef>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include "FaissAssert.h"
#include "VectorTransform.h"
#include "IndexFlat.h"
#include "utils.h"
extern "C" {
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
n, FINTEGER *k, const float *alpha, const float *a,
FINTEGER *lda, const float *b, FINTEGER *
ldb, float *beta, float *c, FINTEGER *ldc);
}
namespace faiss {
/* compute an estimator using look-up tables for typical values of M */
template <typename CT, class C>
void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
size_t ncodes,
const float * __restrict dis_table,
size_t ksub,
size_t k,
float * heap_dis,
long * heap_ids)
{
for (size_t j = 0; j < ncodes; j++) {
float dis = 0;
const float *dt = dis_table;
for (size_t m = 0; m < M; m+=4) {
float dism = 0;
dism = dt[*codes++]; dt += ksub;
dism += dt[*codes++]; dt += ksub;
dism += dt[*codes++]; dt += ksub;
dism += dt[*codes++]; dt += ksub;
dis += dism;
}
if (C::cmp (heap_dis[0], dis)) {
heap_pop<C> (k, heap_dis, heap_ids);
heap_push<C> (k, heap_dis, heap_ids, dis, j);
}
}
}
template <typename CT, class C>
void pq_estimators_from_tables_M4 (const CT * codes,
size_t ncodes,
const float * __restrict dis_table,
size_t ksub,
size_t k,
float * heap_dis,
long * heap_ids)
{
for (size_t j = 0; j < ncodes; j++) {
float dis = 0;
const float *dt = dis_table;
dis = dt[*codes++]; dt += ksub;
dis += dt[*codes++]; dt += ksub;
dis += dt[*codes++]; dt += ksub;
dis += dt[*codes++];
if (C::cmp (heap_dis[0], dis)) {
heap_pop<C> (k, heap_dis, heap_ids);
heap_push<C> (k, heap_dis, heap_ids, dis, j);
}
}
}
template <typename CT, class C>
static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
const CT * codes,
size_t ncodes,
const float * dis_table,
size_t k,
float * heap_dis,
long * heap_ids)
{
if (pq->M == 4) {
pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
dis_table, pq->ksub, k,
heap_dis, heap_ids);
return;
}
if (pq->M % 4 == 0) {
pq_estimators_from_tables_Mmul4<CT, C> (pq->M, codes, ncodes,
dis_table, pq->ksub, k,
heap_dis, heap_ids);
return;
}
/* Default is relatively slow */
const size_t M = pq->M;
const size_t ksub = pq->ksub;
for (size_t j = 0; j < ncodes; j++) {
float dis = 0;
const float * __restrict dt = dis_table;
for (int m = 0; m < M; m++) {
dis += dt[*codes++];
dt += ksub;
}
if (C::cmp (heap_dis[0], dis)) {
heap_pop<C> (k, heap_dis, heap_ids);
heap_push<C> (k, heap_dis, heap_ids, dis, j);
}
}
}
/*********************************************
* PQ implementation
*********************************************/
ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
d(d), M(M), nbits(nbits)
{
set_derived_values ();
}
ProductQuantizer::ProductQuantizer ():
d(0), M(1), nbits(0)
{
set_derived_values ();
}
void ProductQuantizer::set_derived_values () {
// quite a few derived values
FAISS_ASSERT (d % M == 0);
dsub = d / M;
byte_per_idx = (nbits + 7) / 8;
code_size = byte_per_idx * M;
ksub = 1 << nbits;
centroids.resize (d * ksub);
verbose = false;
train_type = Train_default;
}
void ProductQuantizer::set_params (const float * centroids_, int m)
{
memcpy (get_centroids(m, 0), centroids_,
ksub * dsub * sizeof (centroids_[0]));
}
static void init_hypercube (int d, int nbits,
int n, const float * x,
float *centroids)
{
std::vector<float> mean (d);
for (int i = 0; i < n; i++)
for (int j = 0; j < d; j++)
mean [j] += x[i * d + j];
float maxm = 0;
for (int j = 0; j < d; j++) {
mean [j] /= n;
if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
}
for (int i = 0; i < (1 << nbits); i++) {
float * cent = centroids + i * d;
for (int j = 0; j < nbits; j++)
cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
for (int j = nbits; j < d; j++)
cent[j] = mean [j];
}
}
static void init_hypercube_pca (int d, int nbits,
int n, const float * x,
float *centroids)
{
PCAMatrix pca (d, nbits);
pca.train (n, x);
for (int i = 0; i < (1 << nbits); i++) {
float * cent = centroids + i * d;
for (int j = 0; j < d; j++) {
cent[j] = pca.mean[j];
float f = 1.0;
for (int k = 0; k < nbits; k++)
cent[j] += f *
sqrt (pca.eigenvalues [k]) *
(((i >> k) & 1) ? 1 : -1) *
pca.PCAMat [j + k * d];
}
}
}
void ProductQuantizer::train (int n, const float * x)
{
if (train_type != Train_shared) {
train_type_t final_train_type;
final_train_type = train_type;
if (train_type == Train_hypercube ||
train_type == Train_hypercube_pca) {
if (dsub < nbits) {
final_train_type = Train_default;
printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
nbits, dsub);
}
}
float * xslice = new float[n * dsub];
for (int m = 0; m < M; m++) {
for (int j = 0; j < n; j++)
memcpy (xslice + j * dsub,
x + j * d + m * dsub,
dsub * sizeof(float));
Clustering clus (dsub, ksub, cp);
// we have some initialization for the centroids
if (final_train_type != Train_default) {
clus.centroids.resize (dsub * ksub);
}
switch (final_train_type) {
case Train_hypercube:
init_hypercube (dsub, nbits, n, xslice,
clus.centroids.data ());
break;
case Train_hypercube_pca:
init_hypercube_pca (dsub, nbits, n, xslice,
clus.centroids.data ());
break;
case Train_hot_start:
memcpy (clus.centroids.data(),
get_centroids (m, 0),
dsub * ksub * sizeof (float));
break;
default: ;
}
if(verbose) {
clus.verbose = true;
printf ("Training PQ slice %d/%zd\n", m, M);
}
IndexFlatL2 index (dsub);
clus.train (n, xslice, index);
set_params (clus.centroids.data(), m);
}
delete [] xslice;
} else {
Clustering clus (dsub, ksub, cp);
if(verbose) {
clus.verbose = true;
printf ("Training all PQ slices at once\n");
}
IndexFlatL2 index (dsub);
clus.train (n * M, x, index);
for (int m = 0; m < M; m++) {
set_params (clus.centroids.data(), m);
}
}
}
void ProductQuantizer::compute_code (const float * x, uint8_t * code) const
{
float distances [ksub];
for (size_t m = 0; m < M; m++) {
float mindis = 1e20;
int idxm = -1;
const float * xsub = x + m * dsub;
fvec_L2sqr_ny (distances, xsub, get_centroids(m, 0), dsub, ksub);
/* Find best centroid */
size_t i;
for (i = 0; i < ksub; i++) {
float dis = distances [i];
if (dis < mindis) {
mindis = dis;
idxm = i;
}
}
switch (byte_per_idx) {
case 1: code[m] = (uint8_t) idxm; break;
case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
}
}
}
void ProductQuantizer::decode (const uint8_t *code, float *x) const
{
if (byte_per_idx == 1) {
for (size_t m = 0; m < M; m++) {
memcpy (x + m * dsub, get_centroids(m, code[m]),
sizeof(float) * dsub);
}
} else {
const uint16_t *c = (const uint16_t*) code;
for (size_t m = 0; m < M; m++) {
memcpy (x + m * dsub, get_centroids(m, c[m]),
sizeof(float) * dsub);
}
}
}
void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
{
for (size_t i = 0; i < n; i++) {
this->decode (code + M * i, x + d * i);
}
}
void ProductQuantizer::compute_code_from_distance_table (const float *tab,
uint8_t *code) const
{
for (size_t m = 0; m < M; m++) {
float mindis = 1e20;
int idxm = -1;
/* Find best centroid */
for (size_t j = 0; j < ksub; j++) {
float dis = *tab++;
if (dis < mindis) {
mindis = dis;
idxm = j;
}
}
switch (byte_per_idx) {
case 1: code[m] = (uint8_t) idxm; break;
case 2: ((uint16_t *) code)[m] = (uint16_t) idxm; break;
}
}
}
void ProductQuantizer::compute_codes (const float * x,
uint8_t * codes,
size_t n) const
{
if (dsub < 16) { // simple direct computation
#pragma omp parallel for
for (size_t i = 0; i < n; i++)
compute_code (x + i * d, codes + i * code_size);
} else { // worthwile to use BLAS
float *dis_tables = new float [n * ksub * M];
compute_distance_tables (n, x, dis_tables);
#pragma omp parallel for
for (size_t i = 0; i < n; i++) {
uint8_t * code = codes + i * code_size;
const float * tab = dis_tables + i * ksub * M;
compute_code_from_distance_table (tab, code);
}
delete [] dis_tables;
}
}
void ProductQuantizer::compute_distance_table (const float * x,
float * dis_table) const
{
size_t m;
for (m = 0; m < M; m++) {
fvec_L2sqr_ny (dis_table + m * ksub,
x + m * dsub,
get_centroids(m, 0),
dsub,
ksub);
}
}
void ProductQuantizer::compute_inner_prod_table (const float * x,
float * dis_table) const
{
size_t m;
for (m = 0; m < M; m++) {
fvec_inner_products_ny (dis_table + m * ksub,
x + m * dsub,
get_centroids(m, 0),
dsub,
ksub);
}
}
void ProductQuantizer::compute_distance_tables (
size_t nx,
const float * x,
float * dis_tables) const
{
if (dsub < 16) {
#pragma omp parallel for
for (size_t i = 0; i < nx; i++) {
compute_distance_table (x + i * d, dis_tables + i * ksub * M);
}
} else { // use BLAS
for (int m = 0; m < M; m++) {
pairwise_L2sqr (dsub,
nx, x + dsub * m,
ksub, centroids.data() + m * dsub * ksub,
dis_tables + ksub * m,
d, dsub, ksub * M);
}
}
}
void ProductQuantizer::compute_inner_prod_tables (
size_t nx,
const float * x,
float * dis_tables) const
{
if (dsub < 16) {
#pragma omp parallel for
for (size_t i = 0; i < nx; i++) {
compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
}
} else { // use BLAS
// compute distance tables
for (int m = 0; m < M; m++) {
FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
dsubi = dsub, di = d;
float one = 1.0, zero = 0;
sgemm_ ("Transposed", "Not transposed",
&ksubi, &nxi, &dsubi,
&one, &centroids [m * dsub * ksub], &dsubi,
x + dsub * m, &di,
&zero, dis_tables + ksub * m, &ldc);
}
}
}
template <typename CT, class C>
static void pq_knn_search_with_tables (
const ProductQuantizer * pq,
const float *dis_tables,
const uint8_t * codes,
const size_t ncodes,
HeapArray<C> * res,
bool init_finalize_heap)
{
size_t k = res->k, nx = res->nh;
size_t ksub = pq->ksub, M = pq->M;
#pragma omp parallel for
for (size_t i = 0; i < nx; i++) {
/* query preparation for asymmetric search: compute look-up tables */
const float* dis_table = dis_tables + i * ksub * M;
/* Compute distances and keep smallest values */
long * __restrict heap_ids = res->ids + i * k;
float * __restrict heap_dis = res->val + i * k;
if (init_finalize_heap) {
heap_heapify<C> (k, heap_dis, heap_ids);
}
pq_estimators_from_tables<CT, C> (pq,
(CT*)codes, ncodes,
dis_table,
k, heap_dis, heap_ids);
if (init_finalize_heap) {
heap_reorder<C> (k, heap_dis, heap_ids);
}
}
}
/*
static inline void pq_estimators_from_tables (const ProductQuantizer * pq,
const CT * codes,
size_t ncodes,
const float * dis_table,
size_t k,
float * heap_dis,
long * heap_ids)
*/
void ProductQuantizer::search (const float * __restrict x,
size_t nx,
const uint8_t * codes,
const size_t ncodes,
float_maxheap_array_t * res,
bool init_finalize_heap) const
{
float * dis_tables = new float [nx * ksub * M];
compute_distance_tables (nx, x, dis_tables);
FAISS_ASSERT(nx == res->nh);
if (byte_per_idx == 1) {
pq_knn_search_with_tables<uint8_t, CMax<float, long> > (
this, dis_tables, codes, ncodes, res, init_finalize_heap);
} else if (byte_per_idx == 2) {
pq_knn_search_with_tables<uint16_t, CMax<float, long> > (
this, dis_tables, codes, ncodes, res, init_finalize_heap);
}
delete [] dis_tables;
}
void ProductQuantizer::search_ip (const float * __restrict x,
size_t nx,
const uint8_t * codes,
const size_t ncodes,
float_minheap_array_t * res,
bool init_finalize_heap) const
{
float * dis_tables = new float [nx * ksub * M];
compute_inner_prod_tables (nx, x, dis_tables);
FAISS_ASSERT(nx == res->nh);
if (byte_per_idx == 1) {
pq_knn_search_with_tables<uint8_t, CMin<float, long> > (
this, dis_tables, codes, ncodes, res, init_finalize_heap);
} else if (byte_per_idx == 2) {
pq_knn_search_with_tables<uint16_t, CMin<float, long> > (
this, dis_tables, codes, ncodes, res, init_finalize_heap);
}
delete [] dis_tables;
}
static float sqr (float x) {
return x * x;
}
void ProductQuantizer::compute_sdc_table ()
{
sdc_table.resize (M * ksub * ksub);
for (int m = 0; m < M; m++) {
const float *cents = centroids.data() + m * ksub * dsub;
float * dis_tab = sdc_table.data() + m * ksub * ksub;
// TODO optimize with BLAS
for (int i = 0; i < ksub; i++) {
const float *centi = cents + i * dsub;
for (int j = 0; j < ksub; j++) {
float accu = 0;
const float *centj = cents + j * dsub;
for (int k = 0; k < dsub; k++)
accu += sqr (centi[k] - centj[k]);
dis_tab [i + j * ksub] = accu;
}
}
}
}
void ProductQuantizer::search_sdc (const uint8_t * qcodes,
size_t nq,
const uint8_t * bcodes,
const size_t nb,
float_maxheap_array_t * res,
bool init_finalize_heap) const
{
FAISS_ASSERT (sdc_table.size() == M * ksub * ksub);
size_t k = res->k;
FAISS_ASSERT (byte_per_idx == 1);
#pragma omp parallel for
for (size_t i = 0; i < nq; i++) {
/* Compute distances and keep smallest values */
long * heap_ids = res->ids + i * k;
float * heap_dis = res->val + i * k;
const uint8_t * qcode = qcodes + i * code_size;
if (init_finalize_heap)
maxheap_heapify (k, heap_dis, heap_ids);
const uint8_t * bcode = bcodes;
for (size_t j = 0; j < nb; j++) {
float dis = 0;
const float * tab = sdc_table.data();
for (int m = 0; m < M; m++) {
dis += tab[bcode[m] + qcode[m] * ksub];
tab += ksub * ksub;
}
if (dis < heap_dis[0]) {
maxheap_pop (k, heap_dis, heap_ids);
maxheap_push (k, heap_dis, heap_ids, dis, j);
}
bcode += code_size;
}
if (init_finalize_heap)
maxheap_reorder (k, heap_dis, heap_ids);
}
}
} // namespace faiss