15 #include <gtest/gtest.h>
17 #include <faiss/IndexIVF.h>
18 #include <faiss/IndexBinaryIVF.h>
19 #include <faiss/AutoTune.h>
20 #include <faiss/IVFlib.h>
22 using namespace faiss;
40 std::vector<float> make_data(
size_t n)
42 std::vector <float> database (n * d);
43 for (
size_t i = 0; i < n * d; i++) {
44 database[i] = drand48();
49 std::unique_ptr<Index> make_index(
const char *index_type,
51 const std::vector<float> & x)
53 std::unique_ptr<Index> index(
index_factory(d, index_type, metric));
54 index->
train(nb, x.data());
55 index->
add(nb, x.data());
59 std::vector<idx_t> search_index(
Index *index,
const float *xq) {
61 std::vector<idx_t> I(k * nq);
62 std::vector<float> D(k * nq);
63 index->
search (nq, xq, k, D.data(), I.data());
67 std::vector<idx_t> search_index_with_params(
70 std::vector<idx_t> I(k * nq);
71 std::vector<float> D(k * nq);
72 ivflib::search_with_parameters (index, nq, xq, k,
73 D.data(), I.data(), params);
84 int test_params_override (
const char *index_key,
MetricType metric) {
85 std::vector<float> xb = make_data(nb);
86 auto index = make_index(index_key, metric, xb);
89 std::vector<float> xq = make_data(nq);
92 auto res2ref = search_index(index.get(), xq.data());
94 auto res9ref = search_index(index.get(), xq.data());
100 auto res2new = search_index_with_params(index.get(), xq.data(), ¶ms);
102 auto res9new = search_index_with_params(index.get(), xq.data(), ¶ms);
104 if (res2ref != res2new)
107 if (res9ref != res9new)
122 int err1 = test_params_override (
"IVF32,Flat", METRIC_L2);
124 int err2 = test_params_override (
"IVF32,Flat", METRIC_INNER_PRODUCT);
129 int err1 = test_params_override (
"IVF32,PQ8np", METRIC_L2);
131 int err2 = test_params_override (
"IVF32,PQ8np", METRIC_INNER_PRODUCT);
136 int err1 = test_params_override (
"IVF32,SQ8", METRIC_L2);
138 int err2 = test_params_override (
"IVF32,SQ8", METRIC_INNER_PRODUCT);
142 TEST(TPO, IVFFlatPP) {
143 int err1 = test_params_override (
"PCA16,IVF32,SQ8", METRIC_L2);
145 int err2 = test_params_override (
"PCA16,IVF32,SQ8", METRIC_INNER_PRODUCT);
156 std::vector<uint8_t> make_data_binary(
size_t n) {
157 std::vector <uint8_t> database (n * d / 8);
158 for (
size_t i = 0; i < n * d / 8; i++) {
159 database[i] = lrand48();
164 std::unique_ptr<IndexBinaryIVF> make_index(
const char *index_type,
165 const std::vector<uint8_t> & x)
168 auto index = std::unique_ptr<IndexBinaryIVF>
169 (
dynamic_cast<IndexBinaryIVF*
>(index_binary_factory (d, index_type)));
170 index->
train(nb, x.data());
171 index->
add(nb, x.data());
175 std::vector<idx_t> search_index(
IndexBinaryIVF *index,
const uint8_t *xq) {
177 std::vector<idx_t> I(k * nq);
178 std::vector<int32_t> D(k * nq);
179 index->
search (nq, xq, k, D.data(), I.data());
183 std::vector<idx_t> search_index_with_params(
186 std::vector<idx_t> I(k * nq);
187 std::vector<int32_t> D(k * nq);
189 std::vector<idx_t> Iq(params->
nprobe * nq);
190 std::vector<int32_t> Dq(params->
nprobe * nq);
193 Dq.data(), Iq.data());
200 int test_params_override_binary (
const char *index_key) {
201 std::vector<uint8_t> xb = make_data_binary(nb);
202 auto index = make_index (index_key, xb);
203 index->
train(nb, xb.data());
204 index->
add(nb, xb.data());
205 std::vector<uint8_t> xq = make_data_binary(nq);
207 auto res2ref = search_index(index.get(), xq.data());
209 auto res9ref = search_index(index.get(), xq.data());
215 auto res2new = search_index_with_params(index.get(), xq.data(), ¶ms);
217 auto res9new = search_index_with_params(index.get(), xq.data(), ¶ms);
219 if (res2ref != res2new)
222 if (res9ref != res9new)
229 int err1 = test_params_override_binary (
"BIVF32");
void train(idx_t n, const float *x) override
virtual void search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels) const =0
size_t nprobe
number of probes at query time
void search_preassigned(idx_t n, const uint8_t *x, idx_t k, const idx_t *assign, const int32_t *centroid_dis, int32_t *distances, idx_t *labels, bool store_pairs, const IVFSearchParameters *params=nullptr) const
virtual void train(idx_t n, const float *x)
IndexBinary * quantizer
quantizer that maps vectors to inverted lists
void train(idx_t n, const uint8_t *x) override
Trains the quantizer and calls train_residual to train sub-quantizers.
virtual void add(idx_t n, const float *x)=0
void add(idx_t n, const float *x) override
supported only for sub-indices that implement add_with_ids
long idx_t
all indices are this type
virtual void search(idx_t n, const float *x, idx_t k, float *distances, idx_t *labels) const =0
size_t nprobe
number of probes at query time
void add(idx_t n, const uint8_t *x) override
Quantizes x and calls add_with_key.
virtual void set_index_parameter(Index *index, const std::string &name, double val) const
set one of the parameters
virtual void search(idx_t n, const uint8_t *x, idx_t k, int32_t *distances, idx_t *labels) const override
size_t max_codes
max nb of codes to visit to do a query
Index * index_factory(int d, const char *description_in, MetricType metric)
MetricType
Some algorithms support both an inner product version and a L2 search version.