Improve Faiss / PyTorch GPU interoperability
Summary: PyTorch GPU in general is free to use whatever stream it currently wants, based on `torch.cuda.current_stream()`. Due to C++/python language barrier issues, we couldn't previously pass the actual `cudaStream_t` that is currently in use on a given device from PyTorch C++ to Faiss C++ via python. This diff adds conversion functions to convert a Python integer representing a pointer to a `cudaStream_t` (which is itself a `CUstream_st*`), so we can pass the stream specified in `torch.cuda.current_stream()` to `StandardGpuResources::setDefaultStream`. We thus guarantee that all Faiss work is ordered on the same stream that is in use in PyTorch. For use in Python, there is now the `faiss.contrib.pytorch_tensors.using_stream` context object which automatically sets and unsets the current PyTorch stream within Faiss. This takes a `StandardGpuResources` object in Python, and an optional `torch.cuda.Stream` if one wants to use a different stream, otherwise it uses the current one. This is how it is used: ``` # Create a non-default stream s = torch.cuda.Stream() # Have Torch use it with torch.cuda.stream(s): # Have Faiss use the same stream as the above with faiss.contrib.pytorch_tensors.using_stream(res): # Do some work on the GPU faiss.bfKnn(res, args) ``` `using_stream` uses the same pattern as the Pytorch `torch.cuda.stream` object. This replaces any brute-force GPU/CPU synchronization work that was necessary before. Other changes in this diff: - cleans up the config objects in the GpuIndex subclasses, to distinguish between read-only parameters that can only be set upon index construction, versus those that can be changed at runtime. - StandardGpuResources now more properly distinguishes between user-supplied streams (like the PyTorch one) which will not be destroyed upon resources destruction, versus internal streams. - `search_index_pytorch` now needs to take a `StandardGpuResources` object as well, there is no way to get this from an index instance otherwise (or at least, I would have to return a `shared_ptr`, in which case we should just update the Python SWIG stuff to use `shared_ptr` for GpuResources or something. Reviewed By: mdouze Differential Revision: D24260026 fbshipit-source-id: b18bb0eb34eb012584b1c923088228776c10b720pull/1462/head
parent
b459931ae4
commit
e796f4f9df
|
@ -1,5 +1,6 @@
|
|||
import faiss
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
def swig_ptr_from_FloatTensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch trensor (on CPU or GPU) """
|
||||
|
@ -15,9 +16,33 @@ def swig_ptr_from_LongTensor(x):
|
|||
return faiss.cast_integer_to_long_ptr(
|
||||
x.storage().data_ptr() + x.storage_offset() * 8)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_stream(res, pytorch_stream=None):
|
||||
r""" Creates a scoping object to make Faiss GPU use the same stream
|
||||
as pytorch, based on torch.cuda.current_stream().
|
||||
Or, a specific pytorch stream can be passed in as a second
|
||||
argument, in which case we will use that stream.
|
||||
"""
|
||||
|
||||
if pytorch_stream is None:
|
||||
pytorch_stream = torch.cuda.current_stream()
|
||||
|
||||
def search_index_pytorch(index, x, k, D=None, I=None):
|
||||
# This is the cudaStream_t that we wish to use
|
||||
cuda_stream_s = faiss.cast_integer_to_cudastream_t(pytorch_stream.cuda_stream)
|
||||
|
||||
# So we can revert GpuResources stream state upon exit
|
||||
prior_dev = torch.cuda.current_device()
|
||||
prior_stream = res.getDefaultStream(torch.cuda.current_device())
|
||||
|
||||
res.setDefaultStream(torch.cuda.current_device(), cuda_stream_s)
|
||||
|
||||
# Do the user work
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
res.setDefaultStream(prior_dev, prior_stream)
|
||||
|
||||
def search_index_pytorch(res, index, x, k, D=None, I=None):
|
||||
"""call the search function of an index with pytorch tensor I/O (CPU
|
||||
and GPU supported)"""
|
||||
assert x.is_contiguous()
|
||||
|
@ -33,13 +58,13 @@ def search_index_pytorch(index, x, k, D=None, I=None):
|
|||
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
|
||||
else:
|
||||
assert I.size() == (n, k)
|
||||
torch.cuda.synchronize()
|
||||
xptr = swig_ptr_from_FloatTensor(x)
|
||||
Iptr = swig_ptr_from_LongTensor(I)
|
||||
Dptr = swig_ptr_from_FloatTensor(D)
|
||||
index.search_c(n, xptr,
|
||||
k, Dptr, Iptr)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with using_stream(res):
|
||||
xptr = swig_ptr_from_FloatTensor(x)
|
||||
Iptr = swig_ptr_from_LongTensor(I)
|
||||
Dptr = swig_ptr_from_FloatTensor(D)
|
||||
index.search_c(n, xptr, k, Dptr, Iptr)
|
||||
|
||||
return D, I
|
||||
|
||||
|
||||
|
@ -97,6 +122,8 @@ def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,
|
|||
args.numQueries = nq
|
||||
args.outDistances = D_ptr
|
||||
args.outIndices = I_ptr
|
||||
faiss.bfKnn(res, args)
|
||||
|
||||
with using_stream(res):
|
||||
faiss.bfKnn(res, args)
|
||||
|
||||
return D, I
|
||||
|
|
|
@ -42,25 +42,29 @@ GpuIndex::GpuIndex(std::shared_ptr<GpuResources> resources,
|
|||
GpuIndexConfig config) :
|
||||
Index(dims, metric),
|
||||
resources_(resources),
|
||||
device_(config.device),
|
||||
memorySpace_(config.memorySpace),
|
||||
config_(config),
|
||||
minPagedSize_(kMinPageSize) {
|
||||
FAISS_THROW_IF_NOT_FMT(device_ < getNumDevices(),
|
||||
"Invalid GPU device %d", device_);
|
||||
FAISS_THROW_IF_NOT_FMT(config_.device < getNumDevices(),
|
||||
"Invalid GPU device %d", config_.device);
|
||||
|
||||
FAISS_THROW_IF_NOT_MSG(dims > 0, "Invalid number of dimensions");
|
||||
|
||||
FAISS_THROW_IF_NOT_FMT(
|
||||
memorySpace_ == MemorySpace::Device ||
|
||||
(memorySpace_ == MemorySpace::Unified &&
|
||||
getFullUnifiedMemSupport(device_)),
|
||||
config_.memorySpace == MemorySpace::Device ||
|
||||
(config_.memorySpace == MemorySpace::Unified &&
|
||||
getFullUnifiedMemSupport(config_.device)),
|
||||
"Device %d does not support full CUDA 8 Unified Memory (CC 6.0+)",
|
||||
config.device);
|
||||
|
||||
metric_arg = metricArg;
|
||||
|
||||
FAISS_ASSERT((bool) resources_);
|
||||
resources_->initializeForDevice(device_);
|
||||
resources_->initializeForDevice(config_.device);
|
||||
}
|
||||
|
||||
int
|
||||
GpuIndex::getDevice() const {
|
||||
return config_.device;
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -124,7 +128,7 @@ GpuIndex::add_with_ids(Index::idx_t n,
|
|||
}
|
||||
}
|
||||
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
addPaged_((int) n, x, ids ? ids : generatedIds.data());
|
||||
}
|
||||
|
||||
|
@ -171,7 +175,7 @@ GpuIndex::addPage_(int n,
|
|||
|
||||
auto vecs =
|
||||
toDeviceTemporary<float, 2>(resources_.get(),
|
||||
device_,
|
||||
config_.device,
|
||||
const_cast<float*>(x),
|
||||
stream,
|
||||
{n, this->d});
|
||||
|
@ -179,7 +183,7 @@ GpuIndex::addPage_(int n,
|
|||
if (ids) {
|
||||
auto indices =
|
||||
toDeviceTemporary<Index::idx_t, 1>(resources_.get(),
|
||||
device_,
|
||||
config_.device,
|
||||
const_cast<Index::idx_t*>(ids),
|
||||
stream,
|
||||
{n});
|
||||
|
@ -214,8 +218,8 @@ GpuIndex::search(Index::idx_t n,
|
|||
return;
|
||||
}
|
||||
|
||||
DeviceScope scope(device_);
|
||||
auto stream = resources_->getDefaultStream(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
|
||||
// We guarantee that the searchImpl_ will be called with device-resident
|
||||
// pointers.
|
||||
|
@ -228,12 +232,12 @@ GpuIndex::search(Index::idx_t n,
|
|||
// another level of tiling.
|
||||
auto outDistances =
|
||||
toDeviceTemporary<float, 2>(
|
||||
resources_.get(), device_, distances, stream,
|
||||
resources_.get(), config_.device, distances, stream,
|
||||
{(int) n, (int) k});
|
||||
|
||||
auto outLabels =
|
||||
toDeviceTemporary<faiss::Index::idx_t, 2>(
|
||||
resources_.get(), device_, labels, stream,
|
||||
resources_.get(), config_.device, labels, stream,
|
||||
{(int) n, (int) k});
|
||||
|
||||
bool usePaged = false;
|
||||
|
@ -272,12 +276,12 @@ GpuIndex::searchNonPaged_(int n,
|
|||
int k,
|
||||
float* outDistancesData,
|
||||
Index::idx_t* outIndicesData) const {
|
||||
auto stream = resources_->getDefaultStream(device_);
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
|
||||
// Make sure arguments are on the device we desire; use temporary
|
||||
// memory allocations to move it if necessary
|
||||
auto vecs = toDeviceTemporary<float, 2>(resources_.get(),
|
||||
device_,
|
||||
config_.device,
|
||||
const_cast<float*>(x),
|
||||
stream,
|
||||
{n, (int) this->d});
|
||||
|
@ -335,8 +339,8 @@ GpuIndex::searchFromCpuPaged_(int n,
|
|||
// 1 2 3 1 ... (pinned buf A)
|
||||
// time ->
|
||||
//
|
||||
auto defaultStream = resources_->getDefaultStream(device_);
|
||||
auto copyStream = resources_->getAsyncCopyStream(device_);
|
||||
auto defaultStream = resources_->getDefaultStream(config_.device);
|
||||
auto copyStream = resources_->getAsyncCopyStream(config_.device);
|
||||
|
||||
FAISS_ASSERT((size_t) pageSizeInVecs * this->d <=
|
||||
(size_t) std::numeric_limits<int>::max());
|
||||
|
|
|
@ -36,9 +36,8 @@ class GpuIndex : public faiss::Index {
|
|||
float metricArg,
|
||||
GpuIndexConfig config);
|
||||
|
||||
inline int getDevice() const {
|
||||
return device_;
|
||||
}
|
||||
/// Returns the device that this index is resident on
|
||||
int getDevice() const;
|
||||
|
||||
/// Set the minimum data size for searches (in MiB) for which we use
|
||||
/// CPU -> GPU paging
|
||||
|
@ -136,11 +135,8 @@ private:
|
|||
/// Manages streams, cuBLAS handles and scratch memory for devices
|
||||
std::shared_ptr<GpuResources> resources_;
|
||||
|
||||
/// The GPU device we are resident on
|
||||
const int device_;
|
||||
|
||||
/// The memory space of our primary storage on the GPU
|
||||
const MemorySpace memorySpace_;
|
||||
/// Our configuration options
|
||||
const GpuIndexConfig config_;
|
||||
|
||||
/// Size above which we page copies from the CPU to GPU
|
||||
size_t minPagedSize_;
|
||||
|
|
|
@ -23,7 +23,7 @@ GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResourcesProvider* provider,
|
|||
GpuIndexBinaryFlatConfig config)
|
||||
: IndexBinary(index->d),
|
||||
resources_(provider->getResources()),
|
||||
config_(std::move(config)) {
|
||||
binaryFlatConfig_(config) {
|
||||
FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0,
|
||||
"vector dimension (number of bits) "
|
||||
"must be divisible by 8 (passed %d)",
|
||||
|
@ -41,7 +41,7 @@ GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResourcesProvider* provider,
|
|||
GpuIndexBinaryFlatConfig config)
|
||||
: IndexBinary(dims),
|
||||
resources_(provider->getResources()),
|
||||
config_(std::move(config)) {
|
||||
binaryFlatConfig_(std::move(config)) {
|
||||
FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0,
|
||||
"vector dimension (number of bits) "
|
||||
"must be divisible by 8 (passed %d)",
|
||||
|
@ -51,17 +51,22 @@ GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResourcesProvider* provider,
|
|||
this->is_trained = true;
|
||||
|
||||
// Construct index
|
||||
DeviceScope scope(config_.device);
|
||||
data_.reset(
|
||||
new BinaryFlatIndex(resources_.get(), this->d, config_.memorySpace));
|
||||
DeviceScope scope(binaryFlatConfig_.device);
|
||||
data_.reset(new BinaryFlatIndex(resources_.get(),
|
||||
this->d, binaryFlatConfig_.memorySpace));
|
||||
}
|
||||
|
||||
GpuIndexBinaryFlat::~GpuIndexBinaryFlat() {
|
||||
}
|
||||
|
||||
int
|
||||
GpuIndexBinaryFlat::getDevice() const {
|
||||
return binaryFlatConfig_.device;
|
||||
}
|
||||
|
||||
void
|
||||
GpuIndexBinaryFlat::copyFrom(const faiss::IndexBinaryFlat* index) {
|
||||
DeviceScope scope(config_.device);
|
||||
DeviceScope scope(binaryFlatConfig_.device);
|
||||
|
||||
this->d = index->d;
|
||||
|
||||
|
@ -76,20 +81,20 @@ GpuIndexBinaryFlat::copyFrom(const faiss::IndexBinaryFlat* index) {
|
|||
|
||||
// destroy old first before allocating new
|
||||
data_.reset();
|
||||
data_.reset(
|
||||
new BinaryFlatIndex(resources_.get(), this->d, config_.memorySpace));
|
||||
data_.reset(new BinaryFlatIndex(resources_.get(),
|
||||
this->d, binaryFlatConfig_.memorySpace));
|
||||
|
||||
// The index could be empty
|
||||
if (index->ntotal > 0) {
|
||||
data_->add(index->xb.data(),
|
||||
index->ntotal,
|
||||
resources_->getDefaultStream(config_.device));
|
||||
resources_->getDefaultStream(binaryFlatConfig_.device));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
GpuIndexBinaryFlat::copyTo(faiss::IndexBinaryFlat* index) const {
|
||||
DeviceScope scope(config_.device);
|
||||
DeviceScope scope(binaryFlatConfig_.device);
|
||||
|
||||
index->d = this->d;
|
||||
index->ntotal = this->ntotal;
|
||||
|
@ -101,18 +106,18 @@ GpuIndexBinaryFlat::copyTo(faiss::IndexBinaryFlat* index) const {
|
|||
if (this->ntotal > 0) {
|
||||
fromDevice(data_->getVectorsRef(),
|
||||
index->xb.data(),
|
||||
resources_->getDefaultStream(config_.device));
|
||||
resources_->getDefaultStream(binaryFlatConfig_.device));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
GpuIndexBinaryFlat::add(faiss::IndexBinary::idx_t n,
|
||||
const uint8_t* x) {
|
||||
DeviceScope scope(config_.device);
|
||||
DeviceScope scope(binaryFlatConfig_.device);
|
||||
|
||||
// To avoid multiple re-allocations, ensure we have enough storage
|
||||
// available
|
||||
data_->reserve(n, resources_->getDefaultStream(config_.device));
|
||||
data_->reserve(n, resources_->getDefaultStream(binaryFlatConfig_.device));
|
||||
|
||||
// Due to GPU indexing in int32, we can't store more than this
|
||||
// number of vectors on a GPU
|
||||
|
@ -123,13 +128,13 @@ GpuIndexBinaryFlat::add(faiss::IndexBinary::idx_t n,
|
|||
|
||||
data_->add((const unsigned char*) x,
|
||||
n,
|
||||
resources_->getDefaultStream(config_.device));
|
||||
resources_->getDefaultStream(binaryFlatConfig_.device));
|
||||
this->ntotal += n;
|
||||
}
|
||||
|
||||
void
|
||||
GpuIndexBinaryFlat::reset() {
|
||||
DeviceScope scope(config_.device);
|
||||
DeviceScope scope(binaryFlatConfig_.device);
|
||||
|
||||
// Free the underlying memory
|
||||
data_->reset();
|
||||
|
@ -155,8 +160,8 @@ GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n,
|
|||
getMaxKSelection(),
|
||||
(int) k); // select limitation
|
||||
|
||||
DeviceScope scope(config_.device);
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
DeviceScope scope(binaryFlatConfig_.device);
|
||||
auto stream = resources_->getDefaultStream(binaryFlatConfig_.device);
|
||||
|
||||
// The input vectors may be too large for the GPU, but we still
|
||||
// assume that the output distances and labels are not.
|
||||
|
@ -165,7 +170,7 @@ GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n,
|
|||
// If we reach a point where all inputs are too big, we can add
|
||||
// another level of tiling.
|
||||
auto outDistances = toDeviceTemporary<int32_t, 2>(resources_.get(),
|
||||
config_.device,
|
||||
binaryFlatConfig_.device,
|
||||
distances,
|
||||
stream,
|
||||
{(int) n, (int) k});
|
||||
|
@ -203,7 +208,7 @@ GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n,
|
|||
// Convert and copy int indices out
|
||||
auto outIndices =
|
||||
toDeviceTemporary<faiss::Index::idx_t, 2>(resources_.get(),
|
||||
config_.device,
|
||||
binaryFlatConfig_.device,
|
||||
labels,
|
||||
stream,
|
||||
{(int) n, (int) k});
|
||||
|
@ -227,12 +232,12 @@ GpuIndexBinaryFlat::searchNonPaged_(int n,
|
|||
Tensor<int32_t, 2, true> outDistances(outDistancesData, {n, k});
|
||||
Tensor<int, 2, true> outIndices(outIndicesData, {n, k});
|
||||
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
auto stream = resources_->getDefaultStream(binaryFlatConfig_.device);
|
||||
|
||||
// Make sure arguments are on the device we desire; use temporary
|
||||
// memory allocations to move it if necessary
|
||||
auto vecs = toDeviceTemporary<uint8_t, 2>(resources_.get(),
|
||||
config_.device,
|
||||
binaryFlatConfig_.device,
|
||||
const_cast<uint8_t*>(x),
|
||||
stream,
|
||||
{n, (int) (this->d / 8)});
|
||||
|
@ -272,10 +277,10 @@ GpuIndexBinaryFlat::searchFromCpuPaged_(int n,
|
|||
void
|
||||
GpuIndexBinaryFlat::reconstruct(faiss::IndexBinary::idx_t key,
|
||||
uint8_t* out) const {
|
||||
DeviceScope scope(config_.device);
|
||||
DeviceScope scope(binaryFlatConfig_.device);
|
||||
|
||||
FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds");
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
auto stream = resources_->getDefaultStream(binaryFlatConfig_.device);
|
||||
|
||||
auto& vecs = data_->getVectorsRef();
|
||||
auto vec = vecs[key];
|
||||
|
|
|
@ -38,6 +38,9 @@ class GpuIndexBinaryFlat : public IndexBinary {
|
|||
|
||||
~GpuIndexBinaryFlat() override;
|
||||
|
||||
/// Returns the device that this index is resident on
|
||||
int getDevice() const;
|
||||
|
||||
/// Initialize ourselves from the given CPU index; will overwrite
|
||||
/// all data in ourselves
|
||||
void copyFrom(const faiss::IndexBinaryFlat* index);
|
||||
|
@ -80,7 +83,7 @@ class GpuIndexBinaryFlat : public IndexBinary {
|
|||
std::shared_ptr<GpuResources> resources_;
|
||||
|
||||
/// Configuration options
|
||||
GpuIndexBinaryFlatConfig config_;
|
||||
const GpuIndexBinaryFlatConfig binaryFlatConfig_;
|
||||
|
||||
/// Holds our GPU data containing the list of vectors
|
||||
std::unique_ptr<BinaryFlatIndex> data_;
|
||||
|
|
|
@ -27,7 +27,7 @@ GpuIndexFlat::GpuIndexFlat(GpuResourcesProvider* provider,
|
|||
index->metric_type,
|
||||
index->metric_arg,
|
||||
config),
|
||||
config_(std::move(config)) {
|
||||
flatConfig_(config) {
|
||||
// Flat index doesn't need training
|
||||
this->is_trained = true;
|
||||
|
||||
|
@ -42,7 +42,7 @@ GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
|
|||
index->metric_type,
|
||||
index->metric_arg,
|
||||
config),
|
||||
config_(std::move(config)) {
|
||||
flatConfig_(config) {
|
||||
// Flat index doesn't need training
|
||||
this->is_trained = true;
|
||||
|
||||
|
@ -58,17 +58,17 @@ GpuIndexFlat::GpuIndexFlat(GpuResourcesProvider* provider,
|
|||
metric,
|
||||
0,
|
||||
config),
|
||||
config_(std::move(config)) {
|
||||
flatConfig_(config) {
|
||||
// Flat index doesn't need training
|
||||
this->is_trained = true;
|
||||
|
||||
// Construct index
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
data_.reset(new FlatIndex(resources_.get(),
|
||||
dims,
|
||||
config_.useFloat16,
|
||||
config_.storeTransposed,
|
||||
memorySpace_));
|
||||
flatConfig_.useFloat16,
|
||||
flatConfig_.storeTransposed,
|
||||
config_.memorySpace));
|
||||
}
|
||||
|
||||
GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
|
||||
|
@ -76,17 +76,17 @@ GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
|
|||
faiss::MetricType metric,
|
||||
GpuIndexFlatConfig config) :
|
||||
GpuIndex(resources, dims, metric, 0, config),
|
||||
config_(std::move(config)) {
|
||||
flatConfig_(config) {
|
||||
// Flat index doesn't need training
|
||||
this->is_trained = true;
|
||||
|
||||
// Construct index
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
data_.reset(new FlatIndex(resources_.get(),
|
||||
dims,
|
||||
config_.useFloat16,
|
||||
config_.storeTransposed,
|
||||
memorySpace_));
|
||||
flatConfig_.useFloat16,
|
||||
flatConfig_.storeTransposed,
|
||||
config_.memorySpace));
|
||||
}
|
||||
|
||||
GpuIndexFlat::~GpuIndexFlat() {
|
||||
|
@ -94,7 +94,7 @@ GpuIndexFlat::~GpuIndexFlat() {
|
|||
|
||||
void
|
||||
GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
GpuIndex::copyFrom(index);
|
||||
|
||||
|
@ -109,21 +109,21 @@ GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) {
|
|||
data_.reset();
|
||||
data_.reset(new FlatIndex(resources_.get(),
|
||||
this->d,
|
||||
config_.useFloat16,
|
||||
config_.storeTransposed,
|
||||
memorySpace_));
|
||||
flatConfig_.useFloat16,
|
||||
flatConfig_.storeTransposed,
|
||||
config_.memorySpace));
|
||||
|
||||
// The index could be empty
|
||||
if (index->ntotal > 0) {
|
||||
data_->add(index->xb.data(),
|
||||
index->ntotal,
|
||||
resources_->getDefaultStream(device_));
|
||||
resources_->getDefaultStream(config_.device));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
GpuIndexFlat::copyTo(faiss::IndexFlat* index) const {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
GpuIndex::copyTo(index);
|
||||
|
||||
|
@ -131,10 +131,10 @@ GpuIndexFlat::copyTo(faiss::IndexFlat* index) const {
|
|||
FAISS_ASSERT(data_->getSize() == this->ntotal);
|
||||
index->xb.resize(this->ntotal * this->d);
|
||||
|
||||
auto stream = resources_->getDefaultStream(device_);
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
|
||||
if (this->ntotal > 0) {
|
||||
if (config_.useFloat16) {
|
||||
if (flatConfig_.useFloat16) {
|
||||
auto vecFloat32 = data_->getVectorsFloat32Copy(stream);
|
||||
fromDevice(vecFloat32, index->xb.data(), stream);
|
||||
} else {
|
||||
|
@ -150,7 +150,7 @@ GpuIndexFlat::getNumVecs() const {
|
|||
|
||||
void
|
||||
GpuIndexFlat::reset() {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
// Free the underlying memory
|
||||
data_->reset();
|
||||
|
@ -176,15 +176,15 @@ GpuIndexFlat::add(Index::idx_t n, const float* x) {
|
|||
return;
|
||||
}
|
||||
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
// To avoid multiple re-allocations, ensure we have enough storage
|
||||
// available
|
||||
data_->reserve(n, resources_->getDefaultStream(device_));
|
||||
data_->reserve(n, resources_->getDefaultStream(config_.device));
|
||||
|
||||
// If we're not operating in float16 mode, we don't need the input
|
||||
// data to be resident on our device; we can add directly.
|
||||
if (!config_.useFloat16) {
|
||||
if (!flatConfig_.useFloat16) {
|
||||
addImpl_(n, x, nullptr);
|
||||
} else {
|
||||
// Otherwise, perform the paging
|
||||
|
@ -214,7 +214,7 @@ GpuIndexFlat::addImpl_(int n,
|
|||
"GPU index only supports up to %zu indices",
|
||||
(size_t) std::numeric_limits<int>::max());
|
||||
|
||||
data_->add(x, n, resources_->getDefaultStream(device_));
|
||||
data_->add(x, n, resources_->getDefaultStream(config_.device));
|
||||
this->ntotal += n;
|
||||
}
|
||||
|
||||
|
@ -224,7 +224,7 @@ GpuIndexFlat::searchImpl_(int n,
|
|||
int k,
|
||||
float* distances,
|
||||
Index::idx_t* labels) const {
|
||||
auto stream = resources_->getDefaultStream(device_);
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
|
||||
// Input and output data are already resident on the GPU
|
||||
Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int) this->d});
|
||||
|
@ -249,12 +249,12 @@ GpuIndexFlat::searchImpl_(int n,
|
|||
void
|
||||
GpuIndexFlat::reconstruct(faiss::Index::idx_t key,
|
||||
float* out) const {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds");
|
||||
auto stream = resources_->getDefaultStream(device_);
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
|
||||
if (config_.useFloat16) {
|
||||
if (flatConfig_.useFloat16) {
|
||||
// FIXME jhj: kernel for copy
|
||||
auto vec = data_->getVectorsFloat32Copy(key, 1, stream);
|
||||
fromDevice(vec.data(), out, this->d, stream);
|
||||
|
@ -268,13 +268,13 @@ void
|
|||
GpuIndexFlat::reconstruct_n(faiss::Index::idx_t i0,
|
||||
faiss::Index::idx_t num,
|
||||
float* out) const {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
FAISS_THROW_IF_NOT_MSG(i0 < this->ntotal, "index out of bounds");
|
||||
FAISS_THROW_IF_NOT_MSG(i0 + num - 1 < this->ntotal, "num out of bounds");
|
||||
auto stream = resources_->getDefaultStream(device_);
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
|
||||
if (config_.useFloat16) {
|
||||
if (flatConfig_.useFloat16) {
|
||||
// FIXME jhj: kernel for copy
|
||||
auto vec = data_->getVectorsFloat32Copy(i0, num, stream);
|
||||
fromDevice(vec.data(), out, num * this->d, stream);
|
||||
|
@ -301,23 +301,24 @@ GpuIndexFlat::compute_residual_n(faiss::Index::idx_t n,
|
|||
"GPU index only supports up to %zu indices",
|
||||
(size_t) std::numeric_limits<int>::max());
|
||||
|
||||
auto stream = resources_->getDefaultStream(device_);
|
||||
auto stream = resources_->getDefaultStream(config_.device);
|
||||
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
auto vecsDevice =
|
||||
toDeviceTemporary<float, 2>(resources_.get(), device_,
|
||||
toDeviceTemporary<float, 2>(resources_.get(), config_.device,
|
||||
const_cast<float*>(xs), stream,
|
||||
{(int) n, (int) this->d});
|
||||
auto idsDevice =
|
||||
toDeviceTemporary<faiss::Index::idx_t, 1>(
|
||||
resources_.get(), device_,
|
||||
resources_.get(), config_.device,
|
||||
const_cast<faiss::Index::idx_t*>(keys),
|
||||
stream,
|
||||
{(int) n});
|
||||
|
||||
auto residualDevice =
|
||||
toDeviceTemporary<float, 2>(resources_.get(), device_, residuals, stream,
|
||||
toDeviceTemporary<float, 2>(resources_.get(),
|
||||
config_.device, residuals, stream,
|
||||
{(int) n, (int) this->d});
|
||||
|
||||
// Convert idx_t to int
|
||||
|
|
|
@ -129,8 +129,8 @@ class GpuIndexFlat : public GpuIndex {
|
|||
faiss::Index::idx_t* labels) const override;
|
||||
|
||||
protected:
|
||||
/// Our config object
|
||||
const GpuIndexFlatConfig config_;
|
||||
/// Our configuration options
|
||||
const GpuIndexFlatConfig flatConfig_;
|
||||
|
||||
/// Holds our GPU data containing the list of vectors
|
||||
std::unique_ptr<FlatIndex> data_;
|
||||
|
|
|
@ -22,12 +22,11 @@ GpuIndexIVF::GpuIndexIVF(GpuResourcesProvider* provider,
|
|||
float metricArg,
|
||||
int nlistIn,
|
||||
GpuIndexIVFConfig config) :
|
||||
GpuIndex(provider->getResources(),
|
||||
dims, metric, metricArg, config),
|
||||
GpuIndex(provider->getResources(), dims, metric, metricArg, config),
|
||||
nlist(nlistIn),
|
||||
nprobe(1),
|
||||
quantizer(nullptr),
|
||||
ivfConfig_(std::move(config)) {
|
||||
ivfConfig_(config) {
|
||||
init_();
|
||||
|
||||
// Only IP and L2 are supported for now
|
||||
|
@ -55,7 +54,7 @@ GpuIndexIVF::init_() {
|
|||
// Construct an empty quantizer
|
||||
GpuIndexFlatConfig config = ivfConfig_.flatConfig;
|
||||
// FIXME: inherit our same device
|
||||
config.device = device_;
|
||||
config.device = config_.device;
|
||||
|
||||
if (metric_type == faiss::METRIC_L2) {
|
||||
quantizer = new GpuIndexFlatL2(resources_, d, config);
|
||||
|
@ -79,7 +78,7 @@ GpuIndexIVF::getQuantizer() {
|
|||
|
||||
void
|
||||
GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
GpuIndex::copyFrom(index);
|
||||
|
||||
|
@ -105,7 +104,7 @@ GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
|
|||
// Construct an empty quantizer
|
||||
GpuIndexFlatConfig config = ivfConfig_.flatConfig;
|
||||
// FIXME: inherit our same device
|
||||
config.device = device_;
|
||||
config.device = config_.device;
|
||||
|
||||
if (index->metric_type == faiss::METRIC_L2) {
|
||||
// FIXME: 2 different float16 options?
|
||||
|
@ -143,7 +142,7 @@ GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
|
|||
|
||||
void
|
||||
GpuIndexIVF::copyTo(faiss::IndexIVF* index) const {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
//
|
||||
// Index information
|
||||
|
@ -228,7 +227,7 @@ GpuIndexIVF::trainQuantizer_(faiss::Index::idx_t n, const float* x) {
|
|||
printf ("Training IVF quantizer on %ld vectors in %dD\n", n, d);
|
||||
}
|
||||
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
// leverage the CPU-side k-means code, which works for the GPU
|
||||
// flat index as well
|
||||
|
|
|
@ -83,7 +83,8 @@ class GpuIndexIVF : public GpuIndex {
|
|||
GpuIndexFlat* quantizer;
|
||||
|
||||
protected:
|
||||
GpuIndexIVFConfig ivfConfig_;
|
||||
/// Our configuration options
|
||||
const GpuIndexIVFConfig ivfConfig_;
|
||||
};
|
||||
|
||||
} } // namespace
|
||||
|
|
|
@ -57,14 +57,14 @@ void
|
|||
GpuIndexIVFFlat::reserveMemory(size_t numVecs) {
|
||||
reserveMemoryVecs_ = numVecs;
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
index_->reserveMemory(numVecs);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
GpuIndexIVF::copyFrom(index);
|
||||
|
||||
|
@ -88,7 +88,7 @@ GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
|
|||
false, // no residual
|
||||
nullptr, // no scalar quantizer
|
||||
ivfFlatConfig_.indicesOptions,
|
||||
memorySpace_));
|
||||
config_.memorySpace));
|
||||
|
||||
// Copy all of the IVF data
|
||||
index_->copyInvertedListsFrom(index->invlists);
|
||||
|
@ -96,7 +96,7 @@ GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
|
|||
|
||||
void
|
||||
GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
// We must have the indices in order to copy to ourselves
|
||||
FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF,
|
||||
|
@ -118,7 +118,7 @@ GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const {
|
|||
size_t
|
||||
GpuIndexIVFFlat::reclaimMemory() {
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
return index_->reclaimMemory();
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ GpuIndexIVFFlat::reclaimMemory() {
|
|||
void
|
||||
GpuIndexIVFFlat::reset() {
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
index_->reset();
|
||||
this->ntotal = 0;
|
||||
|
@ -140,7 +140,7 @@ GpuIndexIVFFlat::reset() {
|
|||
|
||||
void
|
||||
GpuIndexIVFFlat::train(Index::idx_t n, const float* x) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
if (this->is_trained) {
|
||||
FAISS_ASSERT(quantizer->is_trained);
|
||||
|
@ -161,7 +161,7 @@ GpuIndexIVFFlat::train(Index::idx_t n, const float* x) {
|
|||
false, // no residual
|
||||
nullptr, // no scalar quantizer
|
||||
ivfFlatConfig_.indicesOptions,
|
||||
memorySpace_));
|
||||
config_.memorySpace));
|
||||
|
||||
if (reserveMemoryVecs_) {
|
||||
index_->reserveMemory(reserveMemoryVecs_);
|
||||
|
|
|
@ -73,8 +73,9 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
|
|||
float* distances,
|
||||
Index::idx_t* labels) const override;
|
||||
|
||||
private:
|
||||
GpuIndexIVFFlatConfig ivfFlatConfig_;
|
||||
protected:
|
||||
/// Our configuration options
|
||||
const GpuIndexIVFFlatConfig ivfFlatConfig_;
|
||||
|
||||
/// Desired inverted list memory reservation
|
||||
size_t reserveMemoryVecs_;
|
||||
|
|
|
@ -30,6 +30,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
|
|||
index->nlist,
|
||||
config),
|
||||
ivfpqConfig_(config),
|
||||
usePrecomputedTables_(config.usePrecomputedTables),
|
||||
subQuantizers_(0),
|
||||
bitsPerCode_(0),
|
||||
reserveMemoryVecs_(0) {
|
||||
|
@ -50,6 +51,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
|
|||
nlist,
|
||||
config),
|
||||
ivfpqConfig_(config),
|
||||
usePrecomputedTables_(config.usePrecomputedTables),
|
||||
subQuantizers_(subQuantizers),
|
||||
bitsPerCode_(bitsPerCode),
|
||||
reserveMemoryVecs_(0) {
|
||||
|
@ -64,7 +66,7 @@ GpuIndexIVFPQ::~GpuIndexIVFPQ() {
|
|||
|
||||
void
|
||||
GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
GpuIndexIVF::copyFrom(index);
|
||||
|
||||
|
@ -105,9 +107,9 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
|
|||
ivfpqConfig_.alternativeLayout,
|
||||
(float*) index->pq.centroids.data(),
|
||||
ivfpqConfig_.indicesOptions,
|
||||
memorySpace_));
|
||||
config_.memorySpace));
|
||||
// Doesn't make sense to reserve memory here
|
||||
index_->setPrecomputedCodes(ivfpqConfig_.usePrecomputedTables);
|
||||
index_->setPrecomputedCodes(usePrecomputedTables_);
|
||||
|
||||
// Copy all of the IVF data
|
||||
index_->copyInvertedListsFrom(index->invlists);
|
||||
|
@ -115,7 +117,7 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
|
|||
|
||||
void
|
||||
GpuIndexIVFPQ::copyTo(faiss::IndexIVFPQ* index) const {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
// We must have the indices in order to copy to ourselves
|
||||
FAISS_THROW_IF_NOT_MSG(ivfpqConfig_.indicesOptions != INDICES_IVF,
|
||||
|
@ -153,9 +155,9 @@ GpuIndexIVFPQ::copyTo(faiss::IndexIVFPQ* index) const {
|
|||
|
||||
fromDevice<float, 3>(devPQCentroids,
|
||||
index->pq.centroids.data(),
|
||||
resources_->getDefaultStream(device_));
|
||||
resources_->getDefaultStream(config_.device));
|
||||
|
||||
if (ivfpqConfig_.usePrecomputedTables) {
|
||||
if (usePrecomputedTables_) {
|
||||
index->precompute_table();
|
||||
}
|
||||
}
|
||||
|
@ -165,16 +167,16 @@ void
|
|||
GpuIndexIVFPQ::reserveMemory(size_t numVecs) {
|
||||
reserveMemoryVecs_ = numVecs;
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
index_->reserveMemory(numVecs);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
GpuIndexIVFPQ::setPrecomputedCodes(bool enable) {
|
||||
ivfpqConfig_.usePrecomputedTables = enable;
|
||||
usePrecomputedTables_ = enable;
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
index_->setPrecomputedCodes(enable);
|
||||
}
|
||||
|
||||
|
@ -183,7 +185,7 @@ GpuIndexIVFPQ::setPrecomputedCodes(bool enable) {
|
|||
|
||||
bool
|
||||
GpuIndexIVFPQ::getPrecomputedCodes() const {
|
||||
return ivfpqConfig_.usePrecomputedTables;
|
||||
return usePrecomputedTables_;
|
||||
}
|
||||
|
||||
int
|
||||
|
@ -204,7 +206,7 @@ GpuIndexIVFPQ::getCentroidsPerSubQuantizer() const {
|
|||
size_t
|
||||
GpuIndexIVFPQ::reclaimMemory() {
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
return index_->reclaimMemory();
|
||||
}
|
||||
|
||||
|
@ -214,7 +216,7 @@ GpuIndexIVFPQ::reclaimMemory() {
|
|||
void
|
||||
GpuIndexIVFPQ::reset() {
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
index_->reset();
|
||||
this->ntotal = 0;
|
||||
|
@ -264,17 +266,17 @@ GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) {
|
|||
ivfpqConfig_.alternativeLayout,
|
||||
pq.centroids.data(),
|
||||
ivfpqConfig_.indicesOptions,
|
||||
memorySpace_));
|
||||
config_.memorySpace));
|
||||
if (reserveMemoryVecs_) {
|
||||
index_->reserveMemory(reserveMemoryVecs_);
|
||||
}
|
||||
|
||||
index_->setPrecomputedCodes(ivfpqConfig_.usePrecomputedTables);
|
||||
index_->setPrecomputedCodes(usePrecomputedTables_);
|
||||
}
|
||||
|
||||
void
|
||||
GpuIndexIVFPQ::train(Index::idx_t n, const float* x) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
if (this->is_trained) {
|
||||
FAISS_ASSERT(quantizer->is_trained);
|
||||
|
@ -289,7 +291,7 @@ GpuIndexIVFPQ::train(Index::idx_t n, const float* x) {
|
|||
// First, make sure that the data is resident on the CPU, if it is not on the
|
||||
// CPU, as we depend upon parts of the CPU code
|
||||
auto hostData = toHost<float, 2>((float*) x,
|
||||
resources_->getDefaultStream(device_),
|
||||
resources_->getDefaultStream(config_.device),
|
||||
{(int) n, (int) this->d});
|
||||
|
||||
trainQuantizer_(n, hostData.data());
|
||||
|
@ -351,7 +353,7 @@ GpuIndexIVFPQ::getListLength(int listId) const {
|
|||
std::vector<unsigned char>
|
||||
GpuIndexIVFPQ::getListCodes(int listId) const {
|
||||
FAISS_ASSERT(index_);
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
return index_->getListCodes(listId);
|
||||
}
|
||||
|
@ -359,7 +361,7 @@ GpuIndexIVFPQ::getListCodes(int listId) const {
|
|||
std::vector<long>
|
||||
GpuIndexIVFPQ::getListIndices(int listId) const {
|
||||
FAISS_ASSERT(index_);
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
return index_->getListIndices(listId);
|
||||
}
|
||||
|
@ -398,15 +400,15 @@ GpuIndexIVFPQ::verifySettings_() const {
|
|||
// codes per subquantizer
|
||||
size_t requiredSmemSize =
|
||||
lookupTableSize * subQuantizers_ * utils::pow2(bitsPerCode_);
|
||||
size_t smemPerBlock = getMaxSharedMemPerBlock(device_);
|
||||
size_t smemPerBlock = getMaxSharedMemPerBlock(config_.device);
|
||||
|
||||
FAISS_THROW_IF_NOT_FMT(requiredSmemSize
|
||||
<= getMaxSharedMemPerBlock(device_),
|
||||
<= getMaxSharedMemPerBlock(config_.device),
|
||||
"Device %d has %zu bytes of shared memory, while "
|
||||
"%d bits per code and %d sub-quantizers requires %zu "
|
||||
"bytes. Consider useFloat16LookupTables and/or "
|
||||
"reduce parameters",
|
||||
device_, smemPerBlock, bitsPerCode_, subQuantizers_,
|
||||
config_.device, smemPerBlock, bitsPerCode_, subQuantizers_,
|
||||
requiredSmemSize);
|
||||
}
|
||||
|
||||
|
|
|
@ -135,13 +135,16 @@ class GpuIndexIVFPQ : public GpuIndexIVF {
|
|||
float* distances,
|
||||
Index::idx_t* labels) const override;
|
||||
|
||||
private:
|
||||
void verifySettings_() const;
|
||||
|
||||
void trainResidualQuantizer_(Index::idx_t n, const float* x);
|
||||
|
||||
private:
|
||||
GpuIndexIVFPQConfig ivfpqConfig_;
|
||||
protected:
|
||||
/// Our configuration options that we were initialized with
|
||||
const GpuIndexIVFPQConfig ivfpqConfig_;
|
||||
|
||||
/// Runtime override: whether or not we use precomputed tables
|
||||
bool usePrecomputedTables_;
|
||||
|
||||
/// Number of sub-quantizers per encoded vector
|
||||
int subQuantizers_;
|
||||
|
|
|
@ -66,7 +66,7 @@ void
|
|||
GpuIndexIVFScalarQuantizer::reserveMemory(size_t numVecs) {
|
||||
reserveMemoryVecs_ = numVecs;
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
index_->reserveMemory(numVecs);
|
||||
}
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ GpuIndexIVFScalarQuantizer::reserveMemory(size_t numVecs) {
|
|||
void
|
||||
GpuIndexIVFScalarQuantizer::copyFrom(
|
||||
const faiss::IndexIVFScalarQuantizer* index) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
// Clear out our old data
|
||||
index_.reset();
|
||||
|
@ -101,7 +101,7 @@ GpuIndexIVFScalarQuantizer::copyFrom(
|
|||
by_residual,
|
||||
&sq,
|
||||
ivfSQConfig_.indicesOptions,
|
||||
memorySpace_));
|
||||
config_.memorySpace));
|
||||
|
||||
// Copy all of the IVF data
|
||||
index_->copyInvertedListsFrom(index->invlists);
|
||||
|
@ -110,7 +110,7 @@ GpuIndexIVFScalarQuantizer::copyFrom(
|
|||
void
|
||||
GpuIndexIVFScalarQuantizer::copyTo(
|
||||
faiss::IndexIVFScalarQuantizer* index) const {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
// We must have the indices in order to copy to ourselves
|
||||
FAISS_THROW_IF_NOT_MSG(
|
||||
|
@ -135,7 +135,7 @@ GpuIndexIVFScalarQuantizer::copyTo(
|
|||
size_t
|
||||
GpuIndexIVFScalarQuantizer::reclaimMemory() {
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
return index_->reclaimMemory();
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ GpuIndexIVFScalarQuantizer::reclaimMemory() {
|
|||
void
|
||||
GpuIndexIVFScalarQuantizer::reset() {
|
||||
if (index_) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
index_->reset();
|
||||
this->ntotal = 0;
|
||||
|
@ -163,7 +163,7 @@ GpuIndexIVFScalarQuantizer::trainResiduals_(Index::idx_t n, const float* x) {
|
|||
|
||||
void
|
||||
GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
|
||||
DeviceScope scope(device_);
|
||||
DeviceScope scope(config_.device);
|
||||
|
||||
if (this->is_trained) {
|
||||
FAISS_ASSERT(quantizer->is_trained);
|
||||
|
@ -178,7 +178,7 @@ GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
|
|||
// First, make sure that the data is resident on the CPU, if it is not on the
|
||||
// CPU, as we depend upon parts of the CPU code
|
||||
auto hostData = toHost<float, 2>((float*) x,
|
||||
resources_->getDefaultStream(device_),
|
||||
resources_->getDefaultStream(config_.device),
|
||||
{(int) n, (int) this->d});
|
||||
|
||||
trainQuantizer_(n, hostData.data());
|
||||
|
@ -192,7 +192,7 @@ GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
|
|||
by_residual,
|
||||
&sq,
|
||||
ivfSQConfig_.indicesOptions,
|
||||
memorySpace_));
|
||||
config_.memorySpace));
|
||||
|
||||
if (reserveMemoryVecs_) {
|
||||
index_->reserveMemory(reserveMemoryVecs_);
|
||||
|
|
|
@ -88,8 +88,9 @@ class GpuIndexIVFScalarQuantizer : public GpuIndexIVF {
|
|||
/// Exposed like the CPU version
|
||||
bool by_residual;
|
||||
|
||||
private:
|
||||
GpuIndexIVFScalarQuantizerConfig ivfSQConfig_;
|
||||
protected:
|
||||
/// Our configuration options
|
||||
const GpuIndexIVFScalarQuantizerConfig ivfSQConfig_;
|
||||
|
||||
/// Desired inverted list memory reservation
|
||||
size_t reserveMemoryVecs_;
|
||||
|
|
|
@ -101,12 +101,8 @@ StandardGpuResourcesImpl::~StandardGpuResourcesImpl() {
|
|||
for (auto& entry : defaultStreams_) {
|
||||
DeviceScope scope(entry.first);
|
||||
|
||||
auto it = userDefaultStreams_.find(entry.first);
|
||||
if (it == userDefaultStreams_.end()) {
|
||||
// The user did not specify this stream, thus we are the ones
|
||||
// who have created it
|
||||
CUDA_VERIFY(cudaStreamDestroy(entry.second));
|
||||
}
|
||||
// We created these streams, so are responsible for destroying them
|
||||
CUDA_VERIFY(cudaStreamDestroy(entry.second));
|
||||
}
|
||||
|
||||
for (auto& entry : alternateStreams_) {
|
||||
|
@ -210,16 +206,14 @@ StandardGpuResourcesImpl::setPinnedMemory(size_t size) {
|
|||
|
||||
void
|
||||
StandardGpuResourcesImpl::setDefaultStream(int device, cudaStream_t stream) {
|
||||
auto it = defaultStreams_.find(device);
|
||||
if (it != defaultStreams_.end()) {
|
||||
// Replace this stream with the user stream
|
||||
CUDA_VERIFY(cudaStreamDestroy(it->second));
|
||||
it->second = stream;
|
||||
}
|
||||
|
||||
userDefaultStreams_[device] = stream;
|
||||
}
|
||||
|
||||
void
|
||||
StandardGpuResourcesImpl::revertDefaultStream(int device) {
|
||||
userDefaultStreams_.erase(device);
|
||||
}
|
||||
|
||||
void
|
||||
StandardGpuResourcesImpl::setDefaultNullStreamAllDevices() {
|
||||
for (int dev = 0; dev < getNumDevices(); ++dev) {
|
||||
|
@ -274,14 +268,8 @@ StandardGpuResourcesImpl::initializeForDevice(int device) {
|
|||
|
||||
// Create streams
|
||||
cudaStream_t defaultStream = 0;
|
||||
auto it = userDefaultStreams_.find(device);
|
||||
if (it != userDefaultStreams_.end()) {
|
||||
// We already have a stream provided by the user
|
||||
defaultStream = it->second;
|
||||
} else {
|
||||
CUDA_VERIFY(cudaStreamCreateWithFlags(&defaultStream,
|
||||
cudaStreamNonBlocking));
|
||||
}
|
||||
CUDA_VERIFY(cudaStreamCreateWithFlags(&defaultStream,
|
||||
cudaStreamNonBlocking));
|
||||
|
||||
defaultStreams_[device] = defaultStream;
|
||||
|
||||
|
@ -341,6 +329,14 @@ StandardGpuResourcesImpl::getBlasHandle(int device) {
|
|||
cudaStream_t
|
||||
StandardGpuResourcesImpl::getDefaultStream(int device) {
|
||||
initializeForDevice(device);
|
||||
|
||||
auto it = userDefaultStreams_.find(device);
|
||||
if (it != userDefaultStreams_.end()) {
|
||||
// There is a user override stream set
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Otherwise, our base default stream
|
||||
return defaultStreams_[device];
|
||||
}
|
||||
|
||||
|
@ -539,6 +535,11 @@ StandardGpuResources::setDefaultStream(int device, cudaStream_t stream) {
|
|||
res_->setDefaultStream(device, stream);
|
||||
}
|
||||
|
||||
void
|
||||
StandardGpuResources::revertDefaultStream(int device) {
|
||||
res_->revertDefaultStream(device);
|
||||
}
|
||||
|
||||
void
|
||||
StandardGpuResources::setDefaultNullStreamAllDevices() {
|
||||
res_->setDefaultNullStreamAllDevices();
|
||||
|
|
|
@ -41,9 +41,23 @@ class StandardGpuResourcesImpl : public GpuResources {
|
|||
/// transfers
|
||||
void setPinnedMemory(size_t size);
|
||||
|
||||
/// Called to change the stream for work ordering
|
||||
/// Called to change the stream for work ordering. We do not own `stream`;
|
||||
/// i.e., it will not be destroyed when the GpuResources object gets cleaned
|
||||
/// up.
|
||||
/// We are guaranteed that all Faiss GPU work is ordered with respect to
|
||||
/// this stream upon exit from an index or other Faiss GPU call.
|
||||
void setDefaultStream(int device, cudaStream_t stream);
|
||||
|
||||
/// Revert the default stream to the original stream managed by this resources
|
||||
/// object, in case someone called `setDefaultStream`.
|
||||
void revertDefaultStream(int device);
|
||||
|
||||
/// Returns the stream for the given device on which all Faiss GPU work is
|
||||
/// ordered.
|
||||
/// We are guaranteed that all Faiss GPU work is ordered with respect to
|
||||
/// this stream upon exit from an index or other Faiss GPU call.
|
||||
cudaStream_t getDefaultStream(int device) override;
|
||||
|
||||
/// Called to change the work ordering streams to the null stream
|
||||
/// for all devices
|
||||
void setDefaultNullStreamAllDevices();
|
||||
|
@ -60,8 +74,6 @@ class StandardGpuResourcesImpl : public GpuResources {
|
|||
|
||||
cublasHandle_t getBlasHandle(int device) override;
|
||||
|
||||
cudaStream_t getDefaultStream(int device) override;
|
||||
|
||||
std::vector<cudaStream_t> getAlternateStreams(int device) override;
|
||||
|
||||
/// Allocate non-temporary GPU memory
|
||||
|
@ -128,7 +140,9 @@ class StandardGpuResourcesImpl : public GpuResources {
|
|||
};
|
||||
|
||||
/// Default implementation of GpuResources that allocates a cuBLAS
|
||||
/// stream and 2 streams for use, as well as temporary memory
|
||||
/// stream and 2 streams for use, as well as temporary memory.
|
||||
/// Internally, the Faiss GPU code uses the instance managed by getResources,
|
||||
/// but this is the user-facing object that is internally reference counted.
|
||||
class StandardGpuResources : public GpuResourcesProvider {
|
||||
public:
|
||||
StandardGpuResources();
|
||||
|
@ -151,9 +165,17 @@ class StandardGpuResources : public GpuResourcesProvider {
|
|||
/// transfers
|
||||
void setPinnedMemory(size_t size);
|
||||
|
||||
/// Called to change the stream for work ordering
|
||||
/// Called to change the stream for work ordering. We do not own `stream`;
|
||||
/// i.e., it will not be destroyed when the GpuResources object gets cleaned
|
||||
/// up.
|
||||
/// We are guaranteed that all Faiss GPU work is ordered with respect to
|
||||
/// this stream upon exit from an index or other Faiss GPU call.
|
||||
void setDefaultStream(int device, cudaStream_t stream);
|
||||
|
||||
/// Revert the default stream to the original stream managed by this resources
|
||||
/// object, in case someone called `setDefaultStream`.
|
||||
void revertDefaultStream(int device);
|
||||
|
||||
/// Called to change the work ordering streams to the null stream
|
||||
/// for all devices
|
||||
void setDefaultNullStreamAllDevices();
|
||||
|
|
|
@ -10,7 +10,7 @@ import unittest
|
|||
import faiss
|
||||
import torch
|
||||
|
||||
from faiss.contrib.pytorch_tensors import search_index_pytorch, search_raw_array_pytorch
|
||||
from faiss.contrib.pytorch_tensors import search_index_pytorch, search_raw_array_pytorch, using_stream
|
||||
|
||||
def to_column_major(x):
|
||||
if hasattr(torch, 'contiguous_format'):
|
||||
|
@ -22,40 +22,38 @@ def to_column_major(x):
|
|||
class PytorchFaissInterop(unittest.TestCase):
|
||||
|
||||
def test_interop(self):
|
||||
|
||||
d = 16
|
||||
nq = 5
|
||||
nb = 20
|
||||
d = 128
|
||||
nq = 100
|
||||
nb = 1000
|
||||
k = 10
|
||||
|
||||
xq = faiss.randn(nq * d, 1234).reshape(nq, d)
|
||||
xb = faiss.randn(nb * d, 1235).reshape(nb, d)
|
||||
|
||||
res = faiss.StandardGpuResources()
|
||||
index = faiss.GpuIndexFlatIP(res, d)
|
||||
index.add(xb)
|
||||
|
||||
# reference CPU result
|
||||
Dref, Iref = index.search(xq, 5)
|
||||
# Let's run on a non-default stream
|
||||
s = torch.cuda.Stream()
|
||||
|
||||
# query is pytorch tensor (CPU)
|
||||
xq_torch = torch.FloatTensor(xq)
|
||||
# Torch will run on this stream
|
||||
with torch.cuda.stream(s):
|
||||
# query is pytorch tensor (CPU and GPU)
|
||||
xq_torch_cpu = torch.FloatTensor(xq)
|
||||
xq_torch_gpu = xq_torch_cpu.cuda()
|
||||
|
||||
D2, I2 = search_index_pytorch(index, xq_torch, 5)
|
||||
index = faiss.GpuIndexFlatIP(res, d)
|
||||
index.add(xb)
|
||||
|
||||
assert np.all(Iref == I2.numpy())
|
||||
# Query with GPU tensor (this will be done on the current pytorch stream)
|
||||
D2, I2 = search_index_pytorch(res, index, xq_torch_gpu, k)
|
||||
Dref, Iref = index.search(xq, k)
|
||||
|
||||
# query is pytorch tensor (GPU)
|
||||
xq_torch = xq_torch.cuda()
|
||||
# no need for a sync here
|
||||
assert np.all(Iref == I2.cpu().numpy())
|
||||
|
||||
D3, I3 = search_index_pytorch(index, xq_torch, 5)
|
||||
# Query with CPU tensor
|
||||
D3, I3 = search_index_pytorch(res, index, xq_torch_cpu, k)
|
||||
|
||||
# D3 and I3 are on torch tensors on GPU as well.
|
||||
# this does a sync, which is useful because faiss and
|
||||
# pytorch use different Cuda streams.
|
||||
res.syncDefaultStreamCurrentDevice()
|
||||
|
||||
assert np.all(Iref == I3.cpu().numpy())
|
||||
assert np.all(Iref == I3.numpy())
|
||||
|
||||
def test_raw_array_search(self):
|
||||
d = 32
|
||||
|
@ -74,55 +72,51 @@ class PytorchFaissInterop(unittest.TestCase):
|
|||
|
||||
# resource object, can be re-used over calls
|
||||
res = faiss.StandardGpuResources()
|
||||
# put on same stream as pytorch to avoid synchronizing streams
|
||||
res.setDefaultNullStreamAllDevices()
|
||||
|
||||
for xq_row_major in True, False:
|
||||
for xb_row_major in True, False:
|
||||
# Let's have pytorch use a non-default stream
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
for xq_row_major in True, False:
|
||||
for xb_row_major in True, False:
|
||||
|
||||
# move to pytorch & GPU
|
||||
xq_t = torch.from_numpy(xq).cuda()
|
||||
xb_t = torch.from_numpy(xb).cuda()
|
||||
# move to pytorch & GPU
|
||||
xq_t = torch.from_numpy(xq).cuda()
|
||||
xb_t = torch.from_numpy(xb).cuda()
|
||||
|
||||
if not xq_row_major:
|
||||
xq_t = to_column_major(xq_t)
|
||||
assert not xq_t.is_contiguous()
|
||||
|
||||
if not xb_row_major:
|
||||
xb_t = to_column_major(xb_t)
|
||||
assert not xb_t.is_contiguous()
|
||||
|
||||
D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)
|
||||
|
||||
# back to CPU for verification
|
||||
D = D.cpu().numpy()
|
||||
I = I.cpu().numpy()
|
||||
|
||||
assert np.all(I == gt_I)
|
||||
assert np.all(np.abs(D - gt_D).max() < 1e-4)
|
||||
|
||||
|
||||
|
||||
# test on subset
|
||||
try:
|
||||
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
|
||||
except TypeError:
|
||||
if not xq_row_major:
|
||||
# then it is expected
|
||||
continue
|
||||
# otherwise it is an error
|
||||
raise
|
||||
xq_t = to_column_major(xq_t)
|
||||
assert not xq_t.is_contiguous()
|
||||
|
||||
# back to CPU for verification
|
||||
D = D.cpu().numpy()
|
||||
I = I.cpu().numpy()
|
||||
if not xb_row_major:
|
||||
xb_t = to_column_major(xb_t)
|
||||
assert not xb_t.is_contiguous()
|
||||
|
||||
assert np.all(I == gt_I[60:80])
|
||||
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
|
||||
D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)
|
||||
|
||||
# back to CPU for verification
|
||||
D = D.cpu().numpy()
|
||||
I = I.cpu().numpy()
|
||||
|
||||
assert np.all(I == gt_I)
|
||||
assert np.all(np.abs(D - gt_D).max() < 1e-4)
|
||||
|
||||
# test on subset
|
||||
try:
|
||||
# This internally uses the current pytorch stream
|
||||
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
|
||||
except TypeError:
|
||||
if not xq_row_major:
|
||||
# then it is expected
|
||||
continue
|
||||
# otherwise it is an error
|
||||
raise
|
||||
|
||||
# back to CPU for verification
|
||||
D = D.cpu().numpy()
|
||||
I = I.cpu().numpy()
|
||||
|
||||
assert np.all(I == gt_I[60:80])
|
||||
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -274,7 +274,6 @@ void gpu_sync_all_devices()
|
|||
|
||||
%}
|
||||
|
||||
|
||||
%template() std::pair<int, unsigned long>;
|
||||
%template() std::map<std::string, std::pair<int, unsigned long> >;
|
||||
%template() std::map<int, std::map<std::string, std::pair<int, unsigned long> > >;
|
||||
|
@ -287,6 +286,21 @@ void gpu_sync_all_devices()
|
|||
%include <faiss/gpu/GpuResources.h>
|
||||
%include <faiss/gpu/StandardGpuResources.h>
|
||||
|
||||
typedef CUstream_st* cudaStream_t;
|
||||
|
||||
%inline %{
|
||||
|
||||
// interop between pytorch exposed cudaStream_t and faiss
|
||||
cudaStream_t cast_integer_to_cudastream_t(long x) {
|
||||
return (cudaStream_t) x;
|
||||
}
|
||||
|
||||
long cast_cudastream_t_to_integer(cudaStream_t x) {
|
||||
return (long) x;
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
#else
|
||||
|
||||
%{
|
||||
|
|
Loading…
Reference in New Issue