From bf73e38d10ae6818d7e5d7250a55bb0c9944a9ef Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 31 May 2024 14:48:13 -0700 Subject: [PATCH] add skip_storage flag to HNSW (#3487) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3487 Sometimes it is not useful to serialize the storage index along with a HNSW index. This diff adds a flag that supports skipping the storage of the index. Searchign and adding to the index is not possible until a storage index is added back in. Reviewed By: junjieqi Differential Revision: D57911060 fbshipit-source-id: 5a4ceee4a8f53f6f746df59af3942b813a99c14f --- faiss/IndexHNSW.cpp | 5 ++--- faiss/impl/index_read.cpp | 10 ++++++---- faiss/impl/index_write.cpp | 25 ++++++++++++++++--------- faiss/index_io.h | 11 ++++++----- faiss/python/__init__.py | 4 ++-- tests/test_graph_based.py | 36 ++++++++++++++++++++++++++++++++++++ 6 files changed, 68 insertions(+), 23 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 068691721..94798c1b4 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #include #include @@ -251,7 +249,8 @@ void hnsw_search( const SearchParameters* params_in) { FAISS_THROW_IF_NOT_MSG( index->storage, - "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly"); + "No storage index, please use IndexHNSWFlat (or variants) " + "instead of IndexHNSW directly"); const SearchParametersHNSW* params = nullptr; const HNSW& hnsw = index->hnsw; diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index 8d80329bf..ce4b1e76b 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #include #include @@ -531,7 +529,11 @@ Index* read_index(IOReader* f, int io_flags) { Index* idx = nullptr; uint32_t h; READ1(h); - if (h == fourcc("IxFI") || h == fourcc("IxF2") || h == fourcc("IxFl")) { + if (h == fourcc("null")) { + // denotes a missing index, useful for some cases + return nullptr; + } else if ( + h == fourcc("IxFI") || h == fourcc("IxF2") || h == fourcc("IxFl")) { IndexFlat* idxf; if (h == fourcc("IxFI")) { idxf = new IndexFlatIP(); @@ -961,7 +963,7 @@ Index* read_index(IOReader* f, int io_flags) { read_index_header(idxhnsw, f); read_HNSW(&idxhnsw->hnsw, f); idxhnsw->storage = read_index(f, io_flags); - idxhnsw->own_fields = true; + idxhnsw->own_fields = idxhnsw->storage != nullptr; if (h == fourcc("IHNp") && !(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) { dynamic_cast(idxhnsw->storage)->pq.compute_sdc_table(); } diff --git a/faiss/impl/index_write.cpp b/faiss/impl/index_write.cpp index b2808d717..01e5ae725 100644 --- a/faiss/impl/index_write.cpp +++ b/faiss/impl/index_write.cpp @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - #include #include @@ -390,8 +388,12 @@ static void write_ivf_header(const IndexIVF* ivf, IOWriter* f) { write_direct_map(&ivf->direct_map, f); } -void write_index(const Index* idx, IOWriter* f) { - if (const IndexFlat* idxf = dynamic_cast(idx)) { +void write_index(const Index* idx, IOWriter* f, int io_flags) { + if (idx == nullptr) { + // eg. for a storage component of HNSW that is set to nullptr + uint32_t h = fourcc("null"); + WRITE1(h); + } else if (const IndexFlat* idxf = dynamic_cast(idx)) { uint32_t h = fourcc(idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI" : idxf->metric_type == METRIC_L2 ? "IxF2" @@ -765,7 +767,12 @@ void write_index(const Index* idx, IOWriter* f) { WRITE1(h); write_index_header(idxhnsw, f); write_HNSW(&idxhnsw->hnsw, f); - write_index(idxhnsw->storage, f); + if (io_flags & IO_FLAG_SKIP_STORAGE) { + uint32_t n4 = fourcc("null"); + WRITE1(n4); + } else { + write_index(idxhnsw->storage, f); + } } else if (const IndexNSG* idxnsg = dynamic_cast(idx)) { uint32_t h = dynamic_cast(idx) ? fourcc("INSf") : dynamic_cast(idx) ? fourcc("INSp") @@ -841,14 +848,14 @@ void write_index(const Index* idx, IOWriter* f) { } } -void write_index(const Index* idx, FILE* f) { +void write_index(const Index* idx, FILE* f, int io_flags) { FileIOWriter writer(f); - write_index(idx, &writer); + write_index(idx, &writer, io_flags); } -void write_index(const Index* idx, const char* fname) { +void write_index(const Index* idx, const char* fname, int io_flags) { FileIOWriter writer(fname); - write_index(idx, &writer); + write_index(idx, &writer, io_flags); } void write_VectorTransform(const VectorTransform* vt, const char* fname) { diff --git a/faiss/index_io.h b/faiss/index_io.h index f73cd073b..3e77d0227 100644 --- a/faiss/index_io.h +++ b/faiss/index_io.h @@ -5,8 +5,6 @@ * LICENSE file in the root directory of this source tree. */ -// -*- c++ -*- - // I/O code for indexes #ifndef FAISS_INDEX_IO_H @@ -35,9 +33,12 @@ struct IOReader; struct IOWriter; struct InvertedLists; -void write_index(const Index* idx, const char* fname); -void write_index(const Index* idx, FILE* f); -void write_index(const Index* idx, IOWriter* writer); +/// skip the storage for graph-based indexes +const int IO_FLAG_SKIP_STORAGE = 1; + +void write_index(const Index* idx, const char* fname, int io_flags = 0); +void write_index(const Index* idx, FILE* f, int io_flags = 0); +void write_index(const Index* idx, IOWriter* writer, int io_flags = 0); void write_index_binary(const IndexBinary* idx, const char* fname); void write_index_binary(const IndexBinary* idx, FILE* f); diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index 0562d1dd8..ce4b42c61 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -292,10 +292,10 @@ IVFSearchParameters = SearchParametersIVF ########################################### -def serialize_index(index): +def serialize_index(index, io_flags=0): """ convert an index to a numpy uint8 array """ writer = VectorIOWriter() - write_index(index, writer) + write_index(index, writer, io_flags) return vector_to_array(writer.data) diff --git a/tests/test_graph_based.py b/tests/test_graph_based.py index d5ddbeec3..95925d7ae 100644 --- a/tests/test_graph_based.py +++ b/tests/test_graph_based.py @@ -133,6 +133,42 @@ class TestHNSW(unittest.TestCase): Dhnsw, Ihnsw = index.search(self.xq, 1) self.assertGreater(stats.ndis, len(self.xq) * index.hnsw.efSearch) + def test_io_no_storage(self): + d = self.xq.shape[1] + index = faiss.IndexHNSWFlat(d, 16) + index.add(self.xb) + + Dref, Iref = index.search(self.xq, 5) + + # test writing without storage + index2 = faiss.deserialize_index( + faiss.serialize_index(index, faiss.IO_FLAG_SKIP_STORAGE) + ) + self.assertEquals(index2.storage, None) + self.assertRaises( + RuntimeError, + index2.search, self.xb, 1) + + # make sure we can store an index with empty storage + index4 = faiss.deserialize_index( + faiss.serialize_index(index2)) + + # add storage afterwards + index.storage = faiss.clone_index(index.storage) + index.own_fields = True + + Dnew, Inew = index.search(self.xq, 5) + np.testing.assert_array_equal(Dnew, Dref) + np.testing.assert_array_equal(Inew, Iref) + + if False: + # test reading without storage + # not implemented because it is hard to skip over an index + index3 = faiss.deserialize_index( + faiss.serialize_index(index), faiss.IO_FLAG_SKIP_STORAGE + ) + self.assertEquals(index3.storage, None) + class TestNSG(unittest.TestCase):