109 lines
2.8 KiB
C++
109 lines
2.8 KiB
C++
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
// -*- c++ -*-
|
|
|
|
#ifndef INDEX_FLAT_H
|
|
#define INDEX_FLAT_H
|
|
|
|
#include <vector>
|
|
|
|
#include <faiss/IndexFlatCodes.h>
|
|
|
|
namespace faiss {
|
|
|
|
/** Index that stores the full vectors and performs exhaustive search */
|
|
struct IndexFlat : IndexFlatCodes {
|
|
explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2);
|
|
|
|
void search(
|
|
idx_t n,
|
|
const float* x,
|
|
idx_t k,
|
|
float* distances,
|
|
idx_t* labels) const override;
|
|
|
|
void range_search(
|
|
idx_t n,
|
|
const float* x,
|
|
float radius,
|
|
RangeSearchResult* result) const override;
|
|
|
|
void reconstruct(idx_t key, float* recons) const override;
|
|
|
|
/** compute distance with a subset of vectors
|
|
*
|
|
* @param x query vectors, size n * d
|
|
* @param labels indices of the vectors that should be compared
|
|
* for each query vector, size n * k
|
|
* @param distances
|
|
* corresponding output distances, size n * k
|
|
*/
|
|
void compute_distance_subset(
|
|
idx_t n,
|
|
const float* x,
|
|
idx_t k,
|
|
float* distances,
|
|
const idx_t* labels) const;
|
|
|
|
// get pointer to the floating point data
|
|
float* get_xb() {
|
|
return (float*)codes.data();
|
|
}
|
|
const float* get_xb() const {
|
|
return (const float*)codes.data();
|
|
}
|
|
|
|
IndexFlat() {}
|
|
|
|
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
|
|
|
|
/* The stanadlone codec interface (just memcopies in this case) */
|
|
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
|
|
|
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
};
|
|
|
|
struct IndexFlatIP : IndexFlat {
|
|
explicit IndexFlatIP(idx_t d) : IndexFlat(d, METRIC_INNER_PRODUCT) {}
|
|
IndexFlatIP() {}
|
|
};
|
|
|
|
struct IndexFlatL2 : IndexFlat {
|
|
explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
|
|
IndexFlatL2() {}
|
|
};
|
|
|
|
/// optimized version for 1D "vectors".
|
|
struct IndexFlat1D : IndexFlatL2 {
|
|
bool continuous_update; ///< is the permutation updated continuously?
|
|
|
|
std::vector<idx_t> perm; ///< sorted database indices
|
|
|
|
explicit IndexFlat1D(bool continuous_update = true);
|
|
|
|
/// if not continuous_update, call this between the last add and
|
|
/// the first search
|
|
void update_permutation();
|
|
|
|
void add(idx_t n, const float* x) override;
|
|
|
|
void reset() override;
|
|
|
|
/// Warn: the distances returned are L1 not L2
|
|
void search(
|
|
idx_t n,
|
|
const float* x,
|
|
idx_t k,
|
|
float* distances,
|
|
idx_t* labels) const override;
|
|
};
|
|
|
|
} // namespace faiss
|
|
|
|
#endif
|