mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Integrate IVF-Flat from RAFT (#2521)
Summary: This is a design proposal that demonstrates an approach to enabling optional support for [RAFT](https://github.com/rapidsai/raft) versions of IVF PQ and IVF Flat (and brute force w/ fused k-selection when k <= 64). There are still a few open issues and design discussions needed for the new RAFT index types to support the full range of features of that FAISS' current gpu index types. Checklist for the integration todos: - [x] Rebase on current `main` branch - [X] The raft handle has been plugged directly into the StandardGpuResources - [X] `FlatIndex` passing Googletests - [x] Use `CodePacker` to support `copyFrom()` and `copyTo()` - [X] `IVF-flat passing Googletests - [ ] Raise appropriate exceptions for operations which are not yet supported by RAFT Additional features we've discussed: - [x] Separate IVF lists into individual memory chunks - [ ] Saving/loading To build FAISS w/ optional RAFT support: ``` mkdir build cd build cmake ../ -DFAISS_ENABLE_RAFT=ON -DFAISS_ENABLE_GPU=ON make -j ``` For development/testing, we've also supplied a bash script to make things easier: `build.sh` Below is a benchmark comparing the training of IVF Flat indices for RAFT and FAISS:  The benchmark was produced using Googlebench in [this](https://github.com/tfeher/raft/tree/raft_faiss_bench) RAFT fork. We're going to provide benchmarks for the queries as well. There are still a couple bottlenecks to be removed in the IVF-Flat training implementation and we'll update the current benchmark when ready. Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2521 Test Plan: `buck test mode/debuck test mode/dev-nosan //faiss/gpu/test:test_gpu_index_ivfflat` Reviewed By: algoriddle Differential Revision: D49118319 Pulled By: mdouze fbshipit-source-id: 5916108bc27154acf7c92021ba579a6ca85d730b
This commit is contained in:
parent
458633c203
commit
edcf7438bb
58
build.sh
Normal file
58
build.sh
Normal file
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
|
||||
# NOTE: This file is temporary for the proof-of-concept branch and will be removed before this PR is merged
|
||||
|
||||
BUILD_TYPE=Release
|
||||
BUILD_DIR=build/
|
||||
|
||||
RAFT_REPO_REL=""
|
||||
EXTRA_CMAKE_ARGS=""
|
||||
set -e
|
||||
|
||||
if [[ ${RAFT_REPO_REL} != "" ]]; then
|
||||
RAFT_REPO_PATH="`readlink -f \"${RAFT_REPO_REL}\"`"
|
||||
EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DCPM_raft_SOURCE=${RAFT_REPO_PATH}"
|
||||
fi
|
||||
|
||||
if [ "$1" == "clean" ]; then
|
||||
rm -rf build
|
||||
rm -rf .cache
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$1" == "test" ]; then
|
||||
make -C build -j test
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$1" == "test-raft" ]; then
|
||||
./build/faiss/gpu/test/TestRaftIndexIVFFlat
|
||||
exit 0
|
||||
fi
|
||||
|
||||
mkdir -p $BUILD_DIR
|
||||
cd $BUILD_DIR
|
||||
|
||||
cmake \
|
||||
-DFAISS_ENABLE_GPU=ON \
|
||||
-DFAISS_ENABLE_RAFT=ON \
|
||||
-DFAISS_ENABLE_PYTHON=OFF \
|
||||
-DBUILD_TESTING=ON \
|
||||
-DBUILD_SHARED_LIBS=OFF \
|
||||
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
|
||||
-DFAISS_OPT_LEVEL=avx2 \
|
||||
-DRAFT_NVTX=OFF \
|
||||
-DCMAKE_CUDA_ARCHITECTURES="NATIVE" \
|
||||
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
|
||||
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
|
||||
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
|
||||
${EXTRA_CMAKE_ARGS} \
|
||||
../
|
||||
|
||||
|
||||
# make -C build -j12 faiss
|
||||
cmake --build . -j12
|
||||
# make -C build -j12 swigfaiss
|
||||
# (cd build/faiss/python && python setup.py install)
|
||||
|
2
cmake/thirdparty/fetch_rapids.cmake
vendored
2
cmake/thirdparty/fetch_rapids.cmake
vendored
@ -15,7 +15,7 @@
|
||||
# or implied. See the License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
# =============================================================================
|
||||
set(RAPIDS_VERSION "23.06")
|
||||
set(RAPIDS_VERSION "23.08")
|
||||
|
||||
if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
|
||||
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake
|
||||
|
@ -238,9 +238,11 @@ generate_ivf_interleaved_code()
|
||||
|
||||
if(FAISS_ENABLE_RAFT)
|
||||
list(APPEND FAISS_GPU_HEADERS
|
||||
impl/RaftIVFFlat.cuh
|
||||
impl/RaftFlatIndex.cuh)
|
||||
list(APPEND FAISS_GPU_SRC
|
||||
impl/RaftFlatIndex.cu)
|
||||
impl/RaftFlatIndex.cu
|
||||
impl/RaftIVFFlat.cu)
|
||||
|
||||
target_compile_definitions(faiss PUBLIC USE_NVIDIA_RAFT=1)
|
||||
target_compile_definitions(faiss_avx2 PUBLIC USE_NVIDIA_RAFT=1)
|
||||
|
@ -16,6 +16,11 @@
|
||||
#include <faiss/gpu/impl/IVFBase.cuh>
|
||||
#include <faiss/gpu/utils/CopyUtils.cuh>
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
#include <raft/core/handle.hpp>
|
||||
#include <raft/neighbors/ivf_flat.cuh>
|
||||
#endif
|
||||
|
||||
namespace faiss {
|
||||
namespace gpu {
|
||||
|
||||
@ -444,14 +449,46 @@ void GpuIndexIVF::trainQuantizer_(idx_t n, const float* x) {
|
||||
printf("Training IVF quantizer on %ld vectors in %dD\n", n, d);
|
||||
}
|
||||
|
||||
// leverage the CPU-side k-means code, which works for the GPU
|
||||
// flat index as well
|
||||
quantizer->reset();
|
||||
Clustering clus(this->d, nlist, this->cp);
|
||||
clus.verbose = verbose;
|
||||
clus.train(n, x, *quantizer);
|
||||
quantizer->is_trained = true;
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
|
||||
if (config_.use_raft) {
|
||||
const raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
|
||||
raft::neighbors::ivf_flat::index_params raft_idx_params;
|
||||
raft_idx_params.n_lists = nlist;
|
||||
raft_idx_params.metric = metric_type == faiss::METRIC_L2
|
||||
? raft::distance::DistanceType::L2Expanded
|
||||
: raft::distance::DistanceType::InnerProduct;
|
||||
raft_idx_params.add_data_on_build = false;
|
||||
raft_idx_params.kmeans_trainset_fraction = 1.0;
|
||||
raft_idx_params.kmeans_n_iters = cp.niter;
|
||||
raft_idx_params.adaptive_centers = !cp.frozen_centroids;
|
||||
|
||||
auto raft_index = raft::neighbors::ivf_flat::build(
|
||||
raft_handle, raft_idx_params, x, n, (idx_t)d);
|
||||
|
||||
raft_handle.sync_stream();
|
||||
|
||||
quantizer->train(nlist, raft_index.centers().data_handle());
|
||||
quantizer->add(nlist, raft_index.centers().data_handle());
|
||||
} else
|
||||
#else
|
||||
if (config_.use_raft) {
|
||||
FAISS_THROW_MSG(
|
||||
"RAFT has not been compiled into the current version so it cannot be used.");
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
// leverage the CPU-side k-means code, which works for the GPU
|
||||
// flat index as well
|
||||
Clustering clus(this->d, nlist, this->cp);
|
||||
clus.verbose = verbose;
|
||||
clus.train(n, x, *quantizer);
|
||||
}
|
||||
quantizer->is_trained = true;
|
||||
FAISS_ASSERT(quantizer->ntotal == nlist);
|
||||
}
|
||||
|
||||
|
@ -73,10 +73,10 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
|
||||
virtual void updateQuantizer() = 0;
|
||||
|
||||
/// Returns the number of inverted lists we're managing
|
||||
idx_t getNumLists() const;
|
||||
virtual idx_t getNumLists() const;
|
||||
|
||||
/// Returns the number of vectors present in a particular inverted list
|
||||
idx_t getListLength(idx_t listId) const;
|
||||
virtual idx_t getListLength(idx_t listId) const;
|
||||
|
||||
/// Return the encoded vector data contained in a particular inverted list,
|
||||
/// for debugging purposes.
|
||||
@ -84,12 +84,13 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
|
||||
/// GPU-side representation.
|
||||
/// Otherwise, it is converted to the CPU format.
|
||||
/// compliant format, while the native GPU format may differ.
|
||||
std::vector<uint8_t> getListVectorData(idx_t listId, bool gpuFormat = false)
|
||||
const;
|
||||
virtual std::vector<uint8_t> getListVectorData(
|
||||
idx_t listId,
|
||||
bool gpuFormat = false) const;
|
||||
|
||||
/// Return the vector indices contained in a particular inverted list, for
|
||||
/// debugging purposes.
|
||||
std::vector<idx_t> getListIndices(idx_t listId) const;
|
||||
virtual std::vector<idx_t> getListIndices(idx_t listId) const;
|
||||
|
||||
void search_preassigned(
|
||||
idx_t n,
|
||||
@ -121,7 +122,7 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
|
||||
int getCurrentNProbe_(const SearchParameters* params) const;
|
||||
void verifyIVFSettings_() const;
|
||||
bool addImplRequiresIDs_() const override;
|
||||
void trainQuantizer_(idx_t n, const float* x);
|
||||
virtual void trainQuantizer_(idx_t n, const float* x);
|
||||
|
||||
/// Called from GpuIndex for add/add_with_ids
|
||||
void addImpl_(idx_t n, const float* x, const idx_t* ids) override;
|
||||
|
@ -15,6 +15,10 @@
|
||||
#include <faiss/gpu/utils/CopyUtils.cuh>
|
||||
#include <faiss/gpu/utils/Float16.cuh>
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
#include <faiss/gpu/impl/RaftIVFFlat.cuh>
|
||||
#endif
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace faiss {
|
||||
@ -70,8 +74,7 @@ GpuIndexIVFFlat::GpuIndexIVFFlat(
|
||||
// no other quantizer that we need to train, so this is sufficient
|
||||
if (this->is_trained) {
|
||||
FAISS_ASSERT(this->quantizer);
|
||||
|
||||
index_.reset(new IVFFlat(
|
||||
set_index_(
|
||||
resources_.get(),
|
||||
this->d,
|
||||
this->nlist,
|
||||
@ -81,7 +84,7 @@ GpuIndexIVFFlat::GpuIndexIVFFlat(
|
||||
nullptr, // no scalar quantizer
|
||||
ivfFlatConfig_.interleavedLayout,
|
||||
ivfFlatConfig_.indicesOptions,
|
||||
config_.memorySpace));
|
||||
config_.memorySpace);
|
||||
baseIndex_ = std::static_pointer_cast<IVFBase, IVFFlat>(index_);
|
||||
updateQuantizer();
|
||||
}
|
||||
@ -89,6 +92,54 @@ GpuIndexIVFFlat::GpuIndexIVFFlat(
|
||||
|
||||
GpuIndexIVFFlat::~GpuIndexIVFFlat() {}
|
||||
|
||||
void GpuIndexIVFFlat::set_index_(
|
||||
GpuResources* resources,
|
||||
int dim,
|
||||
int nlist,
|
||||
faiss::MetricType metric,
|
||||
float metricArg,
|
||||
bool useResidual,
|
||||
/// Optional ScalarQuantizer
|
||||
faiss::ScalarQuantizer* scalarQ,
|
||||
bool interleavedLayout,
|
||||
IndicesOptions indicesOptions,
|
||||
MemorySpace space) {
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
|
||||
if (config_.use_raft) {
|
||||
index_.reset(new RaftIVFFlat(
|
||||
resources,
|
||||
dim,
|
||||
nlist,
|
||||
metric,
|
||||
metricArg,
|
||||
useResidual,
|
||||
scalarQ,
|
||||
interleavedLayout,
|
||||
indicesOptions,
|
||||
space));
|
||||
} else
|
||||
#else
|
||||
if (config_.use_raft) {
|
||||
FAISS_THROW_MSG(
|
||||
"RAFT has not been compiled into the current version so it cannot be used.");
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
index_.reset(new IVFFlat(
|
||||
resources,
|
||||
dim,
|
||||
nlist,
|
||||
metric,
|
||||
metricArg,
|
||||
useResidual,
|
||||
scalarQ,
|
||||
interleavedLayout,
|
||||
indicesOptions,
|
||||
space));
|
||||
}
|
||||
}
|
||||
|
||||
void GpuIndexIVFFlat::reserveMemory(size_t numVecs) {
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
@ -110,25 +161,25 @@ void GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
|
||||
|
||||
// The other index might not be trained
|
||||
if (!index->is_trained) {
|
||||
FAISS_ASSERT(!this->is_trained);
|
||||
FAISS_ASSERT(!is_trained);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, we can populate ourselves from the other index
|
||||
FAISS_ASSERT(this->is_trained);
|
||||
FAISS_ASSERT(is_trained);
|
||||
|
||||
// Copy our lists as well
|
||||
index_.reset(new IVFFlat(
|
||||
set_index_(
|
||||
resources_.get(),
|
||||
this->d,
|
||||
this->nlist,
|
||||
d,
|
||||
nlist,
|
||||
index->metric_type,
|
||||
index->metric_arg,
|
||||
false, // no residual
|
||||
nullptr, // no scalar quantizer
|
||||
ivfFlatConfig_.interleavedLayout,
|
||||
ivfFlatConfig_.indicesOptions,
|
||||
config_.memorySpace));
|
||||
config_.memorySpace);
|
||||
baseIndex_ = std::static_pointer_cast<IVFBase, IVFFlat>(index_);
|
||||
updateQuantizer();
|
||||
|
||||
@ -201,18 +252,30 @@ void GpuIndexIVFFlat::train(idx_t n, const float* x) {
|
||||
|
||||
FAISS_ASSERT(!index_);
|
||||
|
||||
// FIXME: GPUize more of this
|
||||
// First, make sure that the data is resident on the CPU, if it is not on
|
||||
// the CPU, as we depend upon parts of the CPU code
|
||||
auto hostData = toHost<float, 2>(
|
||||
(float*)x,
|
||||
resources_->getDefaultStream(config_.device),
|
||||
{n, this->d});
|
||||
|
||||
trainQuantizer_(n, hostData.data());
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
if (config_.use_raft) {
|
||||
// No need to copy the data to host
|
||||
trainQuantizer_(n, x);
|
||||
} else
|
||||
#else
|
||||
if (config_.use_raft) {
|
||||
FAISS_THROW_MSG(
|
||||
"RAFT has not been compiled into the current version so it cannot be used.");
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
// FIXME: GPUize more of this
|
||||
// First, make sure that the data is resident on the CPU, if it is not
|
||||
// on the CPU, as we depend upon parts of the CPU code
|
||||
auto hostData = toHost<float, 2>(
|
||||
(float*)x,
|
||||
resources_->getDefaultStream(config_.device),
|
||||
{n, this->d});
|
||||
trainQuantizer_(n, hostData.data());
|
||||
}
|
||||
|
||||
// The quantizer is now trained; construct the IVF index
|
||||
index_.reset(new IVFFlat(
|
||||
set_index_(
|
||||
resources_.get(),
|
||||
this->d,
|
||||
this->nlist,
|
||||
@ -222,7 +285,7 @@ void GpuIndexIVFFlat::train(idx_t n, const float* x) {
|
||||
nullptr, // no scalar quantizer
|
||||
ivfFlatConfig_.interleavedLayout,
|
||||
ivfFlatConfig_.indicesOptions,
|
||||
config_.memorySpace));
|
||||
config_.memorySpace);
|
||||
baseIndex_ = std::static_pointer_cast<IVFBase, IVFFlat>(index_);
|
||||
updateQuantizer();
|
||||
|
||||
|
@ -8,6 +8,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
#include <faiss/impl/ScalarQuantizer.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace faiss {
|
||||
@ -86,6 +88,19 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
|
||||
void train(idx_t n, const float* x) override;
|
||||
|
||||
protected:
|
||||
void set_index_(
|
||||
GpuResources* resources,
|
||||
int dim,
|
||||
int nlist,
|
||||
faiss::MetricType metric,
|
||||
float metricArg,
|
||||
bool useResidual,
|
||||
/// Optional ScalarQuantizer
|
||||
faiss::ScalarQuantizer* scalarQ,
|
||||
bool interleavedLayout,
|
||||
IndicesOptions indicesOptions,
|
||||
MemorySpace space);
|
||||
|
||||
/// Our configuration options
|
||||
const GpuIndexIVFFlatConfig ivfFlatConfig_;
|
||||
|
||||
|
@ -362,7 +362,11 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) {
|
||||
|
||||
defaultStreams_[device] = defaultStream;
|
||||
|
||||
cudaStream_t asyncCopyStream = nullptr;
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
raftHandles_.emplace(std::make_pair(device, defaultStream));
|
||||
#endif
|
||||
|
||||
cudaStream_t asyncCopyStream = 0;
|
||||
CUDA_VERIFY(
|
||||
cudaStreamCreateWithFlags(&asyncCopyStream, cudaStreamNonBlocking));
|
||||
|
||||
|
@ -45,7 +45,7 @@ class IVFBase {
|
||||
|
||||
/// Clear out all inverted lists, but retain the coarse quantizer
|
||||
/// and the product quantizer info
|
||||
void reset();
|
||||
virtual void reset();
|
||||
|
||||
/// Return the number of dimensions we are indexing
|
||||
idx_t getDim() const;
|
||||
@ -59,29 +59,30 @@ class IVFBase {
|
||||
|
||||
/// For debugging purposes, return the list length of a particular
|
||||
/// list
|
||||
idx_t getListLength(idx_t listId) const;
|
||||
virtual idx_t getListLength(idx_t listId) const;
|
||||
|
||||
/// Return the list indices of a particular list back to the CPU
|
||||
std::vector<idx_t> getListIndices(idx_t listId) const;
|
||||
virtual std::vector<idx_t> getListIndices(idx_t listId) const;
|
||||
|
||||
/// Return the encoded vectors of a particular list back to the CPU
|
||||
std::vector<uint8_t> getListVectorData(idx_t listId, bool gpuFormat) const;
|
||||
virtual std::vector<uint8_t> getListVectorData(idx_t listId, bool gpuFormat)
|
||||
const;
|
||||
|
||||
/// Copy all inverted lists from a CPU representation to ourselves
|
||||
void copyInvertedListsFrom(const InvertedLists* ivf);
|
||||
virtual void copyInvertedListsFrom(const InvertedLists* ivf);
|
||||
|
||||
/// Copy all inverted lists from ourselves to a CPU representation
|
||||
void copyInvertedListsTo(InvertedLists* ivf);
|
||||
virtual void copyInvertedListsTo(InvertedLists* ivf);
|
||||
|
||||
/// Update our coarse quantizer with this quantizer instance; may be a CPU
|
||||
/// or GPU quantizer
|
||||
void updateQuantizer(Index* quantizer);
|
||||
virtual void updateQuantizer(Index* quantizer);
|
||||
|
||||
/// Classify and encode/add vectors to our IVF lists.
|
||||
/// The input data must be on our current device.
|
||||
/// Returns the number of vectors successfully added. Vectors may
|
||||
/// not be able to be added because they contain NaNs.
|
||||
idx_t addVectors(
|
||||
virtual idx_t addVectors(
|
||||
Index* coarseQuantizer,
|
||||
Tensor<float, 2, true>& vecs,
|
||||
Tensor<idx_t, 1, true>& indices);
|
||||
@ -111,7 +112,7 @@ class IVFBase {
|
||||
protected:
|
||||
/// Adds a set of codes and indices to a list, with the representation
|
||||
/// coming from the CPU equivalent
|
||||
void addEncodedVectorsToList_(
|
||||
virtual void addEncodedVectorsToList_(
|
||||
idx_t listId,
|
||||
// resident on the host
|
||||
const void* codes,
|
||||
|
@ -60,17 +60,17 @@ class IVFFlat : public IVFBase {
|
||||
size_t getCpuVectorsEncodingSize_(idx_t numVecs) const override;
|
||||
|
||||
/// Translate to our preferred GPU encoding
|
||||
std::vector<uint8_t> translateCodesToGpu_(
|
||||
virtual std::vector<uint8_t> translateCodesToGpu_(
|
||||
std::vector<uint8_t> codes,
|
||||
idx_t numVecs) const override;
|
||||
|
||||
/// Translate from our preferred GPU encoding
|
||||
std::vector<uint8_t> translateCodesFromGpu_(
|
||||
virtual std::vector<uint8_t> translateCodesFromGpu_(
|
||||
std::vector<uint8_t> codes,
|
||||
idx_t numVecs) const override;
|
||||
|
||||
/// Encode the vectors that we're adding and append to our IVF lists
|
||||
void appendVectors_(
|
||||
virtual void appendVectors_(
|
||||
Tensor<float, 2, true>& vecs,
|
||||
Tensor<float, 2, true>& ivfCentroidResiduals,
|
||||
Tensor<idx_t, 1, true>& indices,
|
||||
@ -84,7 +84,7 @@ class IVFFlat : public IVFBase {
|
||||
|
||||
/// Shared IVF search implementation, used by both search and
|
||||
/// searchPreassigned
|
||||
void searchImpl_(
|
||||
virtual void searchImpl_(
|
||||
Tensor<float, 2, true>& queries,
|
||||
Tensor<float, 2, true>& coarseDistances,
|
||||
Tensor<idx_t, 2, true>& coarseIndices,
|
||||
|
604
faiss/gpu/impl/RaftIVFFlat.cu
Normal file
604
faiss/gpu/impl/RaftIVFFlat.cu
Normal file
@ -0,0 +1,604 @@
|
||||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
/*
|
||||
* Copyright (c) 2023, NVIDIA CORPORATION.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include <faiss/gpu/GpuIndex.h>
|
||||
#include <faiss/gpu/GpuResources.h>
|
||||
#include <faiss/gpu/impl/InterleavedCodes.h>
|
||||
#include <faiss/gpu/impl/RemapIndices.h>
|
||||
#include <faiss/gpu/utils/DeviceUtils.h>
|
||||
#include <thrust/host_vector.h>
|
||||
#include <faiss/gpu/impl/FlatIndex.cuh>
|
||||
#include <faiss/gpu/impl/IVFAppend.cuh>
|
||||
#include <faiss/gpu/impl/IVFFlat.cuh>
|
||||
#include <faiss/gpu/impl/IVFFlatScan.cuh>
|
||||
#include <faiss/gpu/impl/IVFInterleaved.cuh>
|
||||
#include <faiss/gpu/impl/RaftIVFFlat.cuh>
|
||||
#include <faiss/gpu/utils/ConversionOperators.cuh>
|
||||
#include <faiss/gpu/utils/CopyUtils.cuh>
|
||||
#include <faiss/gpu/utils/DeviceDefs.cuh>
|
||||
#include <faiss/gpu/utils/Float16.cuh>
|
||||
#include <faiss/gpu/utils/HostTensor.cuh>
|
||||
#include <faiss/gpu/utils/Transpose.cuh>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <raft/core/device_mdspan.hpp>
|
||||
#include <raft/core/handle.hpp>
|
||||
#include <raft/neighbors/ivf_flat_codepacker.hpp>
|
||||
#include <raft/neighbors/ivf_flat.cuh>
|
||||
|
||||
namespace faiss {
|
||||
namespace gpu {
|
||||
|
||||
RaftIVFFlat::RaftIVFFlat(
|
||||
GpuResources* res,
|
||||
int dim,
|
||||
int nlist,
|
||||
faiss::MetricType metric,
|
||||
float metricArg,
|
||||
bool useResidual,
|
||||
faiss::ScalarQuantizer* scalarQ,
|
||||
bool interleavedLayout,
|
||||
IndicesOptions indicesOptions,
|
||||
MemorySpace space)
|
||||
: IVFFlat(res,
|
||||
dim,
|
||||
nlist,
|
||||
metric,
|
||||
metricArg,
|
||||
useResidual,
|
||||
scalarQ,
|
||||
interleavedLayout,
|
||||
indicesOptions,
|
||||
space) {
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
indicesOptions == INDICES_64_BIT,
|
||||
"only INDICES_64_BIT is supported for RAFT index");
|
||||
reset();
|
||||
}
|
||||
|
||||
RaftIVFFlat::~RaftIVFFlat() {}
|
||||
|
||||
/// Find the approximate k nearest neighbors for `queries` against
|
||||
/// our database
|
||||
void RaftIVFFlat::search(
|
||||
Index* coarseQuantizer,
|
||||
Tensor<float, 2, true>& queries,
|
||||
int nprobe,
|
||||
int k,
|
||||
Tensor<float, 2, true>& outDistances,
|
||||
Tensor<idx_t, 2, true>& outIndices) {
|
||||
// TODO: We probably don't want to ignore the coarse quantizer here...
|
||||
|
||||
uint32_t numQueries = queries.getSize(0);
|
||||
uint32_t cols = queries.getSize(1);
|
||||
uint32_t k_ = k;
|
||||
|
||||
// Device is already set in GpuIndex::search
|
||||
FAISS_ASSERT(raft_knn_index.has_value());
|
||||
FAISS_ASSERT(numQueries > 0);
|
||||
FAISS_ASSERT(cols == dim_);
|
||||
FAISS_THROW_IF_NOT(nprobe > 0 && nprobe <= numLists_);
|
||||
|
||||
const raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
raft::neighbors::ivf_flat::search_params pams;
|
||||
pams.n_probes = nprobe;
|
||||
|
||||
auto queries_view = raft::make_device_matrix_view<const float, idx_t>(
|
||||
queries.data(), (idx_t)numQueries, (idx_t)cols);
|
||||
auto out_inds_view = raft::make_device_matrix_view<idx_t, idx_t>(
|
||||
outIndices.data(), (idx_t)numQueries, (idx_t)k_);
|
||||
auto out_dists_view = raft::make_device_matrix_view<float, idx_t>(
|
||||
outDistances.data(), (idx_t)numQueries, (idx_t)k_);
|
||||
|
||||
raft::neighbors::ivf_flat::search<float, idx_t>(
|
||||
raft_handle,
|
||||
pams,
|
||||
raft_knn_index.value(),
|
||||
queries_view,
|
||||
out_inds_view,
|
||||
out_dists_view);
|
||||
|
||||
/// Identify NaN rows and mask their nearest neighbors
|
||||
auto nan_flag = raft::make_device_vector<bool>(raft_handle, numQueries);
|
||||
|
||||
validRowIndices_(queries, nan_flag.data_handle());
|
||||
|
||||
raft::linalg::map_offset(
|
||||
raft_handle,
|
||||
raft::make_device_vector_view(outIndices.data(), numQueries * k_),
|
||||
[nan_flag = nan_flag.data_handle(),
|
||||
out_inds = outIndices.data(),
|
||||
k_] __device__(uint32_t i) {
|
||||
uint32_t row = i / k_;
|
||||
if (!nan_flag[row])
|
||||
return idx_t(-1);
|
||||
return out_inds[i];
|
||||
});
|
||||
|
||||
float max_val = std::numeric_limits<float>::max();
|
||||
raft::linalg::map_offset(
|
||||
raft_handle,
|
||||
raft::make_device_vector_view(outDistances.data(), numQueries * k_),
|
||||
[nan_flag = nan_flag.data_handle(),
|
||||
out_dists = outDistances.data(),
|
||||
max_val,
|
||||
k_] __device__(uint32_t i) {
|
||||
uint32_t row = i / k_;
|
||||
if (!nan_flag[row])
|
||||
return max_val;
|
||||
return out_dists[i];
|
||||
});
|
||||
}
|
||||
|
||||
/// Classify and encode/add vectors to our IVF lists.
|
||||
/// The input data must be on our current device.
|
||||
/// Returns the number of vectors successfully added. Vectors may
|
||||
/// not be able to be added because they contain NaNs.
|
||||
idx_t RaftIVFFlat::addVectors(
|
||||
Index* coarseQuantizer,
|
||||
Tensor<float, 2, true>& vecs,
|
||||
Tensor<idx_t, 1, true>& indices) {
|
||||
/// TODO: We probably don't want to ignore the coarse quantizer here
|
||||
|
||||
idx_t n_rows = vecs.getSize(0);
|
||||
|
||||
const raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
|
||||
/// Remove NaN values
|
||||
auto nan_flag = raft::make_device_vector<bool, idx_t>(raft_handle, n_rows);
|
||||
|
||||
validRowIndices_(vecs, nan_flag.data_handle());
|
||||
|
||||
idx_t n_rows_valid = thrust::reduce(
|
||||
raft_handle.get_thrust_policy(),
|
||||
nan_flag.data_handle(),
|
||||
nan_flag.data_handle() + n_rows,
|
||||
0);
|
||||
|
||||
if (n_rows_valid < n_rows) {
|
||||
auto gather_indices = raft::make_device_vector<idx_t, idx_t>(
|
||||
raft_handle, n_rows_valid);
|
||||
|
||||
auto count = thrust::make_counting_iterator(0);
|
||||
|
||||
thrust::copy_if(
|
||||
raft_handle.get_thrust_policy(),
|
||||
count,
|
||||
count + n_rows,
|
||||
gather_indices.data_handle(),
|
||||
[nan_flag = nan_flag.data_handle()] __device__(auto i) {
|
||||
return nan_flag[i];
|
||||
});
|
||||
|
||||
raft::matrix::gather(
|
||||
raft_handle,
|
||||
raft::make_device_matrix_view<float, idx_t>(
|
||||
vecs.data(), n_rows, dim_),
|
||||
raft::make_const_mdspan(gather_indices.view()),
|
||||
(idx_t)16);
|
||||
|
||||
auto valid_indices = raft::make_device_vector<idx_t, idx_t>(
|
||||
raft_handle, n_rows_valid);
|
||||
|
||||
raft::matrix::gather(
|
||||
raft_handle,
|
||||
raft::make_device_matrix_view<idx_t>(
|
||||
indices.data(), n_rows, (idx_t)1),
|
||||
raft::make_const_mdspan(gather_indices.view()));
|
||||
}
|
||||
|
||||
FAISS_ASSERT(raft_knn_index.has_value());
|
||||
raft_knn_index.emplace(raft::neighbors::ivf_flat::extend(
|
||||
raft_handle,
|
||||
raft::make_device_matrix_view<const float, idx_t>(
|
||||
vecs.data(), n_rows_valid, dim_),
|
||||
std::make_optional<raft::device_vector_view<const idx_t, idx_t>>(
|
||||
raft::make_device_vector_view<const idx_t, idx_t>(
|
||||
indices.data(), n_rows_valid)),
|
||||
raft_knn_index.value()));
|
||||
|
||||
return n_rows_valid;
|
||||
}
|
||||
|
||||
void RaftIVFFlat::reset() {
|
||||
raft_knn_index.reset();
|
||||
}
|
||||
|
||||
idx_t RaftIVFFlat::getListLength(idx_t listId) const {
|
||||
FAISS_ASSERT(raft_knn_index.has_value());
|
||||
const raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
|
||||
uint32_t size;
|
||||
raft::update_host(
|
||||
&size,
|
||||
raft_knn_index.value().list_sizes().data_handle() + listId,
|
||||
1,
|
||||
raft_handle.get_stream());
|
||||
raft_handle.sync_stream();
|
||||
|
||||
return static_cast<int>(size);
|
||||
}
|
||||
|
||||
/// Return the list indices of a particular list back to the CPU
|
||||
std::vector<idx_t> RaftIVFFlat::getListIndices(idx_t listId) const {
|
||||
FAISS_ASSERT(raft_knn_index.has_value());
|
||||
const raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
auto stream = raft_handle.get_stream();
|
||||
|
||||
idx_t listSize = getListLength(listId);
|
||||
|
||||
std::vector<idx_t> vec(listSize);
|
||||
|
||||
// fetch the list indices ptr on host
|
||||
idx_t* list_indices_ptr;
|
||||
|
||||
// fetch the list indices ptr on host
|
||||
raft::update_host(
|
||||
&list_indices_ptr,
|
||||
raft_knn_index.value().inds_ptrs().data_handle() + listId,
|
||||
1,
|
||||
stream);
|
||||
raft_handle.sync_stream();
|
||||
|
||||
raft::update_host(vec.data(), list_indices_ptr, listSize, stream);
|
||||
raft_handle.sync_stream();
|
||||
|
||||
return vec;
|
||||
}
|
||||
|
||||
/// Return the encoded vectors of a particular list back to the CPU
|
||||
std::vector<uint8_t> RaftIVFFlat::getListVectorData(
|
||||
idx_t listId,
|
||||
bool gpuFormat) const {
|
||||
if (gpuFormat) {
|
||||
FAISS_THROW_MSG("gpuFormat is not suppported for raft indices");
|
||||
}
|
||||
FAISS_ASSERT(raft_knn_index.has_value());
|
||||
|
||||
const raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
auto stream = raft_handle.get_stream();
|
||||
|
||||
idx_t listSize = getListLength(listId);
|
||||
|
||||
// the interleaved block can be slightly larger than the list size (it's
|
||||
// rounded up)
|
||||
auto gpuListSizeInBytes = getGpuVectorsEncodingSize_(listSize);
|
||||
auto cpuListSizeInBytes = getCpuVectorsEncodingSize_(listSize);
|
||||
|
||||
std::vector<uint8_t> interleaved_codes(gpuListSizeInBytes);
|
||||
std::vector<uint8_t> flat_codes(cpuListSizeInBytes);
|
||||
|
||||
float* list_data_ptr;
|
||||
|
||||
// fetch the list data ptr on host
|
||||
raft::update_host(
|
||||
&list_data_ptr,
|
||||
raft_knn_index.value().data_ptrs().data_handle() + listId,
|
||||
1,
|
||||
stream);
|
||||
raft_handle.sync_stream();
|
||||
|
||||
raft::update_host(
|
||||
interleaved_codes.data(),
|
||||
reinterpret_cast<uint8_t*>(list_data_ptr),
|
||||
gpuListSizeInBytes,
|
||||
stream);
|
||||
raft_handle.sync_stream();
|
||||
|
||||
RaftIVFFlatCodePackerInterleaved packer(
|
||||
(size_t)listSize, dim_, raft_knn_index.value().veclen());
|
||||
packer.unpack_all(interleaved_codes.data(), flat_codes.data());
|
||||
return flat_codes;
|
||||
}
|
||||
|
||||
/// Performs search when we are already given the IVF cells to look at
|
||||
/// (GpuIndexIVF::search_preassigned implementation)
|
||||
void RaftIVFFlat::searchPreassigned(
|
||||
Index* coarseQuantizer,
|
||||
Tensor<float, 2, true>& vecs,
|
||||
Tensor<float, 2, true>& ivfDistances,
|
||||
Tensor<idx_t, 2, true>& ivfAssignments,
|
||||
int k,
|
||||
Tensor<float, 2, true>& outDistances,
|
||||
Tensor<idx_t, 2, true>& outIndices,
|
||||
bool storePairs) {
|
||||
// TODO: Fill this in!
|
||||
}
|
||||
|
||||
void RaftIVFFlat::updateQuantizer(Index* quantizer) {
|
||||
idx_t quantizer_ntotal = quantizer->ntotal;
|
||||
|
||||
const raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
auto stream = raft_handle.get_stream();
|
||||
|
||||
auto total_elems = size_t(quantizer_ntotal) * size_t(quantizer->d);
|
||||
|
||||
raft::logger::get().set_level(RAFT_LEVEL_TRACE);
|
||||
|
||||
raft::neighbors::ivf_flat::index_params pams;
|
||||
pams.add_data_on_build = false;
|
||||
|
||||
pams.n_lists = this->numLists_;
|
||||
|
||||
switch (this->metric_) {
|
||||
case faiss::METRIC_L2:
|
||||
pams.metric = raft::distance::DistanceType::L2Expanded;
|
||||
break;
|
||||
case faiss::METRIC_INNER_PRODUCT:
|
||||
pams.metric = raft::distance::DistanceType::InnerProduct;
|
||||
break;
|
||||
default:
|
||||
FAISS_THROW_MSG("Metric is not supported.");
|
||||
}
|
||||
|
||||
raft_knn_index.emplace(raft_handle, pams, (uint32_t)this->dim_);
|
||||
|
||||
cudaMemsetAsync(
|
||||
raft_knn_index.value().list_sizes().data_handle(),
|
||||
0,
|
||||
raft_knn_index.value().list_sizes().size() * sizeof(uint32_t),
|
||||
stream);
|
||||
cudaMemsetAsync(
|
||||
raft_knn_index.value().data_ptrs().data_handle(),
|
||||
0,
|
||||
raft_knn_index.value().data_ptrs().size() * sizeof(float*),
|
||||
stream);
|
||||
cudaMemsetAsync(
|
||||
raft_knn_index.value().inds_ptrs().data_handle(),
|
||||
0,
|
||||
raft_knn_index.value().inds_ptrs().size() * sizeof(idx_t*),
|
||||
stream);
|
||||
|
||||
/// Copy (reconstructed) centroids over, rather than re-training
|
||||
std::vector<float> buf_host(total_elems);
|
||||
quantizer->reconstruct_n(0, quantizer_ntotal, buf_host.data());
|
||||
|
||||
raft::update_device(
|
||||
raft_knn_index.value().centers().data_handle(),
|
||||
buf_host.data(),
|
||||
total_elems,
|
||||
stream);
|
||||
}
|
||||
|
||||
void RaftIVFFlat::copyInvertedListsFrom(const InvertedLists* ivf) {
|
||||
size_t nlist = ivf ? ivf->nlist : 0;
|
||||
size_t ntotal = ivf ? ivf->compute_ntotal() : 0;
|
||||
|
||||
raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
|
||||
std::vector<uint32_t> list_sizes_(nlist);
|
||||
std::vector<idx_t> indices_(ntotal);
|
||||
|
||||
// the index must already exist
|
||||
FAISS_ASSERT(raft_knn_index.has_value());
|
||||
|
||||
auto& raft_lists = raft_knn_index.value().lists();
|
||||
|
||||
// conservative memory alloc for cloning cpu inverted lists
|
||||
raft::neighbors::ivf_flat::list_spec<uint32_t, float, idx_t> raft_list_spec{
|
||||
static_cast<uint32_t>(dim_), true};
|
||||
|
||||
for (size_t i = 0; i < nlist; ++i) {
|
||||
size_t listSize = ivf->list_size(i);
|
||||
|
||||
// GPU index can only support max int entries per list
|
||||
FAISS_THROW_IF_NOT_FMT(
|
||||
listSize <= (size_t)std::numeric_limits<int>::max(),
|
||||
"GPU inverted list can only support "
|
||||
"%zu entries; %zu found",
|
||||
(size_t)std::numeric_limits<int>::max(),
|
||||
listSize);
|
||||
|
||||
// store the list size
|
||||
list_sizes_[i] = static_cast<uint32_t>(listSize);
|
||||
|
||||
raft::neighbors::ivf::resize_list(
|
||||
raft_handle,
|
||||
raft_lists[i],
|
||||
raft_list_spec,
|
||||
(uint32_t)listSize,
|
||||
(uint32_t)0);
|
||||
}
|
||||
|
||||
// Update the pointers and the sizes
|
||||
raft_knn_index.value().recompute_internal_state(raft_handle);
|
||||
|
||||
for (size_t i = 0; i < nlist; ++i) {
|
||||
size_t listSize = ivf->list_size(i);
|
||||
addEncodedVectorsToList_(
|
||||
i, ivf->get_codes(i), ivf->get_ids(i), listSize);
|
||||
}
|
||||
|
||||
raft::update_device(
|
||||
raft_knn_index.value().list_sizes().data_handle(),
|
||||
list_sizes_.data(),
|
||||
nlist,
|
||||
raft_handle.get_stream());
|
||||
|
||||
// Precompute the centers vector norms for L2Expanded distance
|
||||
if (this->metric_ == faiss::METRIC_L2) {
|
||||
raft_knn_index.value().allocate_center_norms(raft_handle);
|
||||
raft::linalg::rowNorm(
|
||||
raft_knn_index.value().center_norms().value().data_handle(),
|
||||
raft_knn_index.value().centers().data_handle(),
|
||||
raft_knn_index.value().dim(),
|
||||
(uint32_t)nlist,
|
||||
raft::linalg::L2Norm,
|
||||
true,
|
||||
raft_handle.get_stream());
|
||||
}
|
||||
}
|
||||
|
||||
size_t RaftIVFFlat::getGpuVectorsEncodingSize_(idx_t numVecs) const {
|
||||
idx_t bits = 32 /* float */;
|
||||
|
||||
// bytes to encode a block of 32 vectors (single dimension)
|
||||
idx_t bytesPerDimBlock = bits * 32 / 8; // = 128
|
||||
|
||||
// bytes to fully encode 32 vectors
|
||||
idx_t bytesPerBlock = bytesPerDimBlock * dim_;
|
||||
|
||||
// number of blocks of 32 vectors we have
|
||||
idx_t numBlocks =
|
||||
utils::divUp(numVecs, raft::neighbors::ivf_flat::kIndexGroupSize);
|
||||
|
||||
// total size to encode numVecs
|
||||
return bytesPerBlock * numBlocks;
|
||||
}
|
||||
|
||||
void RaftIVFFlat::addEncodedVectorsToList_(
|
||||
idx_t listId,
|
||||
const void* codes,
|
||||
const idx_t* indices,
|
||||
idx_t numVecs) {
|
||||
auto stream = resources_->getDefaultStreamCurrentDevice();
|
||||
|
||||
// This list must already exist
|
||||
FAISS_ASSERT(raft_knn_index.has_value());
|
||||
|
||||
// This list must currently be empty
|
||||
FAISS_ASSERT(getListLength(listId) == 0);
|
||||
|
||||
// If there's nothing to add, then there's nothing we have to do
|
||||
if (numVecs == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The GPU might have a different layout of the memory
|
||||
auto gpuListSizeInBytes = getGpuVectorsEncodingSize_(numVecs);
|
||||
auto cpuListSizeInBytes = getCpuVectorsEncodingSize_(numVecs);
|
||||
|
||||
// We only have int32 length representations on the GPU per each
|
||||
// list; the length is in sizeof(char)
|
||||
FAISS_ASSERT(gpuListSizeInBytes <= (size_t)std::numeric_limits<int>::max());
|
||||
|
||||
std::vector<uint8_t> interleaved_codes(gpuListSizeInBytes);
|
||||
RaftIVFFlatCodePackerInterleaved packer(
|
||||
(size_t)numVecs, (uint32_t)dim_, raft_knn_index.value().veclen());
|
||||
|
||||
packer.pack_all(
|
||||
reinterpret_cast<const uint8_t*>(codes), interleaved_codes.data());
|
||||
|
||||
float* list_data_ptr;
|
||||
const raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
|
||||
/// fetch the list data ptr on host
|
||||
raft::update_host(
|
||||
&list_data_ptr,
|
||||
raft_knn_index.value().data_ptrs().data_handle() + listId,
|
||||
1,
|
||||
stream);
|
||||
raft_handle.sync_stream();
|
||||
|
||||
raft::update_device(
|
||||
reinterpret_cast<uint8_t*>(list_data_ptr),
|
||||
interleaved_codes.data(),
|
||||
gpuListSizeInBytes,
|
||||
stream);
|
||||
|
||||
/// Handle the indices as well
|
||||
idx_t* list_indices_ptr;
|
||||
|
||||
// fetch the list indices ptr on host
|
||||
raft::update_host(
|
||||
&list_indices_ptr,
|
||||
raft_knn_index.value().inds_ptrs().data_handle() + listId,
|
||||
1,
|
||||
stream);
|
||||
raft_handle.sync_stream();
|
||||
|
||||
raft::update_device(list_indices_ptr, indices, numVecs, stream);
|
||||
}
|
||||
|
||||
void RaftIVFFlat::validRowIndices_(
|
||||
Tensor<float, 2, true>& vecs,
|
||||
bool* nan_flag) {
|
||||
raft::device_resources& raft_handle =
|
||||
resources_->getRaftHandleCurrentDevice();
|
||||
idx_t n_rows = vecs.getSize(0);
|
||||
|
||||
thrust::fill_n(raft_handle.get_thrust_policy(), nan_flag, n_rows, true);
|
||||
raft::linalg::map_offset(
|
||||
raft_handle,
|
||||
raft::make_device_vector_view<bool, idx_t>(nan_flag, n_rows),
|
||||
[vecs = vecs.data(), dim_ = this->dim_] __device__(idx_t i) {
|
||||
for (idx_t col = 0; col < dim_; col++) {
|
||||
if (!isfinite(vecs[i * dim_ + col])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
RaftIVFFlatCodePackerInterleaved::RaftIVFFlatCodePackerInterleaved(
|
||||
size_t list_size,
|
||||
uint32_t dim,
|
||||
uint32_t chunk_size) {
|
||||
this->dim = dim;
|
||||
this->chunk_size = chunk_size;
|
||||
// NB: dim should be divisible by the number of 4 byte records in one chunk
|
||||
FAISS_ASSERT(dim % chunk_size == 0);
|
||||
nvec = list_size;
|
||||
code_size = dim * 4;
|
||||
block_size =
|
||||
utils::roundUp(nvec, raft::neighbors::ivf_flat::kIndexGroupSize);
|
||||
}
|
||||
|
||||
void RaftIVFFlatCodePackerInterleaved::pack_1(
|
||||
const uint8_t* flat_code,
|
||||
size_t offset,
|
||||
uint8_t* block) const {
|
||||
raft::neighbors::ivf_flat::codepacker::pack_1(
|
||||
reinterpret_cast<const uint32_t*>(flat_code),
|
||||
reinterpret_cast<uint32_t*>(block),
|
||||
dim,
|
||||
chunk_size,
|
||||
static_cast<uint32_t>(offset));
|
||||
}
|
||||
|
||||
void RaftIVFFlatCodePackerInterleaved::unpack_1(
|
||||
const uint8_t* block,
|
||||
size_t offset,
|
||||
uint8_t* flat_code) const {
|
||||
raft::neighbors::ivf_flat::codepacker::unpack_1(
|
||||
reinterpret_cast<const uint32_t*>(block),
|
||||
reinterpret_cast<uint32_t*>(flat_code),
|
||||
dim,
|
||||
chunk_size,
|
||||
static_cast<uint32_t>(offset));
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace faiss
|
149
faiss/gpu/impl/RaftIVFFlat.cuh
Normal file
149
faiss/gpu/impl/RaftIVFFlat.cuh
Normal file
@ -0,0 +1,149 @@
|
||||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
/*
|
||||
* Copyright (c) 2023, NVIDIA CORPORATION.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <faiss/gpu/impl/GpuScalarQuantizer.cuh>
|
||||
#include <faiss/gpu/impl/IVFBase.cuh>
|
||||
#include <faiss/gpu/impl/IVFFlat.cuh>
|
||||
|
||||
#include <faiss/impl/CodePacker.h>
|
||||
|
||||
#include <raft/neighbors/ivf_flat.cuh>
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace faiss {
|
||||
namespace gpu {
|
||||
|
||||
class RaftIVFFlat : public IVFFlat {
|
||||
public:
|
||||
RaftIVFFlat(
|
||||
GpuResources* resources,
|
||||
int dim,
|
||||
int nlist,
|
||||
faiss::MetricType metric,
|
||||
float metricArg,
|
||||
bool useResidual,
|
||||
/// Optional ScalarQuantizer
|
||||
faiss::ScalarQuantizer* scalarQ,
|
||||
bool interleavedLayout,
|
||||
IndicesOptions indicesOptions,
|
||||
MemorySpace space);
|
||||
|
||||
~RaftIVFFlat() override;
|
||||
|
||||
/// Find the approximate k nearest neigbors for `queries` against
|
||||
/// our database
|
||||
void search(
|
||||
Index* coarseQuantizer,
|
||||
Tensor<float, 2, true>& queries,
|
||||
int nprobe,
|
||||
int k,
|
||||
Tensor<float, 2, true>& outDistances,
|
||||
Tensor<idx_t, 2, true>& outIndices) override;
|
||||
|
||||
/// Performs search when we are already given the IVF cells to look at
|
||||
/// (GpuIndexIVF::search_preassigned implementation)
|
||||
void searchPreassigned(
|
||||
Index* coarseQuantizer,
|
||||
Tensor<float, 2, true>& vecs,
|
||||
Tensor<float, 2, true>& ivfDistances,
|
||||
Tensor<idx_t, 2, true>& ivfAssignments,
|
||||
int k,
|
||||
Tensor<float, 2, true>& outDistances,
|
||||
Tensor<idx_t, 2, true>& outIndices,
|
||||
bool storePairs) override;
|
||||
|
||||
/// Classify and encode/add vectors to our IVF lists.
|
||||
/// The input data must be on our current device.
|
||||
/// Returns the number of vectors successfully added. Vectors may
|
||||
/// not be able to be added because they contain NaNs.
|
||||
idx_t addVectors(
|
||||
Index* coarseQuantizer,
|
||||
Tensor<float, 2, true>& vecs,
|
||||
Tensor<idx_t, 1, true>& indices) override;
|
||||
|
||||
/// Reserve GPU memory in our inverted lists for this number of vectors
|
||||
// void reserveMemory(idx_t numVecs) override;
|
||||
|
||||
/// Clear out all inverted lists, but retain the coarse quantizer
|
||||
/// and the product quantizer info
|
||||
void reset() override;
|
||||
|
||||
/// For debugging purposes, return the list length of a particular
|
||||
/// list
|
||||
idx_t getListLength(idx_t listId) const override;
|
||||
|
||||
/// Return the list indices of a particular list back to the CPU
|
||||
std::vector<idx_t> getListIndices(idx_t listId) const override;
|
||||
|
||||
/// Return the encoded vectors of a particular list back to the CPU
|
||||
std::vector<uint8_t> getListVectorData(idx_t listId, bool gpuFormat)
|
||||
const override;
|
||||
|
||||
void updateQuantizer(Index* quantizer) override;
|
||||
|
||||
/// Copy all inverted lists from a CPU representation to ourselves
|
||||
void copyInvertedListsFrom(const InvertedLists* ivf) override;
|
||||
|
||||
/// Filter out matrix rows containing NaN values
|
||||
void validRowIndices_(Tensor<float, 2, true>& vecs, bool* nan_flag);
|
||||
|
||||
protected:
|
||||
/// Adds a set of codes and indices to a list, with the representation
|
||||
/// coming from the CPU equivalent
|
||||
void addEncodedVectorsToList_(
|
||||
idx_t listId,
|
||||
// resident on the host
|
||||
const void* codes,
|
||||
// resident on the host
|
||||
const idx_t* indices,
|
||||
idx_t numVecs) override;
|
||||
|
||||
/// Returns the number of bytes in which an IVF list containing numVecs
|
||||
/// vectors is encoded on the device. Note that due to padding this is not
|
||||
/// the same as the encoding size for a subset of vectors in an IVF list;
|
||||
/// this is the size for an entire IVF list
|
||||
size_t getGpuVectorsEncodingSize_(idx_t numVecs) const override;
|
||||
|
||||
std::optional<raft::neighbors::ivf_flat::index<float, idx_t>>
|
||||
raft_knn_index{std::nullopt};
|
||||
};
|
||||
|
||||
struct RaftIVFFlatCodePackerInterleaved : CodePacker {
|
||||
RaftIVFFlatCodePackerInterleaved(
|
||||
size_t list_size,
|
||||
uint32_t dim,
|
||||
uint32_t chuk_size);
|
||||
void pack_1(const uint8_t* flat_code, size_t offset, uint8_t* block)
|
||||
const final;
|
||||
void unpack_1(const uint8_t* block, size_t offset, uint8_t* flat_code)
|
||||
const final;
|
||||
|
||||
protected:
|
||||
uint32_t chunk_size;
|
||||
uint32_t dim;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace faiss
|
@ -749,7 +749,6 @@ void testSearchAndReconstruct(bool use_raft) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexFlat, SearchAndReconstruct) {
|
||||
testSearchAndReconstruct(false);
|
||||
}
|
||||
@ -767,4 +766,4 @@ int main(int argc, char** argv) {
|
||||
faiss::gpu::setTestSeed(100);
|
||||
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
}
|
@ -30,6 +30,7 @@
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include "faiss/gpu/GpuIndicesOptions.h"
|
||||
|
||||
// FIXME: figure out a better way to test fp16
|
||||
constexpr float kF16MaxRelErr = 0.3f;
|
||||
@ -55,6 +56,8 @@ struct Options {
|
||||
faiss::gpu::INDICES_64_BIT});
|
||||
|
||||
device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
|
||||
|
||||
use_raft = false;
|
||||
}
|
||||
|
||||
std::string toString() const {
|
||||
@ -62,7 +65,7 @@ struct Options {
|
||||
str << "IVFFlat device " << device << " numVecs " << numAdd << " dim "
|
||||
<< dim << " numCentroids " << numCentroids << " nprobe " << nprobe
|
||||
<< " numQuery " << numQuery << " k " << k << " indicesOpt "
|
||||
<< indicesOpt;
|
||||
<< indicesOpt << " use_raft " << use_raft;
|
||||
|
||||
return str.str();
|
||||
}
|
||||
@ -76,6 +79,7 @@ struct Options {
|
||||
int k;
|
||||
int device;
|
||||
faiss::gpu::IndicesOptions indicesOpt;
|
||||
bool use_raft;
|
||||
};
|
||||
|
||||
void queryTest(
|
||||
@ -106,6 +110,7 @@ void queryTest(
|
||||
config.device = opt.device;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
config.flatConfig.useFloat16 = useFloat16CoarseQuantizer;
|
||||
config.use_raft = opt.use_raft;
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlat gpuIndex(
|
||||
&res, cpuIndex.d, cpuIndex.nlist, cpuIndex.metric_type, config);
|
||||
@ -129,7 +134,10 @@ void queryTest(
|
||||
}
|
||||
}
|
||||
|
||||
void addTest(faiss::MetricType metricType, bool useFloat16CoarseQuantizer) {
|
||||
void addTest(
|
||||
faiss::MetricType metricType,
|
||||
bool useFloat16CoarseQuantizer,
|
||||
bool use_raft) {
|
||||
for (int tries = 0; tries < 2; ++tries) {
|
||||
Options opt;
|
||||
|
||||
@ -153,8 +161,10 @@ void addTest(faiss::MetricType metricType, bool useFloat16CoarseQuantizer) {
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlatConfig config;
|
||||
config.device = opt.device;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
config.indicesOptions =
|
||||
use_raft ? faiss::gpu::INDICES_64_BIT : opt.indicesOpt;
|
||||
config.flatConfig.useFloat16 = useFloat16CoarseQuantizer;
|
||||
config.use_raft = use_raft;
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlat gpuIndex(
|
||||
&res, cpuIndex.d, cpuIndex.nlist, cpuIndex.metric_type, config);
|
||||
@ -178,7 +188,7 @@ void addTest(faiss::MetricType metricType, bool useFloat16CoarseQuantizer) {
|
||||
}
|
||||
}
|
||||
|
||||
void copyToTest(bool useFloat16CoarseQuantizer) {
|
||||
void copyToTest(bool useFloat16CoarseQuantizer, bool use_raft) {
|
||||
Options opt;
|
||||
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
|
||||
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);
|
||||
@ -188,8 +198,10 @@ void copyToTest(bool useFloat16CoarseQuantizer) {
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlatConfig config;
|
||||
config.device = opt.device;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
config.indicesOptions =
|
||||
use_raft ? faiss::gpu::INDICES_64_BIT : opt.indicesOpt;
|
||||
config.flatConfig.useFloat16 = useFloat16CoarseQuantizer;
|
||||
config.use_raft = use_raft;
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlat gpuIndex(
|
||||
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
|
||||
@ -229,7 +241,7 @@ void copyToTest(bool useFloat16CoarseQuantizer) {
|
||||
compFloat16 ? 0.30f : 0.015f);
|
||||
}
|
||||
|
||||
void copyFromTest(bool useFloat16CoarseQuantizer) {
|
||||
void copyFromTest(bool useFloat16CoarseQuantizer, bool use_raft) {
|
||||
Options opt;
|
||||
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
|
||||
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);
|
||||
@ -247,8 +259,10 @@ void copyFromTest(bool useFloat16CoarseQuantizer) {
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlatConfig config;
|
||||
config.device = opt.device;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
config.indicesOptions =
|
||||
use_raft ? faiss::gpu::INDICES_64_BIT : opt.indicesOpt;
|
||||
config.flatConfig.useFloat16 = useFloat16CoarseQuantizer;
|
||||
config.use_raft = use_raft;
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, 1, 1, faiss::METRIC_L2, config);
|
||||
gpuIndex.nprobe = 1;
|
||||
@ -280,19 +294,35 @@ void copyFromTest(bool useFloat16CoarseQuantizer) {
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_32_Add_L2) {
|
||||
addTest(faiss::METRIC_L2, false);
|
||||
addTest(faiss::METRIC_L2, false, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
addTest(faiss::METRIC_L2, false, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_32_Add_IP) {
|
||||
addTest(faiss::METRIC_INNER_PRODUCT, false);
|
||||
addTest(faiss::METRIC_INNER_PRODUCT, false, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
addTest(faiss::METRIC_INNER_PRODUCT, false, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float16_32_Add_L2) {
|
||||
addTest(faiss::METRIC_L2, true);
|
||||
addTest(faiss::METRIC_L2, true, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
addTest(faiss::METRIC_L2, true, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float16_32_Add_IP) {
|
||||
addTest(faiss::METRIC_INNER_PRODUCT, true);
|
||||
addTest(faiss::METRIC_INNER_PRODUCT, true, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
addTest(faiss::METRIC_INNER_PRODUCT, true, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
@ -300,11 +330,25 @@ TEST(TestGpuIndexIVFFlat, Float16_32_Add_IP) {
|
||||
//
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_Query_L2) {
|
||||
queryTest(Options(), faiss::METRIC_L2, false);
|
||||
Options opt;
|
||||
queryTest(opt, faiss::METRIC_L2, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
opt.use_raft = true;
|
||||
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
|
||||
queryTest(opt, faiss::METRIC_L2, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_Query_IP) {
|
||||
queryTest(Options(), faiss::METRIC_INNER_PRODUCT, false);
|
||||
Options opt;
|
||||
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
opt.use_raft = true;
|
||||
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
|
||||
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, LargeBatch) {
|
||||
@ -312,16 +356,36 @@ TEST(TestGpuIndexIVFFlat, LargeBatch) {
|
||||
opt.dim = 3;
|
||||
opt.numQuery = 100000;
|
||||
queryTest(opt, faiss::METRIC_L2, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
opt.use_raft = true;
|
||||
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
|
||||
queryTest(opt, faiss::METRIC_L2, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
// float16 coarse quantizer
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float16_32_Query_L2) {
|
||||
queryTest(Options(), faiss::METRIC_L2, true);
|
||||
Options opt;
|
||||
queryTest(opt, faiss::METRIC_L2, true);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
opt.use_raft = true;
|
||||
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
|
||||
queryTest(opt, faiss::METRIC_L2, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float16_32_Query_IP) {
|
||||
queryTest(Options(), faiss::METRIC_INNER_PRODUCT, true);
|
||||
Options opt;
|
||||
queryTest(opt, faiss::METRIC_INNER_PRODUCT, true);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
opt.use_raft = true;
|
||||
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
|
||||
queryTest(opt, faiss::METRIC_INNER_PRODUCT, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
@ -333,24 +397,48 @@ TEST(TestGpuIndexIVFFlat, Float32_Query_L2_64) {
|
||||
Options opt;
|
||||
opt.dim = 64;
|
||||
queryTest(opt, faiss::METRIC_L2, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
opt.use_raft = true;
|
||||
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
|
||||
queryTest(opt, faiss::METRIC_L2, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_Query_IP_64) {
|
||||
Options opt;
|
||||
opt.dim = 64;
|
||||
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
opt.use_raft = true;
|
||||
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
|
||||
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_Query_L2_128) {
|
||||
Options opt;
|
||||
opt.dim = 128;
|
||||
queryTest(opt, faiss::METRIC_L2, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
opt.use_raft = true;
|
||||
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
|
||||
queryTest(opt, faiss::METRIC_L2, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_Query_IP_128) {
|
||||
Options opt;
|
||||
opt.dim = 128;
|
||||
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
opt.use_raft = true;
|
||||
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
|
||||
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
@ -358,11 +446,19 @@ TEST(TestGpuIndexIVFFlat, Float32_Query_IP_128) {
|
||||
//
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_32_CopyTo) {
|
||||
copyToTest(false);
|
||||
copyToTest(false, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
copyToTest(false, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_32_CopyFrom) {
|
||||
copyFromTest(false);
|
||||
copyFromTest(false, false);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
copyFromTest(false, true);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, Float32_negative) {
|
||||
@ -392,6 +488,14 @@ TEST(TestGpuIndexIVFFlat, Float32_negative) {
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
res.noTempMemory();
|
||||
|
||||
// Construct a positive test set
|
||||
auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim);
|
||||
|
||||
// Put all vecs on positive size
|
||||
for (auto& f : queryVecs) {
|
||||
f = std::abs(f);
|
||||
}
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlatConfig config;
|
||||
config.device = opt.device;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
@ -401,14 +505,6 @@ TEST(TestGpuIndexIVFFlat, Float32_negative) {
|
||||
gpuIndex.copyFrom(&cpuIndex);
|
||||
gpuIndex.nprobe = opt.nprobe;
|
||||
|
||||
// Construct a positive test set
|
||||
auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim);
|
||||
|
||||
// Put all vecs on positive size
|
||||
for (auto& f : queryVecs) {
|
||||
f = std::abs(f);
|
||||
}
|
||||
|
||||
bool compFloat16 = false;
|
||||
faiss::gpu::compareIndices(
|
||||
queryVecs,
|
||||
@ -424,6 +520,31 @@ TEST(TestGpuIndexIVFFlat, Float32_negative) {
|
||||
// in fp16. Figure out another way to test
|
||||
compFloat16 ? 0.99f : 0.1f,
|
||||
compFloat16 ? 0.65f : 0.015f);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
config.use_raft = true;
|
||||
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
|
||||
&res, cpuIndex.d, cpuIndex.nlist, cpuIndex.metric_type, config);
|
||||
raftGpuIndex.copyFrom(&cpuIndex);
|
||||
raftGpuIndex.nprobe = opt.nprobe;
|
||||
|
||||
faiss::gpu::compareIndices(
|
||||
queryVecs,
|
||||
cpuIndex,
|
||||
raftGpuIndex,
|
||||
opt.numQuery,
|
||||
opt.dim,
|
||||
opt.k,
|
||||
opt.toString(),
|
||||
compFloat16 ? kF16MaxRelErr : kF32MaxRelErr,
|
||||
// FIXME: the fp16 bounds are
|
||||
// useless when math (the accumulator) is
|
||||
// in fp16. Figure out another way to test
|
||||
compFloat16 ? 0.99f : 0.1f,
|
||||
compFloat16 ? 0.65f : 0.015f);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
@ -439,6 +560,13 @@ TEST(TestGpuIndexIVFFlat, QueryNaN) {
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
res.noTempMemory();
|
||||
|
||||
int numQuery = 10;
|
||||
std::vector<float> nans(
|
||||
numQuery * opt.dim, std::numeric_limits<float>::quiet_NaN());
|
||||
|
||||
std::vector<float> distances(numQuery * opt.k, 0);
|
||||
std::vector<faiss::idx_t> indices(numQuery * opt.k, 0);
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlatConfig config;
|
||||
config.device = opt.device;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
@ -451,13 +579,6 @@ TEST(TestGpuIndexIVFFlat, QueryNaN) {
|
||||
gpuIndex.train(opt.numTrain, trainVecs.data());
|
||||
gpuIndex.add(opt.numAdd, addVecs.data());
|
||||
|
||||
int numQuery = 10;
|
||||
std::vector<float> nans(
|
||||
numQuery * opt.dim, std::numeric_limits<float>::quiet_NaN());
|
||||
|
||||
std::vector<float> distances(numQuery * opt.k, 0);
|
||||
std::vector<faiss::idx_t> indices(numQuery * opt.k, 0);
|
||||
|
||||
gpuIndex.search(
|
||||
numQuery, nans.data(), opt.k, distances.data(), indices.data());
|
||||
|
||||
@ -469,6 +590,31 @@ TEST(TestGpuIndexIVFFlat, QueryNaN) {
|
||||
std::numeric_limits<float>::max());
|
||||
}
|
||||
}
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
config.use_raft = true;
|
||||
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
|
||||
std::fill(distances.begin(), distances.end(), 0);
|
||||
std::fill(indices.begin(), indices.end(), 0);
|
||||
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
|
||||
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
|
||||
raftGpuIndex.nprobe = opt.nprobe;
|
||||
|
||||
raftGpuIndex.train(opt.numTrain, trainVecs.data());
|
||||
raftGpuIndex.add(opt.numAdd, addVecs.data());
|
||||
|
||||
raftGpuIndex.search(
|
||||
numQuery, nans.data(), opt.k, distances.data(), indices.data());
|
||||
|
||||
for (int q = 0; q < numQuery; ++q) {
|
||||
for (int k = 0; k < opt.k; ++k) {
|
||||
EXPECT_EQ(indices[q * opt.k + k], -1);
|
||||
EXPECT_EQ(
|
||||
distances[q * opt.k + k],
|
||||
std::numeric_limits<float>::max());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, AddNaN) {
|
||||
@ -477,15 +623,6 @@ TEST(TestGpuIndexIVFFlat, AddNaN) {
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
res.noTempMemory();
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlatConfig config;
|
||||
config.device = opt.device;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
config.flatConfig.useFloat16 = faiss::gpu::randBool();
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlat gpuIndex(
|
||||
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
|
||||
gpuIndex.nprobe = opt.nprobe;
|
||||
|
||||
int numNans = 10;
|
||||
std::vector<float> nans(
|
||||
numNans * opt.dim, std::numeric_limits<float>::quiet_NaN());
|
||||
@ -497,6 +634,14 @@ TEST(TestGpuIndexIVFFlat, AddNaN) {
|
||||
}
|
||||
|
||||
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
|
||||
|
||||
faiss::gpu::GpuIndexIVFFlatConfig config;
|
||||
config.device = opt.device;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
config.flatConfig.useFloat16 = faiss::gpu::randBool();
|
||||
faiss::gpu::GpuIndexIVFFlat gpuIndex(
|
||||
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
|
||||
gpuIndex.nprobe = opt.nprobe;
|
||||
gpuIndex.train(opt.numTrain, trainVecs.data());
|
||||
|
||||
// should not crash
|
||||
@ -514,6 +659,27 @@ TEST(TestGpuIndexIVFFlat, AddNaN) {
|
||||
opt.k,
|
||||
distance.data(),
|
||||
indices.data());
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
config.use_raft = true;
|
||||
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
|
||||
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
|
||||
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
|
||||
raftGpuIndex.nprobe = opt.nprobe;
|
||||
raftGpuIndex.train(opt.numTrain, trainVecs.data());
|
||||
|
||||
// should not crash
|
||||
EXPECT_EQ(raftGpuIndex.ntotal, 0);
|
||||
raftGpuIndex.add(numNans, nans.data());
|
||||
|
||||
// should not crash
|
||||
raftGpuIndex.search(
|
||||
opt.numQuery,
|
||||
queryVecs.data(),
|
||||
opt.k,
|
||||
distance.data(),
|
||||
indices.data());
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, UnifiedMemory) {
|
||||
@ -570,6 +736,26 @@ TEST(TestGpuIndexIVFFlat, UnifiedMemory) {
|
||||
kF32MaxRelErr,
|
||||
0.1f,
|
||||
0.015f);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
config.use_raft = true;
|
||||
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
|
||||
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
|
||||
&res, dim, numCentroids, faiss::METRIC_L2, config);
|
||||
raftGpuIndex.copyFrom(&cpuIndex);
|
||||
raftGpuIndex.nprobe = nprobe;
|
||||
|
||||
faiss::gpu::compareIndices(
|
||||
cpuIndex,
|
||||
raftGpuIndex,
|
||||
numQuery,
|
||||
dim,
|
||||
k,
|
||||
"Unified Memory",
|
||||
kF32MaxRelErr,
|
||||
0.1f,
|
||||
0.015f);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFFlat, LongIVFList) {
|
||||
@ -628,6 +814,27 @@ TEST(TestGpuIndexIVFFlat, LongIVFList) {
|
||||
kF32MaxRelErr,
|
||||
0.1f,
|
||||
0.015f);
|
||||
|
||||
#if defined USE_NVIDIA_RAFT
|
||||
config.use_raft = true;
|
||||
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
|
||||
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
|
||||
&res, dim, numCentroids, faiss::METRIC_L2, config);
|
||||
raftGpuIndex.train(numTrain, trainVecs.data());
|
||||
raftGpuIndex.add(numAdd, addVecs.data());
|
||||
raftGpuIndex.nprobe = 1;
|
||||
|
||||
faiss::gpu::compareIndices(
|
||||
cpuIndex,
|
||||
raftGpuIndex,
|
||||
numQuery,
|
||||
dim,
|
||||
k,
|
||||
"Unified Memory",
|
||||
kF32MaxRelErr,
|
||||
0.1f,
|
||||
0.015f);
|
||||
#endif
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
@ -637,4 +844,4 @@ int main(int argc, char** argv) {
|
||||
faiss::gpu::setTestSeed(100);
|
||||
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user