Implement search methods for ProductAdditiveQuantizer (#2336)
Summary: Work in progress. This PR is going to implement the following search methods for ProductAdditiveQuantizer, including index factory and I/O: - [x] IndexProductAdditiveQuantizer - [x] IndexIVFProductAdditiveQuantizer - [x] IndexProductAdditiveQuantizerFastScan - [x] IndexIVFProductAdditiveQuantizerFastScan Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2336 Test Plan: buck test //faiss/tests/:test_fast_scan buck test //faiss/tests/:test_fast_scan_ivf buck test //faiss/tests/:test_local_search_quantizer buck test //faiss/tests/:test_residual_quantizer Reviewed By: alexanderguzhva Differential Revision: D37172745 Pulled By: mdouze fbshipit-source-id: 6ff18bfc462525478c90cd42e21805ab8605bd0fpull/2398/head
parent
ae0a9d86a7
commit
838f85cb52
|
@ -12,6 +12,8 @@ the Facebook Faiss team. Feel free to add entries here if you submit a PR.
|
|||
|
||||
- Added sparse k-means routines and moved the generic kmeans to contrib
|
||||
- Added FlatDistanceComputer for all FlatCodes indexes
|
||||
- Support for fast accumulation of 4-bit LSQ and RQ
|
||||
- Added product additive quantization
|
||||
|
||||
## [1.7.2] - 2021-12-15
|
||||
### Added
|
||||
|
|
|
@ -353,6 +353,57 @@ void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
|
|||
is_trained = true;
|
||||
}
|
||||
|
||||
|
||||
/**************************************************************************************
|
||||
* IndexProductResidualQuantizer
|
||||
**************************************************************************************/
|
||||
|
||||
IndexProductResidualQuantizer::IndexProductResidualQuantizer(
|
||||
int d, ///< dimensionality of the input vectors
|
||||
size_t nsplits, ///< number of residual quantizers
|
||||
size_t Msub, ///< number of subquantizers per RQ
|
||||
size_t nbits, ///< number of bit per subvector index
|
||||
MetricType metric,
|
||||
Search_type_t search_type)
|
||||
: IndexAdditiveQuantizer(d, &prq, metric), prq(d, nsplits, Msub, nbits, search_type) {
|
||||
code_size = prq.code_size;
|
||||
is_trained = false;
|
||||
}
|
||||
|
||||
IndexProductResidualQuantizer::IndexProductResidualQuantizer()
|
||||
: IndexProductResidualQuantizer(0, 0, 0, 0) {}
|
||||
|
||||
void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
|
||||
prq.train(n, x);
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
|
||||
/**************************************************************************************
|
||||
* IndexProductLocalSearchQuantizer
|
||||
**************************************************************************************/
|
||||
|
||||
IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
|
||||
int d, ///< dimensionality of the input vectors
|
||||
size_t nsplits, ///< number of local search quantizers
|
||||
size_t Msub, ///< number of subquantizers per LSQ
|
||||
size_t nbits, ///< number of bit per subvector index
|
||||
MetricType metric,
|
||||
Search_type_t search_type)
|
||||
: IndexAdditiveQuantizer(d, &plsq, metric), plsq(d, nsplits, Msub, nbits, search_type) {
|
||||
code_size = plsq.code_size;
|
||||
is_trained = false;
|
||||
}
|
||||
|
||||
IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer()
|
||||
: IndexProductLocalSearchQuantizer(0, 0, 0, 0) {}
|
||||
|
||||
void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
|
||||
plsq.train(n, x);
|
||||
is_trained = true;
|
||||
}
|
||||
|
||||
|
||||
/**************************************************************************************
|
||||
* AdditiveCoarseQuantizer
|
||||
**************************************************************************************/
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include <faiss/IndexFlatCodes.h>
|
||||
#include <faiss/impl/LocalSearchQuantizer.h>
|
||||
#include <faiss/impl/ProductAdditiveQuantizer.h>
|
||||
#include <faiss/impl/ResidualQuantizer.h>
|
||||
#include <faiss/impl/platform_macros.h>
|
||||
|
||||
|
@ -28,8 +29,8 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
|
|||
using Search_type_t = AdditiveQuantizer::Search_type_t;
|
||||
|
||||
explicit IndexAdditiveQuantizer(
|
||||
idx_t d = 0,
|
||||
AdditiveQuantizer* aq = nullptr,
|
||||
idx_t d,
|
||||
AdditiveQuantizer* aq,
|
||||
MetricType metric = METRIC_L2);
|
||||
|
||||
void search(
|
||||
|
@ -100,6 +101,58 @@ struct IndexLocalSearchQuantizer : IndexAdditiveQuantizer {
|
|||
void train(idx_t n, const float* x) override;
|
||||
};
|
||||
|
||||
/** Index based on a product residual quantizer.
|
||||
*/
|
||||
struct IndexProductResidualQuantizer : IndexAdditiveQuantizer {
|
||||
/// The product residual quantizer used to encode the vectors
|
||||
ProductResidualQuantizer prq;
|
||||
|
||||
/** Constructor.
|
||||
*
|
||||
* @param d dimensionality of the input vectors
|
||||
* @param nsplits number of residual quantizers
|
||||
* @param Msub number of subquantizers per RQ
|
||||
* @param nbits number of bit per subvector index
|
||||
*/
|
||||
IndexProductResidualQuantizer(
|
||||
int d, ///< dimensionality of the input vectors
|
||||
size_t nsplits, ///< number of residual quantizers
|
||||
size_t Msub, ///< number of subquantizers per RQ
|
||||
size_t nbits, ///< number of bit per subvector index
|
||||
MetricType metric = METRIC_L2,
|
||||
Search_type_t search_type = AdditiveQuantizer::ST_decompress);
|
||||
|
||||
IndexProductResidualQuantizer();
|
||||
|
||||
void train(idx_t n, const float* x) override;
|
||||
};
|
||||
|
||||
/** Index based on a product local search quantizer.
|
||||
*/
|
||||
struct IndexProductLocalSearchQuantizer : IndexAdditiveQuantizer {
|
||||
/// The product local search quantizer used to encode the vectors
|
||||
ProductLocalSearchQuantizer plsq;
|
||||
|
||||
/** Constructor.
|
||||
*
|
||||
* @param d dimensionality of the input vectors
|
||||
* @param nsplits number of local search quantizers
|
||||
* @param Msub number of subquantizers per LSQ
|
||||
* @param nbits number of bit per subvector index
|
||||
*/
|
||||
IndexProductLocalSearchQuantizer(
|
||||
int d, ///< dimensionality of the input vectors
|
||||
size_t nsplits, ///< number of local search quantizers
|
||||
size_t Msub, ///< number of subquantizers per LSQ
|
||||
size_t nbits, ///< number of bit per subvector index
|
||||
MetricType metric = METRIC_L2,
|
||||
Search_type_t search_type = AdditiveQuantizer::ST_decompress);
|
||||
|
||||
IndexProductLocalSearchQuantizer();
|
||||
|
||||
void train(idx_t n, const float* x) override;
|
||||
};
|
||||
|
||||
/** A "virtual" index where the elements are the residual quantizer centroids.
|
||||
*
|
||||
* Intended for use as a coarse quantizer in an IndexIVF.
|
||||
|
|
|
@ -251,4 +251,46 @@ IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan() {
|
|||
aq = &lsq;
|
||||
}
|
||||
|
||||
/**************************************************************************************
|
||||
* IndexProductResidualQuantizerFastScan
|
||||
**************************************************************************************/
|
||||
|
||||
IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan(
|
||||
int d, ///< dimensionality of the input vectors
|
||||
size_t nsplits, ///< number of residual quantizers
|
||||
size_t Msub, ///< number of subquantizers per RQ
|
||||
size_t nbits, ///< number of bit per subvector index
|
||||
MetricType metric,
|
||||
Search_type_t search_type,
|
||||
int bbs)
|
||||
: prq(d, nsplits, Msub, nbits, search_type) {
|
||||
init(&prq, metric, bbs);
|
||||
}
|
||||
|
||||
IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan() {
|
||||
aq = &prq;
|
||||
}
|
||||
|
||||
/**************************************************************************************
|
||||
* IndexProductLocalSearchQuantizerFastScan
|
||||
**************************************************************************************/
|
||||
|
||||
IndexProductLocalSearchQuantizerFastScan::
|
||||
IndexProductLocalSearchQuantizerFastScan(
|
||||
int d, ///< dimensionality of the input vectors
|
||||
size_t nsplits, ///< number of local search quantizers
|
||||
size_t Msub, ///< number of subquantizers per LSQ
|
||||
size_t nbits, ///< number of bit per subvector index
|
||||
MetricType metric,
|
||||
Search_type_t search_type,
|
||||
int bbs)
|
||||
: plsq(d, nsplits, Msub, nbits, search_type) {
|
||||
init(&plsq, metric, bbs);
|
||||
}
|
||||
|
||||
IndexProductLocalSearchQuantizerFastScan::
|
||||
IndexProductLocalSearchQuantizerFastScan() {
|
||||
aq = &plsq;
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include <faiss/IndexAdditiveQuantizer.h>
|
||||
#include <faiss/IndexFastScan.h>
|
||||
#include <faiss/impl/AdditiveQuantizer.h>
|
||||
#include <faiss/impl/ProductAdditiveQuantizer.h>
|
||||
#include <faiss/utils/AlignedTable.h>
|
||||
|
||||
namespace faiss {
|
||||
|
@ -109,6 +110,10 @@ struct IndexResidualQuantizerFastScan : IndexAdditiveQuantizerFastScan {
|
|||
IndexResidualQuantizerFastScan();
|
||||
};
|
||||
|
||||
/** Index based on a local search quantizer. Stored vectors are
|
||||
* approximated by local search quantization codes.
|
||||
* Can also be used as a codec
|
||||
*/
|
||||
struct IndexLocalSearchQuantizerFastScan : IndexAdditiveQuantizerFastScan {
|
||||
LocalSearchQuantizer lsq;
|
||||
|
||||
|
@ -131,4 +136,63 @@ struct IndexLocalSearchQuantizerFastScan : IndexAdditiveQuantizerFastScan {
|
|||
IndexLocalSearchQuantizerFastScan();
|
||||
};
|
||||
|
||||
/** Index based on a product residual quantizer. Stored vectors are
|
||||
* approximated by product residual quantization codes.
|
||||
* Can also be used as a codec
|
||||
*/
|
||||
struct IndexProductResidualQuantizerFastScan : IndexAdditiveQuantizerFastScan {
|
||||
/// The product residual quantizer used to encode the vectors
|
||||
ProductResidualQuantizer prq;
|
||||
|
||||
/** Constructor.
|
||||
*
|
||||
* @param d dimensionality of the input vectors
|
||||
* @param nsplits number of residual quantizers
|
||||
* @param Msub number of subquantizers per RQ
|
||||
* @param nbits number of bit per subvector index
|
||||
* @param metric metric type
|
||||
* @param search_type AQ search type
|
||||
*/
|
||||
IndexProductResidualQuantizerFastScan(
|
||||
int d, ///< dimensionality of the input vectors
|
||||
size_t nsplits, ///< number of residual quantizers
|
||||
size_t Msub, ///< number of subquantizers per RQ
|
||||
size_t nbits, ///< number of bit per subvector index
|
||||
MetricType metric = METRIC_L2,
|
||||
Search_type_t search_type = AdditiveQuantizer::ST_norm_rq2x4,
|
||||
int bbs = 32);
|
||||
|
||||
IndexProductResidualQuantizerFastScan();
|
||||
};
|
||||
|
||||
/** Index based on a product local search quantizer. Stored vectors are
|
||||
* approximated by product local search quantization codes.
|
||||
* Can also be used as a codec
|
||||
*/
|
||||
struct IndexProductLocalSearchQuantizerFastScan
|
||||
: IndexAdditiveQuantizerFastScan {
|
||||
/// The product local search quantizer used to encode the vectors
|
||||
ProductLocalSearchQuantizer plsq;
|
||||
|
||||
/** Constructor.
|
||||
*
|
||||
* @param d dimensionality of the input vectors
|
||||
* @param nsplits number of local search quantizers
|
||||
* @param Msub number of subquantizers per LSQ
|
||||
* @param nbits number of bit per subvector index
|
||||
* @param metric metric type
|
||||
* @param search_type AQ search type
|
||||
*/
|
||||
IndexProductLocalSearchQuantizerFastScan(
|
||||
int d, ///< dimensionality of the input vectors
|
||||
size_t nsplits, ///< number of local search quantizers
|
||||
size_t Msub, ///< number of subquantizers per LSQ
|
||||
size_t nbits, ///< number of bit per subvector index
|
||||
MetricType metric = METRIC_L2,
|
||||
Search_type_t search_type = AdditiveQuantizer::ST_norm_rq2x4,
|
||||
int bbs = 32);
|
||||
|
||||
IndexProductLocalSearchQuantizerFastScan();
|
||||
};
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -342,4 +342,50 @@ IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer()
|
|||
|
||||
IndexIVFLocalSearchQuantizer::~IndexIVFLocalSearchQuantizer() {}
|
||||
|
||||
/**************************************************************************************
|
||||
* IndexIVFProductResidualQuantizer
|
||||
**************************************************************************************/
|
||||
|
||||
IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer(
|
||||
Index* quantizer,
|
||||
size_t d,
|
||||
size_t nlist,
|
||||
size_t nsplits,
|
||||
size_t Msub,
|
||||
size_t nbits,
|
||||
MetricType metric,
|
||||
Search_type_t search_type)
|
||||
: IndexIVFAdditiveQuantizer(&prq, quantizer, d, nlist, metric),
|
||||
prq(d, nsplits, Msub, nbits, search_type) {
|
||||
code_size = invlists->code_size = prq.code_size;
|
||||
}
|
||||
|
||||
IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer()
|
||||
: IndexIVFAdditiveQuantizer(&prq) {}
|
||||
|
||||
IndexIVFProductResidualQuantizer::~IndexIVFProductResidualQuantizer() {}
|
||||
|
||||
/**************************************************************************************
|
||||
* IndexIVFProductLocalSearchQuantizer
|
||||
**************************************************************************************/
|
||||
|
||||
IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer(
|
||||
Index* quantizer,
|
||||
size_t d,
|
||||
size_t nlist,
|
||||
size_t nsplits,
|
||||
size_t Msub,
|
||||
size_t nbits,
|
||||
MetricType metric,
|
||||
Search_type_t search_type)
|
||||
: IndexIVFAdditiveQuantizer(&plsq, quantizer, d, nlist, metric),
|
||||
plsq(d, nsplits, Msub, nbits, search_type) {
|
||||
code_size = invlists->code_size = plsq.code_size;
|
||||
}
|
||||
|
||||
IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer()
|
||||
: IndexIVFAdditiveQuantizer(&plsq) {}
|
||||
|
||||
IndexIVFProductLocalSearchQuantizer::~IndexIVFProductLocalSearchQuantizer() {}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include <faiss/IndexIVF.h>
|
||||
#include <faiss/impl/LocalSearchQuantizer.h>
|
||||
#include <faiss/impl/ProductAdditiveQuantizer.h>
|
||||
#include <faiss/impl/ResidualQuantizer.h>
|
||||
#include <faiss/impl/platform_macros.h>
|
||||
|
||||
|
@ -118,6 +119,64 @@ struct IndexIVFLocalSearchQuantizer : IndexIVFAdditiveQuantizer {
|
|||
virtual ~IndexIVFLocalSearchQuantizer();
|
||||
};
|
||||
|
||||
/** IndexIVF based on a product residual quantizer. Stored vectors are
|
||||
* approximated by product residual quantization codes.
|
||||
*/
|
||||
struct IndexIVFProductResidualQuantizer : IndexIVFAdditiveQuantizer {
|
||||
/// The product residual quantizer used to encode the vectors
|
||||
ProductResidualQuantizer prq;
|
||||
|
||||
/** Constructor.
|
||||
*
|
||||
* @param d dimensionality of the input vectors
|
||||
* @param nsplits number of residual quantizers
|
||||
* @param Msub number of subquantizers per RQ
|
||||
* @param nbits number of bit per subvector index
|
||||
*/
|
||||
IndexIVFProductResidualQuantizer(
|
||||
Index* quantizer,
|
||||
size_t d,
|
||||
size_t nlist,
|
||||
size_t nsplits,
|
||||
size_t Msub,
|
||||
size_t nbits,
|
||||
MetricType metric = METRIC_L2,
|
||||
Search_type_t search_type = AdditiveQuantizer::ST_decompress);
|
||||
|
||||
IndexIVFProductResidualQuantizer();
|
||||
|
||||
virtual ~IndexIVFProductResidualQuantizer();
|
||||
};
|
||||
|
||||
/** IndexIVF based on a product local search quantizer. Stored vectors are
|
||||
* approximated by product local search quantization codes.
|
||||
*/
|
||||
struct IndexIVFProductLocalSearchQuantizer : IndexIVFAdditiveQuantizer {
|
||||
/// The product local search quantizer used to encode the vectors
|
||||
ProductLocalSearchQuantizer plsq;
|
||||
|
||||
/** Constructor.
|
||||
*
|
||||
* @param d dimensionality of the input vectors
|
||||
* @param nsplits number of local search quantizers
|
||||
* @param Msub number of subquantizers per LSQ
|
||||
* @param nbits number of bit per subvector index
|
||||
*/
|
||||
IndexIVFProductLocalSearchQuantizer(
|
||||
Index* quantizer,
|
||||
size_t d,
|
||||
size_t nlist,
|
||||
size_t nsplits,
|
||||
size_t Msub,
|
||||
size_t nbits,
|
||||
MetricType metric = METRIC_L2,
|
||||
Search_type_t search_type = AdditiveQuantizer::ST_decompress);
|
||||
|
||||
IndexIVFProductLocalSearchQuantizer();
|
||||
|
||||
virtual ~IndexIVFProductLocalSearchQuantizer();
|
||||
};
|
||||
|
||||
} // namespace faiss
|
||||
|
||||
#endif
|
||||
|
|
|
@ -499,8 +499,6 @@ IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan() {
|
|||
aq = &lsq;
|
||||
}
|
||||
|
||||
IndexIVFLocalSearchQuantizerFastScan::~IndexIVFLocalSearchQuantizerFastScan() {}
|
||||
|
||||
/********** IndexIVFResidualQuantizerFastScan ************/
|
||||
IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan(
|
||||
Index* quantizer,
|
||||
|
@ -527,6 +525,62 @@ IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan() {
|
|||
aq = &rq;
|
||||
}
|
||||
|
||||
IndexIVFResidualQuantizerFastScan::~IndexIVFResidualQuantizerFastScan() {}
|
||||
/********** IndexIVFProductLocalSearchQuantizerFastScan ************/
|
||||
IndexIVFProductLocalSearchQuantizerFastScan::
|
||||
IndexIVFProductLocalSearchQuantizerFastScan(
|
||||
Index* quantizer,
|
||||
size_t d,
|
||||
size_t nlist,
|
||||
size_t nsplits,
|
||||
size_t Msub,
|
||||
size_t nbits,
|
||||
MetricType metric,
|
||||
Search_type_t search_type,
|
||||
int bbs)
|
||||
: IndexIVFAdditiveQuantizerFastScan(
|
||||
quantizer,
|
||||
nullptr,
|
||||
d,
|
||||
nlist,
|
||||
metric,
|
||||
bbs),
|
||||
plsq(d, nsplits, Msub, nbits, search_type) {
|
||||
FAISS_THROW_IF_NOT(nbits == 4);
|
||||
init(&plsq, nlist, metric, bbs);
|
||||
}
|
||||
|
||||
IndexIVFProductLocalSearchQuantizerFastScan::
|
||||
IndexIVFProductLocalSearchQuantizerFastScan() {
|
||||
aq = &plsq;
|
||||
}
|
||||
|
||||
/********** IndexIVFProductResidualQuantizerFastScan ************/
|
||||
IndexIVFProductResidualQuantizerFastScan::
|
||||
IndexIVFProductResidualQuantizerFastScan(
|
||||
Index* quantizer,
|
||||
size_t d,
|
||||
size_t nlist,
|
||||
size_t nsplits,
|
||||
size_t Msub,
|
||||
size_t nbits,
|
||||
MetricType metric,
|
||||
Search_type_t search_type,
|
||||
int bbs)
|
||||
: IndexIVFAdditiveQuantizerFastScan(
|
||||
quantizer,
|
||||
nullptr,
|
||||
d,
|
||||
nlist,
|
||||
metric,
|
||||
bbs),
|
||||
prq(d, nsplits, Msub, nbits, search_type) {
|
||||
FAISS_THROW_IF_NOT(nbits == 4);
|
||||
init(&prq, nlist, metric, bbs);
|
||||
}
|
||||
|
||||
IndexIVFProductResidualQuantizerFastScan::
|
||||
IndexIVFProductResidualQuantizerFastScan() {
|
||||
aq = &prq;
|
||||
}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include <faiss/IndexIVFAdditiveQuantizer.h>
|
||||
#include <faiss/IndexIVFFastScan.h>
|
||||
#include <faiss/impl/AdditiveQuantizer.h>
|
||||
#include <faiss/impl/LocalSearchQuantizer.h>
|
||||
#include <faiss/impl/ProductAdditiveQuantizer.h>
|
||||
#include <faiss/utils/AlignedTable.h>
|
||||
|
||||
namespace faiss {
|
||||
|
@ -113,8 +113,6 @@ struct IndexIVFLocalSearchQuantizerFastScan
|
|||
int bbs = 32);
|
||||
|
||||
IndexIVFLocalSearchQuantizerFastScan();
|
||||
|
||||
~IndexIVFLocalSearchQuantizerFastScan();
|
||||
};
|
||||
|
||||
struct IndexIVFResidualQuantizerFastScan : IndexIVFAdditiveQuantizerFastScan {
|
||||
|
@ -131,8 +129,42 @@ struct IndexIVFResidualQuantizerFastScan : IndexIVFAdditiveQuantizerFastScan {
|
|||
int bbs = 32);
|
||||
|
||||
IndexIVFResidualQuantizerFastScan();
|
||||
};
|
||||
|
||||
~IndexIVFResidualQuantizerFastScan();
|
||||
struct IndexIVFProductLocalSearchQuantizerFastScan
|
||||
: IndexIVFAdditiveQuantizerFastScan {
|
||||
ProductLocalSearchQuantizer plsq;
|
||||
|
||||
IndexIVFProductLocalSearchQuantizerFastScan(
|
||||
Index* quantizer,
|
||||
size_t d,
|
||||
size_t nlist,
|
||||
size_t nsplits,
|
||||
size_t Msub,
|
||||
size_t nbits,
|
||||
MetricType metric = METRIC_L2,
|
||||
Search_type_t search_type = AdditiveQuantizer::ST_norm_lsq2x4,
|
||||
int bbs = 32);
|
||||
|
||||
IndexIVFProductLocalSearchQuantizerFastScan();
|
||||
};
|
||||
|
||||
struct IndexIVFProductResidualQuantizerFastScan
|
||||
: IndexIVFAdditiveQuantizerFastScan {
|
||||
ProductResidualQuantizer prq;
|
||||
|
||||
IndexIVFProductResidualQuantizerFastScan(
|
||||
Index* quantizer,
|
||||
size_t d,
|
||||
size_t nlist,
|
||||
size_t nsplits,
|
||||
size_t Msub,
|
||||
size_t nbits,
|
||||
MetricType metric = METRIC_L2,
|
||||
Search_type_t search_type = AdditiveQuantizer::ST_norm_lsq2x4,
|
||||
int bbs = 32);
|
||||
|
||||
IndexIVFProductResidualQuantizerFastScan();
|
||||
};
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -50,26 +50,21 @@ ProductAdditiveQuantizer::ProductAdditiveQuantizer(
|
|||
init(d, aqs, search_type);
|
||||
}
|
||||
|
||||
ProductAdditiveQuantizer::ProductAdditiveQuantizer() {}
|
||||
ProductAdditiveQuantizer::ProductAdditiveQuantizer()
|
||||
: ProductAdditiveQuantizer(0, {}) {}
|
||||
|
||||
void ProductAdditiveQuantizer::init(
|
||||
size_t d,
|
||||
const std::vector<AdditiveQuantizer*>& aqs,
|
||||
Search_type_t search_type) {
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
!aqs.empty(), "At least one additive quantizer is required.");
|
||||
for (size_t i = 0; i < aqs.size(); i++) {
|
||||
const auto& q = aqs[i];
|
||||
FAISS_THROW_IF_NOT(q->d == aqs[0]->d);
|
||||
FAISS_THROW_IF_NOT(q->M == aqs[0]->M);
|
||||
FAISS_THROW_IF_NOT(q->nbits[0] == aqs[0]->nbits[0]);
|
||||
}
|
||||
|
||||
// AdditiveQuantizer constructor
|
||||
this->d = d;
|
||||
this->search_type = search_type;
|
||||
M = aqs.size() * aqs[0]->M;
|
||||
nbits = std::vector<size_t>(M, aqs[0]->nbits[0]);
|
||||
M = 0;
|
||||
for (const auto& q : aqs) {
|
||||
M += q->M;
|
||||
nbits.insert(nbits.end(), q->nbits.begin(), q->nbits.end());
|
||||
}
|
||||
verbose = false;
|
||||
is_trained = false;
|
||||
norm_max = norm_min = NAN;
|
||||
|
@ -139,6 +134,15 @@ void ProductAdditiveQuantizer::train(size_t n, const float* x) {
|
|||
}
|
||||
|
||||
is_trained = true;
|
||||
|
||||
// train norm
|
||||
std::vector<int32_t> codes(n * M);
|
||||
compute_unpacked_codes(x, codes.data(), n);
|
||||
std::vector<float> x_recons(n * d);
|
||||
std::vector<float> norms(n);
|
||||
decode_unpacked(codes.data(), x_recons.data(), n);
|
||||
fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
|
||||
train_norm(n, norms.data());
|
||||
}
|
||||
|
||||
void ProductAdditiveQuantizer::compute_codes_add_centroids(
|
||||
|
@ -148,7 +152,17 @@ void ProductAdditiveQuantizer::compute_codes_add_centroids(
|
|||
const float* centroids) const {
|
||||
// size (n, M)
|
||||
std::vector<int32_t> unpacked_codes(n * M);
|
||||
compute_unpacked_codes(x, unpacked_codes.data(), n, centroids);
|
||||
|
||||
// pack
|
||||
pack_codes(n, unpacked_codes.data(), codes_out, -1, nullptr, centroids);
|
||||
}
|
||||
|
||||
void ProductAdditiveQuantizer::compute_unpacked_codes(
|
||||
const float* x,
|
||||
int32_t* unpacked_codes,
|
||||
size_t n,
|
||||
const float* centroids) const {
|
||||
/// TODO: actuallly we do not need to unpack and pack
|
||||
size_t offset_d = 0, offset_m = 0;
|
||||
std::vector<float> xsub;
|
||||
|
@ -183,9 +197,6 @@ void ProductAdditiveQuantizer::compute_codes_add_centroids(
|
|||
offset_d += q->d;
|
||||
offset_m += q->M;
|
||||
}
|
||||
|
||||
// pack
|
||||
pack_codes(n, unpacked_codes.data(), codes_out, -1, nullptr, centroids);
|
||||
}
|
||||
|
||||
void ProductAdditiveQuantizer::decode_unpacked(
|
||||
|
@ -318,22 +329,26 @@ ProductLocalSearchQuantizer::ProductLocalSearchQuantizer(
|
|||
size_t Msub,
|
||||
size_t nbits,
|
||||
Search_type_t search_type) {
|
||||
FAISS_THROW_IF_NOT(d % nsplits == 0);
|
||||
size_t dsub = d / nsplits;
|
||||
std::vector<AdditiveQuantizer*> aqs;
|
||||
|
||||
for (size_t i = 0; i < nsplits; i++) {
|
||||
auto lsq = new LocalSearchQuantizer(dsub, Msub, nbits, ST_decompress);
|
||||
aqs.push_back(lsq);
|
||||
if (nsplits > 0) {
|
||||
FAISS_THROW_IF_NOT(d % nsplits == 0);
|
||||
size_t dsub = d / nsplits;
|
||||
|
||||
for (size_t i = 0; i < nsplits; i++) {
|
||||
auto lsq =
|
||||
new LocalSearchQuantizer(dsub, Msub, nbits, ST_decompress);
|
||||
aqs.push_back(lsq);
|
||||
}
|
||||
}
|
||||
init(d, aqs, search_type);
|
||||
|
||||
for (auto& q : aqs) {
|
||||
delete q;
|
||||
}
|
||||
}
|
||||
|
||||
ProductLocalSearchQuantizer::ProductLocalSearchQuantizer() {}
|
||||
ProductLocalSearchQuantizer::ProductLocalSearchQuantizer()
|
||||
: ProductLocalSearchQuantizer(0, 0, 0, 0) {}
|
||||
|
||||
/*************************************
|
||||
* Product Residual Quantizer
|
||||
|
@ -345,21 +360,24 @@ ProductResidualQuantizer::ProductResidualQuantizer(
|
|||
size_t Msub,
|
||||
size_t nbits,
|
||||
Search_type_t search_type) {
|
||||
FAISS_THROW_IF_NOT(d % nsplits == 0);
|
||||
size_t dsub = d / nsplits;
|
||||
std::vector<AdditiveQuantizer*> aqs;
|
||||
|
||||
for (size_t i = 0; i < nsplits; i++) {
|
||||
auto rq = new ResidualQuantizer(dsub, Msub, nbits, ST_decompress);
|
||||
aqs.push_back(rq);
|
||||
if (nsplits > 0) {
|
||||
FAISS_THROW_IF_NOT(d % nsplits == 0);
|
||||
size_t dsub = d / nsplits;
|
||||
|
||||
for (size_t i = 0; i < nsplits; i++) {
|
||||
auto rq = new ResidualQuantizer(dsub, Msub, nbits, ST_decompress);
|
||||
aqs.push_back(rq);
|
||||
}
|
||||
}
|
||||
init(d, aqs, search_type);
|
||||
|
||||
for (auto& q : aqs) {
|
||||
delete q;
|
||||
}
|
||||
}
|
||||
|
||||
ProductResidualQuantizer::ProductResidualQuantizer() {}
|
||||
ProductResidualQuantizer::ProductResidualQuantizer()
|
||||
: ProductResidualQuantizer(0, 0, 0, 0) {}
|
||||
|
||||
} // namespace faiss
|
||||
|
|
|
@ -58,11 +58,6 @@ struct ProductAdditiveQuantizer : AdditiveQuantizer {
|
|||
///< Train the product additive quantizer
|
||||
void train(size_t n, const float* x) override;
|
||||
|
||||
void compute_codes(const float* x, uint8_t* codes, size_t n)
|
||||
const override {
|
||||
compute_codes_add_centroids(x, codes, n);
|
||||
}
|
||||
|
||||
/** Encode a set of vectors
|
||||
*
|
||||
* @param x vectors to encode, size n * d
|
||||
|
@ -75,6 +70,12 @@ struct ProductAdditiveQuantizer : AdditiveQuantizer {
|
|||
size_t n,
|
||||
const float* centroids = nullptr) const override;
|
||||
|
||||
void compute_unpacked_codes(
|
||||
const float* x,
|
||||
int32_t* codes,
|
||||
size_t n,
|
||||
const float* centroids = nullptr) const;
|
||||
|
||||
/** Decode a set of vectors in non-packed format
|
||||
*
|
||||
* @param codes codes to decode, size n * ld_codes
|
||||
|
|
|
@ -567,7 +567,12 @@ void ResidualQuantizer::compute_codes_add_centroids(
|
|||
}
|
||||
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
||||
size_t i1 = std::min(n, i0 + bs);
|
||||
compute_codes(x + i0 * d, codes_out + i0 * code_size, i1 - i0);
|
||||
const float* cent = nullptr;
|
||||
if (centroids != nullptr) {
|
||||
cent = centroids + i0 * d;
|
||||
}
|
||||
compute_codes_add_centroids(
|
||||
x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -311,6 +311,37 @@ static void read_LocalSearchQuantizer(LocalSearchQuantizer* lsq, IOReader* f) {
|
|||
READ1(lsq->update_codebooks_with_double);
|
||||
}
|
||||
|
||||
static void read_ProductAdditiveQuantizer(
|
||||
ProductAdditiveQuantizer* paq,
|
||||
IOReader* f) {
|
||||
read_AdditiveQuantizer(paq, f);
|
||||
READ1(paq->nsplits);
|
||||
}
|
||||
|
||||
static void read_ProductResidualQuantizer(
|
||||
ProductResidualQuantizer* prq,
|
||||
IOReader* f) {
|
||||
read_ProductAdditiveQuantizer(prq, f);
|
||||
|
||||
for (size_t i = 0; i < prq->nsplits; i++) {
|
||||
auto rq = new ResidualQuantizer();
|
||||
read_ResidualQuantizer(rq, f);
|
||||
prq->quantizers.push_back(rq);
|
||||
}
|
||||
}
|
||||
|
||||
static void read_ProductLocalSearchQuantizer(
|
||||
ProductLocalSearchQuantizer* plsq,
|
||||
IOReader* f) {
|
||||
read_ProductAdditiveQuantizer(plsq, f);
|
||||
|
||||
for (size_t i = 0; i < plsq->nsplits; i++) {
|
||||
auto lsq = new LocalSearchQuantizer();
|
||||
read_LocalSearchQuantizer(lsq, f);
|
||||
plsq->quantizers.push_back(lsq);
|
||||
}
|
||||
}
|
||||
|
||||
static void read_ScalarQuantizer(ScalarQuantizer* ivsc, IOReader* f) {
|
||||
READ1(ivsc->qtype);
|
||||
READ1(ivsc->rangestat);
|
||||
|
@ -559,6 +590,20 @@ Index* read_index(IOReader* f, int io_flags) {
|
|||
READ1(idxr->code_size);
|
||||
READVECTOR(idxr->codes);
|
||||
idx = idxr;
|
||||
} else if (h == fourcc("IxPR")) {
|
||||
auto idxpr = new IndexProductResidualQuantizer();
|
||||
read_index_header(idxpr, f);
|
||||
read_ProductResidualQuantizer(&idxpr->prq, f);
|
||||
READ1(idxpr->code_size);
|
||||
READVECTOR(idxpr->codes);
|
||||
idx = idxpr;
|
||||
} else if (h == fourcc("IxPL")) {
|
||||
auto idxpl = new IndexProductLocalSearchQuantizer();
|
||||
read_index_header(idxpl, f);
|
||||
read_ProductLocalSearchQuantizer(&idxpl->plsq, f);
|
||||
READ1(idxpl->code_size);
|
||||
READVECTOR(idxpl->codes);
|
||||
idx = idxpl;
|
||||
} else if (h == fourcc("ImRQ")) {
|
||||
ResidualCoarseQuantizer* idxr = new ResidualCoarseQuantizer();
|
||||
read_index_header(idxr, f);
|
||||
|
@ -566,20 +611,35 @@ Index* read_index(IOReader* f, int io_flags) {
|
|||
READ1(idxr->beam_factor);
|
||||
idxr->set_beam_factor(idxr->beam_factor);
|
||||
idx = idxr;
|
||||
} else if (h == fourcc("ILfs") || h == fourcc("IRfs")) {
|
||||
} else if (
|
||||
h == fourcc("ILfs") || h == fourcc("IRfs") || h == fourcc("IPRf") ||
|
||||
h == fourcc("IPLf")) {
|
||||
bool is_LSQ = h == fourcc("ILfs");
|
||||
bool is_RQ = h == fourcc("IRfs");
|
||||
bool is_PLSQ = h == fourcc("IPLf");
|
||||
|
||||
IndexAdditiveQuantizerFastScan* idxaqfs;
|
||||
if (is_LSQ) {
|
||||
idxaqfs = new IndexLocalSearchQuantizerFastScan();
|
||||
} else {
|
||||
} else if (is_RQ) {
|
||||
idxaqfs = new IndexResidualQuantizerFastScan();
|
||||
} else if (is_PLSQ) {
|
||||
idxaqfs = new IndexProductLocalSearchQuantizerFastScan();
|
||||
} else {
|
||||
idxaqfs = new IndexProductResidualQuantizerFastScan();
|
||||
}
|
||||
read_index_header(idxaqfs, f);
|
||||
|
||||
if (is_LSQ) {
|
||||
read_LocalSearchQuantizer((LocalSearchQuantizer*)idxaqfs->aq, f);
|
||||
} else {
|
||||
} else if (is_RQ) {
|
||||
read_ResidualQuantizer((ResidualQuantizer*)idxaqfs->aq, f);
|
||||
} else if (is_PLSQ) {
|
||||
read_ProductLocalSearchQuantizer(
|
||||
(ProductLocalSearchQuantizer*)idxaqfs->aq, f);
|
||||
} else {
|
||||
read_ProductResidualQuantizer(
|
||||
(ProductResidualQuantizer*)idxaqfs->aq, f);
|
||||
}
|
||||
|
||||
READ1(idxaqfs->implem);
|
||||
|
@ -599,20 +659,35 @@ Index* read_index(IOReader* f, int io_flags) {
|
|||
|
||||
READVECTOR(idxaqfs->codes);
|
||||
idx = idxaqfs;
|
||||
} else if (h == fourcc("IVLf") || h == fourcc("IVRf")) {
|
||||
} else if (
|
||||
h == fourcc("IVLf") || h == fourcc("IVRf") || h == fourcc("NPLf") ||
|
||||
h == fourcc("NPRf")) {
|
||||
bool is_LSQ = h == fourcc("IVLf");
|
||||
bool is_RQ = h == fourcc("IVRf");
|
||||
bool is_PLSQ = h == fourcc("NPLf");
|
||||
|
||||
IndexIVFAdditiveQuantizerFastScan* ivaqfs;
|
||||
if (is_LSQ) {
|
||||
ivaqfs = new IndexIVFLocalSearchQuantizerFastScan();
|
||||
} else {
|
||||
} else if (is_RQ) {
|
||||
ivaqfs = new IndexIVFResidualQuantizerFastScan();
|
||||
} else if (is_PLSQ) {
|
||||
ivaqfs = new IndexIVFProductLocalSearchQuantizerFastScan();
|
||||
} else {
|
||||
ivaqfs = new IndexIVFProductResidualQuantizerFastScan();
|
||||
}
|
||||
read_ivf_header(ivaqfs, f);
|
||||
|
||||
if (is_LSQ) {
|
||||
read_LocalSearchQuantizer((LocalSearchQuantizer*)ivaqfs->aq, f);
|
||||
} else {
|
||||
} else if (is_RQ) {
|
||||
read_ResidualQuantizer((ResidualQuantizer*)ivaqfs->aq, f);
|
||||
} else if (is_PLSQ) {
|
||||
read_ProductLocalSearchQuantizer(
|
||||
(ProductLocalSearchQuantizer*)ivaqfs->aq, f);
|
||||
} else {
|
||||
read_ProductResidualQuantizer(
|
||||
(ProductResidualQuantizer*)ivaqfs->aq, f);
|
||||
}
|
||||
|
||||
READ1(ivaqfs->by_residual);
|
||||
|
@ -712,20 +787,34 @@ Index* read_index(IOReader* f, int io_flags) {
|
|||
}
|
||||
read_InvertedLists(ivsc, f, io_flags);
|
||||
idx = ivsc;
|
||||
} else if (h == fourcc("IwLS") || h == fourcc("IwRQ")) {
|
||||
} else if (
|
||||
h == fourcc("IwLS") || h == fourcc("IwRQ") || h == fourcc("IwPL") ||
|
||||
h == fourcc("IwPR")) {
|
||||
bool is_LSQ = h == fourcc("IwLS");
|
||||
bool is_RQ = h == fourcc("IwRQ");
|
||||
bool is_PLSQ = h == fourcc("IwPL");
|
||||
IndexIVFAdditiveQuantizer* iva;
|
||||
if (is_LSQ) {
|
||||
iva = new IndexIVFLocalSearchQuantizer();
|
||||
} else {
|
||||
} else if (is_RQ) {
|
||||
iva = new IndexIVFResidualQuantizer();
|
||||
} else if (is_PLSQ) {
|
||||
iva = new IndexIVFProductLocalSearchQuantizer();
|
||||
} else {
|
||||
iva = new IndexIVFProductResidualQuantizer();
|
||||
}
|
||||
read_ivf_header(iva, f);
|
||||
READ1(iva->code_size);
|
||||
if (is_LSQ) {
|
||||
read_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
|
||||
} else {
|
||||
} else if (is_RQ) {
|
||||
read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
|
||||
} else if (is_PLSQ) {
|
||||
read_ProductLocalSearchQuantizer(
|
||||
(ProductLocalSearchQuantizer*)iva->aq, f);
|
||||
} else {
|
||||
read_ProductResidualQuantizer(
|
||||
(ProductResidualQuantizer*)iva->aq, f);
|
||||
}
|
||||
READ1(iva->by_residual);
|
||||
READ1(iva->use_precomputed_table);
|
||||
|
|
|
@ -207,6 +207,33 @@ static void write_LocalSearchQuantizer(
|
|||
WRITE1(lsq->update_codebooks_with_double);
|
||||
}
|
||||
|
||||
static void write_ProductAdditiveQuantizer(
|
||||
const ProductAdditiveQuantizer* paq,
|
||||
IOWriter* f) {
|
||||
write_AdditiveQuantizer(paq, f);
|
||||
WRITE1(paq->nsplits);
|
||||
}
|
||||
|
||||
static void write_ProductResidualQuantizer(
|
||||
const ProductResidualQuantizer* prq,
|
||||
IOWriter* f) {
|
||||
write_ProductAdditiveQuantizer(prq, f);
|
||||
for (const auto aq : prq->quantizers) {
|
||||
auto rq = dynamic_cast<const ResidualQuantizer*>(aq);
|
||||
write_ResidualQuantizer(rq, f);
|
||||
}
|
||||
}
|
||||
|
||||
static void write_ProductLocalSearchQuantizer(
|
||||
const ProductLocalSearchQuantizer* plsq,
|
||||
IOWriter* f) {
|
||||
write_ProductAdditiveQuantizer(plsq, f);
|
||||
for (const auto aq : plsq->quantizers) {
|
||||
auto lsq = dynamic_cast<const LocalSearchQuantizer*>(aq);
|
||||
write_LocalSearchQuantizer(lsq, f);
|
||||
}
|
||||
}
|
||||
|
||||
static void write_ScalarQuantizer(const ScalarQuantizer* ivsc, IOWriter* f) {
|
||||
WRITE1(ivsc->qtype);
|
||||
WRITE1(ivsc->rangestat);
|
||||
|
@ -394,28 +421,62 @@ void write_index(const Index* idx, IOWriter* f) {
|
|||
write_LocalSearchQuantizer(&idxr->lsq, f);
|
||||
WRITE1(idxr->code_size);
|
||||
WRITEVECTOR(idxr->codes);
|
||||
} else if (
|
||||
const IndexProductResidualQuantizer* idxpr =
|
||||
dynamic_cast<const IndexProductResidualQuantizer*>(idx)) {
|
||||
uint32_t h = fourcc("IxPR");
|
||||
WRITE1(h);
|
||||
write_index_header(idx, f);
|
||||
write_ProductResidualQuantizer(&idxpr->prq, f);
|
||||
WRITE1(idxpr->code_size);
|
||||
WRITEVECTOR(idxpr->codes);
|
||||
} else if (
|
||||
const IndexProductLocalSearchQuantizer* idxpl =
|
||||
dynamic_cast<const IndexProductLocalSearchQuantizer*>(
|
||||
idx)) {
|
||||
uint32_t h = fourcc("IxPL");
|
||||
WRITE1(h);
|
||||
write_index_header(idx, f);
|
||||
write_ProductLocalSearchQuantizer(&idxpl->plsq, f);
|
||||
WRITE1(idxpl->code_size);
|
||||
WRITEVECTOR(idxpl->codes);
|
||||
} else if (
|
||||
auto* idxaqfs =
|
||||
dynamic_cast<const IndexAdditiveQuantizerFastScan*>(idx)) {
|
||||
auto idxlsqfs =
|
||||
dynamic_cast<const IndexLocalSearchQuantizerFastScan*>(idx);
|
||||
auto idxrqfs = dynamic_cast<const IndexResidualQuantizerFastScan*>(idx);
|
||||
FAISS_THROW_IF_NOT(idxlsqfs || idxrqfs);
|
||||
auto idxplsqfs =
|
||||
dynamic_cast<const IndexProductLocalSearchQuantizerFastScan*>(
|
||||
idx);
|
||||
auto idxprqfs =
|
||||
dynamic_cast<const IndexProductResidualQuantizerFastScan*>(idx);
|
||||
FAISS_THROW_IF_NOT(idxlsqfs || idxrqfs || idxplsqfs || idxprqfs);
|
||||
|
||||
if (idxlsqfs) {
|
||||
uint32_t h = fourcc("ILfs");
|
||||
WRITE1(h);
|
||||
} else {
|
||||
} else if (idxrqfs) {
|
||||
uint32_t h = fourcc("IRfs");
|
||||
WRITE1(h);
|
||||
} else if (idxplsqfs) {
|
||||
uint32_t h = fourcc("IPLf");
|
||||
WRITE1(h);
|
||||
} else if (idxprqfs) {
|
||||
uint32_t h = fourcc("IPRf");
|
||||
WRITE1(h);
|
||||
}
|
||||
|
||||
write_index_header(idxaqfs, f);
|
||||
|
||||
if (idxlsqfs) {
|
||||
write_LocalSearchQuantizer(&idxlsqfs->lsq, f);
|
||||
} else {
|
||||
} else if (idxrqfs) {
|
||||
write_ResidualQuantizer(&idxrqfs->rq, f);
|
||||
} else if (idxplsqfs) {
|
||||
write_ProductLocalSearchQuantizer(&idxplsqfs->plsq, f);
|
||||
} else if (idxprqfs) {
|
||||
write_ProductResidualQuantizer(&idxprqfs->prq, f);
|
||||
}
|
||||
WRITE1(idxaqfs->implem);
|
||||
WRITE1(idxaqfs->bbs);
|
||||
|
@ -441,22 +502,37 @@ void write_index(const Index* idx, IOWriter* f) {
|
|||
dynamic_cast<const IndexIVFLocalSearchQuantizerFastScan*>(idx);
|
||||
auto ivrqfs =
|
||||
dynamic_cast<const IndexIVFResidualQuantizerFastScan*>(idx);
|
||||
FAISS_THROW_IF_NOT(ivlsqfs || ivrqfs);
|
||||
auto ivplsqfs = dynamic_cast<
|
||||
const IndexIVFProductLocalSearchQuantizerFastScan*>(idx);
|
||||
auto ivprqfs =
|
||||
dynamic_cast<const IndexIVFProductResidualQuantizerFastScan*>(
|
||||
idx);
|
||||
FAISS_THROW_IF_NOT(ivlsqfs || ivrqfs || ivplsqfs || ivprqfs);
|
||||
|
||||
if (ivlsqfs) {
|
||||
uint32_t h = fourcc("IVLf");
|
||||
WRITE1(h);
|
||||
} else {
|
||||
} else if (ivrqfs) {
|
||||
uint32_t h = fourcc("IVRf");
|
||||
WRITE1(h);
|
||||
} else if (ivplsqfs) {
|
||||
uint32_t h = fourcc("NPLf"); // N means IV ...
|
||||
WRITE1(h);
|
||||
} else {
|
||||
uint32_t h = fourcc("NPRf");
|
||||
WRITE1(h);
|
||||
}
|
||||
|
||||
write_ivf_header(ivaqfs, f);
|
||||
|
||||
if (ivlsqfs) {
|
||||
write_LocalSearchQuantizer(&ivlsqfs->lsq, f);
|
||||
} else {
|
||||
} else if (ivrqfs) {
|
||||
write_ResidualQuantizer(&ivrqfs->rq, f);
|
||||
} else if (ivplsqfs) {
|
||||
write_ProductLocalSearchQuantizer(&ivplsqfs->plsq, f);
|
||||
} else {
|
||||
write_ProductResidualQuantizer(&ivprqfs->prq, f);
|
||||
}
|
||||
|
||||
WRITE1(ivaqfs->by_residual);
|
||||
|
@ -550,14 +626,33 @@ void write_index(const Index* idx, IOWriter* f) {
|
|||
write_InvertedLists(ivsc->invlists, f);
|
||||
} else if (auto iva = dynamic_cast<const IndexIVFAdditiveQuantizer*>(idx)) {
|
||||
bool is_LSQ = dynamic_cast<const IndexIVFLocalSearchQuantizer*>(iva);
|
||||
uint32_t h = fourcc(is_LSQ ? "IwLS" : "IwRQ");
|
||||
bool is_RQ = dynamic_cast<const IndexIVFResidualQuantizer*>(iva);
|
||||
bool is_PLSQ =
|
||||
dynamic_cast<const IndexIVFProductLocalSearchQuantizer*>(iva);
|
||||
uint32_t h;
|
||||
if (is_LSQ) {
|
||||
h = fourcc("IwLS");
|
||||
} else if (is_RQ) {
|
||||
h = fourcc("IwRQ");
|
||||
} else if (is_PLSQ) {
|
||||
h = fourcc("IwPL");
|
||||
} else {
|
||||
h = fourcc("IwPR");
|
||||
}
|
||||
|
||||
WRITE1(h);
|
||||
write_ivf_header(iva, f);
|
||||
WRITE1(iva->code_size);
|
||||
if (is_LSQ) {
|
||||
write_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
|
||||
} else {
|
||||
} else if (is_RQ) {
|
||||
write_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
|
||||
} else if (is_PLSQ) {
|
||||
write_ProductLocalSearchQuantizer(
|
||||
(ProductLocalSearchQuantizer*)iva->aq, f);
|
||||
} else {
|
||||
write_ProductResidualQuantizer(
|
||||
(ProductResidualQuantizer*)iva->aq, f);
|
||||
}
|
||||
WRITE1(iva->by_residual);
|
||||
WRITE1(iva->use_precomputed_table);
|
||||
|
|
|
@ -159,6 +159,8 @@ const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
|
|||
const std::string aq_norm_pattern =
|
||||
"(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4|_Nlsq2x4|_Nrq2x4)";
|
||||
|
||||
const std::string paq_def_pattern = "([0-9]+)x([0-9]+)x([0-9]+)";
|
||||
|
||||
AdditiveQuantizer::Search_type_t aq_parse_search_type(
|
||||
std::string stok,
|
||||
MetricType metric) {
|
||||
|
@ -345,6 +347,21 @@ IndexIVF* parse_IndexIVF(
|
|||
}
|
||||
return index_ivf;
|
||||
}
|
||||
if (match("(PRQ|PLSQ)" + paq_def_pattern + aq_norm_pattern)) {
|
||||
int nsplits = mres_to_int(sm[2]);
|
||||
int Msub = mres_to_int(sm[3]);
|
||||
int nbit = mres_to_int(sm[4]);
|
||||
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
|
||||
IndexIVF* index_ivf;
|
||||
if (sm[1].str() == "PRQ") {
|
||||
index_ivf = new IndexIVFProductResidualQuantizer(
|
||||
get_q(), d, nlist, nsplits, Msub, nbit, mt, st);
|
||||
} else {
|
||||
index_ivf = new IndexIVFProductLocalSearchQuantizer(
|
||||
get_q(), d, nlist, nsplits, Msub, nbit, mt, st);
|
||||
}
|
||||
return index_ivf;
|
||||
}
|
||||
if (match("(RQ|LSQ)([0-9]+)x4fs(r?)(_[0-9]+)?" + aq_norm_pattern)) {
|
||||
int M = std::stoi(sm[2].str());
|
||||
int bbs = mres_to_int(sm[4], 32, 1);
|
||||
|
@ -360,6 +377,23 @@ IndexIVF* parse_IndexIVF(
|
|||
index_ivf->by_residual = (sm[3].str() == "r");
|
||||
return index_ivf;
|
||||
}
|
||||
if (match("(PRQ|PLSQ)([0-9]+)x([0-9]+)x4fs(r?)(_[0-9]+)?" +
|
||||
aq_norm_pattern)) {
|
||||
int nsplits = std::stoi(sm[2].str());
|
||||
int Msub = std::stoi(sm[3].str());
|
||||
int bbs = mres_to_int(sm[5], 32, 1);
|
||||
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
|
||||
IndexIVFAdditiveQuantizerFastScan* index_ivf;
|
||||
if (sm[1].str() == "PRQ") {
|
||||
index_ivf = new IndexIVFProductResidualQuantizerFastScan(
|
||||
get_q(), d, nlist, nsplits, Msub, 4, mt, st, bbs);
|
||||
} else {
|
||||
index_ivf = new IndexIVFProductLocalSearchQuantizerFastScan(
|
||||
get_q(), d, nlist, nsplits, Msub, 4, mt, st, bbs);
|
||||
}
|
||||
index_ivf->by_residual = (sm[4].str() == "r");
|
||||
return index_ivf;
|
||||
}
|
||||
if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
|
||||
int outdim = mres_to_int(sm[2], d); // is also the number of bits
|
||||
std::unique_ptr<VectorTransform> vt;
|
||||
|
@ -550,6 +584,26 @@ Index* parse_other_indexes(
|
|||
return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
|
||||
}
|
||||
|
||||
// IndexProductResidualQuantizer
|
||||
if (match("PRQ" + paq_def_pattern + aq_norm_pattern)) {
|
||||
int nsplits = mres_to_int(sm[1]);
|
||||
int Msub = mres_to_int(sm[2]);
|
||||
int nbit = mres_to_int(sm[3]);
|
||||
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
||||
return new IndexProductResidualQuantizer(
|
||||
d, nsplits, Msub, nbit, metric, st);
|
||||
}
|
||||
|
||||
// IndexProductLocalSearchQuantizer
|
||||
if (match("PLSQ" + paq_def_pattern + aq_norm_pattern)) {
|
||||
int nsplits = mres_to_int(sm[1]);
|
||||
int Msub = mres_to_int(sm[2]);
|
||||
int nbit = mres_to_int(sm[3]);
|
||||
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
||||
return new IndexProductLocalSearchQuantizer(
|
||||
d, nsplits, Msub, nbit, metric, st);
|
||||
}
|
||||
|
||||
// IndexAdditiveQuantizerFastScan
|
||||
// RQ{M}x4fs_{bbs}_{search_type}
|
||||
pattern = "(LSQ|RQ)([0-9]+)x4fs(_[0-9]+)?" + aq_norm_pattern;
|
||||
|
@ -566,6 +620,24 @@ Index* parse_other_indexes(
|
|||
}
|
||||
}
|
||||
|
||||
// IndexProductAdditiveQuantizerFastScan
|
||||
// PRQ{nsplits}x{Msub}x4fs_{bbs}_{search_type}
|
||||
pattern = "(PLSQ|PRQ)([0-9]+)x([0-9]+)x4fs(_[0-9]+)?" + aq_norm_pattern;
|
||||
if (match(pattern)) {
|
||||
int nsplits = std::stoi(sm[2].str());
|
||||
int Msub = std::stoi(sm[3].str());
|
||||
int bbs = mres_to_int(sm[4], 32, 1);
|
||||
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
||||
|
||||
if (sm[1].str() == "PRQ") {
|
||||
return new IndexProductResidualQuantizerFastScan(
|
||||
d, nsplits, Msub, 4, metric, st, bbs);
|
||||
} else if (sm[1].str() == "PLSQ") {
|
||||
return new IndexProductLocalSearchQuantizerFastScan(
|
||||
d, nsplits, Msub, 4, metric, st, bbs);
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -573,8 +573,12 @@ void gpu_sync_all_devices()
|
|||
DOWNCAST ( IndexIVFScalarQuantizer )
|
||||
DOWNCAST ( IndexIVFResidualQuantizer )
|
||||
DOWNCAST ( IndexIVFLocalSearchQuantizer )
|
||||
DOWNCAST ( IndexIVFProductResidualQuantizer )
|
||||
DOWNCAST ( IndexIVFProductLocalSearchQuantizer )
|
||||
DOWNCAST ( IndexIVFResidualQuantizerFastScan )
|
||||
DOWNCAST ( IndexIVFLocalSearchQuantizerFastScan )
|
||||
DOWNCAST ( IndexIVFProductResidualQuantizerFastScan )
|
||||
DOWNCAST ( IndexIVFProductLocalSearchQuantizerFastScan )
|
||||
DOWNCAST ( IndexIVFFlatDedup )
|
||||
DOWNCAST ( IndexIVFFlat )
|
||||
DOWNCAST ( IndexIVF )
|
||||
|
@ -587,8 +591,12 @@ void gpu_sync_all_devices()
|
|||
DOWNCAST ( IndexLocalSearchQuantizer )
|
||||
DOWNCAST ( IndexResidualQuantizerFastScan )
|
||||
DOWNCAST ( IndexLocalSearchQuantizerFastScan )
|
||||
DOWNCAST ( IndexProductResidualQuantizerFastScan )
|
||||
DOWNCAST ( IndexProductLocalSearchQuantizerFastScan )
|
||||
DOWNCAST ( ResidualCoarseQuantizer )
|
||||
DOWNCAST ( LocalSearchCoarseQuantizer )
|
||||
DOWNCAST ( IndexProductResidualQuantizer )
|
||||
DOWNCAST ( IndexProductLocalSearchQuantizer )
|
||||
DOWNCAST ( IndexScalarQuantizer )
|
||||
DOWNCAST ( IndexLSH )
|
||||
DOWNCAST ( IndexLattice )
|
||||
|
|
|
@ -471,7 +471,6 @@ class TestAQFastScan(unittest.TestCase):
|
|||
Compare IndexAdditiveQuantizerFastScan with IndexAQ (qint8)
|
||||
"""
|
||||
d = 16
|
||||
# ds = datasets.SyntheticDataset(d, 1000, 2000, 1000, metric_type)
|
||||
ds = datasets.SyntheticDataset(d, 1000, 1000, 500, metric_type)
|
||||
gt = ds.get_groundtruth(k=1)
|
||||
|
||||
|
@ -632,3 +631,70 @@ def add_TestAQFastScan_subtest_from_idxaq(implem, metric):
|
|||
for implem in 2, 3, 4:
|
||||
add_TestAQFastScan_subtest_from_idxaq(implem, 'L2')
|
||||
add_TestAQFastScan_subtest_from_idxaq(implem, 'IP')
|
||||
|
||||
|
||||
class TestPAQFastScan(unittest.TestCase):
|
||||
|
||||
def subtest_accuracy(self, paq):
|
||||
"""
|
||||
Compare IndexPAQFastScan with IndexPAQ (qint8)
|
||||
"""
|
||||
d = 16
|
||||
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
|
||||
gt = ds.get_groundtruth(k=1)
|
||||
|
||||
index = faiss.index_factory(d, f'{paq}2x3x4_Nqint8')
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
Dref, Iref = index.search(ds.get_queries(), 1)
|
||||
|
||||
indexfs = faiss.index_factory(d, f'{paq}2x3x4fs_Nlsq2x4')
|
||||
indexfs.train(ds.get_train())
|
||||
indexfs.add(ds.get_database())
|
||||
Da, Ia = indexfs.search(ds.get_queries(), 1)
|
||||
|
||||
nq = Iref.shape[0]
|
||||
recall_ref = (Iref == gt).sum() / nq
|
||||
recall = (Ia == gt).sum() / nq
|
||||
|
||||
assert abs(recall_ref - recall) < 0.05
|
||||
|
||||
def test_accuracy_PLSQ(self):
|
||||
self.subtest_accuracy("PLSQ")
|
||||
|
||||
def test_accuracy_PRQ(self):
|
||||
self.subtest_accuracy("PRQ")
|
||||
|
||||
def subtest_factory(self, paq):
|
||||
index = faiss.index_factory(16, f'{paq}2x3x4fs_Nlsq2x4')
|
||||
q = faiss.downcast_Quantizer(index.aq)
|
||||
self.assertEqual(q.nsplits, 2)
|
||||
self.assertEqual(q.subquantizer(0).M, 3)
|
||||
|
||||
def test_factory(self):
|
||||
self.subtest_factory('PRQ')
|
||||
self.subtest_factory('PLSQ')
|
||||
|
||||
def subtest_io(self, factory_str):
|
||||
d = 8
|
||||
ds = datasets.SyntheticDataset(d, 1000, 500, 100)
|
||||
|
||||
index = faiss.index_factory(d, factory_str)
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
D1, I1 = index.search(ds.get_queries(), 1)
|
||||
|
||||
fd, fname = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
try:
|
||||
faiss.write_index(index, fname)
|
||||
index2 = faiss.read_index(fname)
|
||||
D2, I2 = index2.search(ds.get_queries(), 1)
|
||||
np.testing.assert_array_equal(I1, I2)
|
||||
finally:
|
||||
if os.path.exists(fname):
|
||||
os.unlink(fname)
|
||||
|
||||
def test_io(self):
|
||||
self.subtest_io('PLSQ2x3x4fs_Nlsq2x4')
|
||||
self.subtest_io('PRQ2x3x4fs_Nrq2x4')
|
||||
|
|
|
@ -548,7 +548,6 @@ class TestIVFAQFastScan(unittest.TestCase):
|
|||
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
|
||||
gt = ds.get_groundtruth(k=1)
|
||||
|
||||
# if metric_type == 'L2':
|
||||
metric = faiss.METRIC_L2
|
||||
postfix1 = '_Nqint8'
|
||||
postfix2 = f'_N{st}2x4'
|
||||
|
@ -631,18 +630,18 @@ class TestIVFAQFastScan(unittest.TestCase):
|
|||
|
||||
assert index.nlist == nlist
|
||||
assert index.bbs == bbs
|
||||
aq = faiss.downcast_AdditiveQuantizer(index.aq)
|
||||
assert aq.M == M
|
||||
q = faiss.downcast_Quantizer(index.aq)
|
||||
assert q.M == M
|
||||
|
||||
if aq == 'LSQ':
|
||||
assert isinstance(aq, faiss.LocalSearchQuantizer)
|
||||
assert isinstance(q, faiss.LocalSearchQuantizer)
|
||||
if aq == 'RQ':
|
||||
assert isinstance(aq, faiss.ResidualQuantizer)
|
||||
assert isinstance(q, faiss.ResidualQuantizer)
|
||||
|
||||
if st == 'lsq':
|
||||
assert aq.search_type == AQ.ST_norm_lsq2x4
|
||||
assert q.search_type == AQ.ST_norm_lsq2x4
|
||||
if st == 'rq':
|
||||
assert aq.search_type == AQ.ST_norm_rq2x4
|
||||
assert q.search_type == AQ.ST_norm_rq2x4
|
||||
|
||||
assert index.by_residual == (r == 'r')
|
||||
|
||||
|
@ -710,3 +709,78 @@ for byr in True, False:
|
|||
|
||||
add_TestIVFAQFastScan_subtest_rescale_accuracy('LSQ', 'lsq', byr, implem)
|
||||
add_TestIVFAQFastScan_subtest_rescale_accuracy('RQ', 'rq', byr, implem)
|
||||
|
||||
|
||||
class TestIVFPAQFastScan(unittest.TestCase):
|
||||
|
||||
def subtest_accuracy(self, paq):
|
||||
"""
|
||||
Compare IndexIVFAdditiveQuantizerFastScan with
|
||||
IndexIVFAdditiveQuantizer
|
||||
"""
|
||||
nlist, d = 16, 8
|
||||
ds = datasets.SyntheticDataset(d, 1000, 1000, 500)
|
||||
gt = ds.get_groundtruth(k=1)
|
||||
|
||||
index = faiss.index_factory(d, f'IVF{nlist},{paq}2x3x4_Nqint8')
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
index.nprobe = 4
|
||||
Dref, Iref = index.search(ds.get_queries(), 1)
|
||||
|
||||
indexfs = faiss.index_factory(d, f'IVF{nlist},{paq}2x3x4fsr_Nlsq2x4')
|
||||
indexfs.train(ds.get_train())
|
||||
indexfs.add(ds.get_database())
|
||||
indexfs.nprobe = 4
|
||||
D1, I1 = indexfs.search(ds.get_queries(), 1)
|
||||
|
||||
nq = Iref.shape[0]
|
||||
recall_ref = (Iref == gt).sum() / nq
|
||||
recall1 = (I1 == gt).sum() / nq
|
||||
|
||||
print(paq, recall_ref, recall1)
|
||||
assert abs(recall_ref - recall1) < 0.05
|
||||
|
||||
def test_accuracy_PLSQ(self):
|
||||
self.subtest_accuracy("PLSQ")
|
||||
|
||||
def test_accuracy_PRQ(self):
|
||||
self.subtest_accuracy("PRQ")
|
||||
|
||||
def subtest_factory(self, paq):
|
||||
nlist, d = 128, 16
|
||||
index = faiss.index_factory(d, f'IVF{nlist},{paq}2x3x4fsr_Nlsq2x4')
|
||||
q = faiss.downcast_Quantizer(index.aq)
|
||||
|
||||
self.assertEqual(index.nlist, nlist)
|
||||
self.assertEqual(q.nsplits, 2)
|
||||
self.assertEqual(q.subquantizer(0).M, 3)
|
||||
self.assertTrue(index.by_residual)
|
||||
|
||||
def test_factory(self):
|
||||
self.subtest_factory('PLSQ')
|
||||
self.subtest_factory('PRQ')
|
||||
|
||||
def subtest_io(self, factory_str):
|
||||
d = 8
|
||||
ds = datasets.SyntheticDataset(d, 1000, 2000, 1000)
|
||||
|
||||
index = faiss.index_factory(d, factory_str)
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
D1, I1 = index.search(ds.get_queries(), 1)
|
||||
|
||||
fd, fname = tempfile.mkstemp()
|
||||
os.close(fd)
|
||||
try:
|
||||
faiss.write_index(index, fname)
|
||||
index2 = faiss.read_index(fname)
|
||||
D2, I2 = index2.search(ds.get_queries(), 1)
|
||||
np.testing.assert_array_equal(I1, I2)
|
||||
finally:
|
||||
if os.path.exists(fname):
|
||||
os.unlink(fname)
|
||||
|
||||
def test_io(self):
|
||||
self.subtest_io('IVF16,PLSQ2x3x4fsr_Nlsq2x4')
|
||||
self.subtest_io('IVF16,PRQ2x3x4fs_Nrq2x4')
|
||||
|
|
|
@ -573,3 +573,101 @@ class TestProductLocalSearchQuantizer(unittest.TestCase):
|
|||
|
||||
# max rtoal in OSX: 2.87e-6
|
||||
np.testing.assert_allclose(lut, lut_ref, rtol=5e-06)
|
||||
|
||||
|
||||
class TestIndexProductLocalSearchQuantizer(unittest.TestCase):
|
||||
|
||||
def test_accuracy1(self):
|
||||
"""check that the error is in the same ballpark as LSQ."""
|
||||
recall1 = self.eval_index_accuracy("PLSQ4x3x5_Nqint8")
|
||||
recall2 = self.eval_index_accuracy("LSQ12x5_Nqint8")
|
||||
self.assertGreaterEqual(recall1, recall2) # 622 vs 551
|
||||
|
||||
def test_accuracy2(self):
|
||||
"""when nsplits = 1, PLSQ should be almost the same as LSQ"""
|
||||
recall1 = self.eval_index_accuracy("PLSQ1x3x5_Nqint8")
|
||||
recall2 = self.eval_index_accuracy("LSQ3x5_Nqint8")
|
||||
diff = abs(recall1 - recall2) # 273 vs 275 in OSX
|
||||
self.assertGreaterEqual(5, diff)
|
||||
|
||||
def eval_index_accuracy(self, index_key):
|
||||
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
|
||||
index = faiss.index_factory(ds.d, index_key)
|
||||
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
D, I = index.search(ds.get_queries(), 10)
|
||||
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
|
||||
|
||||
# do a little I/O test
|
||||
index2 = faiss.deserialize_index(faiss.serialize_index(index))
|
||||
D2, I2 = index2.search(ds.get_queries(), 10)
|
||||
np.testing.assert_array_equal(I2, I)
|
||||
np.testing.assert_array_equal(D2, D)
|
||||
|
||||
return inter
|
||||
|
||||
def test_factory(self):
|
||||
AQ = faiss.AdditiveQuantizer
|
||||
ns, Msub, nbits = 2, 4, 8
|
||||
index = faiss.index_factory(64, f"PLSQ{ns}x{Msub}x{nbits}_Nqint8")
|
||||
assert isinstance(index, faiss.IndexProductLocalSearchQuantizer)
|
||||
self.assertEqual(index.plsq.nsplits, ns)
|
||||
self.assertEqual(index.plsq.subquantizer(0).M, Msub)
|
||||
self.assertEqual(index.plsq.subquantizer(0).nbits.at(0), nbits)
|
||||
self.assertEqual(index.plsq.search_type, AQ.ST_norm_qint8)
|
||||
|
||||
code_size = (ns * Msub * nbits + 7) // 8 + 1
|
||||
self.assertEqual(index.plsq.code_size, code_size)
|
||||
|
||||
|
||||
class TestIndexIVFProductLocalSearchQuantizer(unittest.TestCase):
|
||||
|
||||
def eval_index_accuracy(self, factory_key):
|
||||
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
|
||||
index = faiss.index_factory(ds.d, factory_key)
|
||||
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
|
||||
inters = []
|
||||
for nprobe in 1, 2, 4, 8, 16:
|
||||
index.nprobe = nprobe
|
||||
D, I = index.search(ds.get_queries(), 10)
|
||||
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
|
||||
inters.append(inter)
|
||||
|
||||
inters = np.array(inters)
|
||||
self.assertTrue(np.all(inters[1:] >= inters[:-1]))
|
||||
|
||||
# do a little I/O test
|
||||
index2 = faiss.deserialize_index(faiss.serialize_index(index))
|
||||
D2, I2 = index2.search(ds.get_queries(), 10)
|
||||
np.testing.assert_array_equal(I2, I)
|
||||
np.testing.assert_array_equal(D2, D)
|
||||
|
||||
return inter
|
||||
|
||||
def test_index_accuracy(self):
|
||||
self.eval_index_accuracy("IVF32,PLSQ2x2x5_Nqint8")
|
||||
|
||||
def test_index_accuracy2(self):
|
||||
"""check that the error is in the same ballpark as LSQ."""
|
||||
inter1 = self.eval_index_accuracy("IVF32,PLSQ2x2x5_Nqint8")
|
||||
inter2 = self.eval_index_accuracy("IVF32,LSQ4x5_Nqint8")
|
||||
# print(inter1, inter2) # 381 vs 374
|
||||
self.assertGreaterEqual(inter1 * 1.1, inter2)
|
||||
|
||||
def test_factory(self):
|
||||
AQ = faiss.AdditiveQuantizer
|
||||
ns, Msub, nbits = 2, 4, 8
|
||||
index = faiss.index_factory(64, f"IVF32,PLSQ{ns}x{Msub}x{nbits}_Nqint8")
|
||||
assert isinstance(index, faiss.IndexIVFProductLocalSearchQuantizer)
|
||||
self.assertEqual(index.nlist, 32)
|
||||
self.assertEqual(index.plsq.nsplits, ns)
|
||||
self.assertEqual(index.plsq.subquantizer(0).M, Msub)
|
||||
self.assertEqual(index.plsq.subquantizer(0).nbits.at(0), nbits)
|
||||
self.assertEqual(index.plsq.search_type, AQ.ST_norm_qint8)
|
||||
|
||||
code_size = (ns * Msub * nbits + 7) // 8 + 1
|
||||
self.assertEqual(index.plsq.code_size, code_size)
|
||||
|
|
|
@ -1177,3 +1177,100 @@ class TestProductResidualQuantizer(unittest.TestCase):
|
|||
|
||||
print(err_prq, err_rq)
|
||||
self.assertEqual(err_prq, err_rq)
|
||||
|
||||
|
||||
class TestIndexProductResidualQuantizer(unittest.TestCase):
|
||||
|
||||
def test_accuracy1(self):
|
||||
"""check that the error is in the same ballpark as RQ."""
|
||||
recall1 = self.eval_index_accuracy("PRQ4x3x5_Nqint8")
|
||||
recall2 = self.eval_index_accuracy("RQ12x5_Nqint8")
|
||||
self.assertGreaterEqual(recall1 * 1.1, recall2) # 657 vs 665
|
||||
|
||||
def test_accuracy2(self):
|
||||
"""when nsplits = 1, PRQ should be the same as RQ"""
|
||||
recall1 = self.eval_index_accuracy("PRQ1x3x5_Nqint8")
|
||||
recall2 = self.eval_index_accuracy("RQ3x5_Nqint8")
|
||||
self.assertEqual(recall1, recall2)
|
||||
|
||||
def eval_index_accuracy(self, index_key):
|
||||
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
|
||||
index = faiss.index_factory(ds.d, index_key)
|
||||
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
D, I = index.search(ds.get_queries(), 10)
|
||||
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
|
||||
|
||||
# do a little I/O test
|
||||
index2 = faiss.deserialize_index(faiss.serialize_index(index))
|
||||
D2, I2 = index2.search(ds.get_queries(), 10)
|
||||
np.testing.assert_array_equal(I2, I)
|
||||
np.testing.assert_array_equal(D2, D)
|
||||
|
||||
return inter
|
||||
|
||||
def test_factory(self):
|
||||
AQ = faiss.AdditiveQuantizer
|
||||
ns, Msub, nbits = 2, 4, 8
|
||||
index = faiss.index_factory(64, f"PRQ{ns}x{Msub}x{nbits}_Nqint8")
|
||||
assert isinstance(index, faiss.IndexProductResidualQuantizer)
|
||||
self.assertEqual(index.prq.nsplits, ns)
|
||||
self.assertEqual(index.prq.subquantizer(0).M, Msub)
|
||||
self.assertEqual(index.prq.subquantizer(0).nbits.at(0), nbits)
|
||||
self.assertEqual(index.prq.search_type, AQ.ST_norm_qint8)
|
||||
|
||||
code_size = (ns * Msub * nbits + 7) // 8 + 1
|
||||
self.assertEqual(index.prq.code_size, code_size)
|
||||
|
||||
|
||||
class TestIndexIVFProductResidualQuantizer(unittest.TestCase):
|
||||
|
||||
def eval_index_accuracy(self, factory_key):
|
||||
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
|
||||
index = faiss.index_factory(ds.d, factory_key)
|
||||
|
||||
index.train(ds.get_train())
|
||||
index.add(ds.get_database())
|
||||
|
||||
inters = []
|
||||
for nprobe in 1, 2, 5, 10, 20, 50:
|
||||
index.nprobe = nprobe
|
||||
D, I = index.search(ds.get_queries(), 10)
|
||||
inter = faiss.eval_intersection(I, ds.get_groundtruth(10))
|
||||
inters.append(inter)
|
||||
|
||||
inters = np.array(inters)
|
||||
self.assertTrue(np.all(inters[1:] >= inters[:-1]))
|
||||
|
||||
# do a little I/O test
|
||||
index2 = faiss.deserialize_index(faiss.serialize_index(index))
|
||||
D2, I2 = index2.search(ds.get_queries(), 10)
|
||||
np.testing.assert_array_equal(I2, I)
|
||||
np.testing.assert_array_equal(D2, D)
|
||||
|
||||
return inter
|
||||
|
||||
def test_index_accuracy(self):
|
||||
self.eval_index_accuracy("IVF100,PRQ2x2x5_Nqint8")
|
||||
|
||||
def test_index_accuracy2(self):
|
||||
"""check that the error is in the same ballpark as RQ."""
|
||||
inter1 = self.eval_index_accuracy("IVF100,PRQ2x2x5_Nqint8")
|
||||
inter2 = self.eval_index_accuracy("IVF100,RQ4x5_Nqint8")
|
||||
# print(inter1, inter2) # 392 vs 374
|
||||
self.assertGreaterEqual(inter1 * 1.1, inter2)
|
||||
|
||||
def test_factory(self):
|
||||
AQ = faiss.AdditiveQuantizer
|
||||
ns, Msub, nbits = 2, 4, 8
|
||||
index = faiss.index_factory(64, f"IVF100,PRQ{ns}x{Msub}x{nbits}_Nqint8")
|
||||
assert isinstance(index, faiss.IndexIVFProductResidualQuantizer)
|
||||
self.assertEqual(index.nlist, 100)
|
||||
self.assertEqual(index.prq.nsplits, ns)
|
||||
self.assertEqual(index.prq.subquantizer(0).M, Msub)
|
||||
self.assertEqual(index.prq.subquantizer(0).nbits.at(0), nbits)
|
||||
self.assertEqual(index.prq.search_type, AQ.ST_norm_qint8)
|
||||
|
||||
code_size = (ns * Msub * nbits + 7) // 8 + 1
|
||||
self.assertEqual(index.prq.code_size, code_size)
|
||||
|
|
Loading…
Reference in New Issue