add skip_storage flag to HNSW (#3487)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3487 Sometimes it is not useful to serialize the storage index along with a HNSW index. This diff adds a flag that supports skipping the storage of the index. Searchign and adding to the index is not possible until a storage index is added back in. Reviewed By: junjieqi Differential Revision: D57911060 fbshipit-source-id: 5a4ceee4a8f53f6f746df59af3942b813a99c14fpull/3492/head
parent
22304340d2
commit
bf73e38d10
|
@ -5,8 +5,6 @@
|
||||||
* LICENSE file in the root directory of this source tree.
|
* LICENSE file in the root directory of this source tree.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// -*- c++ -*-
|
|
||||||
|
|
||||||
#include <faiss/IndexHNSW.h>
|
#include <faiss/IndexHNSW.h>
|
||||||
|
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
|
@ -251,7 +249,8 @@ void hnsw_search(
|
||||||
const SearchParameters* params_in) {
|
const SearchParameters* params_in) {
|
||||||
FAISS_THROW_IF_NOT_MSG(
|
FAISS_THROW_IF_NOT_MSG(
|
||||||
index->storage,
|
index->storage,
|
||||||
"Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
|
"No storage index, please use IndexHNSWFlat (or variants) "
|
||||||
|
"instead of IndexHNSW directly");
|
||||||
const SearchParametersHNSW* params = nullptr;
|
const SearchParametersHNSW* params = nullptr;
|
||||||
const HNSW& hnsw = index->hnsw;
|
const HNSW& hnsw = index->hnsw;
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,6 @@
|
||||||
* LICENSE file in the root directory of this source tree.
|
* LICENSE file in the root directory of this source tree.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// -*- c++ -*-
|
|
||||||
|
|
||||||
#include <faiss/index_io.h>
|
#include <faiss/index_io.h>
|
||||||
|
|
||||||
#include <faiss/impl/io_macros.h>
|
#include <faiss/impl/io_macros.h>
|
||||||
|
@ -531,7 +529,11 @@ Index* read_index(IOReader* f, int io_flags) {
|
||||||
Index* idx = nullptr;
|
Index* idx = nullptr;
|
||||||
uint32_t h;
|
uint32_t h;
|
||||||
READ1(h);
|
READ1(h);
|
||||||
if (h == fourcc("IxFI") || h == fourcc("IxF2") || h == fourcc("IxFl")) {
|
if (h == fourcc("null")) {
|
||||||
|
// denotes a missing index, useful for some cases
|
||||||
|
return nullptr;
|
||||||
|
} else if (
|
||||||
|
h == fourcc("IxFI") || h == fourcc("IxF2") || h == fourcc("IxFl")) {
|
||||||
IndexFlat* idxf;
|
IndexFlat* idxf;
|
||||||
if (h == fourcc("IxFI")) {
|
if (h == fourcc("IxFI")) {
|
||||||
idxf = new IndexFlatIP();
|
idxf = new IndexFlatIP();
|
||||||
|
@ -961,7 +963,7 @@ Index* read_index(IOReader* f, int io_flags) {
|
||||||
read_index_header(idxhnsw, f);
|
read_index_header(idxhnsw, f);
|
||||||
read_HNSW(&idxhnsw->hnsw, f);
|
read_HNSW(&idxhnsw->hnsw, f);
|
||||||
idxhnsw->storage = read_index(f, io_flags);
|
idxhnsw->storage = read_index(f, io_flags);
|
||||||
idxhnsw->own_fields = true;
|
idxhnsw->own_fields = idxhnsw->storage != nullptr;
|
||||||
if (h == fourcc("IHNp") && !(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) {
|
if (h == fourcc("IHNp") && !(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) {
|
||||||
dynamic_cast<IndexPQ*>(idxhnsw->storage)->pq.compute_sdc_table();
|
dynamic_cast<IndexPQ*>(idxhnsw->storage)->pq.compute_sdc_table();
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,8 +5,6 @@
|
||||||
* LICENSE file in the root directory of this source tree.
|
* LICENSE file in the root directory of this source tree.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// -*- c++ -*-
|
|
||||||
|
|
||||||
#include <faiss/index_io.h>
|
#include <faiss/index_io.h>
|
||||||
|
|
||||||
#include <faiss/impl/io.h>
|
#include <faiss/impl/io.h>
|
||||||
|
@ -390,8 +388,12 @@ static void write_ivf_header(const IndexIVF* ivf, IOWriter* f) {
|
||||||
write_direct_map(&ivf->direct_map, f);
|
write_direct_map(&ivf->direct_map, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_index(const Index* idx, IOWriter* f) {
|
void write_index(const Index* idx, IOWriter* f, int io_flags) {
|
||||||
if (const IndexFlat* idxf = dynamic_cast<const IndexFlat*>(idx)) {
|
if (idx == nullptr) {
|
||||||
|
// eg. for a storage component of HNSW that is set to nullptr
|
||||||
|
uint32_t h = fourcc("null");
|
||||||
|
WRITE1(h);
|
||||||
|
} else if (const IndexFlat* idxf = dynamic_cast<const IndexFlat*>(idx)) {
|
||||||
uint32_t h =
|
uint32_t h =
|
||||||
fourcc(idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI"
|
fourcc(idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI"
|
||||||
: idxf->metric_type == METRIC_L2 ? "IxF2"
|
: idxf->metric_type == METRIC_L2 ? "IxF2"
|
||||||
|
@ -765,7 +767,12 @@ void write_index(const Index* idx, IOWriter* f) {
|
||||||
WRITE1(h);
|
WRITE1(h);
|
||||||
write_index_header(idxhnsw, f);
|
write_index_header(idxhnsw, f);
|
||||||
write_HNSW(&idxhnsw->hnsw, f);
|
write_HNSW(&idxhnsw->hnsw, f);
|
||||||
write_index(idxhnsw->storage, f);
|
if (io_flags & IO_FLAG_SKIP_STORAGE) {
|
||||||
|
uint32_t n4 = fourcc("null");
|
||||||
|
WRITE1(n4);
|
||||||
|
} else {
|
||||||
|
write_index(idxhnsw->storage, f);
|
||||||
|
}
|
||||||
} else if (const IndexNSG* idxnsg = dynamic_cast<const IndexNSG*>(idx)) {
|
} else if (const IndexNSG* idxnsg = dynamic_cast<const IndexNSG*>(idx)) {
|
||||||
uint32_t h = dynamic_cast<const IndexNSGFlat*>(idx) ? fourcc("INSf")
|
uint32_t h = dynamic_cast<const IndexNSGFlat*>(idx) ? fourcc("INSf")
|
||||||
: dynamic_cast<const IndexNSGPQ*>(idx) ? fourcc("INSp")
|
: dynamic_cast<const IndexNSGPQ*>(idx) ? fourcc("INSp")
|
||||||
|
@ -841,14 +848,14 @@ void write_index(const Index* idx, IOWriter* f) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_index(const Index* idx, FILE* f) {
|
void write_index(const Index* idx, FILE* f, int io_flags) {
|
||||||
FileIOWriter writer(f);
|
FileIOWriter writer(f);
|
||||||
write_index(idx, &writer);
|
write_index(idx, &writer, io_flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_index(const Index* idx, const char* fname) {
|
void write_index(const Index* idx, const char* fname, int io_flags) {
|
||||||
FileIOWriter writer(fname);
|
FileIOWriter writer(fname);
|
||||||
write_index(idx, &writer);
|
write_index(idx, &writer, io_flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_VectorTransform(const VectorTransform* vt, const char* fname) {
|
void write_VectorTransform(const VectorTransform* vt, const char* fname) {
|
||||||
|
|
|
@ -5,8 +5,6 @@
|
||||||
* LICENSE file in the root directory of this source tree.
|
* LICENSE file in the root directory of this source tree.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// -*- c++ -*-
|
|
||||||
|
|
||||||
// I/O code for indexes
|
// I/O code for indexes
|
||||||
|
|
||||||
#ifndef FAISS_INDEX_IO_H
|
#ifndef FAISS_INDEX_IO_H
|
||||||
|
@ -35,9 +33,12 @@ struct IOReader;
|
||||||
struct IOWriter;
|
struct IOWriter;
|
||||||
struct InvertedLists;
|
struct InvertedLists;
|
||||||
|
|
||||||
void write_index(const Index* idx, const char* fname);
|
/// skip the storage for graph-based indexes
|
||||||
void write_index(const Index* idx, FILE* f);
|
const int IO_FLAG_SKIP_STORAGE = 1;
|
||||||
void write_index(const Index* idx, IOWriter* writer);
|
|
||||||
|
void write_index(const Index* idx, const char* fname, int io_flags = 0);
|
||||||
|
void write_index(const Index* idx, FILE* f, int io_flags = 0);
|
||||||
|
void write_index(const Index* idx, IOWriter* writer, int io_flags = 0);
|
||||||
|
|
||||||
void write_index_binary(const IndexBinary* idx, const char* fname);
|
void write_index_binary(const IndexBinary* idx, const char* fname);
|
||||||
void write_index_binary(const IndexBinary* idx, FILE* f);
|
void write_index_binary(const IndexBinary* idx, FILE* f);
|
||||||
|
|
|
@ -292,10 +292,10 @@ IVFSearchParameters = SearchParametersIVF
|
||||||
###########################################
|
###########################################
|
||||||
|
|
||||||
|
|
||||||
def serialize_index(index):
|
def serialize_index(index, io_flags=0):
|
||||||
""" convert an index to a numpy uint8 array """
|
""" convert an index to a numpy uint8 array """
|
||||||
writer = VectorIOWriter()
|
writer = VectorIOWriter()
|
||||||
write_index(index, writer)
|
write_index(index, writer, io_flags)
|
||||||
return vector_to_array(writer.data)
|
return vector_to_array(writer.data)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -133,6 +133,42 @@ class TestHNSW(unittest.TestCase):
|
||||||
Dhnsw, Ihnsw = index.search(self.xq, 1)
|
Dhnsw, Ihnsw = index.search(self.xq, 1)
|
||||||
self.assertGreater(stats.ndis, len(self.xq) * index.hnsw.efSearch)
|
self.assertGreater(stats.ndis, len(self.xq) * index.hnsw.efSearch)
|
||||||
|
|
||||||
|
def test_io_no_storage(self):
|
||||||
|
d = self.xq.shape[1]
|
||||||
|
index = faiss.IndexHNSWFlat(d, 16)
|
||||||
|
index.add(self.xb)
|
||||||
|
|
||||||
|
Dref, Iref = index.search(self.xq, 5)
|
||||||
|
|
||||||
|
# test writing without storage
|
||||||
|
index2 = faiss.deserialize_index(
|
||||||
|
faiss.serialize_index(index, faiss.IO_FLAG_SKIP_STORAGE)
|
||||||
|
)
|
||||||
|
self.assertEquals(index2.storage, None)
|
||||||
|
self.assertRaises(
|
||||||
|
RuntimeError,
|
||||||
|
index2.search, self.xb, 1)
|
||||||
|
|
||||||
|
# make sure we can store an index with empty storage
|
||||||
|
index4 = faiss.deserialize_index(
|
||||||
|
faiss.serialize_index(index2))
|
||||||
|
|
||||||
|
# add storage afterwards
|
||||||
|
index.storage = faiss.clone_index(index.storage)
|
||||||
|
index.own_fields = True
|
||||||
|
|
||||||
|
Dnew, Inew = index.search(self.xq, 5)
|
||||||
|
np.testing.assert_array_equal(Dnew, Dref)
|
||||||
|
np.testing.assert_array_equal(Inew, Iref)
|
||||||
|
|
||||||
|
if False:
|
||||||
|
# test reading without storage
|
||||||
|
# not implemented because it is hard to skip over an index
|
||||||
|
index3 = faiss.deserialize_index(
|
||||||
|
faiss.serialize_index(index), faiss.IO_FLAG_SKIP_STORAGE
|
||||||
|
)
|
||||||
|
self.assertEquals(index3.storage, None)
|
||||||
|
|
||||||
|
|
||||||
class TestNSG(unittest.TestCase):
|
class TestNSG(unittest.TestCase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue