faiss/gpu/GpuIndexBinaryFlat.h

90 lines
2.7 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/IndexBinaryFlat.h>
#include <faiss/gpu/GpuIndex.h>
namespace faiss { namespace gpu {
class BinaryFlatIndex;
class GpuResources;
struct GpuIndexBinaryFlatConfig : public GpuIndexConfig {
};
/// A GPU version of IndexBinaryFlat for brute-force comparison of bit vectors
/// via Hamming distance
class GpuIndexBinaryFlat : public IndexBinary {
public:
/// Construct from a pre-existing faiss::IndexBinaryFlat instance, copying
/// data over to the given GPU
GpuIndexBinaryFlat(GpuResources* resources,
const faiss::IndexBinaryFlat* index,
GpuIndexBinaryFlatConfig config =
GpuIndexBinaryFlatConfig());
/// Construct an empty instance that can be added to
GpuIndexBinaryFlat(GpuResources* resources,
int dims,
GpuIndexBinaryFlatConfig config =
GpuIndexBinaryFlatConfig());
~GpuIndexBinaryFlat() override;
/// Initialize ourselves from the given CPU index; will overwrite
/// all data in ourselves
void copyFrom(const faiss::IndexBinaryFlat* index);
/// Copy ourselves to the given CPU index; will overwrite all data
/// in the index instance
void copyTo(faiss::IndexBinaryFlat* index) const;
void add(faiss::IndexBinary::idx_t n,
const uint8_t* x) override;
void reset() override;
void search(faiss::IndexBinary::idx_t n,
const uint8_t* x,
faiss::IndexBinary::idx_t k,
int32_t* distances,
faiss::IndexBinary::idx_t* labels) const override;
void reconstruct(faiss::IndexBinary::idx_t key,
uint8_t* recons) const override;
protected:
/// Called from search when the input data is on the CPU;
/// potentially allows for pinned memory usage
void searchFromCpuPaged_(int n,
const uint8_t* x,
int k,
int32_t* outDistancesData,
int* outIndicesData) const;
void searchNonPaged_(int n,
const uint8_t* x,
int k,
int32_t* outDistancesData,
int* outIndicesData) const;
protected:
/// Manages streans, cuBLAS handles and scratch memory for devices
GpuResources* resources_;
/// Configuration options
GpuIndexBinaryFlatConfig config_;
/// Holds our GPU data containing the list of vectors; is managed via raw
/// pointer so as to allow non-CUDA compilers to see this header
BinaryFlatIndex* data_;
};
} } // namespace gpu