Improve Faiss / PyTorch GPU interoperability
Summary: PyTorch GPU in general is free to use whatever stream it currently wants, based on `torch.cuda.current_stream()`. Due to C++/python language barrier issues, we couldn't previously pass the actual `cudaStream_t` that is currently in use on a given device from PyTorch C++ to Faiss C++ via python. This diff adds conversion functions to convert a Python integer representing a pointer to a `cudaStream_t` (which is itself a `CUstream_st*`), so we can pass the stream specified in `torch.cuda.current_stream()` to `StandardGpuResources::setDefaultStream`. We thus guarantee that all Faiss work is ordered on the same stream that is in use in PyTorch. For use in Python, there is now the `faiss.contrib.pytorch_tensors.using_stream` context object which automatically sets and unsets the current PyTorch stream within Faiss. This takes a `StandardGpuResources` object in Python, and an optional `torch.cuda.Stream` if one wants to use a different stream, otherwise it uses the current one. This is how it is used: ``` # Create a non-default stream s = torch.cuda.Stream() # Have Torch use it with torch.cuda.stream(s): # Have Faiss use the same stream as the above with faiss.contrib.pytorch_tensors.using_stream(res): # Do some work on the GPU faiss.bfKnn(res, args) ``` `using_stream` uses the same pattern as the Pytorch `torch.cuda.stream` object. This replaces any brute-force GPU/CPU synchronization work that was necessary before. Other changes in this diff: - cleans up the config objects in the GpuIndex subclasses, to distinguish between read-only parameters that can only be set upon index construction, versus those that can be changed at runtime. - StandardGpuResources now more properly distinguishes between user-supplied streams (like the PyTorch one) which will not be destroyed upon resources destruction, versus internal streams. - `search_index_pytorch` now needs to take a `StandardGpuResources` object as well, there is no way to get this from an index instance otherwise (or at least, I would have to return a `shared_ptr`, in which case we should just update the Python SWIG stuff to use `shared_ptr` for GpuResources or something. Reviewed By: mdouze Differential Revision: D24260026 fbshipit-source-id: b18bb0eb34eb012584b1c923088228776c10b720pull/1462/head
parent
b459931ae4
commit
e796f4f9df
|
@ -1,5 +1,6 @@
|
||||||
import faiss
|
import faiss
|
||||||
import torch
|
import torch
|
||||||
|
import contextlib
|
||||||
|
|
||||||
def swig_ptr_from_FloatTensor(x):
|
def swig_ptr_from_FloatTensor(x):
|
||||||
""" gets a Faiss SWIG pointer from a pytorch trensor (on CPU or GPU) """
|
""" gets a Faiss SWIG pointer from a pytorch trensor (on CPU or GPU) """
|
||||||
|
@ -15,9 +16,33 @@ def swig_ptr_from_LongTensor(x):
|
||||||
return faiss.cast_integer_to_long_ptr(
|
return faiss.cast_integer_to_long_ptr(
|
||||||
x.storage().data_ptr() + x.storage_offset() * 8)
|
x.storage().data_ptr() + x.storage_offset() * 8)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def using_stream(res, pytorch_stream=None):
|
||||||
|
r""" Creates a scoping object to make Faiss GPU use the same stream
|
||||||
|
as pytorch, based on torch.cuda.current_stream().
|
||||||
|
Or, a specific pytorch stream can be passed in as a second
|
||||||
|
argument, in which case we will use that stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if pytorch_stream is None:
|
||||||
|
pytorch_stream = torch.cuda.current_stream()
|
||||||
|
|
||||||
def search_index_pytorch(index, x, k, D=None, I=None):
|
# This is the cudaStream_t that we wish to use
|
||||||
|
cuda_stream_s = faiss.cast_integer_to_cudastream_t(pytorch_stream.cuda_stream)
|
||||||
|
|
||||||
|
# So we can revert GpuResources stream state upon exit
|
||||||
|
prior_dev = torch.cuda.current_device()
|
||||||
|
prior_stream = res.getDefaultStream(torch.cuda.current_device())
|
||||||
|
|
||||||
|
res.setDefaultStream(torch.cuda.current_device(), cuda_stream_s)
|
||||||
|
|
||||||
|
# Do the user work
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
res.setDefaultStream(prior_dev, prior_stream)
|
||||||
|
|
||||||
|
def search_index_pytorch(res, index, x, k, D=None, I=None):
|
||||||
"""call the search function of an index with pytorch tensor I/O (CPU
|
"""call the search function of an index with pytorch tensor I/O (CPU
|
||||||
and GPU supported)"""
|
and GPU supported)"""
|
||||||
assert x.is_contiguous()
|
assert x.is_contiguous()
|
||||||
|
@ -33,13 +58,13 @@ def search_index_pytorch(index, x, k, D=None, I=None):
|
||||||
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
|
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
|
||||||
else:
|
else:
|
||||||
assert I.size() == (n, k)
|
assert I.size() == (n, k)
|
||||||
torch.cuda.synchronize()
|
|
||||||
xptr = swig_ptr_from_FloatTensor(x)
|
with using_stream(res):
|
||||||
Iptr = swig_ptr_from_LongTensor(I)
|
xptr = swig_ptr_from_FloatTensor(x)
|
||||||
Dptr = swig_ptr_from_FloatTensor(D)
|
Iptr = swig_ptr_from_LongTensor(I)
|
||||||
index.search_c(n, xptr,
|
Dptr = swig_ptr_from_FloatTensor(D)
|
||||||
k, Dptr, Iptr)
|
index.search_c(n, xptr, k, Dptr, Iptr)
|
||||||
torch.cuda.synchronize()
|
|
||||||
return D, I
|
return D, I
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,6 +122,8 @@ def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,
|
||||||
args.numQueries = nq
|
args.numQueries = nq
|
||||||
args.outDistances = D_ptr
|
args.outDistances = D_ptr
|
||||||
args.outIndices = I_ptr
|
args.outIndices = I_ptr
|
||||||
faiss.bfKnn(res, args)
|
|
||||||
|
with using_stream(res):
|
||||||
|
faiss.bfKnn(res, args)
|
||||||
|
|
||||||
return D, I
|
return D, I
|
||||||
|
|
|
@ -42,25 +42,29 @@ GpuIndex::GpuIndex(std::shared_ptr<GpuResources> resources,
|
||||||
GpuIndexConfig config) :
|
GpuIndexConfig config) :
|
||||||
Index(dims, metric),
|
Index(dims, metric),
|
||||||
resources_(resources),
|
resources_(resources),
|
||||||
device_(config.device),
|
config_(config),
|
||||||
memorySpace_(config.memorySpace),
|
|
||||||
minPagedSize_(kMinPageSize) {
|
minPagedSize_(kMinPageSize) {
|
||||||
FAISS_THROW_IF_NOT_FMT(device_ < getNumDevices(),
|
FAISS_THROW_IF_NOT_FMT(config_.device < getNumDevices(),
|
||||||
"Invalid GPU device %d", device_);
|
"Invalid GPU device %d", config_.device);
|
||||||
|
|
||||||
FAISS_THROW_IF_NOT_MSG(dims > 0, "Invalid number of dimensions");
|
FAISS_THROW_IF_NOT_MSG(dims > 0, "Invalid number of dimensions");
|
||||||
|
|
||||||
FAISS_THROW_IF_NOT_FMT(
|
FAISS_THROW_IF_NOT_FMT(
|
||||||
memorySpace_ == MemorySpace::Device ||
|
config_.memorySpace == MemorySpace::Device ||
|
||||||
(memorySpace_ == MemorySpace::Unified &&
|
(config_.memorySpace == MemorySpace::Unified &&
|
||||||
getFullUnifiedMemSupport(device_)),
|
getFullUnifiedMemSupport(config_.device)),
|
||||||
"Device %d does not support full CUDA 8 Unified Memory (CC 6.0+)",
|
"Device %d does not support full CUDA 8 Unified Memory (CC 6.0+)",
|
||||||
config.device);
|
config.device);
|
||||||
|
|
||||||
metric_arg = metricArg;
|
metric_arg = metricArg;
|
||||||
|
|
||||||
FAISS_ASSERT((bool) resources_);
|
FAISS_ASSERT((bool) resources_);
|
||||||
resources_->initializeForDevice(device_);
|
resources_->initializeForDevice(config_.device);
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
GpuIndex::getDevice() const {
|
||||||
|
return config_.device;
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
|
@ -124,7 +128,7 @@ GpuIndex::add_with_ids(Index::idx_t n,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
addPaged_((int) n, x, ids ? ids : generatedIds.data());
|
addPaged_((int) n, x, ids ? ids : generatedIds.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,7 +175,7 @@ GpuIndex::addPage_(int n,
|
||||||
|
|
||||||
auto vecs =
|
auto vecs =
|
||||||
toDeviceTemporary<float, 2>(resources_.get(),
|
toDeviceTemporary<float, 2>(resources_.get(),
|
||||||
device_,
|
config_.device,
|
||||||
const_cast<float*>(x),
|
const_cast<float*>(x),
|
||||||
stream,
|
stream,
|
||||||
{n, this->d});
|
{n, this->d});
|
||||||
|
@ -179,7 +183,7 @@ GpuIndex::addPage_(int n,
|
||||||
if (ids) {
|
if (ids) {
|
||||||
auto indices =
|
auto indices =
|
||||||
toDeviceTemporary<Index::idx_t, 1>(resources_.get(),
|
toDeviceTemporary<Index::idx_t, 1>(resources_.get(),
|
||||||
device_,
|
config_.device,
|
||||||
const_cast<Index::idx_t*>(ids),
|
const_cast<Index::idx_t*>(ids),
|
||||||
stream,
|
stream,
|
||||||
{n});
|
{n});
|
||||||
|
@ -214,8 +218,8 @@ GpuIndex::search(Index::idx_t n,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
auto stream = resources_->getDefaultStream(device_);
|
auto stream = resources_->getDefaultStream(config_.device);
|
||||||
|
|
||||||
// We guarantee that the searchImpl_ will be called with device-resident
|
// We guarantee that the searchImpl_ will be called with device-resident
|
||||||
// pointers.
|
// pointers.
|
||||||
|
@ -228,12 +232,12 @@ GpuIndex::search(Index::idx_t n,
|
||||||
// another level of tiling.
|
// another level of tiling.
|
||||||
auto outDistances =
|
auto outDistances =
|
||||||
toDeviceTemporary<float, 2>(
|
toDeviceTemporary<float, 2>(
|
||||||
resources_.get(), device_, distances, stream,
|
resources_.get(), config_.device, distances, stream,
|
||||||
{(int) n, (int) k});
|
{(int) n, (int) k});
|
||||||
|
|
||||||
auto outLabels =
|
auto outLabels =
|
||||||
toDeviceTemporary<faiss::Index::idx_t, 2>(
|
toDeviceTemporary<faiss::Index::idx_t, 2>(
|
||||||
resources_.get(), device_, labels, stream,
|
resources_.get(), config_.device, labels, stream,
|
||||||
{(int) n, (int) k});
|
{(int) n, (int) k});
|
||||||
|
|
||||||
bool usePaged = false;
|
bool usePaged = false;
|
||||||
|
@ -272,12 +276,12 @@ GpuIndex::searchNonPaged_(int n,
|
||||||
int k,
|
int k,
|
||||||
float* outDistancesData,
|
float* outDistancesData,
|
||||||
Index::idx_t* outIndicesData) const {
|
Index::idx_t* outIndicesData) const {
|
||||||
auto stream = resources_->getDefaultStream(device_);
|
auto stream = resources_->getDefaultStream(config_.device);
|
||||||
|
|
||||||
// Make sure arguments are on the device we desire; use temporary
|
// Make sure arguments are on the device we desire; use temporary
|
||||||
// memory allocations to move it if necessary
|
// memory allocations to move it if necessary
|
||||||
auto vecs = toDeviceTemporary<float, 2>(resources_.get(),
|
auto vecs = toDeviceTemporary<float, 2>(resources_.get(),
|
||||||
device_,
|
config_.device,
|
||||||
const_cast<float*>(x),
|
const_cast<float*>(x),
|
||||||
stream,
|
stream,
|
||||||
{n, (int) this->d});
|
{n, (int) this->d});
|
||||||
|
@ -335,8 +339,8 @@ GpuIndex::searchFromCpuPaged_(int n,
|
||||||
// 1 2 3 1 ... (pinned buf A)
|
// 1 2 3 1 ... (pinned buf A)
|
||||||
// time ->
|
// time ->
|
||||||
//
|
//
|
||||||
auto defaultStream = resources_->getDefaultStream(device_);
|
auto defaultStream = resources_->getDefaultStream(config_.device);
|
||||||
auto copyStream = resources_->getAsyncCopyStream(device_);
|
auto copyStream = resources_->getAsyncCopyStream(config_.device);
|
||||||
|
|
||||||
FAISS_ASSERT((size_t) pageSizeInVecs * this->d <=
|
FAISS_ASSERT((size_t) pageSizeInVecs * this->d <=
|
||||||
(size_t) std::numeric_limits<int>::max());
|
(size_t) std::numeric_limits<int>::max());
|
||||||
|
|
|
@ -36,9 +36,8 @@ class GpuIndex : public faiss::Index {
|
||||||
float metricArg,
|
float metricArg,
|
||||||
GpuIndexConfig config);
|
GpuIndexConfig config);
|
||||||
|
|
||||||
inline int getDevice() const {
|
/// Returns the device that this index is resident on
|
||||||
return device_;
|
int getDevice() const;
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the minimum data size for searches (in MiB) for which we use
|
/// Set the minimum data size for searches (in MiB) for which we use
|
||||||
/// CPU -> GPU paging
|
/// CPU -> GPU paging
|
||||||
|
@ -136,11 +135,8 @@ private:
|
||||||
/// Manages streams, cuBLAS handles and scratch memory for devices
|
/// Manages streams, cuBLAS handles and scratch memory for devices
|
||||||
std::shared_ptr<GpuResources> resources_;
|
std::shared_ptr<GpuResources> resources_;
|
||||||
|
|
||||||
/// The GPU device we are resident on
|
/// Our configuration options
|
||||||
const int device_;
|
const GpuIndexConfig config_;
|
||||||
|
|
||||||
/// The memory space of our primary storage on the GPU
|
|
||||||
const MemorySpace memorySpace_;
|
|
||||||
|
|
||||||
/// Size above which we page copies from the CPU to GPU
|
/// Size above which we page copies from the CPU to GPU
|
||||||
size_t minPagedSize_;
|
size_t minPagedSize_;
|
||||||
|
|
|
@ -23,7 +23,7 @@ GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResourcesProvider* provider,
|
||||||
GpuIndexBinaryFlatConfig config)
|
GpuIndexBinaryFlatConfig config)
|
||||||
: IndexBinary(index->d),
|
: IndexBinary(index->d),
|
||||||
resources_(provider->getResources()),
|
resources_(provider->getResources()),
|
||||||
config_(std::move(config)) {
|
binaryFlatConfig_(config) {
|
||||||
FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0,
|
FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0,
|
||||||
"vector dimension (number of bits) "
|
"vector dimension (number of bits) "
|
||||||
"must be divisible by 8 (passed %d)",
|
"must be divisible by 8 (passed %d)",
|
||||||
|
@ -41,7 +41,7 @@ GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResourcesProvider* provider,
|
||||||
GpuIndexBinaryFlatConfig config)
|
GpuIndexBinaryFlatConfig config)
|
||||||
: IndexBinary(dims),
|
: IndexBinary(dims),
|
||||||
resources_(provider->getResources()),
|
resources_(provider->getResources()),
|
||||||
config_(std::move(config)) {
|
binaryFlatConfig_(std::move(config)) {
|
||||||
FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0,
|
FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0,
|
||||||
"vector dimension (number of bits) "
|
"vector dimension (number of bits) "
|
||||||
"must be divisible by 8 (passed %d)",
|
"must be divisible by 8 (passed %d)",
|
||||||
|
@ -51,17 +51,22 @@ GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResourcesProvider* provider,
|
||||||
this->is_trained = true;
|
this->is_trained = true;
|
||||||
|
|
||||||
// Construct index
|
// Construct index
|
||||||
DeviceScope scope(config_.device);
|
DeviceScope scope(binaryFlatConfig_.device);
|
||||||
data_.reset(
|
data_.reset(new BinaryFlatIndex(resources_.get(),
|
||||||
new BinaryFlatIndex(resources_.get(), this->d, config_.memorySpace));
|
this->d, binaryFlatConfig_.memorySpace));
|
||||||
}
|
}
|
||||||
|
|
||||||
GpuIndexBinaryFlat::~GpuIndexBinaryFlat() {
|
GpuIndexBinaryFlat::~GpuIndexBinaryFlat() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
GpuIndexBinaryFlat::getDevice() const {
|
||||||
|
return binaryFlatConfig_.device;
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexBinaryFlat::copyFrom(const faiss::IndexBinaryFlat* index) {
|
GpuIndexBinaryFlat::copyFrom(const faiss::IndexBinaryFlat* index) {
|
||||||
DeviceScope scope(config_.device);
|
DeviceScope scope(binaryFlatConfig_.device);
|
||||||
|
|
||||||
this->d = index->d;
|
this->d = index->d;
|
||||||
|
|
||||||
|
@ -76,20 +81,20 @@ GpuIndexBinaryFlat::copyFrom(const faiss::IndexBinaryFlat* index) {
|
||||||
|
|
||||||
// destroy old first before allocating new
|
// destroy old first before allocating new
|
||||||
data_.reset();
|
data_.reset();
|
||||||
data_.reset(
|
data_.reset(new BinaryFlatIndex(resources_.get(),
|
||||||
new BinaryFlatIndex(resources_.get(), this->d, config_.memorySpace));
|
this->d, binaryFlatConfig_.memorySpace));
|
||||||
|
|
||||||
// The index could be empty
|
// The index could be empty
|
||||||
if (index->ntotal > 0) {
|
if (index->ntotal > 0) {
|
||||||
data_->add(index->xb.data(),
|
data_->add(index->xb.data(),
|
||||||
index->ntotal,
|
index->ntotal,
|
||||||
resources_->getDefaultStream(config_.device));
|
resources_->getDefaultStream(binaryFlatConfig_.device));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexBinaryFlat::copyTo(faiss::IndexBinaryFlat* index) const {
|
GpuIndexBinaryFlat::copyTo(faiss::IndexBinaryFlat* index) const {
|
||||||
DeviceScope scope(config_.device);
|
DeviceScope scope(binaryFlatConfig_.device);
|
||||||
|
|
||||||
index->d = this->d;
|
index->d = this->d;
|
||||||
index->ntotal = this->ntotal;
|
index->ntotal = this->ntotal;
|
||||||
|
@ -101,18 +106,18 @@ GpuIndexBinaryFlat::copyTo(faiss::IndexBinaryFlat* index) const {
|
||||||
if (this->ntotal > 0) {
|
if (this->ntotal > 0) {
|
||||||
fromDevice(data_->getVectorsRef(),
|
fromDevice(data_->getVectorsRef(),
|
||||||
index->xb.data(),
|
index->xb.data(),
|
||||||
resources_->getDefaultStream(config_.device));
|
resources_->getDefaultStream(binaryFlatConfig_.device));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexBinaryFlat::add(faiss::IndexBinary::idx_t n,
|
GpuIndexBinaryFlat::add(faiss::IndexBinary::idx_t n,
|
||||||
const uint8_t* x) {
|
const uint8_t* x) {
|
||||||
DeviceScope scope(config_.device);
|
DeviceScope scope(binaryFlatConfig_.device);
|
||||||
|
|
||||||
// To avoid multiple re-allocations, ensure we have enough storage
|
// To avoid multiple re-allocations, ensure we have enough storage
|
||||||
// available
|
// available
|
||||||
data_->reserve(n, resources_->getDefaultStream(config_.device));
|
data_->reserve(n, resources_->getDefaultStream(binaryFlatConfig_.device));
|
||||||
|
|
||||||
// Due to GPU indexing in int32, we can't store more than this
|
// Due to GPU indexing in int32, we can't store more than this
|
||||||
// number of vectors on a GPU
|
// number of vectors on a GPU
|
||||||
|
@ -123,13 +128,13 @@ GpuIndexBinaryFlat::add(faiss::IndexBinary::idx_t n,
|
||||||
|
|
||||||
data_->add((const unsigned char*) x,
|
data_->add((const unsigned char*) x,
|
||||||
n,
|
n,
|
||||||
resources_->getDefaultStream(config_.device));
|
resources_->getDefaultStream(binaryFlatConfig_.device));
|
||||||
this->ntotal += n;
|
this->ntotal += n;
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexBinaryFlat::reset() {
|
GpuIndexBinaryFlat::reset() {
|
||||||
DeviceScope scope(config_.device);
|
DeviceScope scope(binaryFlatConfig_.device);
|
||||||
|
|
||||||
// Free the underlying memory
|
// Free the underlying memory
|
||||||
data_->reset();
|
data_->reset();
|
||||||
|
@ -155,8 +160,8 @@ GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n,
|
||||||
getMaxKSelection(),
|
getMaxKSelection(),
|
||||||
(int) k); // select limitation
|
(int) k); // select limitation
|
||||||
|
|
||||||
DeviceScope scope(config_.device);
|
DeviceScope scope(binaryFlatConfig_.device);
|
||||||
auto stream = resources_->getDefaultStream(config_.device);
|
auto stream = resources_->getDefaultStream(binaryFlatConfig_.device);
|
||||||
|
|
||||||
// The input vectors may be too large for the GPU, but we still
|
// The input vectors may be too large for the GPU, but we still
|
||||||
// assume that the output distances and labels are not.
|
// assume that the output distances and labels are not.
|
||||||
|
@ -165,7 +170,7 @@ GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n,
|
||||||
// If we reach a point where all inputs are too big, we can add
|
// If we reach a point where all inputs are too big, we can add
|
||||||
// another level of tiling.
|
// another level of tiling.
|
||||||
auto outDistances = toDeviceTemporary<int32_t, 2>(resources_.get(),
|
auto outDistances = toDeviceTemporary<int32_t, 2>(resources_.get(),
|
||||||
config_.device,
|
binaryFlatConfig_.device,
|
||||||
distances,
|
distances,
|
||||||
stream,
|
stream,
|
||||||
{(int) n, (int) k});
|
{(int) n, (int) k});
|
||||||
|
@ -203,7 +208,7 @@ GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n,
|
||||||
// Convert and copy int indices out
|
// Convert and copy int indices out
|
||||||
auto outIndices =
|
auto outIndices =
|
||||||
toDeviceTemporary<faiss::Index::idx_t, 2>(resources_.get(),
|
toDeviceTemporary<faiss::Index::idx_t, 2>(resources_.get(),
|
||||||
config_.device,
|
binaryFlatConfig_.device,
|
||||||
labels,
|
labels,
|
||||||
stream,
|
stream,
|
||||||
{(int) n, (int) k});
|
{(int) n, (int) k});
|
||||||
|
@ -227,12 +232,12 @@ GpuIndexBinaryFlat::searchNonPaged_(int n,
|
||||||
Tensor<int32_t, 2, true> outDistances(outDistancesData, {n, k});
|
Tensor<int32_t, 2, true> outDistances(outDistancesData, {n, k});
|
||||||
Tensor<int, 2, true> outIndices(outIndicesData, {n, k});
|
Tensor<int, 2, true> outIndices(outIndicesData, {n, k});
|
||||||
|
|
||||||
auto stream = resources_->getDefaultStream(config_.device);
|
auto stream = resources_->getDefaultStream(binaryFlatConfig_.device);
|
||||||
|
|
||||||
// Make sure arguments are on the device we desire; use temporary
|
// Make sure arguments are on the device we desire; use temporary
|
||||||
// memory allocations to move it if necessary
|
// memory allocations to move it if necessary
|
||||||
auto vecs = toDeviceTemporary<uint8_t, 2>(resources_.get(),
|
auto vecs = toDeviceTemporary<uint8_t, 2>(resources_.get(),
|
||||||
config_.device,
|
binaryFlatConfig_.device,
|
||||||
const_cast<uint8_t*>(x),
|
const_cast<uint8_t*>(x),
|
||||||
stream,
|
stream,
|
||||||
{n, (int) (this->d / 8)});
|
{n, (int) (this->d / 8)});
|
||||||
|
@ -272,10 +277,10 @@ GpuIndexBinaryFlat::searchFromCpuPaged_(int n,
|
||||||
void
|
void
|
||||||
GpuIndexBinaryFlat::reconstruct(faiss::IndexBinary::idx_t key,
|
GpuIndexBinaryFlat::reconstruct(faiss::IndexBinary::idx_t key,
|
||||||
uint8_t* out) const {
|
uint8_t* out) const {
|
||||||
DeviceScope scope(config_.device);
|
DeviceScope scope(binaryFlatConfig_.device);
|
||||||
|
|
||||||
FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds");
|
FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds");
|
||||||
auto stream = resources_->getDefaultStream(config_.device);
|
auto stream = resources_->getDefaultStream(binaryFlatConfig_.device);
|
||||||
|
|
||||||
auto& vecs = data_->getVectorsRef();
|
auto& vecs = data_->getVectorsRef();
|
||||||
auto vec = vecs[key];
|
auto vec = vecs[key];
|
||||||
|
|
|
@ -38,6 +38,9 @@ class GpuIndexBinaryFlat : public IndexBinary {
|
||||||
|
|
||||||
~GpuIndexBinaryFlat() override;
|
~GpuIndexBinaryFlat() override;
|
||||||
|
|
||||||
|
/// Returns the device that this index is resident on
|
||||||
|
int getDevice() const;
|
||||||
|
|
||||||
/// Initialize ourselves from the given CPU index; will overwrite
|
/// Initialize ourselves from the given CPU index; will overwrite
|
||||||
/// all data in ourselves
|
/// all data in ourselves
|
||||||
void copyFrom(const faiss::IndexBinaryFlat* index);
|
void copyFrom(const faiss::IndexBinaryFlat* index);
|
||||||
|
@ -80,7 +83,7 @@ class GpuIndexBinaryFlat : public IndexBinary {
|
||||||
std::shared_ptr<GpuResources> resources_;
|
std::shared_ptr<GpuResources> resources_;
|
||||||
|
|
||||||
/// Configuration options
|
/// Configuration options
|
||||||
GpuIndexBinaryFlatConfig config_;
|
const GpuIndexBinaryFlatConfig binaryFlatConfig_;
|
||||||
|
|
||||||
/// Holds our GPU data containing the list of vectors
|
/// Holds our GPU data containing the list of vectors
|
||||||
std::unique_ptr<BinaryFlatIndex> data_;
|
std::unique_ptr<BinaryFlatIndex> data_;
|
||||||
|
|
|
@ -27,7 +27,7 @@ GpuIndexFlat::GpuIndexFlat(GpuResourcesProvider* provider,
|
||||||
index->metric_type,
|
index->metric_type,
|
||||||
index->metric_arg,
|
index->metric_arg,
|
||||||
config),
|
config),
|
||||||
config_(std::move(config)) {
|
flatConfig_(config) {
|
||||||
// Flat index doesn't need training
|
// Flat index doesn't need training
|
||||||
this->is_trained = true;
|
this->is_trained = true;
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
|
||||||
index->metric_type,
|
index->metric_type,
|
||||||
index->metric_arg,
|
index->metric_arg,
|
||||||
config),
|
config),
|
||||||
config_(std::move(config)) {
|
flatConfig_(config) {
|
||||||
// Flat index doesn't need training
|
// Flat index doesn't need training
|
||||||
this->is_trained = true;
|
this->is_trained = true;
|
||||||
|
|
||||||
|
@ -58,17 +58,17 @@ GpuIndexFlat::GpuIndexFlat(GpuResourcesProvider* provider,
|
||||||
metric,
|
metric,
|
||||||
0,
|
0,
|
||||||
config),
|
config),
|
||||||
config_(std::move(config)) {
|
flatConfig_(config) {
|
||||||
// Flat index doesn't need training
|
// Flat index doesn't need training
|
||||||
this->is_trained = true;
|
this->is_trained = true;
|
||||||
|
|
||||||
// Construct index
|
// Construct index
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
data_.reset(new FlatIndex(resources_.get(),
|
data_.reset(new FlatIndex(resources_.get(),
|
||||||
dims,
|
dims,
|
||||||
config_.useFloat16,
|
flatConfig_.useFloat16,
|
||||||
config_.storeTransposed,
|
flatConfig_.storeTransposed,
|
||||||
memorySpace_));
|
config_.memorySpace));
|
||||||
}
|
}
|
||||||
|
|
||||||
GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
|
GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
|
||||||
|
@ -76,17 +76,17 @@ GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
|
||||||
faiss::MetricType metric,
|
faiss::MetricType metric,
|
||||||
GpuIndexFlatConfig config) :
|
GpuIndexFlatConfig config) :
|
||||||
GpuIndex(resources, dims, metric, 0, config),
|
GpuIndex(resources, dims, metric, 0, config),
|
||||||
config_(std::move(config)) {
|
flatConfig_(config) {
|
||||||
// Flat index doesn't need training
|
// Flat index doesn't need training
|
||||||
this->is_trained = true;
|
this->is_trained = true;
|
||||||
|
|
||||||
// Construct index
|
// Construct index
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
data_.reset(new FlatIndex(resources_.get(),
|
data_.reset(new FlatIndex(resources_.get(),
|
||||||
dims,
|
dims,
|
||||||
config_.useFloat16,
|
flatConfig_.useFloat16,
|
||||||
config_.storeTransposed,
|
flatConfig_.storeTransposed,
|
||||||
memorySpace_));
|
config_.memorySpace));
|
||||||
}
|
}
|
||||||
|
|
||||||
GpuIndexFlat::~GpuIndexFlat() {
|
GpuIndexFlat::~GpuIndexFlat() {
|
||||||
|
@ -94,7 +94,7 @@ GpuIndexFlat::~GpuIndexFlat() {
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) {
|
GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
GpuIndex::copyFrom(index);
|
GpuIndex::copyFrom(index);
|
||||||
|
|
||||||
|
@ -109,21 +109,21 @@ GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) {
|
||||||
data_.reset();
|
data_.reset();
|
||||||
data_.reset(new FlatIndex(resources_.get(),
|
data_.reset(new FlatIndex(resources_.get(),
|
||||||
this->d,
|
this->d,
|
||||||
config_.useFloat16,
|
flatConfig_.useFloat16,
|
||||||
config_.storeTransposed,
|
flatConfig_.storeTransposed,
|
||||||
memorySpace_));
|
config_.memorySpace));
|
||||||
|
|
||||||
// The index could be empty
|
// The index could be empty
|
||||||
if (index->ntotal > 0) {
|
if (index->ntotal > 0) {
|
||||||
data_->add(index->xb.data(),
|
data_->add(index->xb.data(),
|
||||||
index->ntotal,
|
index->ntotal,
|
||||||
resources_->getDefaultStream(device_));
|
resources_->getDefaultStream(config_.device));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexFlat::copyTo(faiss::IndexFlat* index) const {
|
GpuIndexFlat::copyTo(faiss::IndexFlat* index) const {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
GpuIndex::copyTo(index);
|
GpuIndex::copyTo(index);
|
||||||
|
|
||||||
|
@ -131,10 +131,10 @@ GpuIndexFlat::copyTo(faiss::IndexFlat* index) const {
|
||||||
FAISS_ASSERT(data_->getSize() == this->ntotal);
|
FAISS_ASSERT(data_->getSize() == this->ntotal);
|
||||||
index->xb.resize(this->ntotal * this->d);
|
index->xb.resize(this->ntotal * this->d);
|
||||||
|
|
||||||
auto stream = resources_->getDefaultStream(device_);
|
auto stream = resources_->getDefaultStream(config_.device);
|
||||||
|
|
||||||
if (this->ntotal > 0) {
|
if (this->ntotal > 0) {
|
||||||
if (config_.useFloat16) {
|
if (flatConfig_.useFloat16) {
|
||||||
auto vecFloat32 = data_->getVectorsFloat32Copy(stream);
|
auto vecFloat32 = data_->getVectorsFloat32Copy(stream);
|
||||||
fromDevice(vecFloat32, index->xb.data(), stream);
|
fromDevice(vecFloat32, index->xb.data(), stream);
|
||||||
} else {
|
} else {
|
||||||
|
@ -150,7 +150,7 @@ GpuIndexFlat::getNumVecs() const {
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexFlat::reset() {
|
GpuIndexFlat::reset() {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
// Free the underlying memory
|
// Free the underlying memory
|
||||||
data_->reset();
|
data_->reset();
|
||||||
|
@ -176,15 +176,15 @@ GpuIndexFlat::add(Index::idx_t n, const float* x) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
// To avoid multiple re-allocations, ensure we have enough storage
|
// To avoid multiple re-allocations, ensure we have enough storage
|
||||||
// available
|
// available
|
||||||
data_->reserve(n, resources_->getDefaultStream(device_));
|
data_->reserve(n, resources_->getDefaultStream(config_.device));
|
||||||
|
|
||||||
// If we're not operating in float16 mode, we don't need the input
|
// If we're not operating in float16 mode, we don't need the input
|
||||||
// data to be resident on our device; we can add directly.
|
// data to be resident on our device; we can add directly.
|
||||||
if (!config_.useFloat16) {
|
if (!flatConfig_.useFloat16) {
|
||||||
addImpl_(n, x, nullptr);
|
addImpl_(n, x, nullptr);
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, perform the paging
|
// Otherwise, perform the paging
|
||||||
|
@ -214,7 +214,7 @@ GpuIndexFlat::addImpl_(int n,
|
||||||
"GPU index only supports up to %zu indices",
|
"GPU index only supports up to %zu indices",
|
||||||
(size_t) std::numeric_limits<int>::max());
|
(size_t) std::numeric_limits<int>::max());
|
||||||
|
|
||||||
data_->add(x, n, resources_->getDefaultStream(device_));
|
data_->add(x, n, resources_->getDefaultStream(config_.device));
|
||||||
this->ntotal += n;
|
this->ntotal += n;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,7 +224,7 @@ GpuIndexFlat::searchImpl_(int n,
|
||||||
int k,
|
int k,
|
||||||
float* distances,
|
float* distances,
|
||||||
Index::idx_t* labels) const {
|
Index::idx_t* labels) const {
|
||||||
auto stream = resources_->getDefaultStream(device_);
|
auto stream = resources_->getDefaultStream(config_.device);
|
||||||
|
|
||||||
// Input and output data are already resident on the GPU
|
// Input and output data are already resident on the GPU
|
||||||
Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int) this->d});
|
Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int) this->d});
|
||||||
|
@ -249,12 +249,12 @@ GpuIndexFlat::searchImpl_(int n,
|
||||||
void
|
void
|
||||||
GpuIndexFlat::reconstruct(faiss::Index::idx_t key,
|
GpuIndexFlat::reconstruct(faiss::Index::idx_t key,
|
||||||
float* out) const {
|
float* out) const {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds");
|
FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds");
|
||||||
auto stream = resources_->getDefaultStream(device_);
|
auto stream = resources_->getDefaultStream(config_.device);
|
||||||
|
|
||||||
if (config_.useFloat16) {
|
if (flatConfig_.useFloat16) {
|
||||||
// FIXME jhj: kernel for copy
|
// FIXME jhj: kernel for copy
|
||||||
auto vec = data_->getVectorsFloat32Copy(key, 1, stream);
|
auto vec = data_->getVectorsFloat32Copy(key, 1, stream);
|
||||||
fromDevice(vec.data(), out, this->d, stream);
|
fromDevice(vec.data(), out, this->d, stream);
|
||||||
|
@ -268,13 +268,13 @@ void
|
||||||
GpuIndexFlat::reconstruct_n(faiss::Index::idx_t i0,
|
GpuIndexFlat::reconstruct_n(faiss::Index::idx_t i0,
|
||||||
faiss::Index::idx_t num,
|
faiss::Index::idx_t num,
|
||||||
float* out) const {
|
float* out) const {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
FAISS_THROW_IF_NOT_MSG(i0 < this->ntotal, "index out of bounds");
|
FAISS_THROW_IF_NOT_MSG(i0 < this->ntotal, "index out of bounds");
|
||||||
FAISS_THROW_IF_NOT_MSG(i0 + num - 1 < this->ntotal, "num out of bounds");
|
FAISS_THROW_IF_NOT_MSG(i0 + num - 1 < this->ntotal, "num out of bounds");
|
||||||
auto stream = resources_->getDefaultStream(device_);
|
auto stream = resources_->getDefaultStream(config_.device);
|
||||||
|
|
||||||
if (config_.useFloat16) {
|
if (flatConfig_.useFloat16) {
|
||||||
// FIXME jhj: kernel for copy
|
// FIXME jhj: kernel for copy
|
||||||
auto vec = data_->getVectorsFloat32Copy(i0, num, stream);
|
auto vec = data_->getVectorsFloat32Copy(i0, num, stream);
|
||||||
fromDevice(vec.data(), out, num * this->d, stream);
|
fromDevice(vec.data(), out, num * this->d, stream);
|
||||||
|
@ -301,23 +301,24 @@ GpuIndexFlat::compute_residual_n(faiss::Index::idx_t n,
|
||||||
"GPU index only supports up to %zu indices",
|
"GPU index only supports up to %zu indices",
|
||||||
(size_t) std::numeric_limits<int>::max());
|
(size_t) std::numeric_limits<int>::max());
|
||||||
|
|
||||||
auto stream = resources_->getDefaultStream(device_);
|
auto stream = resources_->getDefaultStream(config_.device);
|
||||||
|
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
auto vecsDevice =
|
auto vecsDevice =
|
||||||
toDeviceTemporary<float, 2>(resources_.get(), device_,
|
toDeviceTemporary<float, 2>(resources_.get(), config_.device,
|
||||||
const_cast<float*>(xs), stream,
|
const_cast<float*>(xs), stream,
|
||||||
{(int) n, (int) this->d});
|
{(int) n, (int) this->d});
|
||||||
auto idsDevice =
|
auto idsDevice =
|
||||||
toDeviceTemporary<faiss::Index::idx_t, 1>(
|
toDeviceTemporary<faiss::Index::idx_t, 1>(
|
||||||
resources_.get(), device_,
|
resources_.get(), config_.device,
|
||||||
const_cast<faiss::Index::idx_t*>(keys),
|
const_cast<faiss::Index::idx_t*>(keys),
|
||||||
stream,
|
stream,
|
||||||
{(int) n});
|
{(int) n});
|
||||||
|
|
||||||
auto residualDevice =
|
auto residualDevice =
|
||||||
toDeviceTemporary<float, 2>(resources_.get(), device_, residuals, stream,
|
toDeviceTemporary<float, 2>(resources_.get(),
|
||||||
|
config_.device, residuals, stream,
|
||||||
{(int) n, (int) this->d});
|
{(int) n, (int) this->d});
|
||||||
|
|
||||||
// Convert idx_t to int
|
// Convert idx_t to int
|
||||||
|
|
|
@ -129,8 +129,8 @@ class GpuIndexFlat : public GpuIndex {
|
||||||
faiss::Index::idx_t* labels) const override;
|
faiss::Index::idx_t* labels) const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// Our config object
|
/// Our configuration options
|
||||||
const GpuIndexFlatConfig config_;
|
const GpuIndexFlatConfig flatConfig_;
|
||||||
|
|
||||||
/// Holds our GPU data containing the list of vectors
|
/// Holds our GPU data containing the list of vectors
|
||||||
std::unique_ptr<FlatIndex> data_;
|
std::unique_ptr<FlatIndex> data_;
|
||||||
|
|
|
@ -22,12 +22,11 @@ GpuIndexIVF::GpuIndexIVF(GpuResourcesProvider* provider,
|
||||||
float metricArg,
|
float metricArg,
|
||||||
int nlistIn,
|
int nlistIn,
|
||||||
GpuIndexIVFConfig config) :
|
GpuIndexIVFConfig config) :
|
||||||
GpuIndex(provider->getResources(),
|
GpuIndex(provider->getResources(), dims, metric, metricArg, config),
|
||||||
dims, metric, metricArg, config),
|
|
||||||
nlist(nlistIn),
|
nlist(nlistIn),
|
||||||
nprobe(1),
|
nprobe(1),
|
||||||
quantizer(nullptr),
|
quantizer(nullptr),
|
||||||
ivfConfig_(std::move(config)) {
|
ivfConfig_(config) {
|
||||||
init_();
|
init_();
|
||||||
|
|
||||||
// Only IP and L2 are supported for now
|
// Only IP and L2 are supported for now
|
||||||
|
@ -55,7 +54,7 @@ GpuIndexIVF::init_() {
|
||||||
// Construct an empty quantizer
|
// Construct an empty quantizer
|
||||||
GpuIndexFlatConfig config = ivfConfig_.flatConfig;
|
GpuIndexFlatConfig config = ivfConfig_.flatConfig;
|
||||||
// FIXME: inherit our same device
|
// FIXME: inherit our same device
|
||||||
config.device = device_;
|
config.device = config_.device;
|
||||||
|
|
||||||
if (metric_type == faiss::METRIC_L2) {
|
if (metric_type == faiss::METRIC_L2) {
|
||||||
quantizer = new GpuIndexFlatL2(resources_, d, config);
|
quantizer = new GpuIndexFlatL2(resources_, d, config);
|
||||||
|
@ -79,7 +78,7 @@ GpuIndexIVF::getQuantizer() {
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
|
GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
GpuIndex::copyFrom(index);
|
GpuIndex::copyFrom(index);
|
||||||
|
|
||||||
|
@ -105,7 +104,7 @@ GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
|
||||||
// Construct an empty quantizer
|
// Construct an empty quantizer
|
||||||
GpuIndexFlatConfig config = ivfConfig_.flatConfig;
|
GpuIndexFlatConfig config = ivfConfig_.flatConfig;
|
||||||
// FIXME: inherit our same device
|
// FIXME: inherit our same device
|
||||||
config.device = device_;
|
config.device = config_.device;
|
||||||
|
|
||||||
if (index->metric_type == faiss::METRIC_L2) {
|
if (index->metric_type == faiss::METRIC_L2) {
|
||||||
// FIXME: 2 different float16 options?
|
// FIXME: 2 different float16 options?
|
||||||
|
@ -143,7 +142,7 @@ GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVF::copyTo(faiss::IndexIVF* index) const {
|
GpuIndexIVF::copyTo(faiss::IndexIVF* index) const {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Index information
|
// Index information
|
||||||
|
@ -228,7 +227,7 @@ GpuIndexIVF::trainQuantizer_(faiss::Index::idx_t n, const float* x) {
|
||||||
printf ("Training IVF quantizer on %ld vectors in %dD\n", n, d);
|
printf ("Training IVF quantizer on %ld vectors in %dD\n", n, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
// leverage the CPU-side k-means code, which works for the GPU
|
// leverage the CPU-side k-means code, which works for the GPU
|
||||||
// flat index as well
|
// flat index as well
|
||||||
|
|
|
@ -83,7 +83,8 @@ class GpuIndexIVF : public GpuIndex {
|
||||||
GpuIndexFlat* quantizer;
|
GpuIndexFlat* quantizer;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
GpuIndexIVFConfig ivfConfig_;
|
/// Our configuration options
|
||||||
|
const GpuIndexIVFConfig ivfConfig_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} } // namespace
|
} } // namespace
|
||||||
|
|
|
@ -57,14 +57,14 @@ void
|
||||||
GpuIndexIVFFlat::reserveMemory(size_t numVecs) {
|
GpuIndexIVFFlat::reserveMemory(size_t numVecs) {
|
||||||
reserveMemoryVecs_ = numVecs;
|
reserveMemoryVecs_ = numVecs;
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
index_->reserveMemory(numVecs);
|
index_->reserveMemory(numVecs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
|
GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
GpuIndexIVF::copyFrom(index);
|
GpuIndexIVF::copyFrom(index);
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
|
||||||
false, // no residual
|
false, // no residual
|
||||||
nullptr, // no scalar quantizer
|
nullptr, // no scalar quantizer
|
||||||
ivfFlatConfig_.indicesOptions,
|
ivfFlatConfig_.indicesOptions,
|
||||||
memorySpace_));
|
config_.memorySpace));
|
||||||
|
|
||||||
// Copy all of the IVF data
|
// Copy all of the IVF data
|
||||||
index_->copyInvertedListsFrom(index->invlists);
|
index_->copyInvertedListsFrom(index->invlists);
|
||||||
|
@ -96,7 +96,7 @@ GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const {
|
GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
// We must have the indices in order to copy to ourselves
|
// We must have the indices in order to copy to ourselves
|
||||||
FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF,
|
FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF,
|
||||||
|
@ -118,7 +118,7 @@ GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const {
|
||||||
size_t
|
size_t
|
||||||
GpuIndexIVFFlat::reclaimMemory() {
|
GpuIndexIVFFlat::reclaimMemory() {
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
return index_->reclaimMemory();
|
return index_->reclaimMemory();
|
||||||
}
|
}
|
||||||
|
@ -129,7 +129,7 @@ GpuIndexIVFFlat::reclaimMemory() {
|
||||||
void
|
void
|
||||||
GpuIndexIVFFlat::reset() {
|
GpuIndexIVFFlat::reset() {
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
index_->reset();
|
index_->reset();
|
||||||
this->ntotal = 0;
|
this->ntotal = 0;
|
||||||
|
@ -140,7 +140,7 @@ GpuIndexIVFFlat::reset() {
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVFFlat::train(Index::idx_t n, const float* x) {
|
GpuIndexIVFFlat::train(Index::idx_t n, const float* x) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
if (this->is_trained) {
|
if (this->is_trained) {
|
||||||
FAISS_ASSERT(quantizer->is_trained);
|
FAISS_ASSERT(quantizer->is_trained);
|
||||||
|
@ -161,7 +161,7 @@ GpuIndexIVFFlat::train(Index::idx_t n, const float* x) {
|
||||||
false, // no residual
|
false, // no residual
|
||||||
nullptr, // no scalar quantizer
|
nullptr, // no scalar quantizer
|
||||||
ivfFlatConfig_.indicesOptions,
|
ivfFlatConfig_.indicesOptions,
|
||||||
memorySpace_));
|
config_.memorySpace));
|
||||||
|
|
||||||
if (reserveMemoryVecs_) {
|
if (reserveMemoryVecs_) {
|
||||||
index_->reserveMemory(reserveMemoryVecs_);
|
index_->reserveMemory(reserveMemoryVecs_);
|
||||||
|
|
|
@ -73,8 +73,9 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
|
||||||
float* distances,
|
float* distances,
|
||||||
Index::idx_t* labels) const override;
|
Index::idx_t* labels) const override;
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
GpuIndexIVFFlatConfig ivfFlatConfig_;
|
/// Our configuration options
|
||||||
|
const GpuIndexIVFFlatConfig ivfFlatConfig_;
|
||||||
|
|
||||||
/// Desired inverted list memory reservation
|
/// Desired inverted list memory reservation
|
||||||
size_t reserveMemoryVecs_;
|
size_t reserveMemoryVecs_;
|
||||||
|
|
|
@ -30,6 +30,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
|
||||||
index->nlist,
|
index->nlist,
|
||||||
config),
|
config),
|
||||||
ivfpqConfig_(config),
|
ivfpqConfig_(config),
|
||||||
|
usePrecomputedTables_(config.usePrecomputedTables),
|
||||||
subQuantizers_(0),
|
subQuantizers_(0),
|
||||||
bitsPerCode_(0),
|
bitsPerCode_(0),
|
||||||
reserveMemoryVecs_(0) {
|
reserveMemoryVecs_(0) {
|
||||||
|
@ -50,6 +51,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
|
||||||
nlist,
|
nlist,
|
||||||
config),
|
config),
|
||||||
ivfpqConfig_(config),
|
ivfpqConfig_(config),
|
||||||
|
usePrecomputedTables_(config.usePrecomputedTables),
|
||||||
subQuantizers_(subQuantizers),
|
subQuantizers_(subQuantizers),
|
||||||
bitsPerCode_(bitsPerCode),
|
bitsPerCode_(bitsPerCode),
|
||||||
reserveMemoryVecs_(0) {
|
reserveMemoryVecs_(0) {
|
||||||
|
@ -64,7 +66,7 @@ GpuIndexIVFPQ::~GpuIndexIVFPQ() {
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
|
GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
GpuIndexIVF::copyFrom(index);
|
GpuIndexIVF::copyFrom(index);
|
||||||
|
|
||||||
|
@ -105,9 +107,9 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
|
||||||
ivfpqConfig_.alternativeLayout,
|
ivfpqConfig_.alternativeLayout,
|
||||||
(float*) index->pq.centroids.data(),
|
(float*) index->pq.centroids.data(),
|
||||||
ivfpqConfig_.indicesOptions,
|
ivfpqConfig_.indicesOptions,
|
||||||
memorySpace_));
|
config_.memorySpace));
|
||||||
// Doesn't make sense to reserve memory here
|
// Doesn't make sense to reserve memory here
|
||||||
index_->setPrecomputedCodes(ivfpqConfig_.usePrecomputedTables);
|
index_->setPrecomputedCodes(usePrecomputedTables_);
|
||||||
|
|
||||||
// Copy all of the IVF data
|
// Copy all of the IVF data
|
||||||
index_->copyInvertedListsFrom(index->invlists);
|
index_->copyInvertedListsFrom(index->invlists);
|
||||||
|
@ -115,7 +117,7 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVFPQ::copyTo(faiss::IndexIVFPQ* index) const {
|
GpuIndexIVFPQ::copyTo(faiss::IndexIVFPQ* index) const {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
// We must have the indices in order to copy to ourselves
|
// We must have the indices in order to copy to ourselves
|
||||||
FAISS_THROW_IF_NOT_MSG(ivfpqConfig_.indicesOptions != INDICES_IVF,
|
FAISS_THROW_IF_NOT_MSG(ivfpqConfig_.indicesOptions != INDICES_IVF,
|
||||||
|
@ -153,9 +155,9 @@ GpuIndexIVFPQ::copyTo(faiss::IndexIVFPQ* index) const {
|
||||||
|
|
||||||
fromDevice<float, 3>(devPQCentroids,
|
fromDevice<float, 3>(devPQCentroids,
|
||||||
index->pq.centroids.data(),
|
index->pq.centroids.data(),
|
||||||
resources_->getDefaultStream(device_));
|
resources_->getDefaultStream(config_.device));
|
||||||
|
|
||||||
if (ivfpqConfig_.usePrecomputedTables) {
|
if (usePrecomputedTables_) {
|
||||||
index->precompute_table();
|
index->precompute_table();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,16 +167,16 @@ void
|
||||||
GpuIndexIVFPQ::reserveMemory(size_t numVecs) {
|
GpuIndexIVFPQ::reserveMemory(size_t numVecs) {
|
||||||
reserveMemoryVecs_ = numVecs;
|
reserveMemoryVecs_ = numVecs;
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
index_->reserveMemory(numVecs);
|
index_->reserveMemory(numVecs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVFPQ::setPrecomputedCodes(bool enable) {
|
GpuIndexIVFPQ::setPrecomputedCodes(bool enable) {
|
||||||
ivfpqConfig_.usePrecomputedTables = enable;
|
usePrecomputedTables_ = enable;
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
index_->setPrecomputedCodes(enable);
|
index_->setPrecomputedCodes(enable);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,7 +185,7 @@ GpuIndexIVFPQ::setPrecomputedCodes(bool enable) {
|
||||||
|
|
||||||
bool
|
bool
|
||||||
GpuIndexIVFPQ::getPrecomputedCodes() const {
|
GpuIndexIVFPQ::getPrecomputedCodes() const {
|
||||||
return ivfpqConfig_.usePrecomputedTables;
|
return usePrecomputedTables_;
|
||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
|
@ -204,7 +206,7 @@ GpuIndexIVFPQ::getCentroidsPerSubQuantizer() const {
|
||||||
size_t
|
size_t
|
||||||
GpuIndexIVFPQ::reclaimMemory() {
|
GpuIndexIVFPQ::reclaimMemory() {
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
return index_->reclaimMemory();
|
return index_->reclaimMemory();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,7 +216,7 @@ GpuIndexIVFPQ::reclaimMemory() {
|
||||||
void
|
void
|
||||||
GpuIndexIVFPQ::reset() {
|
GpuIndexIVFPQ::reset() {
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
index_->reset();
|
index_->reset();
|
||||||
this->ntotal = 0;
|
this->ntotal = 0;
|
||||||
|
@ -264,17 +266,17 @@ GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) {
|
||||||
ivfpqConfig_.alternativeLayout,
|
ivfpqConfig_.alternativeLayout,
|
||||||
pq.centroids.data(),
|
pq.centroids.data(),
|
||||||
ivfpqConfig_.indicesOptions,
|
ivfpqConfig_.indicesOptions,
|
||||||
memorySpace_));
|
config_.memorySpace));
|
||||||
if (reserveMemoryVecs_) {
|
if (reserveMemoryVecs_) {
|
||||||
index_->reserveMemory(reserveMemoryVecs_);
|
index_->reserveMemory(reserveMemoryVecs_);
|
||||||
}
|
}
|
||||||
|
|
||||||
index_->setPrecomputedCodes(ivfpqConfig_.usePrecomputedTables);
|
index_->setPrecomputedCodes(usePrecomputedTables_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVFPQ::train(Index::idx_t n, const float* x) {
|
GpuIndexIVFPQ::train(Index::idx_t n, const float* x) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
if (this->is_trained) {
|
if (this->is_trained) {
|
||||||
FAISS_ASSERT(quantizer->is_trained);
|
FAISS_ASSERT(quantizer->is_trained);
|
||||||
|
@ -289,7 +291,7 @@ GpuIndexIVFPQ::train(Index::idx_t n, const float* x) {
|
||||||
// First, make sure that the data is resident on the CPU, if it is not on the
|
// First, make sure that the data is resident on the CPU, if it is not on the
|
||||||
// CPU, as we depend upon parts of the CPU code
|
// CPU, as we depend upon parts of the CPU code
|
||||||
auto hostData = toHost<float, 2>((float*) x,
|
auto hostData = toHost<float, 2>((float*) x,
|
||||||
resources_->getDefaultStream(device_),
|
resources_->getDefaultStream(config_.device),
|
||||||
{(int) n, (int) this->d});
|
{(int) n, (int) this->d});
|
||||||
|
|
||||||
trainQuantizer_(n, hostData.data());
|
trainQuantizer_(n, hostData.data());
|
||||||
|
@ -351,7 +353,7 @@ GpuIndexIVFPQ::getListLength(int listId) const {
|
||||||
std::vector<unsigned char>
|
std::vector<unsigned char>
|
||||||
GpuIndexIVFPQ::getListCodes(int listId) const {
|
GpuIndexIVFPQ::getListCodes(int listId) const {
|
||||||
FAISS_ASSERT(index_);
|
FAISS_ASSERT(index_);
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
return index_->getListCodes(listId);
|
return index_->getListCodes(listId);
|
||||||
}
|
}
|
||||||
|
@ -359,7 +361,7 @@ GpuIndexIVFPQ::getListCodes(int listId) const {
|
||||||
std::vector<long>
|
std::vector<long>
|
||||||
GpuIndexIVFPQ::getListIndices(int listId) const {
|
GpuIndexIVFPQ::getListIndices(int listId) const {
|
||||||
FAISS_ASSERT(index_);
|
FAISS_ASSERT(index_);
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
return index_->getListIndices(listId);
|
return index_->getListIndices(listId);
|
||||||
}
|
}
|
||||||
|
@ -398,15 +400,15 @@ GpuIndexIVFPQ::verifySettings_() const {
|
||||||
// codes per subquantizer
|
// codes per subquantizer
|
||||||
size_t requiredSmemSize =
|
size_t requiredSmemSize =
|
||||||
lookupTableSize * subQuantizers_ * utils::pow2(bitsPerCode_);
|
lookupTableSize * subQuantizers_ * utils::pow2(bitsPerCode_);
|
||||||
size_t smemPerBlock = getMaxSharedMemPerBlock(device_);
|
size_t smemPerBlock = getMaxSharedMemPerBlock(config_.device);
|
||||||
|
|
||||||
FAISS_THROW_IF_NOT_FMT(requiredSmemSize
|
FAISS_THROW_IF_NOT_FMT(requiredSmemSize
|
||||||
<= getMaxSharedMemPerBlock(device_),
|
<= getMaxSharedMemPerBlock(config_.device),
|
||||||
"Device %d has %zu bytes of shared memory, while "
|
"Device %d has %zu bytes of shared memory, while "
|
||||||
"%d bits per code and %d sub-quantizers requires %zu "
|
"%d bits per code and %d sub-quantizers requires %zu "
|
||||||
"bytes. Consider useFloat16LookupTables and/or "
|
"bytes. Consider useFloat16LookupTables and/or "
|
||||||
"reduce parameters",
|
"reduce parameters",
|
||||||
device_, smemPerBlock, bitsPerCode_, subQuantizers_,
|
config_.device, smemPerBlock, bitsPerCode_, subQuantizers_,
|
||||||
requiredSmemSize);
|
requiredSmemSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -135,13 +135,16 @@ class GpuIndexIVFPQ : public GpuIndexIVF {
|
||||||
float* distances,
|
float* distances,
|
||||||
Index::idx_t* labels) const override;
|
Index::idx_t* labels) const override;
|
||||||
|
|
||||||
private:
|
|
||||||
void verifySettings_() const;
|
void verifySettings_() const;
|
||||||
|
|
||||||
void trainResidualQuantizer_(Index::idx_t n, const float* x);
|
void trainResidualQuantizer_(Index::idx_t n, const float* x);
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
GpuIndexIVFPQConfig ivfpqConfig_;
|
/// Our configuration options that we were initialized with
|
||||||
|
const GpuIndexIVFPQConfig ivfpqConfig_;
|
||||||
|
|
||||||
|
/// Runtime override: whether or not we use precomputed tables
|
||||||
|
bool usePrecomputedTables_;
|
||||||
|
|
||||||
/// Number of sub-quantizers per encoded vector
|
/// Number of sub-quantizers per encoded vector
|
||||||
int subQuantizers_;
|
int subQuantizers_;
|
||||||
|
|
|
@ -66,7 +66,7 @@ void
|
||||||
GpuIndexIVFScalarQuantizer::reserveMemory(size_t numVecs) {
|
GpuIndexIVFScalarQuantizer::reserveMemory(size_t numVecs) {
|
||||||
reserveMemoryVecs_ = numVecs;
|
reserveMemoryVecs_ = numVecs;
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
index_->reserveMemory(numVecs);
|
index_->reserveMemory(numVecs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -74,7 +74,7 @@ GpuIndexIVFScalarQuantizer::reserveMemory(size_t numVecs) {
|
||||||
void
|
void
|
||||||
GpuIndexIVFScalarQuantizer::copyFrom(
|
GpuIndexIVFScalarQuantizer::copyFrom(
|
||||||
const faiss::IndexIVFScalarQuantizer* index) {
|
const faiss::IndexIVFScalarQuantizer* index) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
// Clear out our old data
|
// Clear out our old data
|
||||||
index_.reset();
|
index_.reset();
|
||||||
|
@ -101,7 +101,7 @@ GpuIndexIVFScalarQuantizer::copyFrom(
|
||||||
by_residual,
|
by_residual,
|
||||||
&sq,
|
&sq,
|
||||||
ivfSQConfig_.indicesOptions,
|
ivfSQConfig_.indicesOptions,
|
||||||
memorySpace_));
|
config_.memorySpace));
|
||||||
|
|
||||||
// Copy all of the IVF data
|
// Copy all of the IVF data
|
||||||
index_->copyInvertedListsFrom(index->invlists);
|
index_->copyInvertedListsFrom(index->invlists);
|
||||||
|
@ -110,7 +110,7 @@ GpuIndexIVFScalarQuantizer::copyFrom(
|
||||||
void
|
void
|
||||||
GpuIndexIVFScalarQuantizer::copyTo(
|
GpuIndexIVFScalarQuantizer::copyTo(
|
||||||
faiss::IndexIVFScalarQuantizer* index) const {
|
faiss::IndexIVFScalarQuantizer* index) const {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
// We must have the indices in order to copy to ourselves
|
// We must have the indices in order to copy to ourselves
|
||||||
FAISS_THROW_IF_NOT_MSG(
|
FAISS_THROW_IF_NOT_MSG(
|
||||||
|
@ -135,7 +135,7 @@ GpuIndexIVFScalarQuantizer::copyTo(
|
||||||
size_t
|
size_t
|
||||||
GpuIndexIVFScalarQuantizer::reclaimMemory() {
|
GpuIndexIVFScalarQuantizer::reclaimMemory() {
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
return index_->reclaimMemory();
|
return index_->reclaimMemory();
|
||||||
}
|
}
|
||||||
|
@ -146,7 +146,7 @@ GpuIndexIVFScalarQuantizer::reclaimMemory() {
|
||||||
void
|
void
|
||||||
GpuIndexIVFScalarQuantizer::reset() {
|
GpuIndexIVFScalarQuantizer::reset() {
|
||||||
if (index_) {
|
if (index_) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
index_->reset();
|
index_->reset();
|
||||||
this->ntotal = 0;
|
this->ntotal = 0;
|
||||||
|
@ -163,7 +163,7 @@ GpuIndexIVFScalarQuantizer::trainResiduals_(Index::idx_t n, const float* x) {
|
||||||
|
|
||||||
void
|
void
|
||||||
GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
|
GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
|
||||||
DeviceScope scope(device_);
|
DeviceScope scope(config_.device);
|
||||||
|
|
||||||
if (this->is_trained) {
|
if (this->is_trained) {
|
||||||
FAISS_ASSERT(quantizer->is_trained);
|
FAISS_ASSERT(quantizer->is_trained);
|
||||||
|
@ -178,7 +178,7 @@ GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
|
||||||
// First, make sure that the data is resident on the CPU, if it is not on the
|
// First, make sure that the data is resident on the CPU, if it is not on the
|
||||||
// CPU, as we depend upon parts of the CPU code
|
// CPU, as we depend upon parts of the CPU code
|
||||||
auto hostData = toHost<float, 2>((float*) x,
|
auto hostData = toHost<float, 2>((float*) x,
|
||||||
resources_->getDefaultStream(device_),
|
resources_->getDefaultStream(config_.device),
|
||||||
{(int) n, (int) this->d});
|
{(int) n, (int) this->d});
|
||||||
|
|
||||||
trainQuantizer_(n, hostData.data());
|
trainQuantizer_(n, hostData.data());
|
||||||
|
@ -192,7 +192,7 @@ GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
|
||||||
by_residual,
|
by_residual,
|
||||||
&sq,
|
&sq,
|
||||||
ivfSQConfig_.indicesOptions,
|
ivfSQConfig_.indicesOptions,
|
||||||
memorySpace_));
|
config_.memorySpace));
|
||||||
|
|
||||||
if (reserveMemoryVecs_) {
|
if (reserveMemoryVecs_) {
|
||||||
index_->reserveMemory(reserveMemoryVecs_);
|
index_->reserveMemory(reserveMemoryVecs_);
|
||||||
|
|
|
@ -88,8 +88,9 @@ class GpuIndexIVFScalarQuantizer : public GpuIndexIVF {
|
||||||
/// Exposed like the CPU version
|
/// Exposed like the CPU version
|
||||||
bool by_residual;
|
bool by_residual;
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
GpuIndexIVFScalarQuantizerConfig ivfSQConfig_;
|
/// Our configuration options
|
||||||
|
const GpuIndexIVFScalarQuantizerConfig ivfSQConfig_;
|
||||||
|
|
||||||
/// Desired inverted list memory reservation
|
/// Desired inverted list memory reservation
|
||||||
size_t reserveMemoryVecs_;
|
size_t reserveMemoryVecs_;
|
||||||
|
|
|
@ -101,12 +101,8 @@ StandardGpuResourcesImpl::~StandardGpuResourcesImpl() {
|
||||||
for (auto& entry : defaultStreams_) {
|
for (auto& entry : defaultStreams_) {
|
||||||
DeviceScope scope(entry.first);
|
DeviceScope scope(entry.first);
|
||||||
|
|
||||||
auto it = userDefaultStreams_.find(entry.first);
|
// We created these streams, so are responsible for destroying them
|
||||||
if (it == userDefaultStreams_.end()) {
|
CUDA_VERIFY(cudaStreamDestroy(entry.second));
|
||||||
// The user did not specify this stream, thus we are the ones
|
|
||||||
// who have created it
|
|
||||||
CUDA_VERIFY(cudaStreamDestroy(entry.second));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& entry : alternateStreams_) {
|
for (auto& entry : alternateStreams_) {
|
||||||
|
@ -210,16 +206,14 @@ StandardGpuResourcesImpl::setPinnedMemory(size_t size) {
|
||||||
|
|
||||||
void
|
void
|
||||||
StandardGpuResourcesImpl::setDefaultStream(int device, cudaStream_t stream) {
|
StandardGpuResourcesImpl::setDefaultStream(int device, cudaStream_t stream) {
|
||||||
auto it = defaultStreams_.find(device);
|
|
||||||
if (it != defaultStreams_.end()) {
|
|
||||||
// Replace this stream with the user stream
|
|
||||||
CUDA_VERIFY(cudaStreamDestroy(it->second));
|
|
||||||
it->second = stream;
|
|
||||||
}
|
|
||||||
|
|
||||||
userDefaultStreams_[device] = stream;
|
userDefaultStreams_[device] = stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
StandardGpuResourcesImpl::revertDefaultStream(int device) {
|
||||||
|
userDefaultStreams_.erase(device);
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
StandardGpuResourcesImpl::setDefaultNullStreamAllDevices() {
|
StandardGpuResourcesImpl::setDefaultNullStreamAllDevices() {
|
||||||
for (int dev = 0; dev < getNumDevices(); ++dev) {
|
for (int dev = 0; dev < getNumDevices(); ++dev) {
|
||||||
|
@ -274,14 +268,8 @@ StandardGpuResourcesImpl::initializeForDevice(int device) {
|
||||||
|
|
||||||
// Create streams
|
// Create streams
|
||||||
cudaStream_t defaultStream = 0;
|
cudaStream_t defaultStream = 0;
|
||||||
auto it = userDefaultStreams_.find(device);
|
CUDA_VERIFY(cudaStreamCreateWithFlags(&defaultStream,
|
||||||
if (it != userDefaultStreams_.end()) {
|
cudaStreamNonBlocking));
|
||||||
// We already have a stream provided by the user
|
|
||||||
defaultStream = it->second;
|
|
||||||
} else {
|
|
||||||
CUDA_VERIFY(cudaStreamCreateWithFlags(&defaultStream,
|
|
||||||
cudaStreamNonBlocking));
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultStreams_[device] = defaultStream;
|
defaultStreams_[device] = defaultStream;
|
||||||
|
|
||||||
|
@ -341,6 +329,14 @@ StandardGpuResourcesImpl::getBlasHandle(int device) {
|
||||||
cudaStream_t
|
cudaStream_t
|
||||||
StandardGpuResourcesImpl::getDefaultStream(int device) {
|
StandardGpuResourcesImpl::getDefaultStream(int device) {
|
||||||
initializeForDevice(device);
|
initializeForDevice(device);
|
||||||
|
|
||||||
|
auto it = userDefaultStreams_.find(device);
|
||||||
|
if (it != userDefaultStreams_.end()) {
|
||||||
|
// There is a user override stream set
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, our base default stream
|
||||||
return defaultStreams_[device];
|
return defaultStreams_[device];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -539,6 +535,11 @@ StandardGpuResources::setDefaultStream(int device, cudaStream_t stream) {
|
||||||
res_->setDefaultStream(device, stream);
|
res_->setDefaultStream(device, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
StandardGpuResources::revertDefaultStream(int device) {
|
||||||
|
res_->revertDefaultStream(device);
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
StandardGpuResources::setDefaultNullStreamAllDevices() {
|
StandardGpuResources::setDefaultNullStreamAllDevices() {
|
||||||
res_->setDefaultNullStreamAllDevices();
|
res_->setDefaultNullStreamAllDevices();
|
||||||
|
|
|
@ -41,9 +41,23 @@ class StandardGpuResourcesImpl : public GpuResources {
|
||||||
/// transfers
|
/// transfers
|
||||||
void setPinnedMemory(size_t size);
|
void setPinnedMemory(size_t size);
|
||||||
|
|
||||||
/// Called to change the stream for work ordering
|
/// Called to change the stream for work ordering. We do not own `stream`;
|
||||||
|
/// i.e., it will not be destroyed when the GpuResources object gets cleaned
|
||||||
|
/// up.
|
||||||
|
/// We are guaranteed that all Faiss GPU work is ordered with respect to
|
||||||
|
/// this stream upon exit from an index or other Faiss GPU call.
|
||||||
void setDefaultStream(int device, cudaStream_t stream);
|
void setDefaultStream(int device, cudaStream_t stream);
|
||||||
|
|
||||||
|
/// Revert the default stream to the original stream managed by this resources
|
||||||
|
/// object, in case someone called `setDefaultStream`.
|
||||||
|
void revertDefaultStream(int device);
|
||||||
|
|
||||||
|
/// Returns the stream for the given device on which all Faiss GPU work is
|
||||||
|
/// ordered.
|
||||||
|
/// We are guaranteed that all Faiss GPU work is ordered with respect to
|
||||||
|
/// this stream upon exit from an index or other Faiss GPU call.
|
||||||
|
cudaStream_t getDefaultStream(int device) override;
|
||||||
|
|
||||||
/// Called to change the work ordering streams to the null stream
|
/// Called to change the work ordering streams to the null stream
|
||||||
/// for all devices
|
/// for all devices
|
||||||
void setDefaultNullStreamAllDevices();
|
void setDefaultNullStreamAllDevices();
|
||||||
|
@ -60,8 +74,6 @@ class StandardGpuResourcesImpl : public GpuResources {
|
||||||
|
|
||||||
cublasHandle_t getBlasHandle(int device) override;
|
cublasHandle_t getBlasHandle(int device) override;
|
||||||
|
|
||||||
cudaStream_t getDefaultStream(int device) override;
|
|
||||||
|
|
||||||
std::vector<cudaStream_t> getAlternateStreams(int device) override;
|
std::vector<cudaStream_t> getAlternateStreams(int device) override;
|
||||||
|
|
||||||
/// Allocate non-temporary GPU memory
|
/// Allocate non-temporary GPU memory
|
||||||
|
@ -128,7 +140,9 @@ class StandardGpuResourcesImpl : public GpuResources {
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Default implementation of GpuResources that allocates a cuBLAS
|
/// Default implementation of GpuResources that allocates a cuBLAS
|
||||||
/// stream and 2 streams for use, as well as temporary memory
|
/// stream and 2 streams for use, as well as temporary memory.
|
||||||
|
/// Internally, the Faiss GPU code uses the instance managed by getResources,
|
||||||
|
/// but this is the user-facing object that is internally reference counted.
|
||||||
class StandardGpuResources : public GpuResourcesProvider {
|
class StandardGpuResources : public GpuResourcesProvider {
|
||||||
public:
|
public:
|
||||||
StandardGpuResources();
|
StandardGpuResources();
|
||||||
|
@ -151,9 +165,17 @@ class StandardGpuResources : public GpuResourcesProvider {
|
||||||
/// transfers
|
/// transfers
|
||||||
void setPinnedMemory(size_t size);
|
void setPinnedMemory(size_t size);
|
||||||
|
|
||||||
/// Called to change the stream for work ordering
|
/// Called to change the stream for work ordering. We do not own `stream`;
|
||||||
|
/// i.e., it will not be destroyed when the GpuResources object gets cleaned
|
||||||
|
/// up.
|
||||||
|
/// We are guaranteed that all Faiss GPU work is ordered with respect to
|
||||||
|
/// this stream upon exit from an index or other Faiss GPU call.
|
||||||
void setDefaultStream(int device, cudaStream_t stream);
|
void setDefaultStream(int device, cudaStream_t stream);
|
||||||
|
|
||||||
|
/// Revert the default stream to the original stream managed by this resources
|
||||||
|
/// object, in case someone called `setDefaultStream`.
|
||||||
|
void revertDefaultStream(int device);
|
||||||
|
|
||||||
/// Called to change the work ordering streams to the null stream
|
/// Called to change the work ordering streams to the null stream
|
||||||
/// for all devices
|
/// for all devices
|
||||||
void setDefaultNullStreamAllDevices();
|
void setDefaultNullStreamAllDevices();
|
||||||
|
|
|
@ -10,7 +10,7 @@ import unittest
|
||||||
import faiss
|
import faiss
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from faiss.contrib.pytorch_tensors import search_index_pytorch, search_raw_array_pytorch
|
from faiss.contrib.pytorch_tensors import search_index_pytorch, search_raw_array_pytorch, using_stream
|
||||||
|
|
||||||
def to_column_major(x):
|
def to_column_major(x):
|
||||||
if hasattr(torch, 'contiguous_format'):
|
if hasattr(torch, 'contiguous_format'):
|
||||||
|
@ -22,40 +22,38 @@ def to_column_major(x):
|
||||||
class PytorchFaissInterop(unittest.TestCase):
|
class PytorchFaissInterop(unittest.TestCase):
|
||||||
|
|
||||||
def test_interop(self):
|
def test_interop(self):
|
||||||
|
d = 128
|
||||||
d = 16
|
nq = 100
|
||||||
nq = 5
|
nb = 1000
|
||||||
nb = 20
|
k = 10
|
||||||
|
|
||||||
xq = faiss.randn(nq * d, 1234).reshape(nq, d)
|
xq = faiss.randn(nq * d, 1234).reshape(nq, d)
|
||||||
xb = faiss.randn(nb * d, 1235).reshape(nb, d)
|
xb = faiss.randn(nb * d, 1235).reshape(nb, d)
|
||||||
|
|
||||||
res = faiss.StandardGpuResources()
|
res = faiss.StandardGpuResources()
|
||||||
index = faiss.GpuIndexFlatIP(res, d)
|
|
||||||
index.add(xb)
|
|
||||||
|
|
||||||
# reference CPU result
|
# Let's run on a non-default stream
|
||||||
Dref, Iref = index.search(xq, 5)
|
s = torch.cuda.Stream()
|
||||||
|
|
||||||
# query is pytorch tensor (CPU)
|
# Torch will run on this stream
|
||||||
xq_torch = torch.FloatTensor(xq)
|
with torch.cuda.stream(s):
|
||||||
|
# query is pytorch tensor (CPU and GPU)
|
||||||
|
xq_torch_cpu = torch.FloatTensor(xq)
|
||||||
|
xq_torch_gpu = xq_torch_cpu.cuda()
|
||||||
|
|
||||||
D2, I2 = search_index_pytorch(index, xq_torch, 5)
|
index = faiss.GpuIndexFlatIP(res, d)
|
||||||
|
index.add(xb)
|
||||||
|
|
||||||
assert np.all(Iref == I2.numpy())
|
# Query with GPU tensor (this will be done on the current pytorch stream)
|
||||||
|
D2, I2 = search_index_pytorch(res, index, xq_torch_gpu, k)
|
||||||
|
Dref, Iref = index.search(xq, k)
|
||||||
|
|
||||||
# query is pytorch tensor (GPU)
|
assert np.all(Iref == I2.cpu().numpy())
|
||||||
xq_torch = xq_torch.cuda()
|
|
||||||
# no need for a sync here
|
|
||||||
|
|
||||||
D3, I3 = search_index_pytorch(index, xq_torch, 5)
|
# Query with CPU tensor
|
||||||
|
D3, I3 = search_index_pytorch(res, index, xq_torch_cpu, k)
|
||||||
|
|
||||||
# D3 and I3 are on torch tensors on GPU as well.
|
assert np.all(Iref == I3.numpy())
|
||||||
# this does a sync, which is useful because faiss and
|
|
||||||
# pytorch use different Cuda streams.
|
|
||||||
res.syncDefaultStreamCurrentDevice()
|
|
||||||
|
|
||||||
assert np.all(Iref == I3.cpu().numpy())
|
|
||||||
|
|
||||||
def test_raw_array_search(self):
|
def test_raw_array_search(self):
|
||||||
d = 32
|
d = 32
|
||||||
|
@ -74,55 +72,51 @@ class PytorchFaissInterop(unittest.TestCase):
|
||||||
|
|
||||||
# resource object, can be re-used over calls
|
# resource object, can be re-used over calls
|
||||||
res = faiss.StandardGpuResources()
|
res = faiss.StandardGpuResources()
|
||||||
# put on same stream as pytorch to avoid synchronizing streams
|
|
||||||
res.setDefaultNullStreamAllDevices()
|
|
||||||
|
|
||||||
for xq_row_major in True, False:
|
# Let's have pytorch use a non-default stream
|
||||||
for xb_row_major in True, False:
|
s = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(s):
|
||||||
|
for xq_row_major in True, False:
|
||||||
|
for xb_row_major in True, False:
|
||||||
|
|
||||||
# move to pytorch & GPU
|
# move to pytorch & GPU
|
||||||
xq_t = torch.from_numpy(xq).cuda()
|
xq_t = torch.from_numpy(xq).cuda()
|
||||||
xb_t = torch.from_numpy(xb).cuda()
|
xb_t = torch.from_numpy(xb).cuda()
|
||||||
|
|
||||||
if not xq_row_major:
|
|
||||||
xq_t = to_column_major(xq_t)
|
|
||||||
assert not xq_t.is_contiguous()
|
|
||||||
|
|
||||||
if not xb_row_major:
|
|
||||||
xb_t = to_column_major(xb_t)
|
|
||||||
assert not xb_t.is_contiguous()
|
|
||||||
|
|
||||||
D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)
|
|
||||||
|
|
||||||
# back to CPU for verification
|
|
||||||
D = D.cpu().numpy()
|
|
||||||
I = I.cpu().numpy()
|
|
||||||
|
|
||||||
assert np.all(I == gt_I)
|
|
||||||
assert np.all(np.abs(D - gt_D).max() < 1e-4)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# test on subset
|
|
||||||
try:
|
|
||||||
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
|
|
||||||
except TypeError:
|
|
||||||
if not xq_row_major:
|
if not xq_row_major:
|
||||||
# then it is expected
|
xq_t = to_column_major(xq_t)
|
||||||
continue
|
assert not xq_t.is_contiguous()
|
||||||
# otherwise it is an error
|
|
||||||
raise
|
|
||||||
|
|
||||||
# back to CPU for verification
|
if not xb_row_major:
|
||||||
D = D.cpu().numpy()
|
xb_t = to_column_major(xb_t)
|
||||||
I = I.cpu().numpy()
|
assert not xb_t.is_contiguous()
|
||||||
|
|
||||||
assert np.all(I == gt_I[60:80])
|
D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)
|
||||||
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
|
|
||||||
|
|
||||||
|
# back to CPU for verification
|
||||||
|
D = D.cpu().numpy()
|
||||||
|
I = I.cpu().numpy()
|
||||||
|
|
||||||
|
assert np.all(I == gt_I)
|
||||||
|
assert np.all(np.abs(D - gt_D).max() < 1e-4)
|
||||||
|
|
||||||
|
# test on subset
|
||||||
|
try:
|
||||||
|
# This internally uses the current pytorch stream
|
||||||
|
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
|
||||||
|
except TypeError:
|
||||||
|
if not xq_row_major:
|
||||||
|
# then it is expected
|
||||||
|
continue
|
||||||
|
# otherwise it is an error
|
||||||
|
raise
|
||||||
|
|
||||||
|
# back to CPU for verification
|
||||||
|
D = D.cpu().numpy()
|
||||||
|
I = I.cpu().numpy()
|
||||||
|
|
||||||
|
assert np.all(I == gt_I[60:80])
|
||||||
|
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -274,7 +274,6 @@ void gpu_sync_all_devices()
|
||||||
|
|
||||||
%}
|
%}
|
||||||
|
|
||||||
|
|
||||||
%template() std::pair<int, unsigned long>;
|
%template() std::pair<int, unsigned long>;
|
||||||
%template() std::map<std::string, std::pair<int, unsigned long> >;
|
%template() std::map<std::string, std::pair<int, unsigned long> >;
|
||||||
%template() std::map<int, std::map<std::string, std::pair<int, unsigned long> > >;
|
%template() std::map<int, std::map<std::string, std::pair<int, unsigned long> > >;
|
||||||
|
@ -287,6 +286,21 @@ void gpu_sync_all_devices()
|
||||||
%include <faiss/gpu/GpuResources.h>
|
%include <faiss/gpu/GpuResources.h>
|
||||||
%include <faiss/gpu/StandardGpuResources.h>
|
%include <faiss/gpu/StandardGpuResources.h>
|
||||||
|
|
||||||
|
typedef CUstream_st* cudaStream_t;
|
||||||
|
|
||||||
|
%inline %{
|
||||||
|
|
||||||
|
// interop between pytorch exposed cudaStream_t and faiss
|
||||||
|
cudaStream_t cast_integer_to_cudastream_t(long x) {
|
||||||
|
return (cudaStream_t) x;
|
||||||
|
}
|
||||||
|
|
||||||
|
long cast_cudastream_t_to_integer(cudaStream_t x) {
|
||||||
|
return (long) x;
|
||||||
|
}
|
||||||
|
|
||||||
|
%}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
%{
|
%{
|
||||||
|
|
Loading…
Reference in New Issue