Implement search methods for ProductAdditiveQuantizer (#2336)

Summary:
Work in progress.

This PR is going to implement the following search methods for ProductAdditiveQuantizer, including index factory and I/O:

- [x] IndexProductAdditiveQuantizer
- [x] IndexIVFProductAdditiveQuantizer
- [x] IndexProductAdditiveQuantizerFastScan
- [x] IndexIVFProductAdditiveQuantizerFastScan

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2336

Test Plan:
buck test //faiss/tests/:test_fast_scan
buck test //faiss/tests/:test_fast_scan_ivf
buck test //faiss/tests/:test_local_search_quantizer
buck test //faiss/tests/:test_residual_quantizer

Reviewed By: alexanderguzhva

Differential Revision: D37172745

Pulled By: mdouze

fbshipit-source-id: 6ff18bfc462525478c90cd42e21805ab8605bd0f
pull/2398/head
Check Deng 2022-07-27 05:32:15 -07:00 committed by Facebook GitHub Bot
parent ae0a9d86a7
commit 838f85cb52
20 changed files with 1095 additions and 69 deletions

View File

@ -12,6 +12,8 @@ the Facebook Faiss team. Feel free to add entries here if you submit a PR.
- Added sparse k-means routines and moved the generic kmeans to contrib
- Added FlatDistanceComputer for all FlatCodes indexes
- Support for fast accumulation of 4-bit LSQ and RQ
- Added product additive quantization
## [1.7.2] - 2021-12-15
### Added

View File

@ -353,6 +353,57 @@ void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
is_trained = true;
}
/**************************************************************************************
* IndexProductResidualQuantizer
**************************************************************************************/
IndexProductResidualQuantizer::IndexProductResidualQuantizer(
int d, ///< dimensionality of the input vectors
size_t nsplits, ///< number of residual quantizers
size_t Msub, ///< number of subquantizers per RQ
size_t nbits, ///< number of bit per subvector index
MetricType metric,
Search_type_t search_type)
: IndexAdditiveQuantizer(d, &prq, metric), prq(d, nsplits, Msub, nbits, search_type) {
code_size = prq.code_size;
is_trained = false;
}
IndexProductResidualQuantizer::IndexProductResidualQuantizer()
: IndexProductResidualQuantizer(0, 0, 0, 0) {}
void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
prq.train(n, x);
is_trained = true;
}
/**************************************************************************************
* IndexProductLocalSearchQuantizer
**************************************************************************************/
IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
int d, ///< dimensionality of the input vectors
size_t nsplits, ///< number of local search quantizers
size_t Msub, ///< number of subquantizers per LSQ
size_t nbits, ///< number of bit per subvector index
MetricType metric,
Search_type_t search_type)
: IndexAdditiveQuantizer(d, &plsq, metric), plsq(d, nsplits, Msub, nbits, search_type) {
code_size = plsq.code_size;
is_trained = false;
}
IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer()
: IndexProductLocalSearchQuantizer(0, 0, 0, 0) {}
void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
plsq.train(n, x);
is_trained = true;
}
/**************************************************************************************
* AdditiveCoarseQuantizer
**************************************************************************************/

View File

@ -15,6 +15,7 @@
#include <faiss/IndexFlatCodes.h>
#include <faiss/impl/LocalSearchQuantizer.h>
#include <faiss/impl/ProductAdditiveQuantizer.h>
#include <faiss/impl/ResidualQuantizer.h>
#include <faiss/impl/platform_macros.h>
@ -28,8 +29,8 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
using Search_type_t = AdditiveQuantizer::Search_type_t;
explicit IndexAdditiveQuantizer(
idx_t d = 0,
AdditiveQuantizer* aq = nullptr,
idx_t d,
AdditiveQuantizer* aq,
MetricType metric = METRIC_L2);
void search(
@ -100,6 +101,58 @@ struct IndexLocalSearchQuantizer : IndexAdditiveQuantizer {
void train(idx_t n, const float* x) override;
};
/** Index based on a product residual quantizer.
*/
struct IndexProductResidualQuantizer : IndexAdditiveQuantizer {
/// The product residual quantizer used to encode the vectors
ProductResidualQuantizer prq;
/** Constructor.
*
* @param d dimensionality of the input vectors
* @param nsplits number of residual quantizers
* @param Msub number of subquantizers per RQ
* @param nbits number of bit per subvector index
*/
IndexProductResidualQuantizer(
int d, ///< dimensionality of the input vectors
size_t nsplits, ///< number of residual quantizers
size_t Msub, ///< number of subquantizers per RQ
size_t nbits, ///< number of bit per subvector index
MetricType metric = METRIC_L2,
Search_type_t search_type = AdditiveQuantizer::ST_decompress);
IndexProductResidualQuantizer();
void train(idx_t n, const float* x) override;
};
/** Index based on a product local search quantizer.
*/
struct IndexProductLocalSearchQuantizer : IndexAdditiveQuantizer {
/// The product local search quantizer used to encode the vectors
ProductLocalSearchQuantizer plsq;
/** Constructor.
*
* @param d dimensionality of the input vectors
* @param nsplits number of local search quantizers
* @param Msub number of subquantizers per LSQ
* @param nbits number of bit per subvector index
*/
IndexProductLocalSearchQuantizer(
int d, ///< dimensionality of the input vectors
size_t nsplits, ///< number of local search quantizers
size_t Msub, ///< number of subquantizers per LSQ
size_t nbits, ///< number of bit per subvector index
MetricType metric = METRIC_L2,
Search_type_t search_type = AdditiveQuantizer::ST_decompress);
IndexProductLocalSearchQuantizer();
void train(idx_t n, const float* x) override;
};
/** A "virtual" index where the elements are the residual quantizer centroids.
*
* Intended for use as a coarse quantizer in an IndexIVF.

View File

@ -251,4 +251,46 @@ IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan() {
aq = &lsq;
}
/**************************************************************************************
* IndexProductResidualQuantizerFastScan
**************************************************************************************/
IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan(
int d, ///< dimensionality of the input vectors
size_t nsplits, ///< number of residual quantizers
size_t Msub, ///< number of subquantizers per RQ
size_t nbits, ///< number of bit per subvector index
MetricType metric,
Search_type_t search_type,
int bbs)
: prq(d, nsplits, Msub, nbits, search_type) {
init(&prq, metric, bbs);
}
IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan() {
aq = &prq;
}
/**************************************************************************************
* IndexProductLocalSearchQuantizerFastScan
**************************************************************************************/
IndexProductLocalSearchQuantizerFastScan::
IndexProductLocalSearchQuantizerFastScan(
int d, ///< dimensionality of the input vectors
size_t nsplits, ///< number of local search quantizers
size_t Msub, ///< number of subquantizers per LSQ
size_t nbits, ///< number of bit per subvector index
MetricType metric,
Search_type_t search_type,
int bbs)
: plsq(d, nsplits, Msub, nbits, search_type) {
init(&plsq, metric, bbs);
}
IndexProductLocalSearchQuantizerFastScan::
IndexProductLocalSearchQuantizerFastScan() {
aq = &plsq;
}
} // namespace faiss

View File

