faiss/tests/test_ivf_index.cpp

245 lines
7.6 KiB
C++
Raw Normal View History

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <omp.h>
#include <algorithm>
#include <cstddef>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <map>
#include <random>
#include <set>
#include <gtest/gtest.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/index_io.h>
namespace {
// stores all ivf lists, used to verify the context
// object is passed to the iterator
class TestContext {
public:
TestContext() {}
void save_code(size_t list_no, const uint8_t* code, size_t code_size) {
list_nos.emplace(id, list_no);
codes.emplace(id, std::vector<uint8_t>(code_size));
for (size_t i = 0; i < code_size; i++) {
codes[id][i] = code[i];
}
id++;
}
// id to codes map
std::unordered_map<faiss::idx_t, std::vector<uint8_t>> codes;
// id to list_no map
std::unordered_map<faiss::idx_t, size_t> list_nos;
faiss::idx_t id = 0;
std::set<size_t> lists_probed;
};
// the iterator that iterates over the codes stored in context object
class TestInvertedListIterator : public faiss::InvertedListsIterator {
public:
TestInvertedListIterator(size_t list_no, TestContext* context)
: list_no{list_no}, context{context} {
it = context->codes.cbegin();
seek_next();
}
~TestInvertedListIterator() override {}
// move the cursor to the first valid entry
void seek_next() {
while (it != context->codes.cend() &&
context->list_nos[it->first] != list_no) {
it++;
}
}
virtual bool is_available() const override {
return it != context->codes.cend();
}
virtual void next() override {
it++;
seek_next();
}
virtual std::pair<faiss::idx_t, const uint8_t*> get_id_and_codes()
override {
if (it == context->codes.cend()) {
FAISS_THROW_MSG("invalid state");
}
return std::make_pair(it->first, it->second.data());
}
private:
size_t list_no;
TestContext* context;
decltype(context->codes.cbegin()) it;
};
class TestInvertedLists : public faiss::InvertedLists {
public:
TestInvertedLists(size_t nlist, size_t code_size)
: faiss::InvertedLists(nlist, code_size) {
use_iterator = true;
}
~TestInvertedLists() override {}
size_t list_size(size_t /*list_no*/) const override {
FAISS_THROW_MSG("unexpected call");
}
faiss::InvertedListsIterator* get_iterator(size_t list_no, void* context)
const override {
auto testContext = (TestContext*)context;
testContext->lists_probed.insert(list_no);
return new TestInvertedListIterator(list_no, testContext);
}
const uint8_t* get_codes(size_t /* list_no */) const override {
FAISS_THROW_MSG("unexpected call");
}
const faiss::idx_t* get_ids(size_t /* list_no */) const override {
FAISS_THROW_MSG("unexpected call");
}
// store the codes in context object
size_t add_entry(
size_t list_no,
faiss::idx_t /*theid*/,
const uint8_t* code,
void* context) override {
auto testContext = (TestContext*)context;
testContext->save_code(list_no, code, code_size);
return 0;
}
size_t add_entries(
size_t /*list_no*/,
size_t /*n_entry*/,
const faiss::idx_t* /*ids*/,
const uint8_t* /*code*/) override {
FAISS_THROW_MSG("unexpected call");
}
void update_entries(
size_t /*list_no*/,
size_t /*offset*/,
size_t /*n_entry*/,
const faiss::idx_t* /*ids*/,
const uint8_t* /*code*/) override {
FAISS_THROW_MSG("unexpected call");
}
void resize(size_t /*list_no*/, size_t /*new_size*/) override {
FAISS_THROW_MSG("unexpected call");
}
};
} // namespace
TEST(IVF, list_context) {
// this test verifies that the context object is passed
// to the InvertedListsIterator and InvertedLists::add_entry.
// the test InvertedLists and InvertedListsIterator reads/writes
// to the test context object.
// the test verifies the context object is modified as expected.
constexpr int d = 32; // dimension
constexpr int nb = 100000; // database size
constexpr int nlist = 100;
std::mt19937 rng;
std::uniform_real_distribution<> distrib;
// disable parallism, or we need to make Context object
// thread-safe
omp_set_num_threads(1);
faiss::IndexFlatL2 quantizer(d); // the other index
faiss::IndexIVFFlat index(&quantizer, d, nlist);
TestInvertedLists inverted_lists(nlist, index.code_size);
index.replace_invlists(&inverted_lists);
{
// training
constexpr size_t nt = 1500; // nb of training vectors
std::vector<float> trainvecs(nt * d);
for (size_t i = 0; i < nt * d; i++) {
trainvecs[i] = distrib(rng);
}
index.verbose = true;
index.train(nt, trainvecs.data());
}
TestContext context;
std::vector<float> query_vector;
constexpr faiss::idx_t query_vector_id = 100;
{
// populating the database
std::vector<float> database(nb * d);
for (size_t i = 0; i < nb * d; i++) {
database[i] = distrib(rng);
// populate the query vector
if (i >= query_vector_id * d && i < query_vector_id * d + d) {
query_vector.push_back(database[i]);
}
}
std::vector<faiss::idx_t> coarse_idx(nb);
index.quantizer->assign(nb, database.data(), coarse_idx.data());
// pass dummy ids, the acutal ids are assigned in TextContext object
std::vector<faiss::idx_t> xids(nb, 42);
index.add_core(
nb, database.data(), xids.data(), coarse_idx.data(), &context);
// check the context object get updated
EXPECT_EQ(nb, context.id) << "should have added all ids";
EXPECT_EQ(nb, context.codes.size())
<< "should have correct number of codes";
EXPECT_EQ(nb, context.list_nos.size())
<< "should have correct number of list numbers";
}
{
constexpr faiss::idx_t k = 100;
constexpr size_t nprobe = 10;
std::vector<float> distances(k);
std::vector<faiss::idx_t> labels(k);
faiss::SearchParametersIVF params;
params.inverted_list_context = &context;
params.nprobe = nprobe;
index.search(
1,
query_vector.data(),
k,
distances.data(),
labels.data(),
&params);
EXPECT_EQ(nprobe, context.lists_probed.size())
<< "should probe nprobe lists";
// check the result contains the query vector, the probablity of
// this fail should be low
auto query_vector_listno = context.list_nos[query_vector_id];
auto& lists_probed = context.lists_probed;
EXPECT_TRUE(
std::find(
lists_probed.cbegin(),
lists_probed.cend(),
query_vector_listno) != lists_probed.cend())
<< "should probe the list of the query vector";
EXPECT_TRUE(
std::find(labels.cbegin(), labels.cend(), query_vector_id) !=
labels.cend())
<< "should return the query vector";
}
}