Demo of residual quantizer distance computer for LaserKNN (#2283)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2283

This is a demonstration for:

- how to use a distance computer to compute query-to-code distances with a residual quantizer

- how to construct a ResidualCoarseQuantizer that uses a prefix of residalquantizer codes

See related doc https://docs.google.com/document/d/1g97lrMXVYh5FcQzw23v_sUE22ybHfCFxtbHyFJwxKKE/edit?usp=sharing

Reviewed By: alexanderguzhva

Differential Revision: D34958088

fbshipit-source-id: edb06ee350de67f855e96ae57a3862fbf14f6e54
pull/2296/head
Matthijs Douze 2022-04-06 12:42:24 -07:00 committed by Facebook GitHub Bot
parent 1806c6af27
commit bb4c987b5c
4 changed files with 421 additions and 4 deletions

View File

@ -21,3 +21,6 @@ target_link_libraries(demo_sift1M PRIVATE faiss)
add_executable(demo_weighted_kmeans EXCLUDE_FROM_ALL demo_weighted_kmeans.cpp)
target_link_libraries(demo_weighted_kmeans PRIVATE faiss)
add_executable(demo_residual_quantizer EXCLUDE_FROM_ALL demo_residual_quantizer.cpp)
target_link_libraries(demo_residual_quantizer PRIVATE faiss)

View File

@ -0,0 +1,292 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
#include <climits>
#include <cstdio>
#include <memory>
#include <faiss/IVFlib.h>
#include <faiss/IndexAdditiveQuantizer.h>
#include <faiss/IndexIVFAdditiveQuantizer.h>
#include <faiss/MetricType.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/random.h>
#include <faiss/utils/utils.h>
/* This demo file shows how to:
* - use a DistanceComputer to compute distances with encoded vectors
* - in the context of an IVF, how to split an additive quantizer into an
* AdditiveCoarseQuantizer and a ResidualQuantizer, in two different ways, with
* and without storing the prefix.
*/
int main() {
/******************************************
* Generate a test dataset
******************************************/
using idx_t = faiss::Index::idx_t;
size_t d = 128;
size_t nt = 10000;
size_t nb = 10000;
size_t nq = 100;
double t0 = faiss::getmillisecs();
auto tic = [t0]() {
printf("[%.3f s] ", (faiss::getmillisecs() - t0) / 1000);
};
tic();
printf("samping dataset of %zd dim vectors, Q %zd B %zd T %zd\n",
d,
nq,
nb,
nt);
std::vector<float> buf(d * (nq + nt + nb));
faiss::rand_smooth_vectors(nq + nt + nb, d, buf.data(), 1234);
const float* xt = buf.data();
const float* xb = buf.data() + nt * d;
const float* xq = buf.data() + (nt + nb) * d;
idx_t k = 10;
std::vector<idx_t> gt(k * nq);
std::vector<float> unused(k * nq);
tic();
printf("compute ground truth, k=%zd\n", k);
faiss::knn_L2sqr(xq, xb, d, nq, nb, k, unused.data(), gt.data());
// a function to compute the accuracy
auto accuracy = [&](const idx_t* I) {
idx_t accu = 0;
for (idx_t q = 0; q < nq; q++) {
accu += faiss::ranklist_intersection_size(
k, gt.data() + q * k, k, I + q * k);
}
return double(accu) / (k * nq);
};
/******************************************
* Prepare the residual quantizer
******************************************/
faiss::ResidualQuantizer rq(
d, 7, 6, faiss::AdditiveQuantizer::ST_norm_qint8);
// do cheap an inaccurate training
rq.cp.niter = 5;
rq.max_beam_size = 5;
rq.train_type = 0;
tic();
printf("training the residual quantizer beam_size=%d\n", rq.max_beam_size);
rq.train(nt, xt);
tic();
printf("encoding the database, code_size=%zd\n", rq.code_size);
size_t code_size = rq.code_size;
std::vector<uint8_t> raw_codes(nb * code_size);
rq.compute_codes(xb, raw_codes.data(), nb);
/****************************************************************
* Make an index that uses that residual quantizer
* Verify that a distance computer gives the same distances
****************************************************************/
{
faiss::IndexResidualQuantizer index(
rq.d, rq.nbits, faiss::METRIC_L2, rq.search_type);
// override trained index
index.rq = rq;
index.is_trained = true;
// override vectors
index.codes = raw_codes;
index.ntotal = nb;
tic();
printf("IndexResidualQuantizer ready, searching\n");
std::vector<float> D(k * nq);
std::vector<idx_t> I(k * nq);
index.search(nq, xq, k, D.data(), I.data());
tic();
printf("Accuracy (intersection @ %zd): %.3f\n", k, accuracy(I.data()));
std::unique_ptr<faiss::FlatCodesDistanceComputer> dc(
index.get_FlatCodesDistanceComputer());
float max_diff12 = 0, max_diff13 = 0;
for (idx_t q = 0; q < nq; q++) {
const float* query = xq + q * d;
dc->set_query(query);
for (int i = 0; i < k; i++) {
// 3 ways of computing the same distance
// distance returned by the index
float dis1 = D[q * k + i];
// distance returned by the DistanceComputer that accesses the
// index
idx_t db_index = I[q * k + i];
float dis2 = (*dc)(db_index);
// distance computer from a code that does not belong to the
// index
const uint8_t* code = raw_codes.data() + code_size * db_index;
float dis3 = dc->distance_to_code(code);
max_diff12 = std::max(std::abs(dis1 - dis2), max_diff12);
max_diff13 = std::max(std::abs(dis1 - dis3), max_diff13);
}
}
tic();
printf("Max DistanceComputer discrepancy 1-2: %g 1-3: %g\n",
max_diff12,
max_diff13);
}
/****************************************************************
* Make an IVF index that uses the first 2 levels as a coarse quantizer
* The IVF codes contain the full code (ie. redundant with the coarse
*quantizer code)
****************************************************************/
{
// build a coarse quantizer from the 2 first levels of the RQ
std::vector<size_t> nbits(2);
std::copy(rq.nbits.begin(), rq.nbits.begin() + 2, nbits.begin());
faiss::ResidualCoarseQuantizer rcq(rq.d, nbits);
// set the coarse quantizer from the 2 first quantizers
rcq.rq.initialize_from(rq);
rcq.is_trained = true;
rcq.ntotal = (idx_t)1 << rcq.rq.tot_bits;
// settings for exhaustive search in RCQ
rcq.centroid_norms.resize(rcq.ntotal);
rcq.aq->compute_centroid_norms(rcq.centroid_norms.data());
rcq.beam_factor = -1.0; // use exact search
size_t nlist = rcq.ntotal;
tic();
printf("RCQ nlist = %zd tot_bits=%zd\n", nlist, rcq.rq.tot_bits);
// build a IVFResidualQuantizer from that
faiss::IndexIVFResidualQuantizer index(
&rcq, rcq.d, nlist, rq.nbits, faiss::METRIC_L2, rq.search_type);
index.by_residual = false;
index.rq = rq;
index.is_trained = true;
// there are 3 ways of filling up the index...
for (std::string filled_with : {"add", "manual", "derived"}) {
tic();
printf("filling up the index with %s, code_size=%zd\n",
filled_with.c_str(),
index.code_size);
index.reset();
if (filled_with == "add") {
// standard add method
index.add(nb, xb);
} else if (filled_with == "manual") {
// compute inverted lists and add elements manually
// fill in the inverted index manually
faiss::InvertedLists& invlists = *index.invlists;
// assign vectors to inverted lists
std::vector<idx_t> listnos(nb);
std::vector<float> unused(nb);
rcq.search(nb, xb, 1, unused.data(), listnos.data());
// populate inverted lists
for (idx_t i = 0; i < nb; i++) {
invlists.add_entry(
listnos[i], i, &raw_codes[i * code_size]);
}
index.ntotal = nb;
} else if (filled_with == "derived") {
// Since we have the raw codes precomputed, their prefix is the
// inverted list index, so let's use that.
faiss::InvertedLists& invlists = *index.invlists;
// populate inverted lists
for (idx_t i = 0; i < nb; i++) {
const uint8_t* code = &raw_codes[i * code_size];
faiss::BitstringReader rd(code, code_size);
idx_t list_no =
rd.read(rcq.rq.tot_bits); // read the list number
invlists.add_entry(list_no, i, code);
}
index.ntotal = nb;
}
tic();
printf("Index filled in\n");
for (int nprobe : {1, 4, 16, 64, int(nlist)}) {
printf("setting nprobe=%-4d", nprobe);
index.nprobe = nprobe;
std::vector<float> D(k * nq);
std::vector<idx_t> I(k * nq);
index.search(nq, xq, k, D.data(), I.data());
tic();
printf("Accuracy (intersection @ %zd): %.3f\n",
k,
accuracy(I.data()));
}
}
}
/****************************************************************
* Make an IVF index that uses the first 2 levels as a coarse
* quantizer, but this time does not store the code prefix from the index
****************************************************************/
{
// build a coarse quantizer from the 2 first levels of the RQ
int nlevel = 2;
std::unique_ptr<faiss::IndexIVFResidualQuantizer> index(
faiss::ivflib::ivf_residual_from_quantizer(rq, nlevel));
// there are 2 ways of filling up the index...
for (std::string filled_with : {"add", "derived"}) {
tic();
printf("filling up the IVF index with %s, code_size=%zd\n",
filled_with.c_str(),
index->code_size);
index->reset();
if (filled_with == "add") {
// standard add method
index->add(nb, xb);
} else if (filled_with == "derived") {
faiss::ivflib::ivf_residual_add_from_flat_codes(
index.get(), nb, raw_codes.data(), rq.code_size);
}
tic();
printf("Index filled in\n");
for (int nprobe : {1, 4, 16, 64, int(index->nlist)}) {
printf("setting nprobe=%-4d", nprobe);
index->nprobe = nprobe;
std::vector<float> D(k * nq);
std::vector<idx_t> I(k * nq);
index->search(nq, xq, k, D.data(), I.data());
tic();
printf("Accuracy (intersection @ %zd): %.3f\n",
k,
accuracy(I.data()));
}
}
}
return 0;
}

View File

@ -5,15 +5,18 @@
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IVFlib.h>
#include <omp.h>
#include <memory>
#include <faiss/IndexAdditiveQuantizer.h>
#include <faiss/IndexIVFAdditiveQuantizer.h>
#include <faiss/IndexPreTransform.h>
#include <faiss/MetaIndexes.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/utils.h>
namespace faiss {
@ -406,5 +409,100 @@ void range_search_with_parameters(
}
}
IndexIVFResidualQuantizer* ivf_residual_from_quantizer(
const ResidualQuantizer& rq,
int nlevel) {
FAISS_THROW_IF_NOT(nlevel > 0 && nlevel + 1 < rq.M);
std::vector<size_t> nbits(nlevel);
std::copy(rq.nbits.begin(), rq.nbits.begin() + nlevel, nbits.begin());
std::unique_ptr<ResidualCoarseQuantizer> rcq(
new ResidualCoarseQuantizer(rq.d, nbits));
// set the coarse quantizer from the 2 first quantizers
rcq->rq.initialize_from(rq);
rcq->is_trained = true;
rcq->ntotal = (idx_t)1 << rcq->rq.tot_bits;
// settings for exhaustive search in RCQ
rcq->centroid_norms.resize(rcq->ntotal);
rcq->aq->compute_centroid_norms(rcq->centroid_norms.data());
rcq->beam_factor = -1.0; // use exact search
size_t nlist = rcq->ntotal;
// build a IVFResidualQuantizer from that
std::vector<size_t> nbits_refined;
for (int i = nlevel; i < rq.M; i++) {
nbits_refined.push_back(rq.nbits[i]);
}
std::unique_ptr<IndexIVFResidualQuantizer> index(
new IndexIVFResidualQuantizer(
rcq.get(),
rq.d,
nlist,
nbits_refined,
faiss::METRIC_L2,
rq.search_type));
index->own_fields = true;
rcq.release();
index->by_residual = true;
index->rq.initialize_from(rq, nlevel);
index->is_trained = true;
return index.release();
}
void ivf_residual_add_from_flat_codes(
IndexIVFResidualQuantizer* index,
size_t nb,
const uint8_t* raw_codes,
int64_t code_size) {
const ResidualCoarseQuantizer* rcq =
dynamic_cast<const faiss::ResidualCoarseQuantizer*>(
index->quantizer);
FAISS_THROW_IF_NOT_MSG(rcq, "the coarse quantizer must be a RCQ");
if (code_size < 0) {
code_size = index->code_size;
}
InvertedLists& invlists = *index->invlists;
const ResidualQuantizer& rq = index->rq;
// populate inverted lists
#pragma omp parallel if (nb > 10000)
{
std::vector<uint8_t> tmp_code(index->code_size);
std::vector<float> tmp(rq.d);
int nt = omp_get_num_threads();
int rank = omp_get_thread_num();
#pragma omp for
for (idx_t i = 0; i < nb; i++) {
const uint8_t* code = &raw_codes[i * code_size];
BitstringReader rd(code, code_size);
idx_t list_no = rd.read(rcq->rq.tot_bits);
if (list_no % nt ==
rank) { // each thread takes care of 1/nt of the invlists
// copy AQ indexes one by one
BitstringWriter wr(tmp_code.data(), tmp_code.size());
for (int j = 0; j < rq.M; j++) {
int nbit = rq.nbits[j];
wr.write(rd.read(nbit), nbit);
}
// we need to recompute the norm
// decode first, does not use the norm component, so that's
// ok
index->rq.decode(tmp_code.data(), tmp.data(), 1);
float norm = fvec_norm_L2sqr(tmp.data(), rq.d);
wr.write(rq.encode_norm(norm), rq.norm_bits);
// add code to the inverted list
invlists.add_entry(list_no, i, tmp_code.data());
}
}
}
index->ntotal += nb;
}
} // namespace ivflib
} // namespace faiss