@ -10,6 +10,7 @@
#include <faiss/IndexAdditiveQuantizer.h>
#include <faiss/IndexFastScan.h>
#include <faiss/impl/AdditiveQuantizer.h>
#include <faiss/impl/ProductAdditiveQuantizer.h>
#include <faiss/utils/AlignedTable.h>
namespace faiss {
@ -109,6 +110,10 @@ struct IndexResidualQuantizerFastScan : IndexAdditiveQuantizerFastScan {
IndexResidualQuantizerFastScan();
};
/** Index based on a local search quantizer. Stored vectors are
* approximated by local search quantization codes.
* Can also be used as a codec
*/
struct IndexLocalSearchQuantizerFastScan : IndexAdditiveQuantizerFastScan {
LocalSearchQuantizer lsq;
@ -131,4 +136,63 @@ struct IndexLocalSearchQuantizerFastScan : IndexAdditiveQuantizerFastScan {
IndexLocalSearchQuantizerFastScan();
};
/** Index based on a product residual quantizer. Stored vectors are
* approximated by product residual quantization codes.
* Can also be used as a codec
*/
struct IndexProductResidualQuantizerFastScan : IndexAdditiveQuantizerFastScan {
/// The product residual quantizer used to encode the vectors
ProductResidualQuantizer prq;
/** Constructor.
*
* @param d dimensionality of the input vectors
* @param nsplits number of residual quantizers
* @param Msub number of subquantizers per RQ
* @param nbits number of bit per subvector index
* @param metric metric type
* @param search_type AQ search type
*/
IndexProductResidualQuantizerFastScan(
int d, ///< dimensionality of the input vectors
size_t nsplits, ///< number of residual quantizers
size_t Msub, ///< number of subquantizers per RQ
size_t nbits, ///< number of bit per subvector index
MetricType metric = METRIC_L2,
Search_type_t search_type = AdditiveQuantizer::ST_norm_rq2x4,
int bbs = 32);
IndexProductResidualQuantizerFastScan();
};
/** Index based on a product local search quantizer. Stored vectors are
* approximated by product local search quantization codes.
* Can also be used as a codec
*/
struct IndexProductLocalSearchQuantizerFastScan
: IndexAdditiveQuantizerFastScan {
/// The product local search quantizer used to encode the vectors
ProductLocalSearchQuantizer plsq;
/** Constructor.
*
* @param d dimensionality of the input vectors
* @param nsplits number of local search quantizers
* @param Msub number of subquantizers per LSQ
* @param nbits number of bit per subvector index
* @param metric metric type
* @param search_type AQ search type
*/
IndexProductLocalSearchQuantizerFastScan(
int d, ///< dimensionality of the input vectors
size_t nsplits, ///< number of local search quantizers
size_t Msub, ///< number of subquantizers per LSQ
size_t nbits, ///< number of bit per subvector index
MetricType metric = METRIC_L2,
Search_type_t search_type = AdditiveQuantizer::ST_norm_rq2x4,
int bbs = 32);
IndexProductLocalSearchQuantizerFastScan();
};
} // namespace faiss

View File

@ -342,4 +342,50 @@ IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer()
IndexIVFLocalSearchQuantizer::~IndexIVFLocalSearchQuantizer() {}
/**************************************************************************************
* IndexIVFProductResidualQuantizer
**************************************************************************************/
IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer(
Index* quantizer,
size_t d,
size_t nlist,
size_t nsplits,
size_t Msub,
size_t nbits,
MetricType metric,
Search_type_t search_type)
: IndexIVFAdditiveQuantizer(&prq, quantizer, d, nlist, metric),
prq(d, nsplits, Msub, nbits, search_type) {
code_size = invlists->code_size = prq.code_size;
}
IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer()
: IndexIVFAdditiveQuantizer(&prq) {}
IndexIVFProductResidualQuantizer::~IndexIVFProductResidualQuantizer() {}
/**************************************************************************************
* IndexIVFProductLocalSearchQuantizer
**************************************************************************************/
IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer(
Index* quantizer,
size_t d,
size_t nlist,
size_t nsplits,
size_t Msub,
size_t nbits,
MetricType metric,
Search_type_t search_type)
: IndexIVFAdditiveQuantizer(&plsq, quantizer, d, nlist, metric),
plsq(d, nsplits, Msub, nbits, search_type) {
code_size = invlists->code_size = plsq.code_size;
}
IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer()
: IndexIVFAdditiveQuantizer(&plsq) {}
IndexIVFProductLocalSearchQuantizer::~IndexIVFProductLocalSearchQuantizer() {}
} // namespace faiss

View File

@ -15,6 +15,7 @@
#include <faiss/IndexIVF.h>
#include <faiss/impl/LocalSearchQuantizer.h>
#include <faiss/impl/ProductAdditiveQuantizer.h>
#include <faiss/impl/ResidualQuantizer.h>
#include <faiss/impl/platform_macros.h>
@ -118,6 +119,64 @@ struct IndexIVFLocalSearchQuantizer : IndexIVFAdditiveQuantizer {
virtual ~IndexIVFLocalSearchQuantizer();
};
/** IndexIVF based on a product residual quantizer. Stored vectors are
* approximated by product residual quantization codes.
*/
struct IndexIVFProductResidualQuantizer : IndexIVFAdditiveQuantizer {
/// The product residual quantizer used to encode the vectors
ProductResidualQuantizer prq;
/** Constructor.
*
* @param d dimensionality of the input vectors
* @param nsplits number of residual quantizers
* @param Msub number of subquantizers per RQ
* @param nbits number of bit per subvector index
*/
IndexIVFProductResidualQuantizer(
Index* quantizer,
size_t d,
size_t nlist,
size_t nsplits,
size_t Msub,
size_t nbits,
MetricType metric = METRIC_L2,
Search_type_t search_type = AdditiveQuantizer::ST_decompress);
IndexIVFProductResidualQuantizer();
virtual ~IndexIVFProductResidualQuantizer();
};
/** IndexIVF based on a product local search quantizer. Stored vectors are
* approximated by product local search quantization codes.
*/
struct IndexIVFProductLocalSearchQuantizer : IndexIVFAdditiveQuantizer {
/// The product local search quantizer used to encode the vectors
ProductLocalSearchQuantizer plsq;
/** Constructor.
*
* @param d dimensionality of the input vectors
* @param nsplits number of local search quantizers
* @param Msub number of subquantizers per LSQ
* @param nbits number of bit per subvector index
*/
IndexIVFProductLocalSearchQuantizer(
Index* quantizer,
size_t d,
size_t nlist,
size_t nsplits,
size_t Msub,
size_t nbits,
MetricType metric = METRIC_L2,
Search_type_t search_type = AdditiveQuantizer::ST_decompress);
IndexIVFProductLocalSearchQuantizer();
virtual ~IndexIVFProductLocalSearchQuantizer();
};
} // namespace faiss
#endif

View File

@ -499,8 +499,6 @@ IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan() {
aq = &lsq;
}
IndexIVFLocalSearchQuantizerFastScan::~IndexIVFLocalSearchQuantizerFastScan() {}
/********** IndexIVFResidualQuantizerFastScan ************/
IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan(
Index* quantizer,
@ -527,6 +525,62 @@ IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan() {
aq = &rq;
}
IndexIVFResidualQuantizerFastScan::~IndexIVFResidualQuantizerFastScan() {}
/********** IndexIVFProductLocalSearchQuantizerFastScan ************/
IndexIVFProductLocalSearchQuantizerFastScan::
IndexIVFProductLocalSearchQuantizerFastScan(
Index* quantizer,
size_t d,
size_t nlist,
size_t nsplits,
size_t Msub,
size_t nbits,
MetricType metric,
Search_type_t search_type,
int bbs)
: IndexIVFAdditiveQuantizerFastScan(
quantizer,
nullptr,
d,
nlist,
metric,
bbs),
plsq(d, nsplits, Msub, nbits, search_type) {
FAISS_THROW_IF_NOT(nbits == 4);
init(&plsq, nlist, metric, bbs);
}
IndexIVFProductLocalSearchQuantizerFastScan::
IndexIVFProductLocalSearchQuantizerFastScan() {
aq = &plsq;
}
/********** IndexIVFProductResidualQuantizerFastScan ************/
IndexIVFProductResidualQuantizerFastScan::
IndexIVFProductResidualQuantizerFastScan(
Index* quantizer,
size_t d,
size_t nlist,
size_t nsplits,
size_t Msub,
size_t nbits,
MetricType metric,
Search_type_t search_type,
int bbs)
: IndexIVFAdditiveQuantizerFastScan(
quantizer,
nullptr,
d,
nlist,
metric,
bbs),
prq(d, nsplits, Msub, nbits, search_type) {
FAISS_THROW_IF_NOT(nbits == 4);
init(&prq, nlist, metric, bbs);
}
IndexIVFProductResidualQuantizerFastScan::
IndexIVFProductResidualQuantizerFastScan() {
aq = &prq;
}
} // namespace faiss

