faiss/gpu/GpuIndexFlat.h

183 lines
5.4 KiB
C++

/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the CC-by-NC license found in the
* LICENSE file in the root directory of this source tree.
*/
// Copyright 2004-present Facebook. All Rights Reserved.
#pragma once
#include "GpuIndex.h"
namespace faiss {
struct IndexFlat;
struct IndexFlatL2;
struct IndexFlatIP;
}
namespace faiss { namespace gpu {
struct FlatIndex;
/// Wrapper around the GPU implementation that looks like
/// faiss::IndexFlat; copies over centroid data from a given
/// faiss::IndexFlat
class GpuIndexFlat : public GpuIndex {
public:
/// Construct from a pre-existing faiss::IndexFlat instance, copying
/// data over to the given GPU
GpuIndexFlat(GpuResources* resources,
int device,
bool useFloat16,
const faiss::IndexFlat* index);
/// Construct an empty instance that can be added to
GpuIndexFlat(GpuResources* resources,
int device,
int dims,
bool useFloat16,
faiss::MetricType metric);
~GpuIndexFlat() override;
/// Set the minimum data size for searches (in MiB) for which we use
/// CPU -> GPU paging
void setMinPagingSize(size_t size);
/// Returns the current minimum data size for paged searches
size_t getMinPagingSize() const;
/// Do we store vectors and perform math in float16?
bool getUseFloat16() const;
/// Initialize ourselves from the given CPU index; will overwrite
/// all data in ourselves
void copyFrom(const faiss::IndexFlat* index);
/// Copy ourselves to the given CPU index; will overwrite all data
/// in the index instance
void copyTo(faiss::IndexFlat* index) const;
/// Returns the number of vectors we contain
size_t getNumVecs() const;
/// Clears all vectors from this index
void reset() override;
/// This index is not trained, so this does nothing
void train(Index::idx_t n, const float* x) override;
/// `x` can be resident on the CPU or any GPU; the proper copies are
/// performed
void add(Index::idx_t n, const float* x) override;
/// `x`, `distances` and `labels` can be resident on the CPU or any
/// GPU; copies are performed as needed
void search(faiss::Index::idx_t n,
const float* x,
faiss::Index::idx_t k,
float* distances,
faiss::Index::idx_t* labels) const override;
/// Reconstruction methods; prefer the batch reconstruct as it will
/// be more efficient
void reconstruct(faiss::Index::idx_t key, float* out) const override;
/// Batch reconstruction method
void reconstruct_n(faiss::Index::idx_t i0,
faiss::Index::idx_t num,
float* out) const override;
void set_typename() override;
/// For internal access
inline FlatIndex* getGpuData() { return data_; }
protected:
/// Called from search when the input data is on the CPU;
/// potentially allows for pinned memory usage
void searchFromCpuPaged_(int n,
const float* x,
int k,
float* outDistancesData,
int* outIndicesData) const;
void searchNonPaged_(int n,
const float* x,
int k,
float* outDistancesData,
int* outIndicesData) const;
protected:
/// Size above which we page copies from the CPU to GPU
size_t minPagedSize_;
/// Whether or not we store our vectors in float32 or float16
const bool useFloat16_;
/// Holds our GPU data containing the list of vectors
FlatIndex* data_;
};
/// Wrapper around the GPU implementation that looks like
/// faiss::IndexFlatL2; copies over centroid data from a given
/// faiss::IndexFlat
class GpuIndexFlatL2 : public GpuIndexFlat {
public:
/// Construct from a pre-existing faiss::IndexFlatL2 instance, copying
/// data over to the given GPU
GpuIndexFlatL2(GpuResources* resources,
int device,
bool useFloat16,
faiss::IndexFlatL2* index);
/// Construct an empty instance that can be added to
GpuIndexFlatL2(GpuResources* resources,
int device,
int dims,
bool useFloat16);
/// Initialize ourselves from the given CPU index; will overwrite
/// all data in ourselves
void copyFrom(faiss::IndexFlatL2* index);
/// Copy ourselves to the given CPU index; will overwrite all data
/// in the index instance
void copyTo(faiss::IndexFlatL2* index);
};
/// Wrapper around the GPU implementation that looks like
/// faiss::IndexFlatIP; copies over centroid data from a given
/// faiss::IndexFlat
class GpuIndexFlatIP : public GpuIndexFlat {
public:
/// Construct from a pre-existing faiss::IndexFlatIP instance, copying
/// data over to the given GPU
GpuIndexFlatIP(GpuResources* resources,
int device,
bool useFloat16,
faiss::IndexFlatIP* index);
/// Construct an empty instance that can be added to
GpuIndexFlatIP(GpuResources* resources,
int device,
int dims,
bool useFloat16);
/// Initialize ourselves from the given CPU index; will overwrite
/// all data in ourselves
void copyFrom(faiss::IndexFlatIP* index);
/// Copy ourselves to the given CPU index; will overwrite all data
/// in the index instance
void copyTo(faiss::IndexFlatIP* index);
};
} } // namespace