View File

@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#ifndef FAISS_IVFLIB_H
#define FAISS_IVFLIB_H
@ -20,6 +18,11 @@
#include <vector>
namespace faiss {
struct IndexIVFResidualQuantizer;
struct IndexResidualQuantizer;
struct ResidualQuantizer;
namespace ivflib {
/** check if two indexes have the same parameters and are trained in
@ -145,6 +148,27 @@ void range_search_with_parameters(
size_t* nb_dis = nullptr,
double* ms_per_stage = nullptr);
/** Build an IndexIVFResidualQuantizer from an ResidualQuantizer, using the
* nlevel first components as coarse quantizer and the rest as codes in invlists
*/
IndexIVFResidualQuantizer* ivf_residual_from_quantizer(
const ResidualQuantizer&,
int nlevel);
/** add from codes. NB that the norm component is not used, so the code_size can
* be provided.
*
* @param ivfrq index to populate with the codes
* @param codes codes to add, size (ncode, code_size)
* @param code_size override the ivfrq's code_size, useful if the norm encoding
* is different
*/
void ivf_residual_add_from_flat_codes(
IndexIVFResidualQuantizer* ivfrq,
size_t ncode,
const uint8_t* codes,
int64_t code_size = -1);
} // namespace ivflib
} // namespace faiss