View File

@ -12,7 +12,7 @@
#include <faiss/IndexIVFAdditiveQuantizer.h>
#include <faiss/IndexIVFFastScan.h>
#include <faiss/impl/AdditiveQuantizer.h>
#include <faiss/impl/LocalSearchQuantizer.h>
#include <faiss/impl/ProductAdditiveQuantizer.h>
#include <faiss/utils/AlignedTable.h>
namespace faiss {
@ -113,8 +113,6 @@ struct IndexIVFLocalSearchQuantizerFastScan
int bbs = 32);
IndexIVFLocalSearchQuantizerFastScan();
~IndexIVFLocalSearchQuantizerFastScan();
};
struct IndexIVFResidualQuantizerFastScan : IndexIVFAdditiveQuantizerFastScan {
@ -131,8 +129,42 @@ struct IndexIVFResidualQuantizerFastScan : IndexIVFAdditiveQuantizerFastScan {
int bbs = 32);
IndexIVFResidualQuantizerFastScan();
};
~IndexIVFResidualQuantizerFastScan();
struct IndexIVFProductLocalSearchQuantizerFastScan
: IndexIVFAdditiveQuantizerFastScan {
ProductLocalSearchQuantizer plsq;
IndexIVFProductLocalSearchQuantizerFastScan(
Index* quantizer,
size_t d,
size_t nlist,
size_t nsplits,
size_t Msub,
size_t nbits,
MetricType metric = METRIC_L2,
Search_type_t search_type = AdditiveQuantizer::ST_norm_lsq2x4,
int bbs = 32);
IndexIVFProductLocalSearchQuantizerFastScan();
};
struct IndexIVFProductResidualQuantizerFastScan
: IndexIVFAdditiveQuantizerFastScan {
ProductResidualQuantizer prq;
IndexIVFProductResidualQuantizerFastScan(
Index* quantizer,
size_t d,
size_t nlist,
size_t nsplits,
size_t Msub,
size_t nbits,
MetricType metric = METRIC_L2,
Search_type_t search_type = AdditiveQuantizer::ST_norm_lsq2x4,
int bbs = 32);
IndexIVFProductResidualQuantizerFastScan();
};
} // namespace faiss

View File

