faiss/IndexPQ.cpp

1189 lines
31 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexPQ.h>
#include <cstddef>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/utils/hamming.h>
namespace faiss {
/*********************************************************
* IndexPQ implementation
********************************************************/
IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
Index(d, metric), pq(d, M, nbits)
{
is_trained = false;
do_polysemous_training = false;
polysemous_ht = nbits * M + 1;
search_type = ST_PQ;
encode_signs = false;
}
IndexPQ::IndexPQ ()
{
metric_type = METRIC_L2;
is_trained = false;
do_polysemous_training = false;
polysemous_ht = pq.nbits * pq.M + 1;
search_type = ST_PQ;
encode_signs = false;
}
void IndexPQ::train (idx_t n, const float *x)
{
if (!do_polysemous_training) { // standard training
pq.train(n, x);
} else {
idx_t ntrain_perm = polysemous_training.ntrain_permutation;
if (ntrain_perm > n / 4)
ntrain_perm = n / 4;
if (verbose) {
printf ("PQ training on %ld points, remains %ld points: "
"training polysemous on %s\n",
n - ntrain_perm, ntrain_perm,
ntrain_perm == 0 ? "centroids" : "these");
}
pq.train(n - ntrain_perm, x);
polysemous_training.optimize_pq_for_hamming (
pq, ntrain_perm, x + (n - ntrain_perm) * d);
}
is_trained = true;
}
void IndexPQ::add (idx_t n, const float *x)
{
FAISS_THROW_IF_NOT (is_trained);
codes.resize ((n + ntotal) * pq.code_size);
pq.compute_codes (x, &codes[ntotal * pq.code_size], n);
ntotal += n;
}
size_t IndexPQ::remove_ids (const IDSelector & sel)
{
idx_t j = 0;
for (idx_t i = 0; i < ntotal; i++) {
if (sel.is_member (i)) {
// should be removed
} else {
if (i > j) {
memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
}
j++;
}
}
size_t nremove = ntotal - j;
if (nremove > 0) {
ntotal = j;
codes.resize (ntotal * pq.code_size);
}
return nremove;
}
void IndexPQ::reset()
{
codes.clear();
ntotal = 0;
}
void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
{
FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
for (idx_t i = 0; i < ni; i++) {
const uint8_t * code = &codes[(i0 + i) * pq.code_size];
pq.decode (code, recons + i * d);
}
}
void IndexPQ::reconstruct (idx_t key, float * recons) const
{
FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
pq.decode (&codes[key * pq.code_size], recons);
}
namespace {
struct PQDis: DistanceComputer {
size_t d;
Index::idx_t nb;
const uint8_t *codes;
size_t code_size;
const ProductQuantizer & pq;
const float *sdc;
std::vector<float> precomputed_table;
size_t ndis;
float operator () (idx_t i) override
{
const uint8_t *code = codes + i * code_size;
const float *dt = precomputed_table.data();
float accu = 0;
for (int j = 0; j < pq.M; j++) {
accu += dt[*code++];
dt += 256;
}
ndis++;
return accu;
}
float symmetric_dis(idx_t i, idx_t j) override
{
const float * sdci = sdc;
float accu = 0;
const uint8_t *codei = codes + i * code_size;
const uint8_t *codej = codes + j * code_size;
for (int l = 0; l < pq.M; l++) {
accu += sdci[(*codei++) + (*codej++) * 256];
sdci += 256 * 256;
}
return accu;
}
explicit PQDis(const IndexPQ& storage, const float* /*q*/ = nullptr)
: pq(storage.pq) {
precomputed_table.resize(pq.M * pq.ksub);
nb = storage.ntotal;
d = storage.d;
codes = storage.codes.data();
code_size = pq.code_size;
FAISS_ASSERT(pq.ksub == 256);
FAISS_ASSERT(pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M);
sdc = pq.sdc_table.data();
ndis = 0;
}
void set_query(const float *x) override {
pq.compute_distance_table(x, precomputed_table.data());
}
};
} // namespace
DistanceComputer * IndexPQ::get_distance_computer() const {
FAISS_THROW_IF_NOT(pq.nbits == 8);
return new PQDis(*this);
}
/*****************************************
* IndexPQ polysemous search routines
******************************************/
void IndexPQ::search (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels) const
{
FAISS_THROW_IF_NOT (is_trained);
if (search_type == ST_PQ) { // Simple PQ search
if (metric_type == METRIC_L2) {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances };
pq.search (x, n, codes.data(), ntotal, &res, true);
} else {
float_minheap_array_t res = {
size_t(n), size_t(k), labels, distances };
pq.search_ip (x, n, codes.data(), ntotal, &res, true);
}
indexPQ_stats.nq += n;
indexPQ_stats.ncode += n * ntotal;
} else if (search_type == ST_polysemous ||
search_type == ST_polysemous_generalize) {
FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
search_core_polysemous (n, x, k, distances, labels);
} else { // code-to-code distances
uint8_t * q_codes = new uint8_t [n * pq.code_size];
ScopeDeleter<uint8_t> del (q_codes);
if (!encode_signs) {
pq.compute_codes (x, q_codes, n);
} else {
FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
memset (q_codes, 0, n * pq.code_size);
for (size_t i = 0; i < n; i++) {
const float *xi = x + i * d;
uint8_t *code = q_codes + i * pq.code_size;
for (int j = 0; j < d; j++)
if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
}
}
if (search_type == ST_SDC) {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};
pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
} else {
int * idistances = new int [n * k];
ScopeDeleter<int> del (idistances);
int_maxheap_array_t res = {
size_t (n), size_t (k), labels, idistances};
if (search_type == ST_HE) {
hammings_knn_hc (&res, q_codes, codes.data(),
ntotal, pq.code_size, true);
} else if (search_type == ST_generalized_HE) {
generalized_hammings_knn_hc (&res, q_codes, codes.data(),
ntotal, pq.code_size, true);
}
// convert distances to floats
for (int i = 0; i < k * n; i++)
distances[i] = idistances[i];
}
indexPQ_stats.nq += n;
indexPQ_stats.ncode += n * ntotal;
}
}
void IndexPQStats::reset()
{
nq = ncode = n_hamming_pass = 0;
}
IndexPQStats indexPQ_stats;
template <class HammingComputer>
static size_t polysemous_inner_loop (
const IndexPQ & index,
const float *dis_table_qi, const uint8_t *q_code,
size_t k, float *heap_dis, int64_t *heap_ids)
{
int M = index.pq.M;
int code_size = index.pq.code_size;
int ksub = index.pq.ksub;
size_t ntotal = index.ntotal;
int ht = index.polysemous_ht;
const uint8_t *b_code = index.codes.data();
size_t n_pass_i = 0;
HammingComputer hc (q_code, code_size);
for (int64_t bi = 0; bi < ntotal; bi++) {
int hd = hc.hamming (b_code);
if (hd < ht) {
n_pass_i ++;
float dis = 0;
const float * dis_table = dis_table_qi;
for (int m = 0; m < M; m++) {
dis += dis_table [b_code[m]];
dis_table += ksub;
}
if (dis < heap_dis[0]) {
maxheap_pop (k, heap_dis, heap_ids);
maxheap_push (k, heap_dis, heap_ids, dis, bi);
}
}
b_code += code_size;
}
return n_pass_i;
}
void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels) const
{
FAISS_THROW_IF_NOT (pq.nbits == 8);
// PQ distance tables
float * dis_tables = new float [n * pq.ksub * pq.M];
ScopeDeleter<float> del (dis_tables);
pq.compute_distance_tables (n, x, dis_tables);
// Hamming embedding queries
uint8_t * q_codes = new uint8_t [n * pq.code_size];
ScopeDeleter<uint8_t> del2 (q_codes);
if (false) {
pq.compute_codes (x, q_codes, n);
} else {
#pragma omp parallel for
for (idx_t qi = 0; qi < n; qi++) {
pq.compute_code_from_distance_table
(dis_tables + qi * pq.M * pq.ksub,
q_codes + qi * pq.code_size);
}
}
size_t n_pass = 0;
#pragma omp parallel for reduction (+: n_pass)
for (idx_t qi = 0; qi < n; qi++) {
const uint8_t * q_code = q_codes + qi * pq.code_size;
const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
int64_t * heap_ids = labels + qi * k;
float *heap_dis = distances + qi * k;
maxheap_heapify (k, heap_dis, heap_ids);
if (search_type == ST_polysemous) {
switch (pq.code_size) {
case 4:
n_pass += polysemous_inner_loop<HammingComputer4>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 8:
n_pass += polysemous_inner_loop<HammingComputer8>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 16:
n_pass += polysemous_inner_loop<HammingComputer16>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 32:
n_pass += polysemous_inner_loop<HammingComputer32>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 20:
n_pass += polysemous_inner_loop<HammingComputer20>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
default:
if (pq.code_size % 8 == 0) {
n_pass += polysemous_inner_loop<HammingComputerM8>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
} else if (pq.code_size % 4 == 0) {
n_pass += polysemous_inner_loop<HammingComputerM4>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
} else {
FAISS_THROW_FMT(
"code size %zd not supported for polysemous",
pq.code_size);
}
break;
}
} else {
switch (pq.code_size) {
case 8:
n_pass += polysemous_inner_loop<GenHammingComputer8>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 16:
n_pass += polysemous_inner_loop<GenHammingComputer16>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
case 32:
n_pass += polysemous_inner_loop<GenHammingComputer32>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
break;
default:
if (pq.code_size % 8 == 0) {
n_pass += polysemous_inner_loop<GenHammingComputerM8>
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
} else {
FAISS_THROW_FMT(
"code size %zd not supported for polysemous",
pq.code_size);
}
break;
}
}
maxheap_reorder (k, heap_dis, heap_ids);
}
indexPQ_stats.nq += n;
indexPQ_stats.ncode += n * ntotal;
indexPQ_stats.n_hamming_pass += n_pass;
}
/* The standalone codec interface (just remaps to the PQ functions) */
size_t IndexPQ::sa_code_size () const
{
return pq.code_size;
}
void IndexPQ::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
{
pq.compute_codes (x, bytes, n);
}
void IndexPQ::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
{
pq.decode (bytes, x, n);
}
/*****************************************
* Stats of IndexPQ codes
******************************************/
void IndexPQ::hamming_distance_table (idx_t n, const float *x,
int32_t *dis) const
{
uint8_t * q_codes = new uint8_t [n * pq.code_size];
ScopeDeleter<uint8_t> del (q_codes);
pq.compute_codes (x, q_codes, n);
hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
}
void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
idx_t nb, const float *xb,
int64_t *hist)
{
FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
FAISS_THROW_IF_NOT (pq.nbits == 8);
// Hamming embedding queries
uint8_t * q_codes = new uint8_t [n * pq.code_size];
ScopeDeleter <uint8_t> del (q_codes);
pq.compute_codes (x, q_codes, n);
uint8_t * b_codes ;
ScopeDeleter <uint8_t> del_b_codes;
if (xb) {
b_codes = new uint8_t [nb * pq.code_size];
del_b_codes.set (b_codes);
pq.compute_codes (xb, b_codes, nb);
} else {
nb = ntotal;
b_codes = codes.data();
}
int nbits = pq.M * pq.nbits;
memset (hist, 0, sizeof(*hist) * (nbits + 1));
size_t bs = 256;
#pragma omp parallel
{
std::vector<int64_t> histi (nbits + 1);
hamdis_t *distances = new hamdis_t [nb * bs];
ScopeDeleter<hamdis_t> del (distances);
#pragma omp for
for (size_t q0 = 0; q0 < n; q0 += bs) {
// printf ("dis stats: %ld/%ld\n", q0, n);
size_t q1 = q0 + bs;
if (q1 > n) q1 = n;
hammings (q_codes + q0 * pq.code_size, b_codes,
q1 - q0, nb,
pq.code_size, distances);
for (size_t i = 0; i < nb * (q1 - q0); i++)
histi [distances [i]]++;
}
#pragma omp critical
{
for (int i = 0; i <= nbits; i++)
hist[i] += histi[i];
}
}
}
/*****************************************
* MultiIndexQuantizer
******************************************/
namespace {
template <typename T>
struct PreSortedArray {
const T * x;
int N;
explicit PreSortedArray (int N): N(N) {
}
void init (const T*x) {
this->x = x;
}
// get smallest value
T get_0 () {
return x[0];
}
// get delta between n-smallest and n-1 -smallest
T get_diff (int n) {
return x[n] - x[n - 1];
}
// remap orders counted from smallest to indices in array
int get_ord (int n) {
return n;
}
};
template <typename T>
struct ArgSort {
const T * x;
bool operator() (size_t i, size_t j) {
return x[i] < x[j];
}
};
/** Array that maintains a permutation of its elements so that the
* array's elements are sorted
*/
template <typename T>
struct SortedArray {
const T * x;
int N;
std::vector<int> perm;
explicit SortedArray (int N) {
this->N = N;
perm.resize (N);
}
void init (const T*x) {
this->x = x;
for (int n = 0; n < N; n++)
perm[n] = n;
ArgSort<T> cmp = {x };
std::sort (perm.begin(), perm.end(), cmp);
}
// get smallest value
T get_0 () {
return x[perm[0]];
}
// get delta between n-smallest and n-1 -smallest
T get_diff (int n) {
return x[perm[n]] - x[perm[n - 1]];
}
// remap orders counted from smallest to indices in array
int get_ord (int n) {
return perm[n];
}
};
/** Array has n values. Sort the k first ones and copy the other ones
* into elements k..n-1
*/
template <class C>
void partial_sort (int k, int n,
const typename C::T * vals, typename C::TI * perm) {
// insert first k elts in heap
for (int i = 1; i < k; i++) {
indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
}
// insert next n - k elts in heap
for (int i = k; i < n; i++) {
typename C::TI id = perm[i];
typename C::TI top = perm[0];
if (C::cmp(vals[top], vals[id])) {
indirect_heap_pop<C> (k, vals, perm);
indirect_heap_push<C> (k, vals, perm, id);
perm[i] = top;
} else {
// nothing, elt at i is good where it is.
}
}
// order the k first elements in heap
for (int i = k - 1; i > 0; i--) {
typename C::TI top = perm[0];
indirect_heap_pop<C> (i + 1, vals, perm);
perm[i] = top;
}
}
/** same as SortedArray, but only the k first elements are sorted */
template <typename T>
struct SemiSortedArray {
const T * x;
int N;
// type of the heap: CMax = sort ascending
typedef CMax<T, int> HC;
std::vector<int> perm;
int k; // k elements are sorted
int initial_k, k_factor;
explicit SemiSortedArray (int N) {
this->N = N;
perm.resize (N);
perm.resize (N);
initial_k = 3;
k_factor = 4;
}
void init (const T*x) {
this->x = x;
for (int n = 0; n < N; n++)
perm[n] = n;
k = 0;
grow (initial_k);
}
/// grow the sorted part of the array to size next_k
void grow (int next_k) {
if (next_k < N) {
partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
k = next_k;
} else { // full sort of remainder of array
ArgSort<T> cmp = {x };
std::sort (perm.begin() + k, perm.end(), cmp);
k = N;
}
}
// get smallest value
T get_0 () {
return x[perm[0]];
}
// get delta between n-smallest and n-1 -smallest
T get_diff (int n) {
if (n >= k) {
// want to keep powers of 2 - 1
int next_k = (k + 1) * k_factor - 1;
grow (next_k);
}
return x[perm[n]] - x[perm[n - 1]];
}
// remap orders counted from smallest to indices in array
int get_ord (int n) {
assert (n < k);
return perm[n];
}
};
/*****************************************
* Find the k smallest sums of M terms, where each term is taken in a
* table x of n values.
*
* A combination of terms is encoded as a scalar 0 <= t < n^M. The
* combination t0 ... t(M-1) that correspond to the sum
*
* sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)]
*
* is encoded as
*
* t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1)
*
* MinSumK is an object rather than a function, so that storage can be
* re-used over several computations with the same sizes. use_seen is
* good when there may be ties in the x array and it is a concern if
* occasionally several t's are returned.
*
* @param x size M * n, values to add up
* @parms k nb of results to retrieve
* @param M nb of terms
* @param n nb of distinct values
* @param sums output, size k, sorted
* @prarm terms output, size k, with encoding as above
*
******************************************/
template <typename T, class SSA, bool use_seen>
struct MinSumK {
int K; ///< nb of sums to return
int M; ///< nb of elements to sum up
int nbit; ///< nb of bits to encode one entry
int N; ///< nb of possible elements for each of the M terms
/** the heap.
* We use a heap to maintain a queue of sums, with the associated
* terms involved in the sum.
*/
typedef CMin<T, int64_t> HC;
size_t heap_capacity, heap_size;
T *bh_val;
int64_t *bh_ids;
std::vector <SSA> ssx;
// all results get pushed several times. When there are ties, they
// are popped interleaved with others, so it is not easy to
// identify them. Therefore, this bit array just marks elements
// that were seen before.
std::vector <uint8_t> seen;
MinSumK (int K, int M, int nbit, int N):
K(K), M(M), nbit(nbit), N(N) {
heap_capacity = K * M;
assert (N <= (1 << nbit));
// we'll do k steps, each step pushes at most M vals
bh_val = new T[heap_capacity];
bh_ids = new int64_t[heap_capacity];
if (use_seen) {
int64_t n_ids = weight(M);
seen.resize ((n_ids + 7) / 8);
}
for (int m = 0; m < M; m++)
ssx.push_back (SSA(N));
}
int64_t weight (int i) {
return 1 << (i * nbit);
}
bool is_seen (int64_t i) {
return (seen[i >> 3] >> (i & 7)) & 1;
}
void mark_seen (int64_t i) {
if (use_seen)
seen [i >> 3] |= 1 << (i & 7);
}
void run (const T *x, int64_t ldx,
T * sums, int64_t * terms) {
heap_size = 0;
for (int m = 0; m < M; m++) {
ssx[m].init(x);
x += ldx;
}
{ // intial result: take min for all elements
T sum = 0;
terms[0] = 0;
mark_seen (0);
for (int m = 0; m < M; m++) {
sum += ssx[m].get_0();
}
sums[0] = sum;
for (int m = 0; m < M; m++) {
heap_push<HC> (++heap_size, bh_val, bh_ids,
sum + ssx[m].get_diff(1),
weight(m));
}
}
for (int k = 1; k < K; k++) {
// pop smallest value from heap
if (use_seen) {// skip already seen elements
while (is_seen (bh_ids[0])) {
assert (heap_size > 0);
heap_pop<HC> (heap_size--, bh_val, bh_ids);
}
}
assert (heap_size > 0);
T sum = sums[k] = bh_val[0];
int64_t ti = terms[k] = bh_ids[0];
if (use_seen) {
mark_seen (ti);
heap_pop<HC> (heap_size--, bh_val, bh_ids);
} else {
do {
heap_pop<HC> (heap_size--, bh_val, bh_ids);
} while (heap_size > 0 && bh_ids[0] == ti);
}
// enqueue followers
int64_t ii = ti;
for (int m = 0; m < M; m++) {
int64_t n = ii & ((1L << nbit) - 1);
ii >>= nbit;
if (n + 1 >= N) continue;
enqueue_follower (ti, m, n, sum);
}
}
/*
for (int k = 0; k < K; k++)
for (int l = k + 1; l < K; l++)
assert (terms[k] != terms[l]);
*/
// convert indices by applying permutation
for (int k = 0; k < K; k++) {
int64_t ii = terms[k];
if (use_seen) {
// clear seen for reuse at next loop
seen[ii >> 3] = 0;
}
int64_t ti = 0;
for (int m = 0; m < M; m++) {
int64_t n = ii & ((1L << nbit) - 1);
ti += int64_t(ssx[m].get_ord(n)) << (nbit * m);
ii >>= nbit;
}
terms[k] = ti;
}
}
void enqueue_follower (int64_t ti, int m, int n, T sum) {
T next_sum = sum + ssx[m].get_diff(n + 1);
int64_t next_ti = ti + weight(m);
heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
}
~MinSumK () {
delete [] bh_ids;
delete [] bh_val;
}
};
} // anonymous namespace
MultiIndexQuantizer::MultiIndexQuantizer (int d,
size_t M,
size_t nbits):
Index(d, METRIC_L2), pq(d, M, nbits)
{
is_trained = false;
pq.verbose = verbose;
}
void MultiIndexQuantizer::train(idx_t n, const float *x)
{
pq.verbose = verbose;
pq.train (n, x);
is_trained = true;
// count virtual elements in index
ntotal = 1;
for (int m = 0; m < pq.M; m++)
ntotal *= pq.ksub;
}
void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
float *distances, idx_t *labels) const {
if (n == 0) return;
// the allocation just below can be severe...
idx_t bs = 32768;
if (n > bs) {
for (idx_t i0 = 0; i0 < n; i0 += bs) {
idx_t i1 = std::min(i0 + bs, n);
if (verbose) {
printf("MultiIndexQuantizer::search: %ld:%ld / %ld\n",
i0, i1, n);
}
search (i1 - i0, x + i0 * d, k,
distances + i0 * k,
labels + i0 * k);
}
return;
}
float * dis_tables = new float [n * pq.ksub * pq.M];
ScopeDeleter<float> del (dis_tables);
pq.compute_distance_tables (n, x, dis_tables);
if (k == 1) {
// simple version that just finds the min in each table
#pragma omp parallel for
for (int i = 0; i < n; i++) {
const float * dis_table = dis_tables + i * pq.ksub * pq.M;
float dis = 0;
idx_t label = 0;
for (int s = 0; s < pq.M; s++) {
float vmin = HUGE_VALF;
idx_t lmin = -1;
for (idx_t j = 0; j < pq.ksub; j++) {
if (dis_table[j] < vmin) {
vmin = dis_table[j];
lmin = j;
}
}
dis += vmin;
label |= lmin << (s * pq.nbits);
dis_table += pq.ksub;
}
distances [i] = dis;
labels [i] = label;
}
} else {
#pragma omp parallel if(n > 1)
{
MinSumK <float, SemiSortedArray<float>, false>
msk(k, pq.M, pq.nbits, pq.ksub);
#pragma omp for
for (int i = 0; i < n; i++) {
msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
distances + i * k, labels + i * k);
}
}
}
}
void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
{
int64_t jj = key;
for (int m = 0; m < pq.M; m++) {
int64_t n = jj & ((1L << pq.nbits) - 1);
jj >>= pq.nbits;
memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
recons += pq.dsub;
}
}
void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
FAISS_THROW_MSG(
"This index has virtual elements, "
"it does not support add");
}
void MultiIndexQuantizer::reset ()
{
FAISS_THROW_MSG ( "This index has virtual elements, "
"it does not support reset");
}
/*****************************************
* MultiIndexQuantizer2
******************************************/
MultiIndexQuantizer2::MultiIndexQuantizer2 (
int d, size_t M, size_t nbits,
Index **indexes):
MultiIndexQuantizer (d, M, nbits)
{
assign_indexes.resize (M);
for (int i = 0; i < M; i++) {
FAISS_THROW_IF_NOT_MSG(
indexes[i]->d == pq.dsub,
"Provided sub-index has incorrect size");
assign_indexes[i] = indexes[i];
}
own_fields = false;
}
MultiIndexQuantizer2::MultiIndexQuantizer2 (
int d, size_t nbits,
Index *assign_index_0,
Index *assign_index_1):
MultiIndexQuantizer (d, 2, nbits)
{
FAISS_THROW_IF_NOT_MSG(
assign_index_0->d == pq.dsub &&
assign_index_1->d == pq.dsub,
"Provided sub-index has incorrect size");
assign_indexes.resize (2);
assign_indexes [0] = assign_index_0;
assign_indexes [1] = assign_index_1;
own_fields = false;
}
void MultiIndexQuantizer2::train(idx_t n, const float* x)
{
MultiIndexQuantizer::train(n, x);
// add centroids to sub-indexes
for (int i = 0; i < pq.M; i++) {
assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0));
}
}
void MultiIndexQuantizer2::search(
idx_t n, const float* x, idx_t K,
float* distances, idx_t* labels) const
{
if (n == 0) return;
int k2 = std::min(K, int64_t(pq.ksub));
int64_t M = pq.M;
int64_t dsub = pq.dsub, ksub = pq.ksub;
// size (M, n, k2)
std::vector<idx_t> sub_ids(n * M * k2);
std::vector<float> sub_dis(n * M * k2);
std::vector<float> xsub(n * dsub);
for (int m = 0; m < M; m++) {
float *xdest = xsub.data();
const float *xsrc = x + m * dsub;
for (int j = 0; j < n; j++) {
memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
xsrc += d;
xdest += dsub;
}
assign_indexes[m]->search(
n, xsub.data(), k2,
&sub_dis[k2 * n * m],
&sub_ids[k2 * n * m]);
}
if (K == 1) {
// simple version that just finds the min in each table
assert (k2 == 1);
for (int i = 0; i < n; i++) {
float dis = 0;
idx_t label = 0;
for (int m = 0; m < M; m++) {
float vmin = sub_dis[i + m * n];
idx_t lmin = sub_ids[i + m * n];
dis += vmin;
label |= lmin << (m * pq.nbits);
}
distances [i] = dis;
labels [i] = label;
}
} else {
#pragma omp parallel if(n > 1)
{
MinSumK <float, PreSortedArray<float>, false>
msk(K, pq.M, pq.nbits, k2);
#pragma omp for
for (int i = 0; i < n; i++) {
idx_t *li = labels + i * K;
msk.run (&sub_dis[i * k2], k2 * n,
distances + i * K, li);
// remap ids
const idx_t *idmap0 = sub_ids.data() + i * k2;
int64_t ld_idmap = k2 * n;
int64_t mask1 = ksub - 1L;
for (int k = 0; k < K; k++) {
const idx_t *idmap = idmap0;
int64_t vin = li[k];
int64_t vout = 0;
int bs = 0;
for (int m = 0; m < M; m++) {
int64_t s = vin & mask1;
vin >>= pq.nbits;
vout |= idmap[s] << bs;
bs += pq.nbits;
idmap += ld_idmap;
}
li[k] = vout;
}
}
}
}
}
} // namespace faiss