faiss/IndexFlat.h

172 lines
4.0 KiB
C++

/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the CC-by-NC license found in the
* LICENSE file in the root directory of this source tree.
*/
// Copyright 2004-present Facebook. All Rights Reserved
// -*- c++ -*-
#ifndef INDEX_FLAT_H
#define INDEX_FLAT_H
#include <vector>
#include "Index.h"
namespace faiss {
/** Index that stores the full vectors and performs exhaustive search */
struct IndexFlat: Index {
/// database vectors, size ntotal * d
std::vector<float> xb;
explicit IndexFlat (idx_t d, MetricType metric = METRIC_INNER_PRODUCT);
virtual void set_typename() override;
virtual void add (idx_t n, const float *x) override;
virtual void reset() override;
virtual void search (
idx_t n,
const float *x,
idx_t k,
float *distances,
idx_t *labels) const override;
virtual void range_search (
idx_t n,
const float *x,
float radius,
RangeSearchResult *result) const override;
virtual void reconstruct (idx_t key, float * recons)
const override;
/** compute distance with a subset of vectors
*
* @param x query vectors, size n * d
* @param labels indices of the vectors that should be compared
* for each query vector, size n * k
* @param distances
* corresponding output distances, size n * k
*/
void compute_distance_subset (
idx_t n,
const float *x,
idx_t k,
float *distances,
const idx_t *labels) const;
IndexFlat () {}
};
struct IndexFlatIP:IndexFlat {
explicit IndexFlatIP (idx_t d): IndexFlat (d, METRIC_INNER_PRODUCT) {}
IndexFlatIP () {}
};
struct IndexFlatL2:IndexFlat {
explicit IndexFlatL2 (idx_t d): IndexFlat (d, METRIC_L2) {}
IndexFlatL2 () {}
};
// same as an IndexFlatL2 but a value is subtracted from each distance
struct IndexFlatL2BaseShift: IndexFlatL2 {
std::vector<float> shift;
IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift);
virtual void search (
idx_t n,
const float *x,
idx_t k,
float *distances,
idx_t *labels) const override;
};
/** Index that queries in a base_index (a fast one) and refines the
* results with an exact search, hopefully improving the results.
*/
struct IndexRefineFlat: Index {
/// storage for full vectors
IndexFlat refine_index;
/// faster index to pre-select the vectors that should be filtered
Index *base_index;
bool own_fields; ///< should the base index be deallocated?
/// factor between k requested in search and the k requested from
/// the base_index (should be >= 1)
float k_factor;
explicit IndexRefineFlat (Index *base_index);
IndexRefineFlat ();
virtual void train (idx_t n, const float *x) override;
virtual void add (idx_t n, const float *x) override;
virtual void reset() override;
virtual void search (
idx_t n,
const float *x,
idx_t k,
float *distances,
idx_t *labels) const override;
virtual void set_typename () override;
virtual ~IndexRefineFlat ();
};
/// optimized version for 1D "vectors"
struct IndexFlat1D:IndexFlatL2 {
bool continuous_update; ///< is the permutation updated continuously?
std::vector<idx_t> perm; ///< sorted database indices
explicit IndexFlat1D (bool continuous_update=true);
/// if not continuous_update, call this between the last add and
/// the first search
void update_permutation ();
virtual void add (idx_t n, const float *x) override;
virtual void reset() override;
/// Warn: the distances returned are L1 not L2
virtual void search (
idx_t n,
const float *x,
idx_t k,
float *distances,
idx_t *labels) const override;
};
}
#endif