281 lines
7.8 KiB
C++
281 lines
7.8 KiB
C++
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// -*- c++ -*-
|
|
|
|
// Auxiliary index structures, that are used in indexes but that can
|
|
// be forward-declared
|
|
|
|
#ifndef FAISS_AUX_INDEX_STRUCTURES_H
|
|
#define FAISS_AUX_INDEX_STRUCTURES_H
|
|
|
|
#include <stdint.h>
|
|
|
|
#include <vector>
|
|
#include <unordered_set>
|
|
#include <memory>
|
|
|
|
|
|
#include "Index.h"
|
|
|
|
namespace faiss {
|
|
|
|
/** The objective is to have a simple result structure while
|
|
* minimizing the number of mem copies in the result. The method
|
|
* do_allocation can be overloaded to allocate the result tables in
|
|
* the matrix type of a scripting language like Lua or Python. */
|
|
struct RangeSearchResult {
|
|
size_t nq; ///< nb of queries
|
|
size_t *lims; ///< size (nq + 1)
|
|
|
|
typedef Index::idx_t idx_t;
|
|
|
|
idx_t *labels; ///< result for query i is labels[lims[i]:lims[i+1]]
|
|
float *distances; ///< corresponding distances (not sorted)
|
|
|
|
size_t buffer_size; ///< size of the result buffers used
|
|
|
|
/// lims must be allocated on input to range_search.
|
|
explicit RangeSearchResult (idx_t nq, bool alloc_lims=true);
|
|
|
|
/// called when lims contains the nb of elements result entries
|
|
/// for each query
|
|
virtual void do_allocation ();
|
|
|
|
virtual ~RangeSearchResult ();
|
|
};
|
|
|
|
|
|
/** Encapsulates a set of ids to remove. */
|
|
struct IDSelector {
|
|
typedef Index::idx_t idx_t;
|
|
virtual bool is_member (idx_t id) const = 0;
|
|
virtual ~IDSelector() {}
|
|
};
|
|
|
|
|
|
|
|
/** remove ids between [imni, imax) */
|
|
struct IDSelectorRange: IDSelector {
|
|
idx_t imin, imax;
|
|
|
|
IDSelectorRange (idx_t imin, idx_t imax);
|
|
bool is_member(idx_t id) const override;
|
|
~IDSelectorRange() override {}
|
|
};
|
|
|
|
|
|
/** Remove ids from a set. Repetitions of ids in the indices set
|
|
* passed to the constructor does not hurt performance. The hash
|
|
* function used for the bloom filter and GCC's implementation of
|
|
* unordered_set are just the least significant bits of the id. This
|
|
* works fine for random ids or ids in sequences but will produce many
|
|
* hash collisions if lsb's are always the same */
|
|
struct IDSelectorBatch: IDSelector {
|
|
|
|
std::unordered_set<idx_t> set;
|
|
|
|
typedef unsigned char uint8_t;
|
|
std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value
|
|
int nbits;
|
|
idx_t mask;
|
|
|
|
IDSelectorBatch (long n, const idx_t *indices);
|
|
bool is_member(idx_t id) const override;
|
|
~IDSelectorBatch() override {}
|
|
};
|
|
|
|
/****************************************************************
|
|
* Result structures for range search.
|
|
*
|
|
* The main constraint here is that we want to support parallel
|
|
* queries from different threads in various ways: 1 thread per query,
|
|
* several threads per query. We store the actual results in blocks of
|
|
* fixed size rather than exponentially increasing memory. At the end,
|
|
* we copy the block content to a linear result array.
|
|
*****************************************************************/
|
|
|
|
/** List of temporary buffers used to store results before they are
|
|
* copied to the RangeSearchResult object. */
|
|
struct BufferList {
|
|
typedef Index::idx_t idx_t;
|
|
|
|
// buffer sizes in # entries
|
|
size_t buffer_size;
|
|
|
|
struct Buffer {
|
|
idx_t *ids;
|
|
float *dis;
|
|
};
|
|
|
|
std::vector<Buffer> buffers;
|
|
size_t wp; ///< write pointer in the last buffer.
|
|
|
|
explicit BufferList (size_t buffer_size);
|
|
|
|
~BufferList ();
|
|
|
|
/// create a new buffer
|
|
void append_buffer ();
|
|
|
|
/// add one result, possibly appending a new buffer if needed
|
|
void add (idx_t id, float dis);
|
|
|
|
/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to
|
|
/// tables dest_ids, dest_dis
|
|
void copy_range (size_t ofs, size_t n,
|
|
idx_t * dest_ids, float *dest_dis);
|
|
|
|
};
|
|
|
|
struct RangeSearchPartialResult;
|
|
|
|
/// result structure for a single query
|
|
struct RangeQueryResult {
|
|
using idx_t = Index::idx_t;
|
|
idx_t qno; //< id of the query
|
|
size_t nres; //< nb of results for this query
|
|
RangeSearchPartialResult * pres;
|
|
|
|
/// called by search function to report a new result
|
|
void add (float dis, idx_t id);
|
|
};
|
|
|
|
/// the entries in the buffers are split per query
|
|
struct RangeSearchPartialResult: BufferList {
|
|
RangeSearchResult * res;
|
|
|
|
/// eventually the result will be stored in res_in
|
|
explicit RangeSearchPartialResult (RangeSearchResult * res_in);
|
|
|
|
/// query ids + nb of results per query.
|
|
std::vector<RangeQueryResult> queries;
|
|
|
|
/// begin a new result
|
|
RangeQueryResult & new_result (idx_t qno);
|
|
|
|
/*****************************************
|
|
* functions used at the end of the search to merge the result
|
|
* lists */
|
|
void finalize ();
|
|
|
|
/// called by range_search before do_allocation
|
|
void set_lims ();
|
|
|
|
/// called by range_search after do_allocation
|
|
void copy_result (bool incremental = false);
|
|
|
|
/// merge a set of PartialResult's into one RangeSearchResult
|
|
/// on ouptut the partialresults are empty!
|
|
static void merge (std::vector <RangeSearchPartialResult *> &
|
|
partial_results, bool do_delete=true);
|
|
|
|
};
|
|
|
|
/***********************************************************
|
|
* Abstract I/O objects
|
|
***********************************************************/
|
|
|
|
struct IOReader {
|
|
// name that can be used in error messages
|
|
std::string name;
|
|
|
|
// fread
|
|
virtual size_t operator()(
|
|
void *ptr, size_t size, size_t nitems) = 0;
|
|
|
|
// return a file number that can be memory-mapped
|
|
virtual int fileno ();
|
|
|
|
virtual ~IOReader() {}
|
|
};
|
|
|
|
struct IOWriter {
|
|
// name that can be used in error messages
|
|
std::string name;
|
|
|
|
// fwrite
|
|
virtual size_t operator()(
|
|
const void *ptr, size_t size, size_t nitems) = 0;
|
|
|
|
// return a file number that can be memory-mapped
|
|
virtual int fileno ();
|
|
|
|
virtual ~IOWriter() {}
|
|
};
|
|
|
|
|
|
struct VectorIOReader:IOReader {
|
|
std::vector<uint8_t> data;
|
|
size_t rp = 0;
|
|
size_t operator()(void *ptr, size_t size, size_t nitems) override;
|
|
};
|
|
|
|
struct VectorIOWriter:IOWriter {
|
|
std::vector<uint8_t> data;
|
|
size_t operator()(const void *ptr, size_t size, size_t nitems) override;
|
|
};
|
|
|
|
/***********************************************************
|
|
* The distance computer maintains a current query and computes
|
|
* distances to elements in an index that supports random access.
|
|
*
|
|
* The DistanceComputer is not intended to be thread-safe (eg. because
|
|
* it maintains counters) so the distance functions are not const,
|
|
* instanciate one from each thread if needed.
|
|
***********************************************************/
|
|
struct DistanceComputer {
|
|
using idx_t = Index::idx_t;
|
|
|
|
/// called before computing distances
|
|
virtual void set_query(const float *x) = 0;
|
|
|
|
/// compute distance of vector i to current query
|
|
virtual float operator () (idx_t i) = 0;
|
|
|
|
/// compute distance between two stored vectors
|
|
virtual float symmetric_dis (idx_t i, idx_t j) = 0;
|
|
|
|
virtual ~DistanceComputer() {}
|
|
};
|
|
|
|
/***********************************************************
|
|
* Interrupt callback
|
|
***********************************************************/
|
|
|
|
struct InterruptCallback {
|
|
virtual bool want_interrupt () = 0;
|
|
virtual ~InterruptCallback() {}
|
|
|
|
static std::unique_ptr<InterruptCallback> instance;
|
|
|
|
/** check if:
|
|
* - an interrupt callback is set
|
|
* - the callback retuns true
|
|
* if this is the case, then throw an exception
|
|
*/
|
|
static void check ();
|
|
|
|
/// same as check() but return true if is interrupted instead of
|
|
/// throwing
|
|
static bool is_interrupted ();
|
|
|
|
/** assuming each iteration takes a certain number of flops, what
|
|
* is a reasonable interval to check for interrupts?
|
|
*/
|
|
static size_t get_period_hint (size_t flops);
|
|
|
|
};
|
|
|
|
|
|
|
|
}; // namespace faiss
|
|
|
|
|
|
|
|
#endif
|