faiss/faiss/IndexIVFPQFastScan.cpp
Junjie Qi 14b8af6e73 Fix IVFPQFastScan decode function (#3312)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3312

as the [#issue3258](https://github.com/facebookresearch/faiss/issues/3258) mentioned, the IVFPQFastScan should have same decoding result as IVFPQ. However, current result is not as expected.

In this PR/Diff, we are going to fix the decoding function

Reviewed By: mdouze

Differential Revision: D55264781

fbshipit-source-id: dfdae9eabceadfc5a3ebb851930d71ce3c1c654d
2024-03-25 11:19:40 -07:00

314 lines
9.1 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/IndexIVFPQFastScan.h>
#include <cassert>
#include <cinttypes>
#include <cstdio>
#include <omp.h>
#include <memory>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/simdlib.h>
#include <faiss/utils/utils.h>
#include <faiss/invlists/BlockInvertedLists.h>
#include <faiss/impl/pq4_fast_scan.h>
#include <faiss/impl/simd_result_handlers.h>
#include <faiss/utils/quantize_lut.h>
namespace faiss {
using namespace simd_result_handlers;
inline size_t roundup(size_t a, size_t b) {
return (a + b - 1) / b * b;
}
IndexIVFPQFastScan::IndexIVFPQFastScan(
Index* quantizer,
size_t d,
size_t nlist,
size_t M,
size_t nbits,
MetricType metric,
int bbs)
: IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) {
by_residual = false; // set to false by default because it's faster
init_fastscan(M, nbits, nlist, metric, bbs);
}
IndexIVFPQFastScan::IndexIVFPQFastScan() {
by_residual = false;
bbs = 0;
M2 = 0;
}
IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
: IndexIVFFastScan(
orig.quantizer,
orig.d,
orig.nlist,
orig.pq.code_size,
orig.metric_type),
pq(orig.pq) {
FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
init_fastscan(orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs);
by_residual = orig.by_residual;
ntotal = orig.ntotal;
is_trained = orig.is_trained;
nprobe = orig.nprobe;
precomputed_table.resize(orig.precomputed_table.size());
if (precomputed_table.nbytes() > 0) {
memcpy(precomputed_table.get(),
orig.precomputed_table.data(),
precomputed_table.nbytes());
}
for (size_t i = 0; i < nlist; i++) {
size_t nb = orig.invlists->list_size(i);
size_t nb2 = roundup(nb, bbs);
AlignedTable<uint8_t> tmp(nb2 * M2 / 2);
pq4_pack_codes(
InvertedLists::ScopedCodes(orig.invlists, i).get(),
nb,
M,
nb2,
bbs,
M2,
tmp.get());
invlists->add_entries(
i,
nb,
InvertedLists::ScopedIds(orig.invlists, i).get(),
tmp.get());
}
orig_invlists = orig.invlists;
}
/*********************************************************
* Training
*********************************************************/
void IndexIVFPQFastScan::train_encoder(
idx_t n,
const float* x,
const idx_t* assign) {
pq.verbose = verbose;
pq.train(n, x);
if (by_residual && metric_type == METRIC_L2) {
precompute_table();
}
}
idx_t IndexIVFPQFastScan::train_encoder_num_vectors() const {
return pq.cp.max_points_per_centroid * pq.ksub;
}
void IndexIVFPQFastScan::precompute_table() {
initialize_IVFPQ_precomputed_table(
use_precomputed_table,
quantizer,
pq,
precomputed_table,
by_residual,
verbose);
}
/*********************************************************
* Code management functions
*********************************************************/
void IndexIVFPQFastScan::encode_vectors(
idx_t n,
const float* x,
const idx_t* list_nos,
uint8_t* codes,
bool include_listnos) const {
if (by_residual) {
AlignedTable<float> residuals(n * d);
for (size_t i = 0; i < n; i++) {
if (list_nos[i] < 0) {
memset(residuals.data() + i * d, 0, sizeof(residuals[0]) * d);
} else {
quantizer->compute_residual(
x + i * d, residuals.data() + i * d, list_nos[i]);
}
}
pq.compute_codes(residuals.data(), codes, n);
} else {
pq.compute_codes(x, codes, n);
}
if (include_listnos) {
size_t coarse_size = coarse_code_size();
for (idx_t i = n - 1; i >= 0; i--) {
uint8_t* code = codes + i * (coarse_size + code_size);
memmove(code + coarse_size, codes + i * code_size, code_size);
encode_listno(list_nos[i], code);
}
}
}
/*********************************************************
* Look-Up Table functions
*********************************************************/
void fvec_madd_simd(
size_t n,
const float* a,
float bf,
const float* b,
float* c) {
assert(is_aligned_pointer(a));
assert(is_aligned_pointer(b));
assert(is_aligned_pointer(c));
assert(n % 8 == 0);
simd8float32 bf8(bf);
n /= 8;
for (size_t i = 0; i < n; i++) {
simd8float32 ai(a);
simd8float32 bi(b);
simd8float32 ci = fmadd(bf8, bi, ai);
ci.store(c);
c += 8;
a += 8;
b += 8;
}
}
bool IndexIVFPQFastScan::lookup_table_is_3d() const {
return by_residual && metric_type == METRIC_L2;
}
void IndexIVFPQFastScan::compute_LUT(
size_t n,
const float* x,
const CoarseQuantized& cq,
AlignedTable<float>& dis_tables,
AlignedTable<float>& biases) const {
size_t dim12 = pq.ksub * pq.M;
size_t d = pq.d;
size_t nprobe = this->nprobe;
if (by_residual) {
if (metric_type == METRIC_L2) {
dis_tables.resize(n * nprobe * dim12);
if (use_precomputed_table == 1) {
biases.resize(n * nprobe);
memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe);
AlignedTable<float> ip_table(n * dim12);
pq.compute_inner_prod_tables(n, x, ip_table.get());
#pragma omp parallel for if (n * nprobe > 8000)
for (idx_t ij = 0; ij < n * nprobe; ij++) {
idx_t i = ij / nprobe;
float* tab = dis_tables.get() + ij * dim12;
idx_t cij = cq.ids[ij];
if (cij >= 0) {
fvec_madd_simd(
dim12,
precomputed_table.get() + cij * dim12,
-2,
ip_table.get() + i * dim12,
tab);
} else {
// fill with NaNs so that they are ignored during
// LUT quantization
memset(tab, -1, sizeof(float) * dim12);
}
}
} else {
std::unique_ptr<float[]> xrel(new float[n * nprobe * d]);
biases.resize(n * nprobe);
memset(biases.get(), 0, sizeof(float) * n * nprobe);
#pragma omp parallel for if (n * nprobe > 8000)
for (idx_t ij = 0; ij < n * nprobe; ij++) {
idx_t i = ij / nprobe;
float* xij = &xrel[ij * d];
idx_t cij = cq.ids[ij];
if (cij >= 0) {
quantizer->compute_residual(x + i * d, xij, cij);
} else {
// will fill with NaNs
memset(xij, -1, sizeof(float) * d);
}
}
pq.compute_distance_tables(
n * nprobe, xrel.get(), dis_tables.get());
}
} else if (metric_type == METRIC_INNER_PRODUCT) {
dis_tables.resize(n * dim12);
pq.compute_inner_prod_tables(n, x, dis_tables.get());
// compute_inner_prod_tables(pq, n, x, dis_tables.get());
biases.resize(n * nprobe);
memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe);
} else {
FAISS_THROW_FMT("metric %d not supported", metric_type);
}
} else {
dis_tables.resize(n * dim12);
if (metric_type == METRIC_L2) {
pq.compute_distance_tables(n, x, dis_tables.get());
} else if (metric_type == METRIC_INNER_PRODUCT) {
pq.compute_inner_prod_tables(n, x, dis_tables.get());
} else {
FAISS_THROW_FMT("metric %d not supported", metric_type);
}
}
}
void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
const {
size_t coarse_size = coarse_code_size();
#pragma omp parallel if (n > 1)
{
std::vector<float> residual(d);
#pragma omp for
for (idx_t i = 0; i < n; i++) {
const uint8_t* code = codes + i * (code_size + coarse_size);
int64_t list_no = decode_listno(code);
float* xi = x + i * d;
pq.decode(code + coarse_size, xi);
if (by_residual) {
quantizer->reconstruct(list_no, residual.data());
for (size_t j = 0; j < d; j++) {
xi[j] += residual[j];
}
}
}
}
}
} // namespace faiss