Integrate IVF-Flat from RAFT (#2521)

Summary:
This is a design proposal that demonstrates an approach to enabling optional support for [RAFT](https://github.com/rapidsai/raft) versions of IVF PQ and IVF Flat (and brute force w/ fused k-selection when k <= 64). There are still a few open issues and design discussions needed for the new RAFT index types to support the full range of features of that FAISS' current gpu index types.

Checklist for the integration todos:
- [x] Rebase on current `main` branch
- [X] The raft handle has been plugged directly into the StandardGpuResources
- [X] `FlatIndex` passing Googletests
- [x] Use `CodePacker` to support `copyFrom()` and `copyTo()`
- [X] `IVF-flat passing Googletests
- [ ] Raise appropriate exceptions for operations which are not yet supported by RAFT

Additional features we've discussed:
- [x] Separate IVF lists into individual memory chunks
- [ ] Saving/loading

To build FAISS w/ optional RAFT support:
```
mkdir build
cd build
cmake ../ -DFAISS_ENABLE_RAFT=ON -DFAISS_ENABLE_GPU=ON
make -j
```

For development/testing, we've also supplied a bash script to make things easier: `build.sh`

Below is a benchmark comparing the training of IVF Flat indices for RAFT and FAISS:
![image](https://user-images.githubusercontent.com/1242464/194944737-8b808f11-e28e-4556-82d1-1ea4b0707283.png)

The benchmark was produced using Googlebench in [this](https://github.com/tfeher/raft/tree/raft_faiss_bench) RAFT fork. We're going to provide benchmarks for the queries as well. There are still a couple bottlenecks to be removed in the IVF-Flat training implementation and we'll update the current benchmark when ready.

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2521

Test Plan: `buck test mode/debuck test mode/dev-nosan //faiss/gpu/test:test_gpu_index_ivfflat`

Reviewed By: algoriddle

Differential Revision: D49118319

Pulled By: mdouze

fbshipit-source-id: 5916108bc27154acf7c92021ba579a6ca85d730b
This commit is contained in:
Corey J. Nolet 2023-10-04 23:42:30 -07:00 committed by Facebook GitHub Bot
parent 458633c203
commit edcf7438bb
14 changed files with 1232 additions and 92 deletions

58
build.sh Normal file
View File

@ -0,0 +1,58 @@
#!/bin/bash
# NOTE: This file is temporary for the proof-of-concept branch and will be removed before this PR is merged
BUILD_TYPE=Release
BUILD_DIR=build/
RAFT_REPO_REL=""
EXTRA_CMAKE_ARGS=""
set -e
if [[ ${RAFT_REPO_REL} != "" ]]; then
RAFT_REPO_PATH="`readlink -f \"${RAFT_REPO_REL}\"`"
EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DCPM_raft_SOURCE=${RAFT_REPO_PATH}"
fi
if [ "$1" == "clean" ]; then
rm -rf build
rm -rf .cache
exit 0
fi
if [ "$1" == "test" ]; then
make -C build -j test
exit 0
fi
if [ "$1" == "test-raft" ]; then
./build/faiss/gpu/test/TestRaftIndexIVFFlat
exit 0
fi
mkdir -p $BUILD_DIR
cd $BUILD_DIR
cmake \
-DFAISS_ENABLE_GPU=ON \
-DFAISS_ENABLE_RAFT=ON \
-DFAISS_ENABLE_PYTHON=OFF \
-DBUILD_TESTING=ON \
-DBUILD_SHARED_LIBS=OFF \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DFAISS_OPT_LEVEL=avx2 \
-DRAFT_NVTX=OFF \
-DCMAKE_CUDA_ARCHITECTURES="NATIVE" \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
${EXTRA_CMAKE_ARGS} \
../
# make -C build -j12 faiss
cmake --build . -j12
# make -C build -j12 swigfaiss
# (cd build/faiss/python && python setup.py install)

View File

@ -15,7 +15,7 @@
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================
set(RAPIDS_VERSION "23.06")
set(RAPIDS_VERSION "23.08")
if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake

View File

@ -238,9 +238,11 @@ generate_ivf_interleaved_code()
if(FAISS_ENABLE_RAFT)
list(APPEND FAISS_GPU_HEADERS
impl/RaftIVFFlat.cuh
impl/RaftFlatIndex.cuh)
list(APPEND FAISS_GPU_SRC
impl/RaftFlatIndex.cu)
impl/RaftFlatIndex.cu
impl/RaftIVFFlat.cu)
target_compile_definitions(faiss PUBLIC USE_NVIDIA_RAFT=1)
target_compile_definitions(faiss_avx2 PUBLIC USE_NVIDIA_RAFT=1)

View File

@ -16,6 +16,11 @@
#include <faiss/gpu/impl/IVFBase.cuh>
#include <faiss/gpu/utils/CopyUtils.cuh>
#if defined USE_NVIDIA_RAFT
#include <raft/core/handle.hpp>
#include <raft/neighbors/ivf_flat.cuh>
#endif
namespace faiss {
namespace gpu {
@ -444,14 +449,46 @@ void GpuIndexIVF::trainQuantizer_(idx_t n, const float* x) {
printf("Training IVF quantizer on %ld vectors in %dD\n", n, d);
}
// leverage the CPU-side k-means code, which works for the GPU
// flat index as well
quantizer->reset();
Clustering clus(this->d, nlist, this->cp);
clus.verbose = verbose;
clus.train(n, x, *quantizer);
quantizer->is_trained = true;
#if defined USE_NVIDIA_RAFT
if (config_.use_raft) {
const raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
raft::neighbors::ivf_flat::index_params raft_idx_params;
raft_idx_params.n_lists = nlist;
raft_idx_params.metric = metric_type == faiss::METRIC_L2
? raft::distance::DistanceType::L2Expanded
: raft::distance::DistanceType::InnerProduct;
raft_idx_params.add_data_on_build = false;
raft_idx_params.kmeans_trainset_fraction = 1.0;
raft_idx_params.kmeans_n_iters = cp.niter;
raft_idx_params.adaptive_centers = !cp.frozen_centroids;
auto raft_index = raft::neighbors::ivf_flat::build(
raft_handle, raft_idx_params, x, n, (idx_t)d);
raft_handle.sync_stream();
quantizer->train(nlist, raft_index.centers().data_handle());
quantizer->add(nlist, raft_index.centers().data_handle());
} else
#else
if (config_.use_raft) {
FAISS_THROW_MSG(
"RAFT has not been compiled into the current version so it cannot be used.");
} else
#endif
{
// leverage the CPU-side k-means code, which works for the GPU
// flat index as well
Clustering clus(this->d, nlist, this->cp);
clus.verbose = verbose;
clus.train(n, x, *quantizer);
}
quantizer->is_trained = true;
FAISS_ASSERT(quantizer->ntotal == nlist);
}

View File

@ -73,10 +73,10 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
virtual void updateQuantizer() = 0;
/// Returns the number of inverted lists we're managing
idx_t getNumLists() const;
virtual idx_t getNumLists() const;
/// Returns the number of vectors present in a particular inverted list
idx_t getListLength(idx_t listId) const;
virtual idx_t getListLength(idx_t listId) const;
/// Return the encoded vector data contained in a particular inverted list,
/// for debugging purposes.
@ -84,12 +84,13 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
/// GPU-side representation.
/// Otherwise, it is converted to the CPU format.
/// compliant format, while the native GPU format may differ.
std::vector<uint8_t> getListVectorData(idx_t listId, bool gpuFormat = false)
const;
virtual std::vector<uint8_t> getListVectorData(
idx_t listId,
bool gpuFormat = false) const;
/// Return the vector indices contained in a particular inverted list, for
/// debugging purposes.
std::vector<idx_t> getListIndices(idx_t listId) const;
virtual std::vector<idx_t> getListIndices(idx_t listId) const;
void search_preassigned(
idx_t n,
@ -121,7 +122,7 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
int getCurrentNProbe_(const SearchParameters* params) const;
void verifyIVFSettings_() const;
bool addImplRequiresIDs_() const override;
void trainQuantizer_(idx_t n, const float* x);
virtual void trainQuantizer_(idx_t n, const float* x);
/// Called from GpuIndex for add/add_with_ids
void addImpl_(idx_t n, const float* x, const idx_t* ids) override;

View File

@ -15,6 +15,10 @@
#include <faiss/gpu/utils/CopyUtils.cuh>
#include <faiss/gpu/utils/Float16.cuh>
#if defined USE_NVIDIA_RAFT
#include <faiss/gpu/impl/RaftIVFFlat.cuh>
#endif
#include <limits>
namespace faiss {
@ -70,8 +74,7 @@ GpuIndexIVFFlat::GpuIndexIVFFlat(
// no other quantizer that we need to train, so this is sufficient
if (this->is_trained) {
FAISS_ASSERT(this->quantizer);
index_.reset(new IVFFlat(
set_index_(
resources_.get(),
this->d,
this->nlist,
@ -81,7 +84,7 @@ GpuIndexIVFFlat::GpuIndexIVFFlat(
nullptr, // no scalar quantizer
ivfFlatConfig_.interleavedLayout,
ivfFlatConfig_.indicesOptions,
config_.memorySpace));
config_.memorySpace);
baseIndex_ = std::static_pointer_cast<IVFBase, IVFFlat>(index_);
updateQuantizer();
}
@ -89,6 +92,54 @@ GpuIndexIVFFlat::GpuIndexIVFFlat(
GpuIndexIVFFlat::~GpuIndexIVFFlat() {}
void GpuIndexIVFFlat::set_index_(
GpuResources* resources,
int dim,
int nlist,
faiss::MetricType metric,
float metricArg,
bool useResidual,
/// Optional ScalarQuantizer
faiss::ScalarQuantizer* scalarQ,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space) {
#if defined USE_NVIDIA_RAFT
if (config_.use_raft) {
index_.reset(new RaftIVFFlat(
resources,
dim,
nlist,
metric,
metricArg,
useResidual,
scalarQ,
interleavedLayout,
indicesOptions,
space));
} else
#else
if (config_.use_raft) {
FAISS_THROW_MSG(
"RAFT has not been compiled into the current version so it cannot be used.");
} else
#endif
{
index_.reset(new IVFFlat(
resources,
dim,
nlist,
metric,
metricArg,
useResidual,
scalarQ,
interleavedLayout,
indicesOptions,
space));
}
}
void GpuIndexIVFFlat::reserveMemory(size_t numVecs) {
DeviceScope scope(config_.device);
@ -110,25 +161,25 @@ void GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
// The other index might not be trained
if (!index->is_trained) {
FAISS_ASSERT(!this->is_trained);
FAISS_ASSERT(!is_trained);
return;
}
// Otherwise, we can populate ourselves from the other index
FAISS_ASSERT(this->is_trained);
FAISS_ASSERT(is_trained);
// Copy our lists as well
index_.reset(new IVFFlat(
set_index_(
resources_.get(),
this->d,
this->nlist,
d,
nlist,
index->metric_type,
index->metric_arg,
false, // no residual
nullptr, // no scalar quantizer
ivfFlatConfig_.interleavedLayout,
ivfFlatConfig_.indicesOptions,
config_.memorySpace));
config_.memorySpace);
baseIndex_ = std::static_pointer_cast<IVFBase, IVFFlat>(index_);
updateQuantizer();
@ -201,18 +252,30 @@ void GpuIndexIVFFlat::train(idx_t n, const float* x) {
FAISS_ASSERT(!index_);
// FIXME: GPUize more of this
// First, make sure that the data is resident on the CPU, if it is not on
// the CPU, as we depend upon parts of the CPU code
auto hostData = toHost<float, 2>(
(float*)x,
resources_->getDefaultStream(config_.device),
{n, this->d});
trainQuantizer_(n, hostData.data());
#if defined USE_NVIDIA_RAFT
if (config_.use_raft) {
// No need to copy the data to host
trainQuantizer_(n, x);
} else
#else
if (config_.use_raft) {
FAISS_THROW_MSG(
"RAFT has not been compiled into the current version so it cannot be used.");
} else
#endif
{
// FIXME: GPUize more of this
// First, make sure that the data is resident on the CPU, if it is not
// on the CPU, as we depend upon parts of the CPU code
auto hostData = toHost<float, 2>(
(float*)x,
resources_->getDefaultStream(config_.device),
{n, this->d});
trainQuantizer_(n, hostData.data());
}
// The quantizer is now trained; construct the IVF index
index_.reset(new IVFFlat(
set_index_(
resources_.get(),
this->d,
this->nlist,
@ -222,7 +285,7 @@ void GpuIndexIVFFlat::train(idx_t n, const float* x) {
nullptr, // no scalar quantizer
ivfFlatConfig_.interleavedLayout,
ivfFlatConfig_.indicesOptions,
config_.memorySpace));
config_.memorySpace);
baseIndex_ = std::static_pointer_cast<IVFBase, IVFFlat>(index_);
updateQuantizer();

View File

@ -8,6 +8,8 @@
#pragma once
#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/impl/ScalarQuantizer.h>
#include <memory>
namespace faiss {
@ -86,6 +88,19 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
void train(idx_t n, const float* x) override;
protected:
void set_index_(
GpuResources* resources,
int dim,
int nlist,
faiss::MetricType metric,
float metricArg,
bool useResidual,
/// Optional ScalarQuantizer
faiss::ScalarQuantizer* scalarQ,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space);
/// Our configuration options
const GpuIndexIVFFlatConfig ivfFlatConfig_;

View File

@ -362,7 +362,11 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) {
defaultStreams_[device] = defaultStream;
cudaStream_t asyncCopyStream = nullptr;
#if defined USE_NVIDIA_RAFT
raftHandles_.emplace(std::make_pair(device, defaultStream));
#endif
cudaStream_t asyncCopyStream = 0;
CUDA_VERIFY(
cudaStreamCreateWithFlags(&asyncCopyStream, cudaStreamNonBlocking));

View File

@ -45,7 +45,7 @@ class IVFBase {
/// Clear out all inverted lists, but retain the coarse quantizer
/// and the product quantizer info
void reset();
virtual void reset();
/// Return the number of dimensions we are indexing
idx_t getDim() const;
@ -59,29 +59,30 @@ class IVFBase {
/// For debugging purposes, return the list length of a particular
/// list
idx_t getListLength(idx_t listId) const;
virtual idx_t getListLength(idx_t listId) const;
/// Return the list indices of a particular list back to the CPU
std::vector<idx_t> getListIndices(idx_t listId) const;
virtual std::vector<idx_t> getListIndices(idx_t listId) const;
/// Return the encoded vectors of a particular list back to the CPU
std::vector<uint8_t> getListVectorData(idx_t listId, bool gpuFormat) const;
virtual std::vector<uint8_t> getListVectorData(idx_t listId, bool gpuFormat)
const;
/// Copy all inverted lists from a CPU representation to ourselves
void copyInvertedListsFrom(const InvertedLists* ivf);
virtual void copyInvertedListsFrom(const InvertedLists* ivf);
/// Copy all inverted lists from ourselves to a CPU representation
void copyInvertedListsTo(InvertedLists* ivf);
virtual void copyInvertedListsTo(InvertedLists* ivf);
/// Update our coarse quantizer with this quantizer instance; may be a CPU
/// or GPU quantizer
void updateQuantizer(Index* quantizer);
virtual void updateQuantizer(Index* quantizer);
/// Classify and encode/add vectors to our IVF lists.
/// The input data must be on our current device.
/// Returns the number of vectors successfully added. Vectors may
/// not be able to be added because they contain NaNs.
idx_t addVectors(
virtual idx_t addVectors(
Index* coarseQuantizer,
Tensor<float, 2, true>& vecs,
Tensor<idx_t, 1, true>& indices);
@ -111,7 +112,7 @@ class IVFBase {
protected:
/// Adds a set of codes and indices to a list, with the representation
/// coming from the CPU equivalent
void addEncodedVectorsToList_(
virtual void addEncodedVectorsToList_(
idx_t listId,
// resident on the host
const void* codes,

View File

@ -60,17 +60,17 @@ class IVFFlat : public IVFBase {
size_t getCpuVectorsEncodingSize_(idx_t numVecs) const override;
/// Translate to our preferred GPU encoding
std::vector<uint8_t> translateCodesToGpu_(
virtual std::vector<uint8_t> translateCodesToGpu_(
std::vector<uint8_t> codes,
idx_t numVecs) const override;
/// Translate from our preferred GPU encoding
std::vector<uint8_t> translateCodesFromGpu_(
virtual std::vector<uint8_t> translateCodesFromGpu_(
std::vector<uint8_t> codes,
idx_t numVecs) const override;
/// Encode the vectors that we're adding and append to our IVF lists
void appendVectors_(
virtual void appendVectors_(
Tensor<float, 2, true>& vecs,
Tensor<float, 2, true>& ivfCentroidResiduals,
Tensor<idx_t, 1, true>& indices,
@ -84,7 +84,7 @@ class IVFFlat : public IVFBase {
/// Shared IVF search implementation, used by both search and
/// searchPreassigned
void searchImpl_(
virtual void searchImpl_(
Tensor<float, 2, true>& queries,
Tensor<float, 2, true>& coarseDistances,
Tensor<idx_t, 2, true>& coarseIndices,

View File

@ -0,0 +1,604 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <faiss/gpu/GpuIndex.h>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/impl/InterleavedCodes.h>
#include <faiss/gpu/impl/RemapIndices.h>
#include <faiss/gpu/utils/DeviceUtils.h>
#include <thrust/host_vector.h>
#include <faiss/gpu/impl/FlatIndex.cuh>
#include <faiss/gpu/impl/IVFAppend.cuh>
#include <faiss/gpu/impl/IVFFlat.cuh>
#include <faiss/gpu/impl/IVFFlatScan.cuh>
#include <faiss/gpu/impl/IVFInterleaved.cuh>
#include <faiss/gpu/impl/RaftIVFFlat.cuh>
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/CopyUtils.cuh>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/Float16.cuh>
#include <faiss/gpu/utils/HostTensor.cuh>
#include <faiss/gpu/utils/Transpose.cuh>
#include <limits>
#include <unordered_map>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/neighbors/ivf_flat_codepacker.hpp>
#include <raft/neighbors/ivf_flat.cuh>
namespace faiss {
namespace gpu {
RaftIVFFlat::RaftIVFFlat(
GpuResources* res,
int dim,
int nlist,
faiss::MetricType metric,
float metricArg,
bool useResidual,
faiss::ScalarQuantizer* scalarQ,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space)
: IVFFlat(res,
dim,
nlist,
metric,
metricArg,
useResidual,
scalarQ,
interleavedLayout,
indicesOptions,
space) {
FAISS_THROW_IF_NOT_MSG(
indicesOptions == INDICES_64_BIT,
"only INDICES_64_BIT is supported for RAFT index");
reset();
}
RaftIVFFlat::~RaftIVFFlat() {}
/// Find the approximate k nearest neighbors for `queries` against
/// our database
void RaftIVFFlat::search(
Index* coarseQuantizer,
Tensor<float, 2, true>& queries,
int nprobe,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices) {
// TODO: We probably don't want to ignore the coarse quantizer here...
uint32_t numQueries = queries.getSize(0);
uint32_t cols = queries.getSize(1);
uint32_t k_ = k;
// Device is already set in GpuIndex::search
FAISS_ASSERT(raft_knn_index.has_value());
FAISS_ASSERT(numQueries > 0);
FAISS_ASSERT(cols == dim_);
FAISS_THROW_IF_NOT(nprobe > 0 && nprobe <= numLists_);
const raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
raft::neighbors::ivf_flat::search_params pams;
pams.n_probes = nprobe;
auto queries_view = raft::make_device_matrix_view<const float, idx_t>(
queries.data(), (idx_t)numQueries, (idx_t)cols);
auto out_inds_view = raft::make_device_matrix_view<idx_t, idx_t>(
outIndices.data(), (idx_t)numQueries, (idx_t)k_);
auto out_dists_view = raft::make_device_matrix_view<float, idx_t>(
outDistances.data(), (idx_t)numQueries, (idx_t)k_);
raft::neighbors::ivf_flat::search<float, idx_t>(
raft_handle,
pams,
raft_knn_index.value(),
queries_view,
out_inds_view,
out_dists_view);
/// Identify NaN rows and mask their nearest neighbors
auto nan_flag = raft::make_device_vector<bool>(raft_handle, numQueries);
validRowIndices_(queries, nan_flag.data_handle());
raft::linalg::map_offset(
raft_handle,
raft::make_device_vector_view(outIndices.data(), numQueries * k_),
[nan_flag = nan_flag.data_handle(),
out_inds = outIndices.data(),
k_] __device__(uint32_t i) {
uint32_t row = i / k_;
if (!nan_flag[row])
return idx_t(-1);
return out_inds[i];
});
float max_val = std::numeric_limits<float>::max();
raft::linalg::map_offset(
raft_handle,
raft::make_device_vector_view(outDistances.data(), numQueries * k_),
[nan_flag = nan_flag.data_handle(),
out_dists = outDistances.data(),
max_val,
k_] __device__(uint32_t i) {
uint32_t row = i / k_;
if (!nan_flag[row])
return max_val;
return out_dists[i];
});
}
/// Classify and encode/add vectors to our IVF lists.
/// The input data must be on our current device.
/// Returns the number of vectors successfully added. Vectors may
/// not be able to be added because they contain NaNs.
idx_t RaftIVFFlat::addVectors(
Index* coarseQuantizer,
Tensor<float, 2, true>& vecs,
Tensor<idx_t, 1, true>& indices) {
/// TODO: We probably don't want to ignore the coarse quantizer here
idx_t n_rows = vecs.getSize(0);
const raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
/// Remove NaN values
auto nan_flag = raft::make_device_vector<bool, idx_t>(raft_handle, n_rows);
validRowIndices_(vecs, nan_flag.data_handle());
idx_t n_rows_valid = thrust::reduce(
raft_handle.get_thrust_policy(),
nan_flag.data_handle(),
nan_flag.data_handle() + n_rows,
0);
if (n_rows_valid < n_rows) {
auto gather_indices = raft::make_device_vector<idx_t, idx_t>(
raft_handle, n_rows_valid);
auto count = thrust::make_counting_iterator(0);
thrust::copy_if(
raft_handle.get_thrust_policy(),
count,
count + n_rows,
gather_indices.data_handle(),
[nan_flag = nan_flag.data_handle()] __device__(auto i) {
return nan_flag[i];
});
raft::matrix::gather(
raft_handle,
raft::make_device_matrix_view<float, idx_t>(
vecs.data(), n_rows, dim_),
raft::make_const_mdspan(gather_indices.view()),
(idx_t)16);
auto valid_indices = raft::make_device_vector<idx_t, idx_t>(
raft_handle, n_rows_valid);
raft::matrix::gather(
raft_handle,
raft::make_device_matrix_view<idx_t>(
indices.data(), n_rows, (idx_t)1),
raft::make_const_mdspan(gather_indices.view()));
}
FAISS_ASSERT(raft_knn_index.has_value());
raft_knn_index.emplace(raft::neighbors::ivf_flat::extend(
raft_handle,
raft::make_device_matrix_view<const float, idx_t>(
vecs.data(), n_rows_valid, dim_),
std::make_optional<raft::device_vector_view<const idx_t, idx_t>>(
raft::make_device_vector_view<const idx_t, idx_t>(
indices.data(), n_rows_valid)),
raft_knn_index.value()));
return n_rows_valid;
}
void RaftIVFFlat::reset() {
raft_knn_index.reset();
}
idx_t RaftIVFFlat::getListLength(idx_t listId) const {
FAISS_ASSERT(raft_knn_index.has_value());
const raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
uint32_t size;
raft::update_host(
&size,
raft_knn_index.value().list_sizes().data_handle() + listId,
1,
raft_handle.get_stream());
raft_handle.sync_stream();
return static_cast<int>(size);
}
/// Return the list indices of a particular list back to the CPU
std::vector<idx_t> RaftIVFFlat::getListIndices(idx_t listId) const {
FAISS_ASSERT(raft_knn_index.has_value());
const raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
auto stream = raft_handle.get_stream();
idx_t listSize = getListLength(listId);
std::vector<idx_t> vec(listSize);
// fetch the list indices ptr on host
idx_t* list_indices_ptr;
// fetch the list indices ptr on host
raft::update_host(
&list_indices_ptr,
raft_knn_index.value().inds_ptrs().data_handle() + listId,
1,
stream);
raft_handle.sync_stream();
raft::update_host(vec.data(), list_indices_ptr, listSize, stream);
raft_handle.sync_stream();
return vec;
}
/// Return the encoded vectors of a particular list back to the CPU
std::vector<uint8_t> RaftIVFFlat::getListVectorData(
idx_t listId,
bool gpuFormat) const {
if (gpuFormat) {
FAISS_THROW_MSG("gpuFormat is not suppported for raft indices");
}
FAISS_ASSERT(raft_knn_index.has_value());
const raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
auto stream = raft_handle.get_stream();
idx_t listSize = getListLength(listId);
// the interleaved block can be slightly larger than the list size (it's
// rounded up)
auto gpuListSizeInBytes = getGpuVectorsEncodingSize_(listSize);
auto cpuListSizeInBytes = getCpuVectorsEncodingSize_(listSize);
std::vector<uint8_t> interleaved_codes(gpuListSizeInBytes);
std::vector<uint8_t> flat_codes(cpuListSizeInBytes);
float* list_data_ptr;
// fetch the list data ptr on host
raft::update_host(
&list_data_ptr,
raft_knn_index.value().data_ptrs().data_handle() + listId,
1,
stream);
raft_handle.sync_stream();
raft::update_host(
interleaved_codes.data(),
reinterpret_cast<uint8_t*>(list_data_ptr),
gpuListSizeInBytes,
stream);
raft_handle.sync_stream();
RaftIVFFlatCodePackerInterleaved packer(
(size_t)listSize, dim_, raft_knn_index.value().veclen());
packer.unpack_all(interleaved_codes.data(), flat_codes.data());
return flat_codes;
}
/// Performs search when we are already given the IVF cells to look at
/// (GpuIndexIVF::search_preassigned implementation)
void RaftIVFFlat::searchPreassigned(
Index* coarseQuantizer,
Tensor<float, 2, true>& vecs,
Tensor<float, 2, true>& ivfDistances,
Tensor<idx_t, 2, true>& ivfAssignments,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices,
bool storePairs) {
// TODO: Fill this in!
}
void RaftIVFFlat::updateQuantizer(Index* quantizer) {
idx_t quantizer_ntotal = quantizer->ntotal;
const raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
auto stream = raft_handle.get_stream();
auto total_elems = size_t(quantizer_ntotal) * size_t(quantizer->d);
raft::logger::get().set_level(RAFT_LEVEL_TRACE);
raft::neighbors::ivf_flat::index_params pams;
pams.add_data_on_build = false;
pams.n_lists = this->numLists_;
switch (this->metric_) {
case faiss::METRIC_L2:
pams.metric = raft::distance::DistanceType::L2Expanded;
break;
case faiss::METRIC_INNER_PRODUCT:
pams.metric = raft::distance::DistanceType::InnerProduct;
break;
default:
FAISS_THROW_MSG("Metric is not supported.");
}
raft_knn_index.emplace(raft_handle, pams, (uint32_t)this->dim_);
cudaMemsetAsync(
raft_knn_index.value().list_sizes().data_handle(),
0,
raft_knn_index.value().list_sizes().size() * sizeof(uint32_t),
stream);
cudaMemsetAsync(
raft_knn_index.value().data_ptrs().data_handle(),
0,
raft_knn_index.value().data_ptrs().size() * sizeof(float*),
stream);
cudaMemsetAsync(
raft_knn_index.value().inds_ptrs().data_handle(),
0,
raft_knn_index.value().inds_ptrs().size() * sizeof(idx_t*),
stream);
/// Copy (reconstructed) centroids over, rather than re-training
std::vector<float> buf_host(total_elems);
quantizer->reconstruct_n(0, quantizer_ntotal, buf_host.data());
raft::update_device(
raft_knn_index.value().centers().data_handle(),
buf_host.data(),
total_elems,
stream);
}
void RaftIVFFlat::copyInvertedListsFrom(const InvertedLists* ivf) {
size_t nlist = ivf ? ivf->nlist : 0;
size_t ntotal = ivf ? ivf->compute_ntotal() : 0;
raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
std::vector<uint32_t> list_sizes_(nlist);
std::vector<idx_t> indices_(ntotal);
// the index must already exist
FAISS_ASSERT(raft_knn_index.has_value());
auto& raft_lists = raft_knn_index.value().lists();
// conservative memory alloc for cloning cpu inverted lists
raft::neighbors::ivf_flat::list_spec<uint32_t, float, idx_t> raft_list_spec{
static_cast<uint32_t>(dim_), true};
for (size_t i = 0; i < nlist; ++i) {
size_t listSize = ivf->list_size(i);
// GPU index can only support max int entries per list
FAISS_THROW_IF_NOT_FMT(
listSize <= (size_t)std::numeric_limits<int>::max(),
"GPU inverted list can only support "
"%zu entries; %zu found",
(size_t)std::numeric_limits<int>::max(),
listSize);
// store the list size
list_sizes_[i] = static_cast<uint32_t>(listSize);
raft::neighbors::ivf::resize_list(
raft_handle,
raft_lists[i],
raft_list_spec,
(uint32_t)listSize,
(uint32_t)0);
}
// Update the pointers and the sizes
raft_knn_index.value().recompute_internal_state(raft_handle);
for (size_t i = 0; i < nlist; ++i) {
size_t listSize = ivf->list_size(i);
addEncodedVectorsToList_(
i, ivf->get_codes(i), ivf->get_ids(i), listSize);
}
raft::update_device(
raft_knn_index.value().list_sizes().data_handle(),
list_sizes_.data(),
nlist,
raft_handle.get_stream());
// Precompute the centers vector norms for L2Expanded distance
if (this->metric_ == faiss::METRIC_L2) {
raft_knn_index.value().allocate_center_norms(raft_handle);
raft::linalg::rowNorm(
raft_knn_index.value().center_norms().value().data_handle(),
raft_knn_index.value().centers().data_handle(),
raft_knn_index.value().dim(),
(uint32_t)nlist,
raft::linalg::L2Norm,
true,
raft_handle.get_stream());
}
}
size_t RaftIVFFlat::getGpuVectorsEncodingSize_(idx_t numVecs) const {
idx_t bits = 32 /* float */;
// bytes to encode a block of 32 vectors (single dimension)
idx_t bytesPerDimBlock = bits * 32 / 8; // = 128
// bytes to fully encode 32 vectors
idx_t bytesPerBlock = bytesPerDimBlock * dim_;
// number of blocks of 32 vectors we have
idx_t numBlocks =
utils::divUp(numVecs, raft::neighbors::ivf_flat::kIndexGroupSize);
// total size to encode numVecs
return bytesPerBlock * numBlocks;
}
void RaftIVFFlat::addEncodedVectorsToList_(
idx_t listId,
const void* codes,
const idx_t* indices,
idx_t numVecs) {
auto stream = resources_->getDefaultStreamCurrentDevice();
// This list must already exist
FAISS_ASSERT(raft_knn_index.has_value());
// This list must currently be empty
FAISS_ASSERT(getListLength(listId) == 0);
// If there's nothing to add, then there's nothing we have to do
if (numVecs == 0) {
return;
}
// The GPU might have a different layout of the memory
auto gpuListSizeInBytes = getGpuVectorsEncodingSize_(numVecs);
auto cpuListSizeInBytes = getCpuVectorsEncodingSize_(numVecs);
// We only have int32 length representations on the GPU per each
// list; the length is in sizeof(char)
FAISS_ASSERT(gpuListSizeInBytes <= (size_t)std::numeric_limits<int>::max());
std::vector<uint8_t> interleaved_codes(gpuListSizeInBytes);
RaftIVFFlatCodePackerInterleaved packer(
(size_t)numVecs, (uint32_t)dim_, raft_knn_index.value().veclen());
packer.pack_all(
reinterpret_cast<const uint8_t*>(codes), interleaved_codes.data());
float* list_data_ptr;
const raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
/// fetch the list data ptr on host
raft::update_host(
&list_data_ptr,
raft_knn_index.value().data_ptrs().data_handle() + listId,
1,
stream);
raft_handle.sync_stream();
raft::update_device(
reinterpret_cast<uint8_t*>(list_data_ptr),
interleaved_codes.data(),
gpuListSizeInBytes,
stream);
/// Handle the indices as well
idx_t* list_indices_ptr;
// fetch the list indices ptr on host
raft::update_host(
&list_indices_ptr,
raft_knn_index.value().inds_ptrs().data_handle() + listId,
1,
stream);
raft_handle.sync_stream();
raft::update_device(list_indices_ptr, indices, numVecs, stream);
}
void RaftIVFFlat::validRowIndices_(
Tensor<float, 2, true>& vecs,
bool* nan_flag) {
raft::device_resources& raft_handle =
resources_->getRaftHandleCurrentDevice();
idx_t n_rows = vecs.getSize(0);
thrust::fill_n(raft_handle.get_thrust_policy(), nan_flag, n_rows, true);
raft::linalg::map_offset(
raft_handle,
raft::make_device_vector_view<bool, idx_t>(nan_flag, n_rows),
[vecs = vecs.data(), dim_ = this->dim_] __device__(idx_t i) {
for (idx_t col = 0; col < dim_; col++) {
if (!isfinite(vecs[i * dim_ + col])) {
return false;
}
}
return true;
});
}
RaftIVFFlatCodePackerInterleaved::RaftIVFFlatCodePackerInterleaved(
size_t list_size,
uint32_t dim,
uint32_t chunk_size) {
this->dim = dim;
this->chunk_size = chunk_size;
// NB: dim should be divisible by the number of 4 byte records in one chunk
FAISS_ASSERT(dim % chunk_size == 0);
nvec = list_size;
code_size = dim * 4;
block_size =
utils::roundUp(nvec, raft::neighbors::ivf_flat::kIndexGroupSize);
}
void RaftIVFFlatCodePackerInterleaved::pack_1(
const uint8_t* flat_code,
size_t offset,
uint8_t* block) const {
raft::neighbors::ivf_flat::codepacker::pack_1(
reinterpret_cast<const uint32_t*>(flat_code),
reinterpret_cast<uint32_t*>(block),
dim,
chunk_size,
static_cast<uint32_t>(offset));
}
void RaftIVFFlatCodePackerInterleaved::unpack_1(
const uint8_t* block,
size_t offset,
uint8_t* flat_code) const {
raft::neighbors::ivf_flat::codepacker::unpack_1(
reinterpret_cast<const uint32_t*>(block),
reinterpret_cast<uint32_t*>(flat_code),
dim,
chunk_size,
static_cast<uint32_t>(offset));
}
} // namespace gpu
} // namespace faiss

View File

@ -0,0 +1,149 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <faiss/gpu/impl/GpuScalarQuantizer.cuh>
#include <faiss/gpu/impl/IVFBase.cuh>
#include <faiss/gpu/impl/IVFFlat.cuh>
#include <faiss/impl/CodePacker.h>
#include <raft/neighbors/ivf_flat.cuh>
#include <optional>
namespace faiss {
namespace gpu {
class RaftIVFFlat : public IVFFlat {
public:
RaftIVFFlat(
GpuResources* resources,
int dim,
int nlist,
faiss::MetricType metric,
float metricArg,
bool useResidual,
/// Optional ScalarQuantizer
faiss::ScalarQuantizer* scalarQ,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space);
~RaftIVFFlat() override;
/// Find the approximate k nearest neigbors for `queries` against
/// our database
void search(
Index* coarseQuantizer,
Tensor<float, 2, true>& queries,
int nprobe,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices) override;
/// Performs search when we are already given the IVF cells to look at
/// (GpuIndexIVF::search_preassigned implementation)
void searchPreassigned(
Index* coarseQuantizer,
Tensor<float, 2, true>& vecs,
Tensor<float, 2, true>& ivfDistances,
Tensor<idx_t, 2, true>& ivfAssignments,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices,
bool storePairs) override;
/// Classify and encode/add vectors to our IVF lists.
/// The input data must be on our current device.
/// Returns the number of vectors successfully added. Vectors may
/// not be able to be added because they contain NaNs.
idx_t addVectors(
Index* coarseQuantizer,
Tensor<float, 2, true>& vecs,
Tensor<idx_t, 1, true>& indices) override;
/// Reserve GPU memory in our inverted lists for this number of vectors
// void reserveMemory(idx_t numVecs) override;
/// Clear out all inverted lists, but retain the coarse quantizer
/// and the product quantizer info
void reset() override;
/// For debugging purposes, return the list length of a particular
/// list
idx_t getListLength(idx_t listId) const override;
/// Return the list indices of a particular list back to the CPU
std::vector<idx_t> getListIndices(idx_t listId) const override;
/// Return the encoded vectors of a particular list back to the CPU
std::vector<uint8_t> getListVectorData(idx_t listId, bool gpuFormat)
const override;
void updateQuantizer(Index* quantizer) override;
/// Copy all inverted lists from a CPU representation to ourselves
void copyInvertedListsFrom(const InvertedLists* ivf) override;
/// Filter out matrix rows containing NaN values
void validRowIndices_(Tensor<float, 2, true>& vecs, bool* nan_flag);
protected:
/// Adds a set of codes and indices to a list, with the representation
/// coming from the CPU equivalent
void addEncodedVectorsToList_(
idx_t listId,
// resident on the host
const void* codes,
// resident on the host
const idx_t* indices,
idx_t numVecs) override;
/// Returns the number of bytes in which an IVF list containing numVecs
/// vectors is encoded on the device. Note that due to padding this is not
/// the same as the encoding size for a subset of vectors in an IVF list;
/// this is the size for an entire IVF list
size_t getGpuVectorsEncodingSize_(idx_t numVecs) const override;
std::optional<raft::neighbors::ivf_flat::index<float, idx_t>>
raft_knn_index{std::nullopt};
};
struct RaftIVFFlatCodePackerInterleaved : CodePacker {
RaftIVFFlatCodePackerInterleaved(
size_t list_size,
uint32_t dim,
uint32_t chuk_size);
void pack_1(const uint8_t* flat_code, size_t offset, uint8_t* block)
const final;
void unpack_1(const uint8_t* block, size_t offset, uint8_t* flat_code)
const final;
protected:
uint32_t chunk_size;
uint32_t dim;
};
} // namespace gpu
} // namespace faiss

View File

@ -749,7 +749,6 @@ void testSearchAndReconstruct(bool use_raft) {
}
}
}
TEST(TestGpuIndexFlat, SearchAndReconstruct) {
testSearchAndReconstruct(false);
}
@ -767,4 +766,4 @@ int main(int argc, char** argv) {
faiss::gpu::setTestSeed(100);
return RUN_ALL_TESTS();
}
}

View File

@ -30,6 +30,7 @@
#include <cmath>
#include <sstream>
#include <vector>
#include "faiss/gpu/GpuIndicesOptions.h"
// FIXME: figure out a better way to test fp16
constexpr float kF16MaxRelErr = 0.3f;
@ -55,6 +56,8 @@ struct Options {
faiss::gpu::INDICES_64_BIT});
device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
use_raft = false;
}
std::string toString() const {
@ -62,7 +65,7 @@ struct Options {
str << "IVFFlat device " << device << " numVecs " << numAdd << " dim "
<< dim << " numCentroids " << numCentroids << " nprobe " << nprobe
<< " numQuery " << numQuery << " k " << k << " indicesOpt "
<< indicesOpt;
<< indicesOpt << " use_raft " << use_raft;
return str.str();
}
@ -76,6 +79,7 @@ struct Options {
int k;
int device;
faiss::gpu::IndicesOptions indicesOpt;
bool use_raft;
};
void queryTest(
@ -106,6 +110,7 @@ void queryTest(
config.device = opt.device;
config.indicesOptions = opt.indicesOpt;
config.flatConfig.useFloat16 = useFloat16CoarseQuantizer;
config.use_raft = opt.use_raft;
faiss::gpu::GpuIndexIVFFlat gpuIndex(
&res, cpuIndex.d, cpuIndex.nlist, cpuIndex.metric_type, config);
@ -129,7 +134,10 @@ void queryTest(
}
}
void addTest(faiss::MetricType metricType, bool useFloat16CoarseQuantizer) {
void addTest(
faiss::MetricType metricType,
bool useFloat16CoarseQuantizer,
bool use_raft) {
for (int tries = 0; tries < 2; ++tries) {
Options opt;
@ -153,8 +161,10 @@ void addTest(faiss::MetricType metricType, bool useFloat16CoarseQuantizer) {
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = opt.device;
config.indicesOptions = opt.indicesOpt;
config.indicesOptions =
use_raft ? faiss::gpu::INDICES_64_BIT : opt.indicesOpt;
config.flatConfig.useFloat16 = useFloat16CoarseQuantizer;
config.use_raft = use_raft;
faiss::gpu::GpuIndexIVFFlat gpuIndex(
&res, cpuIndex.d, cpuIndex.nlist, cpuIndex.metric_type, config);
@ -178,7 +188,7 @@ void addTest(faiss::MetricType metricType, bool useFloat16CoarseQuantizer) {
}
}
void copyToTest(bool useFloat16CoarseQuantizer) {
void copyToTest(bool useFloat16CoarseQuantizer, bool use_raft) {
Options opt;
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);
@ -188,8 +198,10 @@ void copyToTest(bool useFloat16CoarseQuantizer) {
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = opt.device;
config.indicesOptions = opt.indicesOpt;
config.indicesOptions =
use_raft ? faiss::gpu::INDICES_64_BIT : opt.indicesOpt;
config.flatConfig.useFloat16 = useFloat16CoarseQuantizer;
config.use_raft = use_raft;
faiss::gpu::GpuIndexIVFFlat gpuIndex(
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
@ -229,7 +241,7 @@ void copyToTest(bool useFloat16CoarseQuantizer) {
compFloat16 ? 0.30f : 0.015f);
}
void copyFromTest(bool useFloat16CoarseQuantizer) {
void copyFromTest(bool useFloat16CoarseQuantizer, bool use_raft) {
Options opt;
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);
@ -247,8 +259,10 @@ void copyFromTest(bool useFloat16CoarseQuantizer) {
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = opt.device;
config.indicesOptions = opt.indicesOpt;
config.indicesOptions =
use_raft ? faiss::gpu::INDICES_64_BIT : opt.indicesOpt;
config.flatConfig.useFloat16 = useFloat16CoarseQuantizer;
config.use_raft = use_raft;
faiss::gpu::GpuIndexIVFFlat gpuIndex(&res, 1, 1, faiss::METRIC_L2, config);
gpuIndex.nprobe = 1;
@ -280,19 +294,35 @@ void copyFromTest(bool useFloat16CoarseQuantizer) {
}
TEST(TestGpuIndexIVFFlat, Float32_32_Add_L2) {
addTest(faiss::METRIC_L2, false);
addTest(faiss::METRIC_L2, false, false);
#if defined USE_NVIDIA_RAFT
addTest(faiss::METRIC_L2, false, true);
#endif
}
TEST(TestGpuIndexIVFFlat, Float32_32_Add_IP) {
addTest(faiss::METRIC_INNER_PRODUCT, false);
addTest(faiss::METRIC_INNER_PRODUCT, false, false);
#if defined USE_NVIDIA_RAFT
addTest(faiss::METRIC_INNER_PRODUCT, false, true);
#endif
}
TEST(TestGpuIndexIVFFlat, Float16_32_Add_L2) {
addTest(faiss::METRIC_L2, true);
addTest(faiss::METRIC_L2, true, false);
#if defined USE_NVIDIA_RAFT
addTest(faiss::METRIC_L2, true, true);
#endif
}
TEST(TestGpuIndexIVFFlat, Float16_32_Add_IP) {
addTest(faiss::METRIC_INNER_PRODUCT, true);
addTest(faiss::METRIC_INNER_PRODUCT, true, false);
#if defined USE_NVIDIA_RAFT
addTest(faiss::METRIC_INNER_PRODUCT, true, true);
#endif
}
//
@ -300,11 +330,25 @@ TEST(TestGpuIndexIVFFlat, Float16_32_Add_IP) {
//
TEST(TestGpuIndexIVFFlat, Float32_Query_L2) {
queryTest(Options(), faiss::METRIC_L2, false);
Options opt;
queryTest(opt, faiss::METRIC_L2, false);
#if defined USE_NVIDIA_RAFT
opt.use_raft = true;
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
queryTest(opt, faiss::METRIC_L2, false);
#endif
}
TEST(TestGpuIndexIVFFlat, Float32_Query_IP) {
queryTest(Options(), faiss::METRIC_INNER_PRODUCT, false);
Options opt;
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
#if defined USE_NVIDIA_RAFT
opt.use_raft = true;
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
#endif
}
TEST(TestGpuIndexIVFFlat, LargeBatch) {
@ -312,16 +356,36 @@ TEST(TestGpuIndexIVFFlat, LargeBatch) {
opt.dim = 3;
opt.numQuery = 100000;
queryTest(opt, faiss::METRIC_L2, false);
#if defined USE_NVIDIA_RAFT
opt.use_raft = true;
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
queryTest(opt, faiss::METRIC_L2, false);
#endif
}
// float16 coarse quantizer
TEST(TestGpuIndexIVFFlat, Float16_32_Query_L2) {
queryTest(Options(), faiss::METRIC_L2, true);
Options opt;
queryTest(opt, faiss::METRIC_L2, true);
#if defined USE_NVIDIA_RAFT
opt.use_raft = true;
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
queryTest(opt, faiss::METRIC_L2, true);
#endif
}
TEST(TestGpuIndexIVFFlat, Float16_32_Query_IP) {
queryTest(Options(), faiss::METRIC_INNER_PRODUCT, true);
Options opt;
queryTest(opt, faiss::METRIC_INNER_PRODUCT, true);
#if defined USE_NVIDIA_RAFT
opt.use_raft = true;
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
queryTest(opt, faiss::METRIC_INNER_PRODUCT, true);
#endif
}
//
@ -333,24 +397,48 @@ TEST(TestGpuIndexIVFFlat, Float32_Query_L2_64) {
Options opt;
opt.dim = 64;
queryTest(opt, faiss::METRIC_L2, false);
#if defined USE_NVIDIA_RAFT
opt.use_raft = true;
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
queryTest(opt, faiss::METRIC_L2, false);
#endif
}
TEST(TestGpuIndexIVFFlat, Float32_Query_IP_64) {
Options opt;
opt.dim = 64;
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
#if defined USE_NVIDIA_RAFT
opt.use_raft = true;
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
#endif
}
TEST(TestGpuIndexIVFFlat, Float32_Query_L2_128) {
Options opt;
opt.dim = 128;
queryTest(opt, faiss::METRIC_L2, false);
#if defined USE_NVIDIA_RAFT
opt.use_raft = true;
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
queryTest(opt, faiss::METRIC_L2, false);
#endif
}
TEST(TestGpuIndexIVFFlat, Float32_Query_IP_128) {
Options opt;
opt.dim = 128;
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
#if defined USE_NVIDIA_RAFT
opt.use_raft = true;
opt.indicesOpt = faiss::gpu::INDICES_64_BIT;
queryTest(opt, faiss::METRIC_INNER_PRODUCT, false);
#endif
}
//
@ -358,11 +446,19 @@ TEST(TestGpuIndexIVFFlat, Float32_Query_IP_128) {
//
TEST(TestGpuIndexIVFFlat, Float32_32_CopyTo) {
copyToTest(false);
copyToTest(false, false);
#if defined USE_NVIDIA_RAFT
copyToTest(false, true);
#endif
}
TEST(TestGpuIndexIVFFlat, Float32_32_CopyFrom) {
copyFromTest(false);
copyFromTest(false, false);
#if defined USE_NVIDIA_RAFT
copyFromTest(false, true);
#endif
}
TEST(TestGpuIndexIVFFlat, Float32_negative) {
@ -392,6 +488,14 @@ TEST(TestGpuIndexIVFFlat, Float32_negative) {
faiss::gpu::StandardGpuResources res;
res.noTempMemory();
// Construct a positive test set
auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim);
// Put all vecs on positive size
for (auto& f : queryVecs) {
f = std::abs(f);
}
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = opt.device;
config.indicesOptions = opt.indicesOpt;
@ -401,14 +505,6 @@ TEST(TestGpuIndexIVFFlat, Float32_negative) {
gpuIndex.copyFrom(&cpuIndex);
gpuIndex.nprobe = opt.nprobe;
// Construct a positive test set
auto queryVecs = faiss::gpu::randVecs(opt.numQuery, opt.dim);
// Put all vecs on positive size
for (auto& f : queryVecs) {
f = std::abs(f);
}
bool compFloat16 = false;
faiss::gpu::compareIndices(
queryVecs,
@ -424,6 +520,31 @@ TEST(TestGpuIndexIVFFlat, Float32_negative) {
// in fp16. Figure out another way to test
compFloat16 ? 0.99f : 0.1f,
compFloat16 ? 0.65f : 0.015f);
#if defined USE_NVIDIA_RAFT
config.use_raft = true;
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
&res, cpuIndex.d, cpuIndex.nlist, cpuIndex.metric_type, config);
raftGpuIndex.copyFrom(&cpuIndex);
raftGpuIndex.nprobe = opt.nprobe;
faiss::gpu::compareIndices(
queryVecs,
cpuIndex,
raftGpuIndex,
opt.numQuery,
opt.dim,
opt.k,
opt.toString(),
compFloat16 ? kF16MaxRelErr : kF32MaxRelErr,
// FIXME: the fp16 bounds are
// useless when math (the accumulator) is
// in fp16. Figure out another way to test
compFloat16 ? 0.99f : 0.1f,
compFloat16 ? 0.65f : 0.015f);
#endif
}
//
@ -439,6 +560,13 @@ TEST(TestGpuIndexIVFFlat, QueryNaN) {
faiss::gpu::StandardGpuResources res;
res.noTempMemory();
int numQuery = 10;
std::vector<float> nans(
numQuery * opt.dim, std::numeric_limits<float>::quiet_NaN());
std::vector<float> distances(numQuery * opt.k, 0);
std::vector<faiss::idx_t> indices(numQuery * opt.k, 0);
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = opt.device;
config.indicesOptions = opt.indicesOpt;
@ -451,13 +579,6 @@ TEST(TestGpuIndexIVFFlat, QueryNaN) {
gpuIndex.train(opt.numTrain, trainVecs.data());
gpuIndex.add(opt.numAdd, addVecs.data());
int numQuery = 10;
std::vector<float> nans(
numQuery * opt.dim, std::numeric_limits<float>::quiet_NaN());
std::vector<float> distances(numQuery * opt.k, 0);
std::vector<faiss::idx_t> indices(numQuery * opt.k, 0);
gpuIndex.search(
numQuery, nans.data(), opt.k, distances.data(), indices.data());
@ -469,6 +590,31 @@ TEST(TestGpuIndexIVFFlat, QueryNaN) {
std::numeric_limits<float>::max());
}
}
#if defined USE_NVIDIA_RAFT
config.use_raft = true;
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
std::fill(distances.begin(), distances.end(), 0);
std::fill(indices.begin(), indices.end(), 0);
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
raftGpuIndex.nprobe = opt.nprobe;
raftGpuIndex.train(opt.numTrain, trainVecs.data());
raftGpuIndex.add(opt.numAdd, addVecs.data());
raftGpuIndex.search(
numQuery, nans.data(), opt.k, distances.data(), indices.data());
for (int q = 0; q < numQuery; ++q) {
for (int k = 0; k < opt.k; ++k) {
EXPECT_EQ(indices[q * opt.k + k], -1);
EXPECT_EQ(
distances[q * opt.k + k],
std::numeric_limits<float>::max());
}
}
#endif
}
TEST(TestGpuIndexIVFFlat, AddNaN) {
@ -477,15 +623,6 @@ TEST(TestGpuIndexIVFFlat, AddNaN) {
faiss::gpu::StandardGpuResources res;
res.noTempMemory();
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = opt.device;
config.indicesOptions = opt.indicesOpt;
config.flatConfig.useFloat16 = faiss::gpu::randBool();
faiss::gpu::GpuIndexIVFFlat gpuIndex(
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
gpuIndex.nprobe = opt.nprobe;
int numNans = 10;
std::vector<float> nans(
numNans * opt.dim, std::numeric_limits<float>::quiet_NaN());
@ -497,6 +634,14 @@ TEST(TestGpuIndexIVFFlat, AddNaN) {
}
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = opt.device;
config.indicesOptions = opt.indicesOpt;
config.flatConfig.useFloat16 = faiss::gpu::randBool();
faiss::gpu::GpuIndexIVFFlat gpuIndex(
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
gpuIndex.nprobe = opt.nprobe;
gpuIndex.train(opt.numTrain, trainVecs.data());
// should not crash
@ -514,6 +659,27 @@ TEST(TestGpuIndexIVFFlat, AddNaN) {
opt.k,
distance.data(),
indices.data());
#if defined USE_NVIDIA_RAFT
config.use_raft = true;
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
raftGpuIndex.nprobe = opt.nprobe;
raftGpuIndex.train(opt.numTrain, trainVecs.data());
// should not crash
EXPECT_EQ(raftGpuIndex.ntotal, 0);
raftGpuIndex.add(numNans, nans.data());
// should not crash
raftGpuIndex.search(
opt.numQuery,
queryVecs.data(),
opt.k,
distance.data(),
indices.data());
#endif
}
TEST(TestGpuIndexIVFFlat, UnifiedMemory) {
@ -570,6 +736,26 @@ TEST(TestGpuIndexIVFFlat, UnifiedMemory) {
kF32MaxRelErr,
0.1f,
0.015f);
#if defined USE_NVIDIA_RAFT
config.use_raft = true;
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
&res, dim, numCentroids, faiss::METRIC_L2, config);
raftGpuIndex.copyFrom(&cpuIndex);
raftGpuIndex.nprobe = nprobe;
faiss::gpu::compareIndices(
cpuIndex,
raftGpuIndex,
numQuery,
dim,
k,
"Unified Memory",
kF32MaxRelErr,
0.1f,
0.015f);
#endif
}
TEST(TestGpuIndexIVFFlat, LongIVFList) {
@ -628,6 +814,27 @@ TEST(TestGpuIndexIVFFlat, LongIVFList) {
kF32MaxRelErr,
0.1f,
0.015f);
#if defined USE_NVIDIA_RAFT
config.use_raft = true;
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
faiss::gpu::GpuIndexIVFFlat raftGpuIndex(
&res, dim, numCentroids, faiss::METRIC_L2, config);
raftGpuIndex.train(numTrain, trainVecs.data());
raftGpuIndex.add(numAdd, addVecs.data());
raftGpuIndex.nprobe = 1;
faiss::gpu::compareIndices(
cpuIndex,
raftGpuIndex,
numQuery,
dim,
k,
"Unified Memory",
kF32MaxRelErr,
0.1f,
0.015f);
#endif
}
int main(int argc, char** argv) {
@ -637,4 +844,4 @@ int main(int argc, char** argv) {
faiss::gpu::setTestSeed(100);
return RUN_ALL_TESTS();
}
}