diff --git a/IndexIVF.cpp b/IndexIVF.cpp index 2b0f62e7d..d370362a4 100644 --- a/IndexIVF.cpp +++ b/IndexIVF.cpp @@ -590,7 +590,6 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type, } - IndexIVF::~IndexIVF() { if (own_invlists) { diff --git a/index_io.cpp b/index_io.cpp index 4fa769ed3..b5730d12d 100644 --- a/index_io.cpp +++ b/index_io.cpp @@ -73,14 +73,14 @@ static uint32_t fourcc (const char sx[4]) { **************************************************************/ -#define WRITEANDCHECK(ptr, n) { \ - size_t ret = fwrite (ptr, sizeof (* (ptr)), n, f); \ - FAISS_THROW_IF_NOT_MSG (ret == (n), "write error"); \ +#define WRITEANDCHECK(ptr, n) { \ + size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \ + FAISS_THROW_IF_NOT_MSG(ret == (n), "write error"); \ } -#define READANDCHECK(ptr, n) { \ - size_t ret = fread (ptr, sizeof (* (ptr)), n, f); \ - FAISS_THROW_IF_NOT_MSG (ret == (n), "read error"); \ +#define READANDCHECK(ptr, n) { \ + size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \ + FAISS_THROW_IF_NOT_MSG(ret == (n), "read error"); \ } #define WRITE1(x) WRITEANDCHECK(&(x), 1) @@ -106,15 +106,41 @@ struct ScopeFileCloser { ~ScopeFileCloser () {fclose (f); } }; +namespace { + +struct FileIOReader: IOReader { + FILE *f = nullptr; + + FileIOReader(FILE *rf): f(rf) {} + + ~FileIOReader() = default; + + virtual size_t operator()( + void *ptr, size_t size, size_t nitems) override { + return fread(ptr, size, nitems, f); + } +}; + +struct FileIOWriter: IOWriter { + FILE *f = nullptr; + + FileIOWriter(FILE *wf): f(wf) {} + ~FileIOWriter() = default; + + virtual size_t operator()( + const void *ptr, size_t size, size_t nitems) override { + return fwrite(ptr, size, nitems, f); + } +}; +} // namespace /************************************************************* * Write **************************************************************/ - -static void write_index_header (const Index *idx, FILE *f) { +static void write_index_header (const Index *idx, IOWriter *f) { WRITE1 (idx->d); WRITE1 (idx->ntotal); Index::idx_t dummy = 1 << 20; @@ -124,7 +150,7 @@ static void write_index_header (const Index *idx, FILE *f) { WRITE1 (idx->metric_type); } -void write_VectorTransform (const VectorTransform *vt, FILE *f) { +void write_VectorTransform (const VectorTransform *vt, IOWriter *f) { if (const LinearTransform * lt = dynamic_cast < const LinearTransform *> (vt)) { if (dynamic_cast(lt)) { @@ -167,14 +193,16 @@ void write_VectorTransform (const VectorTransform *vt, FILE *f) { WRITE1 (vt->is_trained); } -static void write_ProductQuantizer (const ProductQuantizer *pq, FILE *f) { +static void write_ProductQuantizer ( + const ProductQuantizer *pq, IOWriter *f) { WRITE1 (pq->d); WRITE1 (pq->M); WRITE1 (pq->nbits); WRITEVECTOR (pq->centroids); } - -static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) { + +static void write_ScalarQuantizer ( + const ScalarQuantizer *ivsc, IOWriter *f) { WRITE1 (ivsc->qtype); WRITE1 (ivsc->rangestat); WRITE1 (ivsc->rangestat_arg); @@ -183,7 +211,7 @@ static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) { WRITEVECTOR (ivsc->trained); } -static void write_InvertedLists (const InvertedLists *ils, FILE *f) { +static void write_InvertedLists (const InvertedLists *ils, IOWriter *f) { if (ils == nullptr) { uint32_t h = fourcc ("il00"); WRITE1 (h); @@ -258,10 +286,12 @@ void write_ProductQuantizer (const ProductQuantizer*pq, const char *fname) { FILE *f = fopen (fname, "w"); FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname); ScopeFileCloser closer(f); - write_ProductQuantizer (pq, f); + + FileIOWriter writer(f); + write_ProductQuantizer (pq, &writer); } -static void write_HNSW (const HNSW *hnsw, FILE *f) { +static void write_HNSW (const HNSW *hnsw, IOWriter *f) { WRITEVECTOR (hnsw->assign_probas); WRITEVECTOR (hnsw->cum_nneighbor_per_level); @@ -274,10 +304,9 @@ static void write_HNSW (const HNSW *hnsw, FILE *f) { WRITE1 (hnsw->efConstruction); WRITE1 (hnsw->efSearch); WRITE1 (hnsw->upper_beam); - } -static void write_ivf_header (const IndexIVF * ivf, FILE *f) { +static void write_ivf_header (const IndexIVF *ivf, IOWriter *f) { write_index_header (ivf, f); WRITE1 (ivf->nlist); WRITE1 (ivf->nprobe); @@ -286,7 +315,7 @@ static void write_ivf_header (const IndexIVF * ivf, FILE *f) { WRITEVECTOR (ivf->direct_map); } -void write_index (const Index *idx, FILE *f) { +void write_index (const Index *idx, IOWriter *f) { if (const IndexFlat * idxf = dynamic_cast (idx)) { uint32_t h = fourcc ( idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI" : @@ -418,6 +447,11 @@ void write_index (const Index *idx, FILE *f) { } } +void write_index (const Index *idx, FILE *f) { + FileIOWriter writer(f); + write_index(idx, &writer); +} + void write_index (const Index *idx, const char *fname) { FILE *f = fopen (fname, "w"); FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname); @@ -429,14 +463,16 @@ void write_VectorTransform (const VectorTransform *vt, const char *fname) { FILE *f = fopen (fname, "w"); FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname); ScopeFileCloser closer(f); - write_VectorTransform (vt, f); + + FileIOWriter writer(f); + write_VectorTransform (vt, &writer); } /************************************************************* * Read **************************************************************/ -static void read_index_header (Index *idx, FILE *f) { +static void read_index_header (Index *idx, IOReader *f) { READ1 (idx->d); READ1 (idx->ntotal); Index::idx_t dummy; @@ -447,7 +483,7 @@ static void read_index_header (Index *idx, FILE *f) { idx->verbose = false; } -VectorTransform* read_VectorTransform (FILE *f) { +VectorTransform* read_VectorTransform (IOReader *f) { uint32_t h; READ1 (h); VectorTransform *vt = nullptr; @@ -497,7 +533,7 @@ VectorTransform* read_VectorTransform (FILE *f) { static void read_ArrayInvertedLists_sizes ( - FILE *f, std::vector & sizes) + IOReader *f, std::vector & sizes) { size_t nlist = sizes.size(); uint32_t list_type; @@ -518,8 +554,7 @@ static void read_ArrayInvertedLists_sizes ( } } - -InvertedLists *read_InvertedLists (FILE *f, int io_flags) { +InvertedLists *read_InvertedLists (IOReader *f, int io_flags) { uint32_t h; READ1 (h); if (h == fourcc ("il00")) { @@ -545,6 +580,10 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) { } return ails; } else if (h == fourcc ("ilar") && (io_flags & IO_FLAG_MMAP)) { + auto impl = dynamic_cast(f); + FAISS_THROW_IF_NOT(NULL != impl); + FILE *raw_f = impl->f; + auto ails = new OnDiskInvertedLists (); READ1 (ails->nlist); READ1 (ails->code_size); @@ -552,16 +591,16 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) { ails->lists.resize (ails->nlist); std::vector sizes (ails->nlist); read_ArrayInvertedLists_sizes (f, sizes); - size_t o0 = ftell (f), o = o0; + size_t o0 = ftell (raw_f), o = o0; { // do the mmap struct stat buf; - int ret = fstat (fileno(f), &buf); + int ret = fstat (fileno(raw_f), &buf); FAISS_THROW_IF_NOT_FMT (ret == 0, "fstat failed: %s", strerror(errno)); ails->totsize = buf.st_size; ails->ptr = (uint8_t*)mmap (nullptr, ails->totsize, PROT_READ, MAP_SHARED, - fileno (f), 0); + fileno (raw_f), 0); FAISS_THROW_IF_NOT_FMT (ails->ptr != MAP_FAILED, "could not mmap: %s", strerror(errno)); @@ -574,7 +613,7 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) { ails->code_size); } // resume normal reading of file - fseek (f, o, SEEK_SET); + fseek (raw_f, o, SEEK_SET); return ails; } else if (h == fourcc ("ilod")) { OnDiskInvertedLists *od = new OnDiskInvertedLists(); @@ -601,24 +640,24 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) { } } -static void read_InvertedLists (IndexIVF *ivf, FILE *f, int io_flags) { +static void read_InvertedLists ( + IndexIVF *ivf, IOReader *f, int io_flags) { InvertedLists *ils = read_InvertedLists (f, io_flags); FAISS_THROW_IF_NOT (ils->nlist == ivf->nlist && ils->code_size == ivf->code_size); ivf->invlists = ils; ivf->own_invlists = true; } - - -static void read_ProductQuantizer (ProductQuantizer *pq, FILE *f) { + +static void read_ProductQuantizer (ProductQuantizer *pq, IOReader *f) { READ1 (pq->d); READ1 (pq->M); READ1 (pq->nbits); pq->set_derived_values (); READVECTOR (pq->centroids); } - -static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) { + +static void read_ScalarQuantizer (ScalarQuantizer *ivsc, IOReader *f) { READ1 (ivsc->qtype); READ1 (ivsc->rangestat); READ1 (ivsc->rangestat_arg); @@ -628,7 +667,7 @@ static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) { } -static void read_HNSW (HNSW *hnsw, FILE *f) { +static void read_HNSW (HNSW *hnsw, IOReader *f) { READVECTOR (hnsw->assign_probas); READVECTOR (hnsw->cum_nneighbor_per_level); READVECTOR (hnsw->levels); @@ -648,14 +687,16 @@ ProductQuantizer * read_ProductQuantizer (const char*fname) { ScopeFileCloser closer(f); ProductQuantizer *pq = new ProductQuantizer(); ScopeDeleter1 del (pq); - read_ProductQuantizer(pq, f); + + FileIOReader reader(f); + read_ProductQuantizer(pq, &reader); del.release (); return pq; } static void read_ivf_header ( - IndexIVF * ivf, FILE *f, - std::vector > *ids = nullptr) + IndexIVF *ivf, IOReader *f, + std::vector > *ids = nullptr) { read_index_header (ivf, f); READ1 (ivf->nlist); @@ -683,7 +724,7 @@ static ArrayInvertedLists *set_array_invlist( return ail; } -static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags) +static IndexIVFPQ *read_ivfpq (IOReader *f, uint32_t h, int io_flags) { bool legacy = h == fourcc ("IvQR") || h == fourcc ("IvPQ"); @@ -720,7 +761,7 @@ static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags) int read_old_fmt_hack = 0; -Index *read_index (FILE * f, int io_flags) { +Index *read_index (IOReader *f, int io_flags) { Index * idx = nullptr; uint32_t h; READ1 (h); @@ -913,6 +954,10 @@ Index *read_index (FILE * f, int io_flags) { } +Index *read_index (FILE * f, int io_flags) { + FileIOReader reader(f); + return read_index(&reader, io_flags); +} Index *read_index (const char *fname, int io_flags) { FILE *f = fopen (fname, "r"); @@ -929,7 +974,9 @@ VectorTransform *read_VectorTransform (const char *fname) { perror (""); abort (); } - VectorTransform *vt = read_VectorTransform (f); + + FileIOReader reader(f); + VectorTransform *vt = read_VectorTransform (&reader); fclose (f); return vt; } diff --git a/index_io.h b/index_io.h index e473cd4e8..b95391191 100644 --- a/index_io.h +++ b/index_io.h @@ -21,17 +21,21 @@ struct Index; struct VectorTransform; struct IndexIVF; struct ProductQuantizer; +struct IOReader; +struct IOWriter; void write_index (const Index *idx, FILE *f); void write_index (const Index *idx, const char *fname); +void write_index (const Index *idx, IOWriter *writer); + const int IO_FLAG_MMAP = 1; const int IO_FLAG_READ_ONLY = 2; Index *read_index (FILE * f, int io_flags = 0); Index *read_index (const char *fname, int io_flags = 0); - +Index *read_index (IOReader *reader, int io_flags = 0); void write_VectorTransform (const VectorTransform *vt, const char *fname); @@ -55,6 +59,21 @@ struct Cloner { virtual ~Cloner() {} }; +struct IOReader { + // fread + virtual size_t operator()( + void *ptr, size_t size, size_t nitems) = 0; + virtual ~IOReader() {} +}; + +struct IOWriter { + // fwrite + virtual size_t operator()( + const void *ptr, size_t size, size_t nitems) = 0; + + virtual ~IOWriter() {} +}; + } #endif