mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
read/write index with std::function wrapper (#427)
* add access function to IndexIVF; * - access for IndexIVF; - write_index/read_index with std::function<...>; * - fix test compile on mac; - adjust write/read with std::function; * replace std::function with IOReader/IOWriter; * remove IndexIVF::access // tmp * PFN_WRITE/READ => WRITE; * revert mac compile fix; * rename; * fix compile; * reset CMakeList; * format; remove unused function/header;
This commit is contained in:
parent
433f5c0fa5
commit
abe2b0fd19
@ -590,7 +590,6 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
|
||||
}
|
||||
|
||||
|
||||
|
||||
IndexIVF::~IndexIVF()
|
||||
{
|
||||
if (own_invlists) {
|
||||
|
129
index_io.cpp
129
index_io.cpp
@ -73,14 +73,14 @@ static uint32_t fourcc (const char sx[4]) {
|
||||
**************************************************************/
|
||||
|
||||
|
||||
#define WRITEANDCHECK(ptr, n) { \
|
||||
size_t ret = fwrite (ptr, sizeof (* (ptr)), n, f); \
|
||||
FAISS_THROW_IF_NOT_MSG (ret == (n), "write error"); \
|
||||
#define WRITEANDCHECK(ptr, n) { \
|
||||
size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
|
||||
FAISS_THROW_IF_NOT_MSG(ret == (n), "write error"); \
|
||||
}
|
||||
|
||||
#define READANDCHECK(ptr, n) { \
|
||||
size_t ret = fread (ptr, sizeof (* (ptr)), n, f); \
|
||||
FAISS_THROW_IF_NOT_MSG (ret == (n), "read error"); \
|
||||
#define READANDCHECK(ptr, n) { \
|
||||
size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
|
||||
FAISS_THROW_IF_NOT_MSG(ret == (n), "read error"); \
|
||||
}
|
||||
|
||||
#define WRITE1(x) WRITEANDCHECK(&(x), 1)
|
||||
@ -106,15 +106,41 @@ struct ScopeFileCloser {
|
||||
~ScopeFileCloser () {fclose (f); }
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
struct FileIOReader: IOReader {
|
||||
FILE *f = nullptr;
|
||||
|
||||
FileIOReader(FILE *rf): f(rf) {}
|
||||
|
||||
~FileIOReader() = default;
|
||||
|
||||
virtual size_t operator()(
|
||||
void *ptr, size_t size, size_t nitems) override {
|
||||
return fread(ptr, size, nitems, f);
|
||||
}
|
||||
};
|
||||
|
||||
struct FileIOWriter: IOWriter {
|
||||
FILE *f = nullptr;
|
||||
|
||||
FileIOWriter(FILE *wf): f(wf) {}
|
||||
~FileIOWriter() = default;
|
||||
|
||||
virtual size_t operator()(
|
||||
const void *ptr, size_t size, size_t nitems) override {
|
||||
return fwrite(ptr, size, nitems, f);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
/*************************************************************
|
||||
* Write
|
||||
**************************************************************/
|
||||
|
||||
static void write_index_header (const Index *idx, FILE *f) {
|
||||
static void write_index_header (const Index *idx, IOWriter *f) {
|
||||
WRITE1 (idx->d);
|
||||
WRITE1 (idx->ntotal);
|
||||
Index::idx_t dummy = 1 << 20;
|
||||
@ -124,7 +150,7 @@ static void write_index_header (const Index *idx, FILE *f) {
|
||||
WRITE1 (idx->metric_type);
|
||||
}
|
||||
|
||||
void write_VectorTransform (const VectorTransform *vt, FILE *f) {
|
||||
void write_VectorTransform (const VectorTransform *vt, IOWriter *f) {
|
||||
if (const LinearTransform * lt =
|
||||
dynamic_cast < const LinearTransform *> (vt)) {
|
||||
if (dynamic_cast<const RandomRotationMatrix *>(lt)) {
|
||||
@ -167,14 +193,16 @@ void write_VectorTransform (const VectorTransform *vt, FILE *f) {
|
||||
WRITE1 (vt->is_trained);
|
||||
}
|
||||
|
||||
static void write_ProductQuantizer (const ProductQuantizer *pq, FILE *f) {
|
||||
static void write_ProductQuantizer (
|
||||
const ProductQuantizer *pq, IOWriter *f) {
|
||||
WRITE1 (pq->d);
|
||||
WRITE1 (pq->M);
|
||||
WRITE1 (pq->nbits);
|
||||
WRITEVECTOR (pq->centroids);
|
||||
}
|
||||
|
||||
static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) {
|
||||
|
||||
static void write_ScalarQuantizer (
|
||||
const ScalarQuantizer *ivsc, IOWriter *f) {
|
||||
WRITE1 (ivsc->qtype);
|
||||
WRITE1 (ivsc->rangestat);
|
||||
WRITE1 (ivsc->rangestat_arg);
|
||||
@ -183,7 +211,7 @@ static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) {
|
||||
WRITEVECTOR (ivsc->trained);
|
||||
}
|
||||
|
||||
static void write_InvertedLists (const InvertedLists *ils, FILE *f) {
|
||||
static void write_InvertedLists (const InvertedLists *ils, IOWriter *f) {
|
||||
if (ils == nullptr) {
|
||||
uint32_t h = fourcc ("il00");
|
||||
WRITE1 (h);
|
||||
@ -258,10 +286,12 @@ void write_ProductQuantizer (const ProductQuantizer*pq, const char *fname) {
|
||||
FILE *f = fopen (fname, "w");
|
||||
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
|
||||
ScopeFileCloser closer(f);
|
||||
write_ProductQuantizer (pq, f);
|
||||
|
||||
FileIOWriter writer(f);
|
||||
write_ProductQuantizer (pq, &writer);
|
||||
}
|
||||
|
||||
static void write_HNSW (const HNSW *hnsw, FILE *f) {
|
||||
static void write_HNSW (const HNSW *hnsw, IOWriter *f) {
|
||||
|
||||
WRITEVECTOR (hnsw->assign_probas);
|
||||
WRITEVECTOR (hnsw->cum_nneighbor_per_level);
|
||||
@ -274,10 +304,9 @@ static void write_HNSW (const HNSW *hnsw, FILE *f) {
|
||||
WRITE1 (hnsw->efConstruction);
|
||||
WRITE1 (hnsw->efSearch);
|
||||
WRITE1 (hnsw->upper_beam);
|
||||
|
||||
}
|
||||
|
||||
static void write_ivf_header (const IndexIVF * ivf, FILE *f) {
|
||||
static void write_ivf_header (const IndexIVF *ivf, IOWriter *f) {
|
||||
write_index_header (ivf, f);
|
||||
WRITE1 (ivf->nlist);
|
||||
WRITE1 (ivf->nprobe);
|
||||
@ -286,7 +315,7 @@ static void write_ivf_header (const IndexIVF * ivf, FILE *f) {
|
||||
WRITEVECTOR (ivf->direct_map);
|
||||
}
|
||||
|
||||
void write_index (const Index *idx, FILE *f) {
|
||||
void write_index (const Index *idx, IOWriter *f) {
|
||||
if (const IndexFlat * idxf = dynamic_cast<const IndexFlat *> (idx)) {
|
||||
uint32_t h = fourcc (
|
||||
idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI" :
|
||||
@ -418,6 +447,11 @@ void write_index (const Index *idx, FILE *f) {
|
||||
}
|
||||
}
|
||||
|
||||
void write_index (const Index *idx, FILE *f) {
|
||||
FileIOWriter writer(f);
|
||||
write_index(idx, &writer);
|
||||
}
|
||||
|
||||
void write_index (const Index *idx, const char *fname) {
|
||||
FILE *f = fopen (fname, "w");
|
||||
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
|
||||
@ -429,14 +463,16 @@ void write_VectorTransform (const VectorTransform *vt, const char *fname) {
|
||||
FILE *f = fopen (fname, "w");
|
||||
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
|
||||
ScopeFileCloser closer(f);
|
||||
write_VectorTransform (vt, f);
|
||||
|
||||
FileIOWriter writer(f);
|
||||
write_VectorTransform (vt, &writer);
|
||||
}
|
||||
|
||||
/*************************************************************
|
||||
* Read
|
||||
**************************************************************/
|
||||
|
||||
static void read_index_header (Index *idx, FILE *f) {
|
||||
static void read_index_header (Index *idx, IOReader *f) {
|
||||
READ1 (idx->d);
|
||||
READ1 (idx->ntotal);
|
||||
Index::idx_t dummy;
|
||||
@ -447,7 +483,7 @@ static void read_index_header (Index *idx, FILE *f) {
|
||||
idx->verbose = false;
|
||||
}
|
||||
|
||||
VectorTransform* read_VectorTransform (FILE *f) {
|
||||
VectorTransform* read_VectorTransform (IOReader *f) {
|
||||
uint32_t h;
|
||||
READ1 (h);
|
||||
VectorTransform *vt = nullptr;
|
||||
@ -497,7 +533,7 @@ VectorTransform* read_VectorTransform (FILE *f) {
|
||||
|
||||
|
||||
static void read_ArrayInvertedLists_sizes (
|
||||
FILE *f, std::vector<size_t> & sizes)
|
||||
IOReader *f, std::vector<size_t> & sizes)
|
||||
{
|
||||
size_t nlist = sizes.size();
|
||||
uint32_t list_type;
|
||||
@ -518,8 +554,7 @@ static void read_ArrayInvertedLists_sizes (
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
|
||||
InvertedLists *read_InvertedLists (IOReader *f, int io_flags) {
|
||||
uint32_t h;
|
||||
READ1 (h);
|
||||
if (h == fourcc ("il00")) {
|
||||
@ -545,6 +580,10 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
|
||||
}
|
||||
return ails;
|
||||
} else if (h == fourcc ("ilar") && (io_flags & IO_FLAG_MMAP)) {
|
||||
auto impl = dynamic_cast<FileIOReader*>(f);
|
||||
FAISS_THROW_IF_NOT(NULL != impl);
|
||||
FILE *raw_f = impl->f;
|
||||
|
||||
auto ails = new OnDiskInvertedLists ();
|
||||
READ1 (ails->nlist);
|
||||
READ1 (ails->code_size);
|
||||
@ -552,16 +591,16 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
|
||||
ails->lists.resize (ails->nlist);
|
||||
std::vector<size_t> sizes (ails->nlist);
|
||||
read_ArrayInvertedLists_sizes (f, sizes);
|
||||
size_t o0 = ftell (f), o = o0;
|
||||
size_t o0 = ftell (raw_f), o = o0;
|
||||
{ // do the mmap
|
||||
struct stat buf;
|
||||
int ret = fstat (fileno(f), &buf);
|
||||
int ret = fstat (fileno(raw_f), &buf);
|
||||
FAISS_THROW_IF_NOT_FMT (ret == 0,
|
||||
"fstat failed: %s", strerror(errno));
|
||||
ails->totsize = buf.st_size;
|
||||
ails->ptr = (uint8_t*)mmap (nullptr, ails->totsize,
|
||||
PROT_READ, MAP_SHARED,
|
||||
fileno (f), 0);
|
||||
fileno (raw_f), 0);
|
||||
FAISS_THROW_IF_NOT_FMT (ails->ptr != MAP_FAILED,
|
||||
"could not mmap: %s",
|
||||
strerror(errno));
|
||||
@ -574,7 +613,7 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
|
||||
ails->code_size);
|
||||
}
|
||||
// resume normal reading of file
|
||||
fseek (f, o, SEEK_SET);
|
||||
fseek (raw_f, o, SEEK_SET);
|
||||
return ails;
|
||||
} else if (h == fourcc ("ilod")) {
|
||||
OnDiskInvertedLists *od = new OnDiskInvertedLists();
|
||||
@ -601,24 +640,24 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
|
||||
}
|
||||
}
|
||||
|
||||
static void read_InvertedLists (IndexIVF *ivf, FILE *f, int io_flags) {
|
||||
static void read_InvertedLists (
|
||||
IndexIVF *ivf, IOReader *f, int io_flags) {
|
||||
InvertedLists *ils = read_InvertedLists (f, io_flags);
|
||||
FAISS_THROW_IF_NOT (ils->nlist == ivf->nlist &&
|
||||
ils->code_size == ivf->code_size);
|
||||
ivf->invlists = ils;
|
||||
ivf->own_invlists = true;
|
||||
}
|
||||
|
||||
|
||||
static void read_ProductQuantizer (ProductQuantizer *pq, FILE *f) {
|
||||
|
||||
static void read_ProductQuantizer (ProductQuantizer *pq, IOReader *f) {
|
||||
READ1 (pq->d);
|
||||
READ1 (pq->M);
|
||||
READ1 (pq->nbits);
|
||||
pq->set_derived_values ();
|
||||
READVECTOR (pq->centroids);
|
||||
}
|
||||
|
||||
static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) {
|
||||
|
||||
static void read_ScalarQuantizer (ScalarQuantizer *ivsc, IOReader *f) {
|
||||
READ1 (ivsc->qtype);
|
||||
READ1 (ivsc->rangestat);
|
||||
READ1 (ivsc->rangestat_arg);
|
||||
@ -628,7 +667,7 @@ static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) {
|
||||
}
|
||||
|
||||
|
||||
static void read_HNSW (HNSW *hnsw, FILE *f) {
|
||||
static void read_HNSW (HNSW *hnsw, IOReader *f) {
|
||||
READVECTOR (hnsw->assign_probas);
|
||||
READVECTOR (hnsw->cum_nneighbor_per_level);
|
||||
READVECTOR (hnsw->levels);
|
||||
@ -648,14 +687,16 @@ ProductQuantizer * read_ProductQuantizer (const char*fname) {
|
||||
ScopeFileCloser closer(f);
|
||||
ProductQuantizer *pq = new ProductQuantizer();
|
||||
ScopeDeleter1<ProductQuantizer> del (pq);
|
||||
read_ProductQuantizer(pq, f);
|
||||
|
||||
FileIOReader reader(f);
|
||||
read_ProductQuantizer(pq, &reader);
|
||||
del.release ();
|
||||
return pq;
|
||||
}
|
||||
|
||||
static void read_ivf_header (
|
||||
IndexIVF * ivf, FILE *f,
|
||||
std::vector<std::vector<Index::idx_t> > *ids = nullptr)
|
||||
IndexIVF *ivf, IOReader *f,
|
||||
std::vector<std::vector<Index::idx_t> > *ids = nullptr)
|
||||
{
|
||||
read_index_header (ivf, f);
|
||||
READ1 (ivf->nlist);
|
||||
@ -683,7 +724,7 @@ static ArrayInvertedLists *set_array_invlist(
|
||||
return ail;
|
||||
}
|
||||
|
||||
static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags)
|
||||
static IndexIVFPQ *read_ivfpq (IOReader *f, uint32_t h, int io_flags)
|
||||
{
|
||||
bool legacy = h == fourcc ("IvQR") || h == fourcc ("IvPQ");
|
||||
|
||||
@ -720,7 +761,7 @@ static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags)
|
||||
|
||||
int read_old_fmt_hack = 0;
|
||||
|
||||
Index *read_index (FILE * f, int io_flags) {
|
||||
Index *read_index (IOReader *f, int io_flags) {
|
||||
Index * idx = nullptr;
|
||||
uint32_t h;
|
||||
READ1 (h);
|
||||
@ -913,6 +954,10 @@ Index *read_index (FILE * f, int io_flags) {
|
||||
}
|
||||
|
||||
|
||||
Index *read_index (FILE * f, int io_flags) {
|
||||
FileIOReader reader(f);
|
||||
return read_index(&reader, io_flags);
|
||||
}
|
||||
|
||||
Index *read_index (const char *fname, int io_flags) {
|
||||
FILE *f = fopen (fname, "r");
|
||||
@ -929,7 +974,9 @@ VectorTransform *read_VectorTransform (const char *fname) {
|
||||
perror ("");
|
||||
abort ();
|
||||
}
|
||||
VectorTransform *vt = read_VectorTransform (f);
|
||||
|
||||
FileIOReader reader(f);
|
||||
VectorTransform *vt = read_VectorTransform (&reader);
|
||||
fclose (f);
|
||||
return vt;
|
||||
}
|
||||
|
21
index_io.h
21
index_io.h
@ -21,17 +21,21 @@ struct Index;
|
||||
struct VectorTransform;
|
||||
struct IndexIVF;
|
||||
struct ProductQuantizer;
|
||||
struct IOReader;
|
||||
struct IOWriter;
|
||||
|
||||
void write_index (const Index *idx, FILE *f);
|
||||
void write_index (const Index *idx, const char *fname);
|
||||
|
||||
void write_index (const Index *idx, IOWriter *writer);
|
||||
|
||||
|
||||
const int IO_FLAG_MMAP = 1;
|
||||
const int IO_FLAG_READ_ONLY = 2;
|
||||
|
||||
Index *read_index (FILE * f, int io_flags = 0);
|
||||
Index *read_index (const char *fname, int io_flags = 0);
|
||||
|
||||
Index *read_index (IOReader *reader, int io_flags = 0);
|
||||
|
||||
|
||||
void write_VectorTransform (const VectorTransform *vt, const char *fname);
|
||||
@ -55,6 +59,21 @@ struct Cloner {
|
||||
virtual ~Cloner() {}
|
||||
};
|
||||
|
||||
struct IOReader {
|
||||
// fread
|
||||
virtual size_t operator()(
|
||||
void *ptr, size_t size, size_t nitems) = 0;
|
||||
virtual ~IOReader() {}
|
||||
};
|
||||
|
||||
struct IOWriter {
|
||||
// fwrite
|
||||
virtual size_t operator()(
|
||||
const void *ptr, size_t size, size_t nitems) = 0;
|
||||
|
||||
virtual ~IOWriter() {}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
Loading…
x
Reference in New Issue
Block a user