@ -50,26 +50,21 @@ ProductAdditiveQuantizer::ProductAdditiveQuantizer(
init(d, aqs, search_type);
}
ProductAdditiveQuantizer::ProductAdditiveQuantizer() {}
ProductAdditiveQuantizer::ProductAdditiveQuantizer()
: ProductAdditiveQuantizer(0, {}) {}
void ProductAdditiveQuantizer::init(
size_t d,
const std::vector<AdditiveQuantizer*>& aqs,
Search_type_t search_type) {
FAISS_THROW_IF_NOT_MSG(
!aqs.empty(), "At least one additive quantizer is required.");
for (size_t i = 0; i < aqs.size(); i++) {
const auto& q = aqs[i];
FAISS_THROW_IF_NOT(q->d == aqs[0]->d);
FAISS_THROW_IF_NOT(q->M == aqs[0]->M);
FAISS_THROW_IF_NOT(q->nbits[0] == aqs[0]->nbits[0]);
}
// AdditiveQuantizer constructor
this->d = d;
this->search_type = search_type;
M = aqs.size() * aqs[0]->M;
nbits = std::vector<size_t>(M, aqs[0]->nbits[0]);
M = 0;
for (const auto& q : aqs) {
M += q->M;
nbits.insert(nbits.end(), q->nbits.begin(), q->nbits.end());
}
verbose = false;
is_trained = false;
norm_max = norm_min = NAN;
@ -139,6 +134,15 @@ void ProductAdditiveQuantizer::train(size_t n, const float* x) {
}
is_trained = true;
// train norm
std::vector<int32_t> codes(n * M);
compute_unpacked_codes(x, codes.data(), n);
std::vector<float> x_recons(n * d);
std::vector<float> norms(n);
decode_unpacked(codes.data(), x_recons.data(), n);
fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
train_norm(n, norms.data());
}
void ProductAdditiveQuantizer::compute_codes_add_centroids(
@ -148,7 +152,17 @@ void ProductAdditiveQuantizer::compute_codes_add_centroids(
const float* centroids) const {
// size (n, M)
std::vector<int32_t> unpacked_codes(n * M);
compute_unpacked_codes(x, unpacked_codes.data(), n, centroids);
// pack
pack_codes(n, unpacked_codes.data(), codes_out, -1, nullptr, centroids);
}
void ProductAdditiveQuantizer::compute_unpacked_codes(
const float* x,
int32_t* unpacked_codes,
size_t n,
const float* centroids) const {
/// TODO: actuallly we do not need to unpack and pack
size_t offset_d = 0, offset_m = 0;
std::vector<float> xsub;
@ -183,9 +197,6 @@ void ProductAdditiveQuantizer::compute_codes_add_centroids(
offset_d += q->d;
offset_m += q->M;
}
// pack
pack_codes(n, unpacked_codes.data(), codes_out, -1, nullptr, centroids);
}
void ProductAdditiveQuantizer::decode_unpacked(
@ -318,22 +329,26 @@ ProductLocalSearchQuantizer::ProductLocalSearchQuantizer(
size_t Msub,
size_t nbits,
Search_type_t search_type) {
FAISS_THROW_IF_NOT(d % nsplits == 0);
size_t dsub = d / nsplits;
std::vector<AdditiveQuantizer*> aqs;
for (size_t i = 0; i < nsplits; i++) {
auto lsq = new LocalSearchQuantizer(dsub, Msub, nbits, ST_decompress);
aqs.push_back(lsq);
if (nsplits > 0) {
FAISS_THROW_IF_NOT(d % nsplits == 0);
size_t dsub = d / nsplits;
for (size_t i = 0; i < nsplits; i++) {
auto lsq =
new LocalSearchQuantizer(dsub, Msub, nbits, ST_decompress);
aqs.push_back(lsq);
}
}
init(d, aqs, search_type);
for (auto& q : aqs) {
delete q;
}
}
ProductLocalSearchQuantizer::ProductLocalSearchQuantizer() {}
ProductLocalSearchQuantizer::ProductLocalSearchQuantizer()
: ProductLocalSearchQuantizer(0, 0, 0, 0) {}
/*************************************
* Product Residual Quantizer
@ -345,21 +360,24 @@ ProductResidualQuantizer::ProductResidualQuantizer(
size_t Msub,
size_t nbits,
Search_type_t search_type) {
FAISS_THROW_IF_NOT(d % nsplits == 0);
size_t dsub = d / nsplits;
std::vector<AdditiveQuantizer*> aqs;
for (size_t i = 0; i < nsplits; i++) {
auto rq = new ResidualQuantizer(dsub, Msub, nbits, ST_decompress);
aqs.push_back(rq);
if (nsplits > 0) {
FAISS_THROW_IF_NOT(d % nsplits == 0);
size_t dsub = d / nsplits;
for (size_t i = 0; i < nsplits; i++) {
auto rq = new ResidualQuantizer(dsub, Msub, nbits, ST_decompress);
aqs.push_back(rq);
}
}
init(d, aqs, search_type);
for (auto& q : aqs) {
delete q;
}
}
ProductResidualQuantizer::ProductResidualQuantizer() {}
ProductResidualQuantizer::ProductResidualQuantizer()
: ProductResidualQuantizer(0, 0, 0, 0) {}
} // namespace faiss

View File

@ -58,11 +58,6 @@ struct ProductAdditiveQuantizer : AdditiveQuantizer {
///< Train the product additive quantizer
void train(size_t n, const float* x) override;
void compute_codes(const float* x, uint8_t* codes, size_t n)
const override {
compute_codes_add_centroids(x, codes, n);
}
/** Encode a set of vectors
*
* @param x vectors to encode, size n * d
@ -75,6 +70,12 @@ struct ProductAdditiveQuantizer : AdditiveQuantizer {
size_t n,
const float* centroids = nullptr) const override;
void compute_unpacked_codes(
const float* x,
int32_t* codes,
size_t n,
const float* centroids = nullptr) const;
/** Decode a set of vectors in non-packed format
*
* @param codes codes to decode, size n * ld_codes

View File

@ -567,7 +567,12 @@ void ResidualQuantizer::compute_codes_add_centroids(
}
for (size_t i0 = 0; i0 < n; i0 += bs) {
size_t i1 = std::min(n, i0 + bs);
compute_codes(x + i0 * d, codes_out + i0 * code_size, i1 - i0);
const float* cent = nullptr;
if (centroids != nullptr) {
cent = centroids + i0 * d;
}
compute_codes_add_centroids(
x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
}
return;
}

View File

@ -311,6 +311,37 @@ static void read_LocalSearchQuantizer(LocalSearchQuantizer* lsq, IOReader* f) {
READ1(lsq->update_codebooks_with_double);
}
static void read_ProductAdditiveQuantizer(
ProductAdditiveQuantizer* paq,
IOReader* f) {
read_AdditiveQuantizer(paq, f);
READ1(paq->nsplits);
}
static void read_ProductResidualQuantizer(
ProductResidualQuantizer* prq,
IOReader* f) {
read_ProductAdditiveQuantizer(prq, f);
for (size_t i = 0; i < prq->nsplits; i++) {
auto rq = new ResidualQuantizer();
read_ResidualQuantizer(rq, f);
prq->quantizers.push_back(rq);
}
}
static void read_ProductLocalSearchQuantizer(
ProductLocalSearchQuantizer* plsq,
IOReader* f) {
read_ProductAdditiveQuantizer(plsq, f);
for (size_t i = 0; i < plsq->nsplits; i++) {
auto lsq = new LocalSearchQuantizer();
read_LocalSearchQuantizer(lsq, f);
plsq->quantizers.push_back(lsq);
}
}
static void read_ScalarQuantizer(ScalarQuantizer* ivsc, IOReader* f) {
READ1(ivsc->qtype);
READ1(ivsc->rangestat);
@ -559,6 +590,20 @@ Index* read_index(IOReader* f, int io_flags) {
READ1(idxr->code_size);
READVECTOR(idxr->codes);
idx = idxr;
} else if (h == fourcc("IxPR")) {
auto idxpr = new IndexProductResidualQuantizer();
read_index_header(idxpr, f);
read_ProductResidualQuantizer(&idxpr->prq, f);
READ1(idxpr->code_size);
READVECTOR(idxpr->codes);
idx = idxpr;
} else if (h == fourcc("IxPL")) {
auto idxpl = new IndexProductLocalSearchQuantizer();
read_index_header(idxpl, f);
read_ProductLocalSearchQuantizer(&idxpl->plsq, f);
READ1(idxpl->code_size);
READVECTOR(idxpl->codes);
idx = idxpl;
} else if (h == fourcc("ImRQ")) {
ResidualCoarseQuantizer* idxr = new ResidualCoarseQuantizer();
read_index_header(idxr, f);
@ -566,20 +611,35 @@ Index* read_index(IOReader* f, int io_flags) {
READ1(idxr->beam_factor);
idxr->set_beam_factor(idxr->beam_factor);
idx = idxr;
} else if (h == fourcc("ILfs") || h == fourcc("IRfs")) {
} else if (
h == fourcc("ILfs") || h == fourcc("IRfs") || h == fourcc("IPRf") ||
h == fourcc("IPLf")) {
bool is_LSQ = h == fourcc("ILfs");
bool is_RQ = h == fourcc("IRfs");
bool is_PLSQ = h == fourcc("IPLf");
IndexAdditiveQuantizerFastScan* idxaqfs;
if (is_LSQ) {
idxaqfs = new IndexLocalSearchQuantizerFastScan();
} else {
} else if (is_RQ) {
idxaqfs = new IndexResidualQuantizerFastScan();
} else if (is_PLSQ) {
idxaqfs = new IndexProductLocalSearchQuantizerFastScan();
} else {
idxaqfs = new IndexProductResidualQuantizerFastScan();
}
read_index_header(idxaqfs, f);
if (is_LSQ) {
read_LocalSearchQuantizer((LocalSearchQuantizer*)idxaqfs->aq, f);
} else {
} else if (is_RQ) {
read_ResidualQuantizer((ResidualQuantizer*)idxaqfs->aq, f);
} else if (is_PLSQ) {
read_ProductLocalSearchQuantizer(
(ProductLocalSearchQuantizer*)idxaqfs->aq, f);
} else {
read_ProductResidualQuantizer(
(ProductResidualQuantizer*)idxaqfs->aq, f);
}
READ1(idxaqfs->implem);
@ -599,20 +659,35 @@ Index* read_index(IOReader* f, int io_flags) {
READVECTOR(idxaqfs->codes);
idx = idxaqfs;
} else if (h == fourcc("IVLf") || h == fourcc("IVRf")) {
} else if (
h == fourcc("IVLf") || h == fourcc("IVRf") || h == fourcc("NPLf") ||
h == fourcc("NPRf")) {
bool is_LSQ = h == fourcc("IVLf");
bool is_RQ = h == fourcc("IVRf");
bool is_PLSQ = h == fourcc("NPLf");
IndexIVFAdditiveQuantizerFastScan* ivaqfs;
if (is_LSQ) {
ivaqfs = new IndexIVFLocalSearchQuantizerFastScan();
} else {
} else if (is_RQ) {
ivaqfs = new IndexIVFResidualQuantizerFastScan();
} else if (is_PLSQ) {
ivaqfs = new IndexIVFProductLocalSearchQuantizerFastScan();
} else {
ivaqfs = new IndexIVFProductResidualQuantizerFastScan();
}
read_ivf_header(ivaqfs, f);
if (is_LSQ) {
read_LocalSearchQuantizer((LocalSearchQuantizer*)ivaqfs->aq, f);
} else {
} else if (is_RQ) {
read_ResidualQuantizer((ResidualQuantizer*)ivaqfs->aq, f);
} else if (is_PLSQ) {
read_ProductLocalSearchQuantizer(
(ProductLocalSearchQuantizer*)ivaqfs->aq, f);
} else {
read_ProductResidualQuantizer(
(ProductResidualQuantizer*)ivaqfs->aq, f);
}
READ1(ivaqfs->by_residual);
@ -712,20 +787,34 @@ Index* read_index(IOReader* f, int io_flags) {
}
read_InvertedLists(ivsc, f, io_flags);
idx = ivsc;
} else if (h == fourcc("IwLS") || h == fourcc("IwRQ")) {
} else if (
h == fourcc("IwLS") || h == fourcc("IwRQ") || h == fourcc("IwPL") ||
h == fourcc("IwPR")) {
bool is_LSQ = h == fourcc("IwLS");
bool is_RQ = h == fourcc("IwRQ");
bool is_PLSQ = h == fourcc("IwPL");
IndexIVFAdditiveQuantizer* iva;
if (is_LSQ) {
iva = new IndexIVFLocalSearchQuantizer();
} else {
} else if (is_RQ) {
iva = new IndexIVFResidualQuantizer();
} else if (is_PLSQ) {
iva = new IndexIVFProductLocalSearchQuantizer();
} else {
iva = new IndexIVFProductResidualQuantizer();
}
read_ivf_header(iva, f);
READ1(iva->code_size);
if (is_LSQ) {
read_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
} else {
} else if (is_RQ) {
read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
} else if (is_PLSQ) {
read_ProductLocalSearchQuantizer(
(ProductLocalSearchQuantizer*)iva->aq, f);
} else {
read_ProductResidualQuantizer(
(ProductResidualQuantizer*)iva->aq, f);
}
READ1(iva->by_residual);
READ1(iva->use_precomputed_table);

View File

@ -207,6 +207,33 @@ static void write_LocalSearchQuantizer(
WRITE1(lsq->update_codebooks_with_double);
}
static void write_ProductAdditiveQuantizer(
const ProductAdditiveQuantizer* paq,
IOWriter* f) {
write_AdditiveQuantizer(paq, f);
WRITE1(paq->nsplits);
}
static void write_ProductResidualQuantizer(
const ProductResidualQuantizer* prq,
IOWriter* f) {
write_ProductAdditiveQuantizer(prq, f);
for (const auto aq : prq->quantizers) {
auto rq = dynamic_cast<const ResidualQuantizer*>(aq);
write_ResidualQuantizer(rq, f);
}
}
static void write_ProductLocalSearchQuantizer(
const ProductLocalSearchQuantizer* plsq,
IOWriter* f) {
write_ProductAdditiveQuantizer(plsq, f);
for (const auto aq : plsq->quantizers) {
auto lsq = dynamic_cast<const LocalSearchQuantizer*>(aq);
write_LocalSearchQuantizer(lsq, f);
}
}
static void write_ScalarQuantizer(const ScalarQuantizer* ivsc, IOWriter* f) {
WRITE1(ivsc->qtype);
WRITE1(ivsc->rangestat);
@ -394,28 +421,62 @@ void write_index(const Index* idx, IOWriter* f) {
write_LocalSearchQuantizer(&idxr->lsq, f);
WRITE1(idxr->code_size);
WRITEVECTOR(idxr->codes);
} else if (
const IndexProductResidualQuantizer* idxpr =
dynamic_cast<const IndexProductResidualQuantizer*>(idx)) {
uint32_t h = fourcc("IxPR");
WRITE1(h);
write_index_header(idx, f);
write_ProductResidualQuantizer(&idxpr->prq, f);
WRITE1(idxpr->code_size);
WRITEVECTOR(idxpr->codes);
} else if (
const IndexProductLocalSearchQuantizer* idxpl =
dynamic_cast<const IndexProductLocalSearchQuantizer*>(
idx)) {
uint32_t h = fourcc("IxPL");
WRITE1(h);
write_index_header(idx, f);
write_ProductLocalSearchQuantizer(&idxpl->plsq, f);
WRITE1(idxpl->code_size);
WRITEVECTOR(idxpl->codes);
} else if (
auto* idxaqfs =
dynamic_cast<const IndexAdditiveQuantizerFastScan*>(idx)) {
auto idxlsqfs =
dynamic_cast<const IndexLocalSearchQuantizerFastScan*>(idx);
auto idxrqfs = dynamic_cast<const IndexResidualQuantizerFastScan*>(idx);
FAISS_THROW_IF_NOT(idxlsqfs || idxrqfs);
auto idxplsqfs =
dynamic_cast<const IndexProductLocalSearchQuantizerFastScan*>(
idx);
auto idxprqfs =
dynamic_cast<const IndexProductResidualQuantizerFastScan*>(idx);
FAISS_THROW_IF_NOT(idxlsqfs || idxrqfs || idxplsqfs || idxprqfs);
if (idxlsqfs) {
uint32_t h = fourcc("ILfs");
WRITE1(h);
} else {
} else if (idxrqfs) {
uint32_t h = fourcc("IRfs");
WRITE1(h);
} else if (idxplsqfs) {
uint32_t h = fourcc("IPLf");
WRITE1(h);
} else if (idxprqfs) {
uint32_t h = fourcc("IPRf");
WRITE1(h);
}
write_index_header(idxaqfs, f);
if (idxlsqfs) {
write_LocalSearchQuantizer(&idxlsqfs->lsq, f);
} else {
} else if (idxrqfs) {
write_ResidualQuantizer(&idxrqfs->rq, f);
} else if (idxplsqfs) {
write_ProductLocalSearchQuantizer(&idxplsqfs->plsq, f);
} else if (idxprqfs) {
write_ProductResidualQuantizer(&idxprqfs->prq, f);
}
WRITE1(idxaqfs->implem);
WRITE1(idxaqfs->bbs);
@ -441,22 +502,37 @@ void write_index(const Index* idx, IOWriter* f) {
dynamic_cast<const IndexIVFLocalSearchQuantizerFastScan*>(idx);
auto ivrqfs =
dynamic_cast<const IndexIVFResidualQuantizerFastScan*>(idx);
FAISS_THROW_IF_NOT(ivlsqfs || ivrqfs);
auto ivplsqfs = dynamic_cast<
const IndexIVFProductLocalSearchQuantizerFastScan*>(idx);
auto ivprqfs =
dynamic_cast<const IndexIVFProductResidualQuantizerFastScan*>(
idx);
FAISS_THROW_IF_NOT(ivlsqfs || ivrqfs || ivplsqfs || ivprqfs);
if (ivlsqfs) {
uint32_t h = fourcc("IVLf");
WRITE1(h);
} else {
} else if (ivrqfs) {
uint32_t h = fourcc("IVRf");
WRITE1(h);
} else if (ivplsqfs) {
uint32_t h = fourcc("NPLf"); // N means IV ...
WRITE1(h);
} else {
uint32_t h = fourcc("NPRf");
WRITE1(h);
}
write_ivf_header(ivaqfs, f);
if (ivlsqfs) {
write_LocalSearchQuantizer(&ivlsqfs->lsq, f);
} else {
} else if (ivrqfs) {
write_ResidualQuantizer(&ivrqfs->rq, f);
} else if (ivplsqfs) {
write_ProductLocalSearchQuantizer(&ivplsqfs->plsq, f);
} else {
write_ProductResidualQuantizer(&ivprqfs->prq, f);
}
WRITE1(ivaqfs->by_residual);
@ -550,14 +626,33 @@ void write_index(const Index* idx, IOWriter* f) {
write_InvertedLists(ivsc->invlists, f);
} else if (auto iva = dynamic_cast<const IndexIVFAdditiveQuantizer*>(idx)) {
bool is_LSQ = dynamic_cast<const IndexIVFLocalSearchQuantizer*>(iva);
uint32_t h = fourcc(is_LSQ ? "IwLS" : "IwRQ");
bool is_RQ = dynamic_cast<const IndexIVFResidualQuantizer*>(iva);
bool is_PLSQ =
dynamic_cast<const IndexIVFProductLocalSearchQuantizer*>(iva);
uint32_t h;
if (is_LSQ) {
h = fourcc("IwLS");
} else if (is_RQ) {
h = fourcc("IwRQ");
} else if (is_PLSQ) {
h = fourcc("IwPL");
} else {
h = fourcc("IwPR");
}
WRITE1(h);
write_ivf_header(iva, f);
WRITE1(iva->code_size);
if (is_LSQ) {
write_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
} else {
} else if (is_RQ) {
write_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
} else if (is_PLSQ) {
write_ProductLocalSearchQuantizer(
(ProductLocalSearchQuantizer*)iva->aq, f);
} else {
write_ProductResidualQuantizer(
(ProductResidualQuantizer*)iva->aq, f);
}
WRITE1(iva->by_residual);
WRITE1(iva->use_precomputed_table);

View File

@ -159,6 +159,8 @@ const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
const std::string aq_norm_pattern =
"(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4|_Nlsq2x4|_Nrq2x4)";
const std::string paq_def_pattern = "([0-9]+)x([0-9]+)x([0-9]+)";
AdditiveQuantizer::Search_type_t aq_parse_search_type(
std::string stok,
MetricType metric) {
@ -345,6 +347,21 @@ IndexIVF* parse_IndexIVF(
}
return index_ivf;
}
if (match("(PRQ|PLSQ)" + paq_def_pattern + aq_norm_pattern)) {
int nsplits = mres_to_int(sm[2]);
int Msub = mres_to_int(sm[3]);
int nbit = mres_to_int(sm[4]);
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
IndexIVF* index_ivf;
if (sm[1].str() == "PRQ") {
index_ivf = new IndexIVFProductResidualQuantizer(
get_q(), d, nlist, nsplits, Msub, nbit, mt, st);
} else {
index_ivf = new IndexIVFProductLocalSearchQuantizer(
get_q(), d, nlist, nsplits, Msub, nbit, mt, st);
}
return index_ivf;
}
if (match("(RQ|LSQ)([0-9]+)x4fs(r?)(_[0-9]+)?" + aq_norm_pattern)) {
int M = std::stoi(sm[2].str());
int bbs = mres_to_int(sm[4], 32, 1);
@ -360,6 +377,23 @@ IndexIVF* parse_IndexIVF(
index_ivf->by_residual = (sm[3].str() == "r");
return index_ivf;
}
if (match("(PRQ|PLSQ)([0-9]+)x([0-9]+)x4fs(r?)(_[0-9]+)?" +
aq_norm_pattern)) {
int nsplits = std::stoi(sm[2].str());
int Msub = std::stoi(sm[3].str());
int bbs = mres_to_int(sm[5], 32, 1);
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
IndexIVFAdditiveQuantizerFastScan* index_ivf;
if (sm[1].str() == "PRQ") {
index_ivf = new IndexIVFProductResidualQuantizerFastScan(
get_q(), d, nlist, nsplits, Msub, 4, mt, st, bbs);
} else {
index_ivf = new IndexIVFProductLocalSearchQuantizerFastScan(
get_q(), d, nlist, nsplits, Msub, 4, mt, st, bbs);
}
index_ivf->by_residual = (sm[4].str() == "r");
return index_ivf;
}
if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
int outdim = mres_to_int(sm[2], d); // is also the number of bits
std::unique_ptr<VectorTransform> vt;
@ -550,6 +584,26 @@ Index* parse_other_indexes(
return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
}
// IndexProductResidualQuantizer
if (match("PRQ" + paq_def_pattern + aq_norm_pattern)) {
int nsplits = mres_to_int(sm[1]);
int Msub = mres_to_int(sm[2]);
int nbit = mres_to_int(sm[3]);
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
return new IndexProductResidualQuantizer(
d, nsplits, Msub, nbit, metric, st);
}
// IndexProductLocalSearchQuantizer
if (match("PLSQ" + paq_def_pattern + aq_norm_pattern)) {
int nsplits = mres_to_int(sm[1]);
int Msub = mres_to_int(sm[2]);
int nbit = mres_to_int(sm[3]);
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
return new IndexProductLocalSearchQuantizer(
d, nsplits, Msub, nbit, metric, st);
}
// IndexAdditiveQuantizerFastScan
// RQ{M}x4fs_{bbs}_{search_type}
pattern = "(LSQ|RQ)([0-9]+)x4fs(_[0-9]+)?" + aq_norm_pattern;
@ -566,6 +620,24 @@ Index* parse_other_indexes(
}
}
// IndexProductAdditiveQuantizerFastScan
// PRQ{nsplits}x{Msub}x4fs_{bbs}_{search_type}
pattern = "(PLSQ|PRQ)([0-9]+)x([0-9]+)x4fs(_[0-9]+)?" + aq_norm_pattern;
if (match(pattern)) {
int nsplits = std::stoi(sm[2].str());
int Msub = std::stoi(sm[3].str());
int bbs = mres_to_int(sm[4], 32, 1);
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
if (sm[1].str() == "PRQ") {
return new IndexProductResidualQuantizerFastScan(
d, nsplits, Msub, 4, metric, st, bbs);
} else if (sm[1].str() == "PLSQ") {
return new IndexProductLocalSearchQuantizerFastScan(
d, nsplits, Msub, 4, metric, st, bbs);
}
}
return nullptr;
}

View File

@ -573,8 +573,12 @@ void gpu_sync_all_devices()
DOWNCAST ( IndexIVFScalarQuantizer )
DOWNCAST ( IndexIVFResidualQuantizer )
DOWNCAST ( IndexIVFLocalSearchQuantizer )
DOWNCAST ( IndexIVFProductResidualQuantizer )
DOWNCAST ( IndexIVFProductLocalSearchQuantizer )
DOWNCAST ( IndexIVFResidualQuantizerFastScan )
DOWNCAST ( IndexIVFLocalSearchQuantizerFastScan )
DOWNCAST ( IndexIVFProductResidualQuantizerFastScan )
DOWNCAST ( IndexIVFProductLocalSearchQuantizerFastScan )
DOWNCAST ( IndexIVFFlatDedup )
DOWNCAST ( IndexIVFFlat )
DOWNCAST ( IndexIVF )
@ -587,8 +591,12 @@ void gpu_sync_all_devices()
DOWNCAST ( IndexLocalSearchQuantizer )
DOWNCAST ( IndexResidualQuantizerFastScan )
DOWNCAST ( IndexLocalSearchQuantizerFastScan )
DOWNCAST ( IndexProductResidualQuantizerFastScan )
DOWNCAST ( IndexProductLocalSearchQuantizerFastScan )
DOWNCAST ( ResidualCoarseQuantizer )
DOWNCAST ( LocalSearchCoarseQuantizer )
DOWNCAST ( IndexProductResidualQuantizer )
DOWNCAST ( IndexProductLocalSearchQuantizer )
DOWNCAST ( IndexScalarQuantizer )
DOWNCAST ( IndexLSH )
DOWNCAST ( IndexLattice )

View File

@ -471,7 +471,6 @@ class TestAQFastScan(unittest.TestCase):
Compare IndexAdditiveQuantizerFastScan with IndexAQ (qint8)
"""
d = 16
# ds = datasets.SyntheticDataset(d, 1000, 2000, 1000, metric_type)
ds = datasets.SyntheticDataset(d, 1000, 1000, 500, metric_type)
gt = ds.get_groundtruth(k=1)
@ -632,3 +631,70 @@ def add_TestAQFastScan_subtest_from_idxaq(implem, metric):
for implem in 2, 3, 4:
add_TestAQFastScan_subtest_from_idxaq(implem, 'L2')
add_TestAQFastScan_subtest_from_idxaq(implem, 'IP')
class TestPAQFastScan(unittest.TestCase):
def subtest_accuracy(self, paq):
"""
Compare IndexPAQFastScan with IndexPAQ (qint8)
"""
d = 16
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
gt = ds.get_groundtruth(k=1)
index = faiss.index_factory(d, f'{paq}2x3x4_Nqint8')
index.train(ds.get_train())
index.add(ds.get_database())
Dref, Iref = index.search(ds.get_queries(), 1)
indexfs = faiss.index_factory(d, f'{paq}2x3x4fs_Nlsq2x4')
indexfs.train(ds.get_train())
indexfs.add(ds.get_database())
Da, Ia = indexfs.search(ds.get_queries(), 1)
nq = Iref.shape[0]
recall_ref = (Iref == gt).sum() / nq
recall = (Ia == gt).sum() / nq
assert abs(recall_ref - recall) < 0.05
def test_accuracy_PLSQ(self):
self.subtest_accuracy("PLSQ")
def test_accuracy_PRQ(self):
self.subtest_accuracy("PRQ")
def subtest_factory(self, paq):
index = faiss.index_factory(16, f'{paq}2x3x4fs_Nlsq2x4')
q = faiss.downcast_Quantizer(index.aq)
self.assertEqual(q.nsplits, 2)
self.assertEqual(q.subquantizer(0).M, 3)
def test_factory(self):
self.subtest_factory('PRQ')
self.subtest_factory('PLSQ')
def subtest_io(self, factory_str):
d = 8
ds = datasets.SyntheticDataset(d, 1000, 500, 100)
index = faiss.index_factory(d, factory_str)
index.train(ds.get_train())
index.add(ds.get_database())
D1, I1 = index.search(ds.get_queries(), 1)
fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, fname)
index2 = faiss.read_index(fname)
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)
finally:
if os.path.exists(fname):
os.unlink(fname)
def test_io(self):
self.subtest_io('PLSQ2x3x4fs_Nlsq2x4')
self.subtest_io('PRQ2x3x4fs_Nrq2x4')

View File

@ -548,7 +548,6 @@ class TestIVFAQFastScan(unittest.TestCase):
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
gt = ds.get_groundtruth(k=1)
# if metric_type == 'L2':
metric = faiss.METRIC_L2
postfix1 = '_Nqint8'
postfix2 = f'_N{st}2x4'
@ -631,18 +630,18 @@ class TestIVFAQFastScan(unittest.TestCase):
assert index.nlist == nlist
assert index.bbs == bbs
aq = faiss.downcast_AdditiveQuantizer(index.aq)
assert aq.M == M
q = faiss.downcast_Quantizer(index.aq)
assert q.M == M
if aq == 'LSQ':
assert isinstance(aq, faiss.LocalSearchQuantizer)
assert isinstance(q, faiss.LocalSearchQuantizer)
if aq == 'RQ':
assert isinstance(aq, faiss.ResidualQuantizer)
assert isinstance(q, faiss.ResidualQuantizer)
if st == 'lsq':
assert aq.search_type == AQ.ST_norm_lsq2x4
assert q.search_type == AQ.ST_norm_lsq2x4
if st == 'rq':
assert aq.search_type == AQ.ST_norm_rq2x4
assert q.search_type == AQ.ST_norm_rq2x4
assert index.by_residual == (r == 'r')
@ -710,3 +709,78 @@ for byr in True, False:
add_TestIVFAQFastScan_subtest_rescale_accuracy('LSQ', 'lsq', byr, implem)
add_TestIVFAQFastScan_subtest_rescale_accuracy('RQ', 'rq', byr, implem)
class TestIVFPAQFastScan(unittest.TestCase):
def subtest_accuracy(self, paq):
"""
Compare IndexIVFAdditiveQuantizerFastScan with
IndexIVFAdditiveQuantizer
"""
nlist, d = 16, 8
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
gt = ds.get_groundtruth(k=1)
index = faiss.index_factory(d, f'IVF{nlist},{paq}2x3x4_Nqint8')
index.train(ds.get_train())
index.add(ds.get_database())
index.nprobe = 4
Dref, Iref = index.search(ds.get_queries(), 1)
indexfs = faiss.index_factory(d, f'IVF{nlist},{paq}2x3x4fsr_Nlsq2x4')
indexfs.train(ds.get_train())
indexfs.add(ds.get_database())
indexfs.nprobe = 4
D1, I1 = indexfs.search(ds.get_queries(), 1)
nq = Iref.shape[0]
recall_ref = (Iref == gt).sum() / nq
recall1 = (I1 == gt).sum() / nq
print(paq, recall_ref, recall1)
assert abs(recall_ref - recall1) < 0.05
def test_accuracy_PLSQ(self):
self.subtest_accuracy("PLSQ")
def test_accuracy_PRQ(self):
self.subtest_accuracy("PRQ")
def subtest_factory(self, paq):
nlist, d = 128, 16
index = faiss.index_factory(d, f'IVF{nlist},{paq}2x3x4fsr_Nlsq2x4')
q = faiss.downcast_Quantizer(index.aq)
self.assertEqual(index.nlist, nlist)
self.assertEqual(q.nsplits, 2)
self.assertEqual(q.subquantizer(0).M, 3)
self.assertTrue(index.by_residual)
def test_factory(self):
self.subtest_factory('PLSQ')
self.subtest_factory('PRQ')
def subtest_io(self, factory_str):
d = 8
ds = datasets.SyntheticDataset(d, 1000, 2000, 1000)
index = faiss.index_factory(d, factory_str)
index.train(ds.get_train())
index.add(ds.get_database())
D1, I1 = index.search(ds.get_queries(), 1)
fd, fname = tempfile.mkstemp()
os.close(fd)
try:
faiss.write_index(index, fname)
index2 = faiss.read_index(fname)
D2, I2 = index2.search(ds.get_queries(), 1)
np.testing.assert_array_equal(I1, I2)
finally:
if os.path.exists(fname):
os.unlink(fname)
def test_io(self):
self.subtest_io('IVF16,PLSQ2x3x4fsr_Nlsq2x4')
self.subtest_io('IVF16,PRQ2x3x4fs_Nrq2x4')

View File

@ -573,3 +573,101 @@ class TestProductLocalSearchQuantizer(unittest.TestCase):
# max rtoal in OSX: 2.87e-6
np.testing.assert_allclose(lut, lut_ref, rtol=5e-06)
class TestIndexProductLocalSearchQuantizer(unittest.TestCase):
def test_accuracy1(self):
"""check that the error is in the same ballpark as LSQ."""
recall1 = self.eval_index_accuracy("PLSQ4x3x5_Nqint8")
recall2 = self.eval_index_accuracy("LSQ12x5_Nqint8")
self.assertGreaterEqual(recall1, recall2) # 622 vs 551
def test_accuracy2(self):
"""when nsplits = 1, PLSQ should be almost the same as LSQ"""
recall1 = self.eval_index_accuracy("PLSQ1x3x5_Nqint8")
recall2 = self.eval_index_accuracy("LSQ3x5_Nqint8")
diff = abs(recall1 - recall2) # 273 vs 275 in OSX
self.assertGreaterEqual(5, diff)
def eval_index_accuracy(self, index_key):
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
index = faiss.index_factory(ds.d, index_key)
index.train(ds.get_train())
index.add(ds.get_database())
D, I = index.search(ds.get_queries(), 10)
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
# do a little I/O test
index2 = faiss.deserialize_index(faiss.serialize_index(index))
D2, I2 = index2.search(ds.get_queries(), 10)
np.testing.assert_array_equal(I2, I)
np.testing.assert_array_equal(D2, D)
return inter
def test_factory(self):
AQ = faiss.AdditiveQuantizer
ns, Msub, nbits = 2, 4, 8
index = faiss.index_factory(64, f"PLSQ{ns}x{Msub}x{nbits}_Nqint8")
assert isinstance(index, faiss.IndexProductLocalSearchQuantizer)
self.assertEqual(index.plsq.nsplits, ns)
self.assertEqual(index.plsq.subquantizer(0).M, Msub)
self.assertEqual(index.plsq.subquantizer(0).nbits.at(0), nbits)
self.assertEqual(index.plsq.search_type, AQ.ST_norm_qint8)
code_size = (ns * Msub * nbits + 7) // 8 + 1
self.assertEqual(index.plsq.code_size, code_size)
class TestIndexIVFProductLocalSearchQuantizer(unittest.TestCase):
def eval_index_accuracy(self, factory_key):
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
index = faiss.index_factory(ds.d, factory_key)
index.train(ds.get_train())
index.add(ds.get_database())
inters = []
for nprobe in 1, 2, 4, 8, 16:
index.nprobe = nprobe
D, I = index.search(ds.get_queries(), 10)
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
inters.append(inter)
inters = np.array(inters)
self.assertTrue(np.all(inters[1:] >= inters[:-1]))
# do a little I/O test
index2 = faiss.deserialize_index(faiss.serialize_index(index))
D2, I2 = index2.search(ds.get_queries(), 10)
np.testing.assert_array_equal(I2, I)
np.testing.assert_array_equal(D2, D)
return inter
def test_index_accuracy(self):
self.eval_index_accuracy("IVF32,PLSQ2x2x5_Nqint8")
def test_index_accuracy2(self):
"""check that the error is in the same ballpark as LSQ."""
inter1 = self.eval_index_accuracy("IVF32,PLSQ2x2x5_Nqint8")
inter2 = self.eval_index_accuracy("IVF32,LSQ4x5_Nqint8")
# print(inter1, inter2) # 381 vs 374
self.assertGreaterEqual(inter1 * 1.1, inter2)
def test_factory(self):
AQ = faiss.AdditiveQuantizer
ns, Msub, nbits = 2, 4, 8
index = faiss.index_factory(64, f"IVF32,PLSQ{ns}x{Msub}x{nbits}_Nqint8")
assert isinstance(index, faiss.IndexIVFProductLocalSearchQuantizer)
self.assertEqual(index.nlist, 32)
self.assertEqual(index.plsq.nsplits, ns)
self.assertEqual(index.plsq.subquantizer(0).M, Msub)
self.assertEqual(index.plsq.subquantizer(0).nbits.at(0), nbits)
self.assertEqual(index.plsq.search_type, AQ.ST_norm_qint8)
code_size = (ns * Msub * nbits + 7) // 8 + 1
self.assertEqual(index.plsq.code_size, code_size)

View File

@ -1177,3 +1177,100 @@ class TestProductResidualQuantizer(unittest.TestCase):
print(err_prq, err_rq)
self.assertEqual(err_prq, err_rq)
class TestIndexProductResidualQuantizer(unittest.TestCase):
def test_accuracy1(self):
"""check that the error is in the same ballpark as RQ."""
recall1 = self.eval_index_accuracy("PRQ4x3x5_Nqint8")
recall2 = self.eval_index_accuracy("RQ12x5_Nqint8")
self.assertGreaterEqual(recall1 * 1.1, recall2) # 657 vs 665
def test_accuracy2(self):
"""when nsplits = 1, PRQ should be the same as RQ"""
recall1 = self.eval_index_accuracy("PRQ1x3x5_Nqint8")
recall2 = self.eval_index_accuracy("RQ3x5_Nqint8")
self.assertEqual(recall1, recall2)
def eval_index_accuracy(self, index_key):
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
index = faiss.index_factory(ds.d, index_key)
index.train(ds.get_train())
index.add(ds.get_database())
D, I = index.search(ds.get_queries(), 10)
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
# do a little I/O test
index2 = faiss.deserialize_index(faiss.serialize_index(index))
D2, I2 = index2.search(ds.get_queries(), 10)
np.testing.assert_array_equal(I2, I)
np.testing.assert_array_equal(D2, D)
return inter
def test_factory(self):
AQ = faiss.AdditiveQuantizer
ns, Msub, nbits = 2, 4, 8
index = faiss.index_factory(64, f"PRQ{ns}x{Msub}x{nbits}_Nqint8")
assert isinstance(index, faiss.IndexProductResidualQuantizer)
self.assertEqual(index.prq.nsplits, ns)
self.assertEqual(index.prq.subquantizer(0).M, Msub)
self.assertEqual(index.prq.subquantizer(0).nbits.at(0), nbits)
self.assertEqual(index.prq.search_type, AQ.ST_norm_qint8)
code_size = (ns * Msub * nbits + 7) // 8 + 1
self.assertEqual(index.prq.code_size, code_size)
class TestIndexIVFProductResidualQuantizer(unittest.TestCase):
def eval_index_accuracy(self, factory_key):
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
index = faiss.index_factory(ds.d, factory_key)
index.train(ds.get_train())
index.add(ds.get_database())
inters = []
for nprobe in 1, 2, 5, 10, 20, 50:
index.nprobe = nprobe
D, I = index.search(ds.get_queries(), 10)
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
inters.append(inter)
inters = np.array(inters)
self.assertTrue(np.all(inters[1:] >= inters[:-1]))
# do a little I/O test
index2 = faiss.deserialize_index(faiss.serialize_index(index))
D2, I2 = index2.search(ds.get_queries(), 10)
np.testing.assert_array_equal(I2, I)
np.testing.assert_array_equal(D2, D)
return inter
def test_index_accuracy(self):
self.eval_index_accuracy("IVF100,PRQ2x2x5_Nqint8")
def test_index_accuracy2(self):
"""check that the error is in the same ballpark as RQ."""
inter1 = self.eval_index_accuracy("IVF100,PRQ2x2x5_Nqint8")
inter2 = self.eval_index_accuracy("IVF100,RQ4x5_Nqint8")
# print(inter1, inter2) # 392 vs 374
self.assertGreaterEqual(inter1 * 1.1, inter2)
def test_factory(self):
AQ = faiss.AdditiveQuantizer
ns, Msub, nbits = 2, 4, 8
index = faiss.index_factory(64, f"IVF100,PRQ{ns}x{Msub}x{nbits}_Nqint8")
assert isinstance(index, faiss.IndexIVFProductResidualQuantizer)
self.assertEqual(index.nlist, 100)
self.assertEqual(index.prq.nsplits, ns)
self.assertEqual(index.prq.subquantizer(0).M, Msub)
self.assertEqual(index.prq.subquantizer(0).nbits.at(0), nbits)
self.assertEqual(index.prq.search_type, AQ.ST_norm_qint8)
code_size = (ns * Msub * nbits + 7) // 8 + 1
self.assertEqual(index.prq.code_size, code_size)