faiss/tests/test_dealloc_invlists.cpp

183 lines
4.2 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <vector>
#include <gtest/gtest.h>
#include <faiss/IndexIVF.h>
#include <faiss/AutoTune.h>
#include <faiss/index_io.h>
#include <faiss/IVFlib.h>
using namespace faiss;
namespace {
typedef Index::idx_t idx_t;
// dimension of the vectors to index
int d = 32;
// nb of training vectors
size_t nt = 5000;
// size of the database points per window step
size_t nb = 1000;
// nb of queries
size_t nq = 200;
std::vector<float> make_data(size_t n)
{
std::vector <float> database (n * d);
for (size_t i = 0; i < n * d; i++) {
database[i] = drand48();
}
return database;
}
std::unique_ptr<Index> make_trained_index(const char *index_type)
{
auto index = std::unique_ptr<Index>(index_factory(d, index_type));
auto xt = make_data(nt * d);
index->train(nt, xt.data());
ParameterSpace().set_index_parameter (index.get(), "nprobe", 4);
return index;
}
std::vector<idx_t> search_index(Index *index, const float *xq) {
int k = 10;
std::vector<idx_t> I(k * nq);
std::vector<float> D(k * nq);
index->search (nq, xq, k, D.data(), I.data());
return I;
}
/*************************************************************
* Test functions for a given index type
*************************************************************/
struct EncapsulateInvertedLists: InvertedLists {
const InvertedLists *il;
EncapsulateInvertedLists(const InvertedLists *il):
InvertedLists(il->nlist, il->code_size),
il(il)
{}
static void * memdup (const void *m, size_t size) {
if (size == 0) return nullptr;
return memcpy (malloc(size), m, size);
}
size_t list_size(size_t list_no) const override {
return il->list_size (list_no);
}
const uint8_t * get_codes (size_t list_no) const override {
return (uint8_t*)memdup (il->get_codes(list_no),
list_size(list_no) * code_size);
}
const idx_t * get_ids (size_t list_no) const override {
return (idx_t*)memdup (il->get_ids(list_no),
list_size(list_no) * sizeof(idx_t));
}
void release_codes (size_t, const uint8_t *codes) const override {
free ((void*)codes);
}
void release_ids (size_t, const idx_t *ids) const override {
free ((void*)ids);
}
const uint8_t * get_single_code (size_t list_no, size_t offset)
const override {
return (uint8_t*)memdup (il->get_single_code(list_no, offset),
code_size);
}
size_t add_entries(size_t, size_t, const idx_t*, const uint8_t*) override {
assert(!"not implemented");
return 0;
}
void update_entries(size_t, size_t, size_t, const idx_t*, const uint8_t*)
override {
assert(!"not implemented");
}
void resize(size_t, size_t) override {
assert(!"not implemented");
}
~EncapsulateInvertedLists() override {}
};
int test_dealloc_invlists (const char *index_key) {
std::unique_ptr<Index> index = make_trained_index(index_key);
IndexIVF * index_ivf = ivflib::extract_index_ivf (index.get());
auto xb = make_data (nb * d);
index->add(nb, xb.data());
auto xq = make_data (nq * d);
auto ref_res = search_index (index.get(), xq.data());
EncapsulateInvertedLists eil(index_ivf->invlists);
index_ivf->own_invlists = false;
index_ivf->replace_invlists (&eil, false);
// TEST: this could crash or leak mem
auto new_res = search_index (index.get(), xq.data());
// delete explicitly
delete eil.il;
// just to make sure
EXPECT_EQ (ref_res, new_res);
return 0;
}
} // anonymous namespace
/*************************************************************
* Test entry points
*************************************************************/
TEST(TestIvlistDealloc, IVFFlat) {
test_dealloc_invlists ("IVF32,Flat");
}
TEST(TestIvlistDealloc, IVFSQ) {
test_dealloc_invlists ("IVF32,SQ8");
}
TEST(TestIvlistDealloc, IVFPQ) {
test_dealloc_invlists ("IVF32,PQ4np");
}