/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include using namespace faiss; namespace { typedef Index::idx_t idx_t; // dimension of the vectors to index int d = 32; // nb of training vectors size_t nt = 5000; // size of the database points per window step size_t nb = 1000; // nb of queries size_t nq = 200; std::mt19937 rng; std::vector make_data(size_t n) { std::vector database (n * d); std::uniform_real_distribution<> distrib; for (size_t i = 0; i < n * d; i++) { database[i] = distrib(rng); } return database; } std::unique_ptr make_trained_index(const char *index_type) { auto index = std::unique_ptr(index_factory(d, index_type)); auto xt = make_data(nt * d); index->train(nt, xt.data()); ParameterSpace().set_index_parameter (index.get(), "nprobe", 4); return index; } std::vector search_index(Index *index, const float *xq) { int k = 10; std::vector I(k * nq); std::vector D(k * nq); index->search (nq, xq, k, D.data(), I.data()); return I; } /************************************************************* * Test functions for a given index type *************************************************************/ struct EncapsulateInvertedLists: InvertedLists { const InvertedLists *il; EncapsulateInvertedLists(const InvertedLists *il): InvertedLists(il->nlist, il->code_size), il(il) {} static void * memdup (const void *m, size_t size) { if (size == 0) return nullptr; return memcpy (malloc(size), m, size); } size_t list_size(size_t list_no) const override { return il->list_size (list_no); } const uint8_t * get_codes (size_t list_no) const override { return (uint8_t*)memdup (il->get_codes(list_no), list_size(list_no) * code_size); } const idx_t * get_ids (size_t list_no) const override { return (idx_t*)memdup (il->get_ids(list_no), list_size(list_no) * sizeof(idx_t)); } void release_codes (size_t, const uint8_t *codes) const override { free ((void*)codes); } void release_ids (size_t, const idx_t *ids) const override { free ((void*)ids); } const uint8_t * get_single_code (size_t list_no, size_t offset) const override { return (uint8_t*)memdup (il->get_single_code(list_no, offset), code_size); } size_t add_entries(size_t, size_t, const idx_t*, const uint8_t*) override { assert(!"not implemented"); return 0; } void update_entries(size_t, size_t, size_t, const idx_t*, const uint8_t*) override { assert(!"not implemented"); } void resize(size_t, size_t) override { assert(!"not implemented"); } ~EncapsulateInvertedLists() override {} }; int test_dealloc_invlists (const char *index_key) { std::unique_ptr index = make_trained_index(index_key); IndexIVF * index_ivf = ivflib::extract_index_ivf (index.get()); auto xb = make_data (nb * d); index->add(nb, xb.data()); auto xq = make_data (nq * d); auto ref_res = search_index (index.get(), xq.data()); EncapsulateInvertedLists eil(index_ivf->invlists); index_ivf->own_invlists = false; index_ivf->replace_invlists (&eil, false); // TEST: this could crash or leak mem auto new_res = search_index (index.get(), xq.data()); // delete explicitly delete eil.il; // just to make sure EXPECT_EQ (ref_res, new_res); return 0; } } // anonymous namespace /************************************************************* * Test entry points *************************************************************/ TEST(TestIvlistDealloc, IVFFlat) { test_dealloc_invlists ("IVF32,Flat"); } TEST(TestIvlistDealloc, IVFSQ) { test_dealloc_invlists ("IVF32,SQ8"); } TEST(TestIvlistDealloc, IVFPQ) { test_dealloc_invlists ("IVF32,PQ4np"); }