From a4401c13d88d3a0a6e98297f03601f3c2223223f Mon Sep 17 00:00:00 2001 From: Kaival Parikh <46070017+kaivalnp@users.noreply.github.com> Date: Tue, 1 Apr 2025 11:05:29 -0700 Subject: [PATCH] Allow using custom index readers and writers (#4180) Summary: ### Description - Create custom readers and writers for index IO, which take function pointers as input - Also expose these from the C_API This is helpful for FFI use, where calling processes would pass upcall stubs for streamlined IO Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4180 Reviewed By: gtwang01 Differential Revision: D71208266 Pulled By: mnorris11 fbshipit-source-id: ab82397d4780a2a07c7bfdc52329968377f42af4 --- c_api/CMakeLists.txt | 1 + c_api/impl/io_c.cpp | 76 ++++++++++++++++++++++++++++++++++++++++++++ c_api/impl/io_c.h | 50 +++++++++++++++++++++++++++++ c_api/index_io_c.cpp | 50 +++++++++++++++++++++++++++++ c_api/index_io_c.h | 28 ++++++++++++++++ 5 files changed, 205 insertions(+) create mode 100644 c_api/impl/io_c.cpp create mode 100644 c_api/impl/io_c.h diff --git a/c_api/CMakeLists.txt b/c_api/CMakeLists.txt index 5ce7f5640..6cfc50039 100644 --- a/c_api/CMakeLists.txt +++ b/c_api/CMakeLists.txt @@ -29,6 +29,7 @@ set(FAISS_C_SRC index_factory_c.cpp index_io_c.cpp impl/AuxIndexStructures_c.cpp + impl/io_c.cpp utils/distances_c.cpp utils/utils_c.cpp ) diff --git a/c_api/impl/io_c.cpp b/c_api/impl/io_c.cpp new file mode 100644 index 000000000..58597b97f --- /dev/null +++ b/c_api/impl/io_c.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include "io_c.h" +#include +#include "../macros_impl.h" + +using faiss::IOReader; +using faiss::IOWriter; + +struct CustomIOReader : IOReader { + size_t (*func)(void* ptr, size_t size, size_t nitems) = nullptr; + + CustomIOReader(size_t (*func_in)(void* ptr, size_t size, size_t nitems)); + + size_t operator()(void* ptr, size_t size, size_t nitems) override; +}; + +CustomIOReader::CustomIOReader( + size_t (*func_in)(void* ptr, size_t size, size_t nitems)) + : func(func_in) {} + +size_t CustomIOReader::operator()(void* ptr, size_t size, size_t nitems) { + return func(ptr, size, nitems); +} + +int faiss_CustomIOReader_new( + FaissCustomIOReader** p_out, + size_t (*func_in)(void* ptr, size_t size, size_t nitems)) { + try { + *p_out = reinterpret_cast( + new CustomIOReader(func_in)); + } + CATCH_AND_HANDLE +} + +void faiss_CustomIOReader_free(FaissCustomIOReader* obj) { + delete reinterpret_cast(obj); +} + +struct CustomIOWriter : IOWriter { + size_t (*func)(const void* ptr, size_t size, size_t nitems) = nullptr; + + CustomIOWriter( + size_t (*func_in)(const void* ptr, size_t size, size_t nitems)); + + size_t operator()(const void* ptr, size_t size, size_t nitems) override; +}; + +CustomIOWriter::CustomIOWriter( + size_t (*func_in)(const void* ptr, size_t size, size_t nitems)) + : func(func_in) {} + +size_t CustomIOWriter::operator()(const void* ptr, size_t size, size_t nitems) { + return func(ptr, size, nitems); +} + +int faiss_CustomIOWriter_new( + FaissCustomIOWriter** p_out, + size_t (*func_in)(const void* ptr, size_t size, size_t nitems)) { + try { + *p_out = reinterpret_cast( + new CustomIOWriter(func_in)); + } + CATCH_AND_HANDLE +} + +void faiss_CustomIOWriter_free(FaissCustomIOWriter* obj) { + delete reinterpret_cast(obj); +} diff --git a/c_api/impl/io_c.h b/c_api/impl/io_c.h new file mode 100644 index 000000000..94a604828 --- /dev/null +++ b/c_api/impl/io_c.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c -*- + +#ifndef FAISS_IO_C_H +#define FAISS_IO_C_H + +#include +#include "../faiss_c.h" + +#ifdef __cplusplus +extern "C" { +#endif + +FAISS_DECLARE_CLASS(IOReader) +FAISS_DECLARE_DESTRUCTOR(IOReader) + +FAISS_DECLARE_CLASS(IOWriter) +FAISS_DECLARE_DESTRUCTOR(IOWriter) + +/******************************************************* + * Custom reader + writer + * + * Reader and writer which wraps a function pointer, + * primarily for FFI use. + *******************************************************/ + +FAISS_DECLARE_CLASS(CustomIOReader) +FAISS_DECLARE_DESTRUCTOR(CustomIOReader) + +int faiss_CustomIOReader_new( + FaissCustomIOReader** p_out, + size_t (*func_in)(void* ptr, size_t size, size_t nitems)); + +FAISS_DECLARE_CLASS(CustomIOWriter) +FAISS_DECLARE_DESTRUCTOR(CustomIOWriter) + +int faiss_CustomIOWriter_new( + FaissCustomIOWriter** p_out, + size_t (*func_in)(const void* ptr, size_t size, size_t nitems)); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/c_api/index_io_c.cpp b/c_api/index_io_c.cpp index ea3068e96..b0d6b4777 100644 --- a/c_api/index_io_c.cpp +++ b/c_api/index_io_c.cpp @@ -14,6 +14,8 @@ using faiss::Index; using faiss::IndexBinary; +using faiss::IOReader; +using faiss::IOWriter; using faiss::VectorTransform; int faiss_write_index(const FaissIndex* idx, FILE* f) { @@ -30,6 +32,19 @@ int faiss_write_index_fname(const FaissIndex* idx, const char* fname) { CATCH_AND_HANDLE } +int faiss_write_index_custom( + const FaissIndex* idx, + FaissIOWriter* io_writer, + int io_flags) { + try { + faiss::write_index( + reinterpret_cast(idx), + reinterpret_cast(io_writer), + io_flags); + } + CATCH_AND_HANDLE +} + int faiss_read_index(FILE* f, int io_flags, FaissIndex** p_out) { try { auto out = faiss::read_index(f, io_flags); @@ -49,6 +64,18 @@ int faiss_read_index_fname( CATCH_AND_HANDLE } +int faiss_read_index_custom( + FaissIOReader* io_reader, + int io_flags, + FaissIndex** p_out) { + try { + auto out = faiss::read_index( + reinterpret_cast(io_reader), io_flags); + *p_out = reinterpret_cast(out); + } + CATCH_AND_HANDLE +} + int faiss_write_index_binary(const FaissIndexBinary* idx, FILE* f) { try { faiss::write_index_binary(reinterpret_cast(idx), f); @@ -66,6 +93,17 @@ int faiss_write_index_binary_fname( CATCH_AND_HANDLE } +int faiss_write_index_binary_custom( + const FaissIndexBinary* idx, + FaissIOWriter* io_writer) { + try { + faiss::write_index_binary( + reinterpret_cast(idx), + reinterpret_cast(io_writer)); + } + CATCH_AND_HANDLE +} + int faiss_read_index_binary(FILE* f, int io_flags, FaissIndexBinary** p_out) { try { auto out = faiss::read_index_binary(f, io_flags); @@ -85,6 +123,18 @@ int faiss_read_index_binary_fname( CATCH_AND_HANDLE } +int faiss_read_index_binary_custom( + FaissIOReader* io_reader, + int io_flags, + FaissIndexBinary** p_out) { + try { + auto out = faiss::read_index_binary( + reinterpret_cast(io_reader), io_flags); + *p_out = reinterpret_cast(out); + } + CATCH_AND_HANDLE +} + int faiss_read_VectorTransform_fname( const char* fname, FaissVectorTransform** p_out) { diff --git a/c_api/index_io_c.h b/c_api/index_io_c.h index fd4da615e..8e390dc92 100644 --- a/c_api/index_io_c.h +++ b/c_api/index_io_c.h @@ -16,6 +16,7 @@ #include "Index_c.h" #include "VectorTransform_c.h" #include "faiss_c.h" +#include "impl/io_c.h" #ifdef __cplusplus extern "C" { @@ -32,6 +33,13 @@ int faiss_write_index(const FaissIndex* idx, FILE* f); */ int faiss_write_index_fname(const FaissIndex* idx, const char* fname); +/** Write index to a custom writer. + */ +int faiss_write_index_custom( + const FaissIndex* idx, + FaissIOWriter* io_writer, + int io_flags); + #define FAISS_IO_FLAG_MMAP 1 #define FAISS_IO_FLAG_READ_ONLY 2 @@ -45,6 +53,13 @@ int faiss_read_index(FILE* f, int io_flags, FaissIndex** p_out); */ int faiss_read_index_fname(const char* fname, int io_flags, FaissIndex** p_out); +/** Read index from a custom reader. + */ +int faiss_read_index_custom( + FaissIOReader* io_reader, + int io_flags, + FaissIndex** p_out); + /** Write index to a file. * This is equivalent to `faiss::write_index_binary` when a file descriptor is * provided. @@ -59,6 +74,12 @@ int faiss_write_index_binary_fname( const FaissIndexBinary* idx, const char* fname); +/** Write binary index to a custom writer. + */ +int faiss_write_index_binary_custom( + const FaissIndexBinary* idx, + FaissIOWriter* io_writer); + /** Read index from a file. * This is equivalent to `faiss:read_index_binary` when a file descriptor is * given. @@ -73,6 +94,13 @@ int faiss_read_index_binary_fname( int io_flags, FaissIndexBinary** p_out); +/** Read binary index from a custom reader. + */ +int faiss_read_index_binary_custom( + FaissIOReader* io_reader, + int io_flags, + FaissIndexBinary** p_out); + /** Read vector transform from a file. * This is equivalent to `faiss:read_VectorTransform` when a file path is given. */