Improve Faiss / PyTorch GPU interoperability

Summary:
PyTorch GPU in general is free to use whatever stream it currently wants, based on `torch.cuda.current_stream()`. Due to C++/python language barrier issues, we couldn't previously pass the actual `cudaStream_t` that is currently in use on a given device from PyTorch C++ to Faiss C++ via python.

This diff adds conversion functions to convert a Python integer representing a pointer to a `cudaStream_t` (which is itself a `CUstream_st*`), so we can pass the stream specified in `torch.cuda.current_stream()` to `StandardGpuResources::setDefaultStream`. We thus guarantee that all Faiss work is ordered on the same stream that is in use in PyTorch.

For use in Python, there is now the `faiss.contrib.pytorch_tensors.using_stream` context object which automatically sets and unsets the current PyTorch stream within Faiss. This takes a `StandardGpuResources` object in Python, and an optional `torch.cuda.Stream` if one wants to use a different stream, otherwise it uses the current one. This is how it is used:

```
# Create a non-default stream
s = torch.cuda.Stream()

# Have Torch use it
with torch.cuda.stream(s):

  # Have Faiss use the same stream as the above
  with faiss.contrib.pytorch_tensors.using_stream(res):
    # Do some work on the GPU
    faiss.bfKnn(res, args)
```

`using_stream` uses the same pattern as the Pytorch `torch.cuda.stream` object.

This replaces any brute-force GPU/CPU synchronization work that was necessary before.

Other changes in this diff:
- cleans up the config objects in the GpuIndex subclasses, to distinguish between read-only parameters that can only be set upon index construction, versus those that can be changed at runtime.
- StandardGpuResources now more properly distinguishes between user-supplied streams (like the PyTorch one) which will not be destroyed upon resources destruction, versus internal streams.
- `search_index_pytorch` now needs to take a `StandardGpuResources` object as well, there is no way to get this from an index instance otherwise (or at least, I would have to return a `shared_ptr`, in which case we should just update the Python SWIG stuff to use `shared_ptr` for GpuResources or something.

Reviewed By: mdouze

Differential Revision: D24260026

fbshipit-source-id: b18bb0eb34eb012584b1c923088228776c10b720
pull/1462/head
Jeff Johnson 2020-10-13 09:09:52 -07:00 committed by Facebook GitHub Bot
parent b459931ae4
commit e796f4f9df
19 changed files with 316 additions and 242 deletions

View File

@ -1,5 +1,6 @@
import faiss import faiss
import torch import torch
import contextlib
def swig_ptr_from_FloatTensor(x): def swig_ptr_from_FloatTensor(x):
""" gets a Faiss SWIG pointer from a pytorch trensor (on CPU or GPU) """ """ gets a Faiss SWIG pointer from a pytorch trensor (on CPU or GPU) """
@ -15,9 +16,33 @@ def swig_ptr_from_LongTensor(x):
return faiss.cast_integer_to_long_ptr( return faiss.cast_integer_to_long_ptr(
x.storage().data_ptr() + x.storage_offset() * 8) x.storage().data_ptr() + x.storage_offset() * 8)
@contextlib.contextmanager
def using_stream(res, pytorch_stream=None):
r""" Creates a scoping object to make Faiss GPU use the same stream
as pytorch, based on torch.cuda.current_stream().
Or, a specific pytorch stream can be passed in as a second
argument, in which case we will use that stream.
"""
if pytorch_stream is None:
pytorch_stream = torch.cuda.current_stream()
def search_index_pytorch(index, x, k, D=None, I=None): # This is the cudaStream_t that we wish to use
cuda_stream_s = faiss.cast_integer_to_cudastream_t(pytorch_stream.cuda_stream)
# So we can revert GpuResources stream state upon exit
prior_dev = torch.cuda.current_device()
prior_stream = res.getDefaultStream(torch.cuda.current_device())
res.setDefaultStream(torch.cuda.current_device(), cuda_stream_s)
# Do the user work
try:
yield
finally:
res.setDefaultStream(prior_dev, prior_stream)
def search_index_pytorch(res, index, x, k, D=None, I=None):
"""call the search function of an index with pytorch tensor I/O (CPU """call the search function of an index with pytorch tensor I/O (CPU
and GPU supported)""" and GPU supported)"""
assert x.is_contiguous() assert x.is_contiguous()
@ -33,13 +58,13 @@ def search_index_pytorch(index, x, k, D=None, I=None):
I = torch.empty((n, k), dtype=torch.int64, device=x.device) I = torch.empty((n, k), dtype=torch.int64, device=x.device)
else: else:
assert I.size() == (n, k) assert I.size() == (n, k)
torch.cuda.synchronize()
xptr = swig_ptr_from_FloatTensor(x) with using_stream(res):
Iptr = swig_ptr_from_LongTensor(I) xptr = swig_ptr_from_FloatTensor(x)
Dptr = swig_ptr_from_FloatTensor(D) Iptr = swig_ptr_from_LongTensor(I)
index.search_c(n, xptr, Dptr = swig_ptr_from_FloatTensor(D)
k, Dptr, Iptr) index.search_c(n, xptr, k, Dptr, Iptr)
torch.cuda.synchronize()
return D, I return D, I
@ -97,6 +122,8 @@ def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,
args.numQueries = nq args.numQueries = nq
args.outDistances = D_ptr args.outDistances = D_ptr
args.outIndices = I_ptr args.outIndices = I_ptr
faiss.bfKnn(res, args)
with using_stream(res):
faiss.bfKnn(res, args)
return D, I return D, I

View File

@ -42,25 +42,29 @@ GpuIndex::GpuIndex(std::shared_ptr<GpuResources> resources,
GpuIndexConfig config) : GpuIndexConfig config) :
Index(dims, metric), Index(dims, metric),
resources_(resources), resources_(resources),
device_(config.device), config_(config),
memorySpace_(config.memorySpace),
minPagedSize_(kMinPageSize) { minPagedSize_(kMinPageSize) {
FAISS_THROW_IF_NOT_FMT(device_ < getNumDevices(), FAISS_THROW_IF_NOT_FMT(config_.device < getNumDevices(),
"Invalid GPU device %d", device_); "Invalid GPU device %d", config_.device);
FAISS_THROW_IF_NOT_MSG(dims > 0, "Invalid number of dimensions"); FAISS_THROW_IF_NOT_MSG(dims > 0, "Invalid number of dimensions");
FAISS_THROW_IF_NOT_FMT( FAISS_THROW_IF_NOT_FMT(
memorySpace_ == MemorySpace::Device || config_.memorySpace == MemorySpace::Device ||
(memorySpace_ == MemorySpace::Unified && (config_.memorySpace == MemorySpace::Unified &&
getFullUnifiedMemSupport(device_)), getFullUnifiedMemSupport(config_.device)),
"Device %d does not support full CUDA 8 Unified Memory (CC 6.0+)", "Device %d does not support full CUDA 8 Unified Memory (CC 6.0+)",
config.device); config.device);
metric_arg = metricArg; metric_arg = metricArg;
FAISS_ASSERT((bool) resources_); FAISS_ASSERT((bool) resources_);
resources_->initializeForDevice(device_); resources_->initializeForDevice(config_.device);
}
int
GpuIndex::getDevice() const {
return config_.device;
} }
void void
@ -124,7 +128,7 @@ GpuIndex::add_with_ids(Index::idx_t n,
} }
} }
DeviceScope scope(device_); DeviceScope scope(config_.device);
addPaged_((int) n, x, ids ? ids : generatedIds.data()); addPaged_((int) n, x, ids ? ids : generatedIds.data());
} }
@ -171,7 +175,7 @@ GpuIndex::addPage_(int n,
auto vecs = auto vecs =
toDeviceTemporary<float, 2>(resources_.get(), toDeviceTemporary<float, 2>(resources_.get(),
device_, config_.device,
const_cast<float*>(x), const_cast<float*>(x),
stream, stream,
{n, this->d}); {n, this->d});
@ -179,7 +183,7 @@ GpuIndex::addPage_(int n,
if (ids) { if (ids) {
auto indices = auto indices =
toDeviceTemporary<Index::idx_t, 1>(resources_.get(), toDeviceTemporary<Index::idx_t, 1>(resources_.get(),
device_, config_.device,
const_cast<Index::idx_t*>(ids), const_cast<Index::idx_t*>(ids),
stream, stream,
{n}); {n});
@ -214,8 +218,8 @@ GpuIndex::search(Index::idx_t n,
return; return;
} }
DeviceScope scope(device_); DeviceScope scope(config_.device);
auto stream = resources_->getDefaultStream(device_); auto stream = resources_->getDefaultStream(config_.device);
// We guarantee that the searchImpl_ will be called with device-resident // We guarantee that the searchImpl_ will be called with device-resident
// pointers. // pointers.
@ -228,12 +232,12 @@ GpuIndex::search(Index::idx_t n,
// another level of tiling. // another level of tiling.
auto outDistances = auto outDistances =
toDeviceTemporary<float, 2>( toDeviceTemporary<float, 2>(
resources_.get(), device_, distances, stream, resources_.get(), config_.device, distances, stream,
{(int) n, (int) k}); {(int) n, (int) k});
auto outLabels = auto outLabels =
toDeviceTemporary<faiss::Index::idx_t, 2>( toDeviceTemporary<faiss::Index::idx_t, 2>(
resources_.get(), device_, labels, stream, resources_.get(), config_.device, labels, stream,
{(int) n, (int) k}); {(int) n, (int) k});
bool usePaged = false; bool usePaged = false;
@ -272,12 +276,12 @@ GpuIndex::searchNonPaged_(int n,
int k, int k,
float* outDistancesData, float* outDistancesData,
Index::idx_t* outIndicesData) const { Index::idx_t* outIndicesData) const {
auto stream = resources_->getDefaultStream(device_); auto stream = resources_->getDefaultStream(config_.device);
// Make sure arguments are on the device we desire; use temporary // Make sure arguments are on the device we desire; use temporary
// memory allocations to move it if necessary // memory allocations to move it if necessary
auto vecs = toDeviceTemporary<float, 2>(resources_.get(), auto vecs = toDeviceTemporary<float, 2>(resources_.get(),
device_, config_.device,
const_cast<float*>(x), const_cast<float*>(x),
stream, stream,
{n, (int) this->d}); {n, (int) this->d});
@ -335,8 +339,8 @@ GpuIndex::searchFromCpuPaged_(int n,
// 1 2 3 1 ... (pinned buf A) // 1 2 3 1 ... (pinned buf A)
// time -> // time ->
// //
auto defaultStream = resources_->getDefaultStream(device_); auto defaultStream = resources_->getDefaultStream(config_.device);
auto copyStream = resources_->getAsyncCopyStream(device_); auto copyStream = resources_->getAsyncCopyStream(config_.device);
FAISS_ASSERT((size_t) pageSizeInVecs * this->d <= FAISS_ASSERT((size_t) pageSizeInVecs * this->d <=
(size_t) std::numeric_limits<int>::max()); (size_t) std::numeric_limits<int>::max());

View File

@ -36,9 +36,8 @@ class GpuIndex : public faiss::Index {
float metricArg, float metricArg,
GpuIndexConfig config); GpuIndexConfig config);
inline int getDevice() const { /// Returns the device that this index is resident on
return device_; int getDevice() const;
}
/// Set the minimum data size for searches (in MiB) for which we use /// Set the minimum data size for searches (in MiB) for which we use
/// CPU -> GPU paging /// CPU -> GPU paging
@ -136,11 +135,8 @@ private:
/// Manages streams, cuBLAS handles and scratch memory for devices /// Manages streams, cuBLAS handles and scratch memory for devices
std::shared_ptr<GpuResources> resources_; std::shared_ptr<GpuResources> resources_;
/// The GPU device we are resident on /// Our configuration options
const int device_; const GpuIndexConfig config_;
/// The memory space of our primary storage on the GPU
const MemorySpace memorySpace_;
/// Size above which we page copies from the CPU to GPU /// Size above which we page copies from the CPU to GPU
size_t minPagedSize_; size_t minPagedSize_;

View File

@ -23,7 +23,7 @@ GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResourcesProvider* provider,
GpuIndexBinaryFlatConfig config) GpuIndexBinaryFlatConfig config)
: IndexBinary(index->d), : IndexBinary(index->d),
resources_(provider->getResources()), resources_(provider->getResources()),
config_(std::move(config)) { binaryFlatConfig_(config) {
FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0, FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0,
"vector dimension (number of bits) " "vector dimension (number of bits) "
"must be divisible by 8 (passed %d)", "must be divisible by 8 (passed %d)",
@ -41,7 +41,7 @@ GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResourcesProvider* provider,
GpuIndexBinaryFlatConfig config) GpuIndexBinaryFlatConfig config)
: IndexBinary(dims), : IndexBinary(dims),
resources_(provider->getResources()), resources_(provider->getResources()),
config_(std::move(config)) { binaryFlatConfig_(std::move(config)) {
FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0, FAISS_THROW_IF_NOT_FMT(this->d % 8 == 0,
"vector dimension (number of bits) " "vector dimension (number of bits) "
"must be divisible by 8 (passed %d)", "must be divisible by 8 (passed %d)",
@ -51,17 +51,22 @@ GpuIndexBinaryFlat::GpuIndexBinaryFlat(GpuResourcesProvider* provider,
this->is_trained = true; this->is_trained = true;
// Construct index // Construct index
DeviceScope scope(config_.device); DeviceScope scope(binaryFlatConfig_.device);
data_.reset( data_.reset(new BinaryFlatIndex(resources_.get(),
new BinaryFlatIndex(resources_.get(), this->d, config_.memorySpace)); this->d, binaryFlatConfig_.memorySpace));
} }
GpuIndexBinaryFlat::~GpuIndexBinaryFlat() { GpuIndexBinaryFlat::~GpuIndexBinaryFlat() {
} }
int
GpuIndexBinaryFlat::getDevice() const {
return binaryFlatConfig_.device;
}
void void
GpuIndexBinaryFlat::copyFrom(const faiss::IndexBinaryFlat* index) { GpuIndexBinaryFlat::copyFrom(const faiss::IndexBinaryFlat* index) {
DeviceScope scope(config_.device); DeviceScope scope(binaryFlatConfig_.device);
this->d = index->d; this->d = index->d;
@ -76,20 +81,20 @@ GpuIndexBinaryFlat::copyFrom(const faiss::IndexBinaryFlat* index) {
// destroy old first before allocating new // destroy old first before allocating new
data_.reset(); data_.reset();
data_.reset( data_.reset(new BinaryFlatIndex(resources_.get(),
new BinaryFlatIndex(resources_.get(), this->d, config_.memorySpace)); this->d, binaryFlatConfig_.memorySpace));
// The index could be empty // The index could be empty
if (index->ntotal > 0) { if (index->ntotal > 0) {
data_->add(index->xb.data(), data_->add(index->xb.data(),
index->ntotal, index->ntotal,
resources_->getDefaultStream(config_.device)); resources_->getDefaultStream(binaryFlatConfig_.device));
} }
} }
void void
GpuIndexBinaryFlat::copyTo(faiss::IndexBinaryFlat* index) const { GpuIndexBinaryFlat::copyTo(faiss::IndexBinaryFlat* index) const {
DeviceScope scope(config_.device); DeviceScope scope(binaryFlatConfig_.device);
index->d = this->d; index->d = this->d;
index->ntotal = this->ntotal; index->ntotal = this->ntotal;
@ -101,18 +106,18 @@ GpuIndexBinaryFlat::copyTo(faiss::IndexBinaryFlat* index) const {
if (this->ntotal > 0) { if (this->ntotal > 0) {
fromDevice(data_->getVectorsRef(), fromDevice(data_->getVectorsRef(),
index->xb.data(), index->xb.data(),
resources_->getDefaultStream(config_.device)); resources_->getDefaultStream(binaryFlatConfig_.device));
} }
} }
void void
GpuIndexBinaryFlat::add(faiss::IndexBinary::idx_t n, GpuIndexBinaryFlat::add(faiss::IndexBinary::idx_t n,
const uint8_t* x) { const uint8_t* x) {
DeviceScope scope(config_.device); DeviceScope scope(binaryFlatConfig_.device);
// To avoid multiple re-allocations, ensure we have enough storage // To avoid multiple re-allocations, ensure we have enough storage
// available // available
data_->reserve(n, resources_->getDefaultStream(config_.device)); data_->reserve(n, resources_->getDefaultStream(binaryFlatConfig_.device));
// Due to GPU indexing in int32, we can't store more than this // Due to GPU indexing in int32, we can't store more than this
// number of vectors on a GPU // number of vectors on a GPU
@ -123,13 +128,13 @@ GpuIndexBinaryFlat::add(faiss::IndexBinary::idx_t n,
data_->add((const unsigned char*) x, data_->add((const unsigned char*) x,
n, n,
resources_->getDefaultStream(config_.device)); resources_->getDefaultStream(binaryFlatConfig_.device));
this->ntotal += n; this->ntotal += n;
} }
void void
GpuIndexBinaryFlat::reset() { GpuIndexBinaryFlat::reset() {
DeviceScope scope(config_.device); DeviceScope scope(binaryFlatConfig_.device);
// Free the underlying memory // Free the underlying memory
data_->reset(); data_->reset();
@ -155,8 +160,8 @@ GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n,
getMaxKSelection(), getMaxKSelection(),
(int) k); // select limitation (int) k); // select limitation
DeviceScope scope(config_.device); DeviceScope scope(binaryFlatConfig_.device);
auto stream = resources_->getDefaultStream(config_.device); auto stream = resources_->getDefaultStream(binaryFlatConfig_.device);
// The input vectors may be too large for the GPU, but we still // The input vectors may be too large for the GPU, but we still
// assume that the output distances and labels are not. // assume that the output distances and labels are not.
@ -165,7 +170,7 @@ GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n,
// If we reach a point where all inputs are too big, we can add // If we reach a point where all inputs are too big, we can add
// another level of tiling. // another level of tiling.
auto outDistances = toDeviceTemporary<int32_t, 2>(resources_.get(), auto outDistances = toDeviceTemporary<int32_t, 2>(resources_.get(),
config_.device, binaryFlatConfig_.device,
distances, distances,
stream, stream,
{(int) n, (int) k}); {(int) n, (int) k});
@ -203,7 +208,7 @@ GpuIndexBinaryFlat::search(faiss::IndexBinary::idx_t n,
// Convert and copy int indices out // Convert and copy int indices out
auto outIndices = auto outIndices =
toDeviceTemporary<faiss::Index::idx_t, 2>(resources_.get(), toDeviceTemporary<faiss::Index::idx_t, 2>(resources_.get(),
config_.device, binaryFlatConfig_.device,
labels, labels,
stream, stream,
{(int) n, (int) k}); {(int) n, (int) k});
@ -227,12 +232,12 @@ GpuIndexBinaryFlat::searchNonPaged_(int n,
Tensor<int32_t, 2, true> outDistances(outDistancesData, {n, k}); Tensor<int32_t, 2, true> outDistances(outDistancesData, {n, k});
Tensor<int, 2, true> outIndices(outIndicesData, {n, k}); Tensor<int, 2, true> outIndices(outIndicesData, {n, k});
auto stream = resources_->getDefaultStream(config_.device); auto stream = resources_->getDefaultStream(binaryFlatConfig_.device);
// Make sure arguments are on the device we desire; use temporary // Make sure arguments are on the device we desire; use temporary
// memory allocations to move it if necessary // memory allocations to move it if necessary
auto vecs = toDeviceTemporary<uint8_t, 2>(resources_.get(), auto vecs = toDeviceTemporary<uint8_t, 2>(resources_.get(),
config_.device, binaryFlatConfig_.device,
const_cast<uint8_t*>(x), const_cast<uint8_t*>(x),
stream, stream,
{n, (int) (this->d / 8)}); {n, (int) (this->d / 8)});
@ -272,10 +277,10 @@ GpuIndexBinaryFlat::searchFromCpuPaged_(int n,
void void
GpuIndexBinaryFlat::reconstruct(faiss::IndexBinary::idx_t key, GpuIndexBinaryFlat::reconstruct(faiss::IndexBinary::idx_t key,
uint8_t* out) const { uint8_t* out) const {
DeviceScope scope(config_.device); DeviceScope scope(binaryFlatConfig_.device);
FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds"); FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds");
auto stream = resources_->getDefaultStream(config_.device); auto stream = resources_->getDefaultStream(binaryFlatConfig_.device);
auto& vecs = data_->getVectorsRef(); auto& vecs = data_->getVectorsRef();
auto vec = vecs[key]; auto vec = vecs[key];

View File

@ -38,6 +38,9 @@ class GpuIndexBinaryFlat : public IndexBinary {
~GpuIndexBinaryFlat() override; ~GpuIndexBinaryFlat() override;
/// Returns the device that this index is resident on
int getDevice() const;
/// Initialize ourselves from the given CPU index; will overwrite /// Initialize ourselves from the given CPU index; will overwrite
/// all data in ourselves /// all data in ourselves
void copyFrom(const faiss::IndexBinaryFlat* index); void copyFrom(const faiss::IndexBinaryFlat* index);
@ -80,7 +83,7 @@ class GpuIndexBinaryFlat : public IndexBinary {
std::shared_ptr<GpuResources> resources_; std::shared_ptr<GpuResources> resources_;
/// Configuration options /// Configuration options
GpuIndexBinaryFlatConfig config_; const GpuIndexBinaryFlatConfig binaryFlatConfig_;
/// Holds our GPU data containing the list of vectors /// Holds our GPU data containing the list of vectors
std::unique_ptr<BinaryFlatIndex> data_; std::unique_ptr<BinaryFlatIndex> data_;

View File

@ -27,7 +27,7 @@ GpuIndexFlat::GpuIndexFlat(GpuResourcesProvider* provider,
index->metric_type, index->metric_type,
index->metric_arg, index->metric_arg,
config), config),
config_(std::move(config)) { flatConfig_(config) {
// Flat index doesn't need training // Flat index doesn't need training
this->is_trained = true; this->is_trained = true;
@ -42,7 +42,7 @@ GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
index->metric_type, index->metric_type,
index->metric_arg, index->metric_arg,
config), config),
config_(std::move(config)) { flatConfig_(config) {
// Flat index doesn't need training // Flat index doesn't need training
this->is_trained = true; this->is_trained = true;
@ -58,17 +58,17 @@ GpuIndexFlat::GpuIndexFlat(GpuResourcesProvider* provider,
metric, metric,
0, 0,
config), config),
config_(std::move(config)) { flatConfig_(config) {
// Flat index doesn't need training // Flat index doesn't need training
this->is_trained = true; this->is_trained = true;
// Construct index // Construct index
DeviceScope scope(device_); DeviceScope scope(config_.device);
data_.reset(new FlatIndex(resources_.get(), data_.reset(new FlatIndex(resources_.get(),
dims, dims,
config_.useFloat16, flatConfig_.useFloat16,
config_.storeTransposed, flatConfig_.storeTransposed,
memorySpace_)); config_.memorySpace));
} }
GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources, GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
@ -76,17 +76,17 @@ GpuIndexFlat::GpuIndexFlat(std::shared_ptr<GpuResources> resources,
faiss::MetricType metric, faiss::MetricType metric,
GpuIndexFlatConfig config) : GpuIndexFlatConfig config) :
GpuIndex(resources, dims, metric, 0, config), GpuIndex(resources, dims, metric, 0, config),
config_(std::move(config)) { flatConfig_(config) {
// Flat index doesn't need training // Flat index doesn't need training
this->is_trained = true; this->is_trained = true;
// Construct index // Construct index
DeviceScope scope(device_); DeviceScope scope(config_.device);
data_.reset(new FlatIndex(resources_.get(), data_.reset(new FlatIndex(resources_.get(),
dims, dims,
config_.useFloat16, flatConfig_.useFloat16,
config_.storeTransposed, flatConfig_.storeTransposed,
memorySpace_)); config_.memorySpace));
} }
GpuIndexFlat::~GpuIndexFlat() { GpuIndexFlat::~GpuIndexFlat() {
@ -94,7 +94,7 @@ GpuIndexFlat::~GpuIndexFlat() {
void void
GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) { GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
GpuIndex::copyFrom(index); GpuIndex::copyFrom(index);
@ -109,21 +109,21 @@ GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) {
data_.reset(); data_.reset();
data_.reset(new FlatIndex(resources_.get(), data_.reset(new FlatIndex(resources_.get(),
this->d, this->d,
config_.useFloat16, flatConfig_.useFloat16,
config_.storeTransposed, flatConfig_.storeTransposed,
memorySpace_)); config_.memorySpace));
// The index could be empty // The index could be empty
if (index->ntotal > 0) { if (index->ntotal > 0) {
data_->add(index->xb.data(), data_->add(index->xb.data(),
index->ntotal, index->ntotal,
resources_->getDefaultStream(device_)); resources_->getDefaultStream(config_.device));
} }
} }
void void
GpuIndexFlat::copyTo(faiss::IndexFlat* index) const { GpuIndexFlat::copyTo(faiss::IndexFlat* index) const {
DeviceScope scope(device_); DeviceScope scope(config_.device);
GpuIndex::copyTo(index); GpuIndex::copyTo(index);
@ -131,10 +131,10 @@ GpuIndexFlat::copyTo(faiss::IndexFlat* index) const {
FAISS_ASSERT(data_->getSize() == this->ntotal); FAISS_ASSERT(data_->getSize() == this->ntotal);
index->xb.resize(this->ntotal * this->d); index->xb.resize(this->ntotal * this->d);
auto stream = resources_->getDefaultStream(device_); auto stream = resources_->getDefaultStream(config_.device);
if (this->ntotal > 0) { if (this->ntotal > 0) {
if (config_.useFloat16) { if (flatConfig_.useFloat16) {
auto vecFloat32 = data_->getVectorsFloat32Copy(stream); auto vecFloat32 = data_->getVectorsFloat32Copy(stream);
fromDevice(vecFloat32, index->xb.data(), stream); fromDevice(vecFloat32, index->xb.data(), stream);
} else { } else {
@ -150,7 +150,7 @@ GpuIndexFlat::getNumVecs() const {
void void
GpuIndexFlat::reset() { GpuIndexFlat::reset() {
DeviceScope scope(device_); DeviceScope scope(config_.device);
// Free the underlying memory // Free the underlying memory
data_->reset(); data_->reset();
@ -176,15 +176,15 @@ GpuIndexFlat::add(Index::idx_t n, const float* x) {
return; return;
} }
DeviceScope scope(device_); DeviceScope scope(config_.device);
// To avoid multiple re-allocations, ensure we have enough storage // To avoid multiple re-allocations, ensure we have enough storage
// available // available
data_->reserve(n, resources_->getDefaultStream(device_)); data_->reserve(n, resources_->getDefaultStream(config_.device));
// If we're not operating in float16 mode, we don't need the input // If we're not operating in float16 mode, we don't need the input
// data to be resident on our device; we can add directly. // data to be resident on our device; we can add directly.
if (!config_.useFloat16) { if (!flatConfig_.useFloat16) {
addImpl_(n, x, nullptr); addImpl_(n, x, nullptr);
} else { } else {
// Otherwise, perform the paging // Otherwise, perform the paging
@ -214,7 +214,7 @@ GpuIndexFlat::addImpl_(int n,
"GPU index only supports up to %zu indices", "GPU index only supports up to %zu indices",
(size_t) std::numeric_limits<int>::max()); (size_t) std::numeric_limits<int>::max());
data_->add(x, n, resources_->getDefaultStream(device_)); data_->add(x, n, resources_->getDefaultStream(config_.device));
this->ntotal += n; this->ntotal += n;
} }
@ -224,7 +224,7 @@ GpuIndexFlat::searchImpl_(int n,
int k, int k,
float* distances, float* distances,
Index::idx_t* labels) const { Index::idx_t* labels) const {
auto stream = resources_->getDefaultStream(device_); auto stream = resources_->getDefaultStream(config_.device);
// Input and output data are already resident on the GPU // Input and output data are already resident on the GPU
Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int) this->d}); Tensor<float, 2, true> queries(const_cast<float*>(x), {n, (int) this->d});
@ -249,12 +249,12 @@ GpuIndexFlat::searchImpl_(int n,
void void
GpuIndexFlat::reconstruct(faiss::Index::idx_t key, GpuIndexFlat::reconstruct(faiss::Index::idx_t key,
float* out) const { float* out) const {
DeviceScope scope(device_); DeviceScope scope(config_.device);
FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds"); FAISS_THROW_IF_NOT_MSG(key < this->ntotal, "index out of bounds");
auto stream = resources_->getDefaultStream(device_); auto stream = resources_->getDefaultStream(config_.device);
if (config_.useFloat16) { if (flatConfig_.useFloat16) {
// FIXME jhj: kernel for copy // FIXME jhj: kernel for copy
auto vec = data_->getVectorsFloat32Copy(key, 1, stream); auto vec = data_->getVectorsFloat32Copy(key, 1, stream);
fromDevice(vec.data(), out, this->d, stream); fromDevice(vec.data(), out, this->d, stream);
@ -268,13 +268,13 @@ void
GpuIndexFlat::reconstruct_n(faiss::Index::idx_t i0, GpuIndexFlat::reconstruct_n(faiss::Index::idx_t i0,
faiss::Index::idx_t num, faiss::Index::idx_t num,
float* out) const { float* out) const {
DeviceScope scope(device_); DeviceScope scope(config_.device);
FAISS_THROW_IF_NOT_MSG(i0 < this->ntotal, "index out of bounds"); FAISS_THROW_IF_NOT_MSG(i0 < this->ntotal, "index out of bounds");
FAISS_THROW_IF_NOT_MSG(i0 + num - 1 < this->ntotal, "num out of bounds"); FAISS_THROW_IF_NOT_MSG(i0 + num - 1 < this->ntotal, "num out of bounds");
auto stream = resources_->getDefaultStream(device_); auto stream = resources_->getDefaultStream(config_.device);
if (config_.useFloat16) { if (flatConfig_.useFloat16) {
// FIXME jhj: kernel for copy // FIXME jhj: kernel for copy
auto vec = data_->getVectorsFloat32Copy(i0, num, stream); auto vec = data_->getVectorsFloat32Copy(i0, num, stream);
fromDevice(vec.data(), out, num * this->d, stream); fromDevice(vec.data(), out, num * this->d, stream);
@ -301,23 +301,24 @@ GpuIndexFlat::compute_residual_n(faiss::Index::idx_t n,
"GPU index only supports up to %zu indices", "GPU index only supports up to %zu indices",
(size_t) std::numeric_limits<int>::max()); (size_t) std::numeric_limits<int>::max());
auto stream = resources_->getDefaultStream(device_); auto stream = resources_->getDefaultStream(config_.device);
DeviceScope scope(device_); DeviceScope scope(config_.device);
auto vecsDevice = auto vecsDevice =
toDeviceTemporary<float, 2>(resources_.get(), device_, toDeviceTemporary<float, 2>(resources_.get(), config_.device,
const_cast<float*>(xs), stream, const_cast<float*>(xs), stream,
{(int) n, (int) this->d}); {(int) n, (int) this->d});
auto idsDevice = auto idsDevice =
toDeviceTemporary<faiss::Index::idx_t, 1>( toDeviceTemporary<faiss::Index::idx_t, 1>(
resources_.get(), device_, resources_.get(), config_.device,
const_cast<faiss::Index::idx_t*>(keys), const_cast<faiss::Index::idx_t*>(keys),
stream, stream,
{(int) n}); {(int) n});
auto residualDevice = auto residualDevice =
toDeviceTemporary<float, 2>(resources_.get(), device_, residuals, stream, toDeviceTemporary<float, 2>(resources_.get(),
config_.device, residuals, stream,
{(int) n, (int) this->d}); {(int) n, (int) this->d});
// Convert idx_t to int // Convert idx_t to int

View File

@ -129,8 +129,8 @@ class GpuIndexFlat : public GpuIndex {
faiss::Index::idx_t* labels) const override; faiss::Index::idx_t* labels) const override;
protected: protected:
/// Our config object /// Our configuration options
const GpuIndexFlatConfig config_; const GpuIndexFlatConfig flatConfig_;
/// Holds our GPU data containing the list of vectors /// Holds our GPU data containing the list of vectors
std::unique_ptr<FlatIndex> data_; std::unique_ptr<FlatIndex> data_;

View File

@ -22,12 +22,11 @@ GpuIndexIVF::GpuIndexIVF(GpuResourcesProvider* provider,
float metricArg, float metricArg,
int nlistIn, int nlistIn,
GpuIndexIVFConfig config) : GpuIndexIVFConfig config) :
GpuIndex(provider->getResources(), GpuIndex(provider->getResources(), dims, metric, metricArg, config),
dims, metric, metricArg, config),
nlist(nlistIn), nlist(nlistIn),
nprobe(1), nprobe(1),
quantizer(nullptr), quantizer(nullptr),
ivfConfig_(std::move(config)) { ivfConfig_(config) {
init_(); init_();
// Only IP and L2 are supported for now // Only IP and L2 are supported for now
@ -55,7 +54,7 @@ GpuIndexIVF::init_() {
// Construct an empty quantizer // Construct an empty quantizer
GpuIndexFlatConfig config = ivfConfig_.flatConfig; GpuIndexFlatConfig config = ivfConfig_.flatConfig;
// FIXME: inherit our same device // FIXME: inherit our same device
config.device = device_; config.device = config_.device;
if (metric_type == faiss::METRIC_L2) { if (metric_type == faiss::METRIC_L2) {
quantizer = new GpuIndexFlatL2(resources_, d, config); quantizer = new GpuIndexFlatL2(resources_, d, config);
@ -79,7 +78,7 @@ GpuIndexIVF::getQuantizer() {
void void
GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) { GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
GpuIndex::copyFrom(index); GpuIndex::copyFrom(index);
@ -105,7 +104,7 @@ GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
// Construct an empty quantizer // Construct an empty quantizer
GpuIndexFlatConfig config = ivfConfig_.flatConfig; GpuIndexFlatConfig config = ivfConfig_.flatConfig;
// FIXME: inherit our same device // FIXME: inherit our same device
config.device = device_; config.device = config_.device;
if (index->metric_type == faiss::METRIC_L2) { if (index->metric_type == faiss::METRIC_L2) {
// FIXME: 2 different float16 options? // FIXME: 2 different float16 options?
@ -143,7 +142,7 @@ GpuIndexIVF::copyFrom(const faiss::IndexIVF* index) {
void void
GpuIndexIVF::copyTo(faiss::IndexIVF* index) const { GpuIndexIVF::copyTo(faiss::IndexIVF* index) const {
DeviceScope scope(device_); DeviceScope scope(config_.device);
// //
// Index information // Index information
@ -228,7 +227,7 @@ GpuIndexIVF::trainQuantizer_(faiss::Index::idx_t n, const float* x) {
printf ("Training IVF quantizer on %ld vectors in %dD\n", n, d); printf ("Training IVF quantizer on %ld vectors in %dD\n", n, d);
} }
DeviceScope scope(device_); DeviceScope scope(config_.device);
// leverage the CPU-side k-means code, which works for the GPU // leverage the CPU-side k-means code, which works for the GPU
// flat index as well // flat index as well

View File

@ -83,7 +83,8 @@ class GpuIndexIVF : public GpuIndex {
GpuIndexFlat* quantizer; GpuIndexFlat* quantizer;
protected: protected:
GpuIndexIVFConfig ivfConfig_; /// Our configuration options
const GpuIndexIVFConfig ivfConfig_;
}; };
} } // namespace } } // namespace

View File

@ -57,14 +57,14 @@ void
GpuIndexIVFFlat::reserveMemory(size_t numVecs) { GpuIndexIVFFlat::reserveMemory(size_t numVecs) {
reserveMemoryVecs_ = numVecs; reserveMemoryVecs_ = numVecs;
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
index_->reserveMemory(numVecs); index_->reserveMemory(numVecs);
} }
} }
void void
GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) { GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
GpuIndexIVF::copyFrom(index); GpuIndexIVF::copyFrom(index);
@ -88,7 +88,7 @@ GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
false, // no residual false, // no residual
nullptr, // no scalar quantizer nullptr, // no scalar quantizer
ivfFlatConfig_.indicesOptions, ivfFlatConfig_.indicesOptions,
memorySpace_)); config_.memorySpace));
// Copy all of the IVF data // Copy all of the IVF data
index_->copyInvertedListsFrom(index->invlists); index_->copyInvertedListsFrom(index->invlists);
@ -96,7 +96,7 @@ GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
void void
GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const { GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const {
DeviceScope scope(device_); DeviceScope scope(config_.device);
// We must have the indices in order to copy to ourselves // We must have the indices in order to copy to ourselves
FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF, FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF,
@ -118,7 +118,7 @@ GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const {
size_t size_t
GpuIndexIVFFlat::reclaimMemory() { GpuIndexIVFFlat::reclaimMemory() {
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
return index_->reclaimMemory(); return index_->reclaimMemory();
} }
@ -129,7 +129,7 @@ GpuIndexIVFFlat::reclaimMemory() {
void void
GpuIndexIVFFlat::reset() { GpuIndexIVFFlat::reset() {
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
index_->reset(); index_->reset();
this->ntotal = 0; this->ntotal = 0;
@ -140,7 +140,7 @@ GpuIndexIVFFlat::reset() {
void void
GpuIndexIVFFlat::train(Index::idx_t n, const float* x) { GpuIndexIVFFlat::train(Index::idx_t n, const float* x) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
if (this->is_trained) { if (this->is_trained) {
FAISS_ASSERT(quantizer->is_trained); FAISS_ASSERT(quantizer->is_trained);
@ -161,7 +161,7 @@ GpuIndexIVFFlat::train(Index::idx_t n, const float* x) {
false, // no residual false, // no residual
nullptr, // no scalar quantizer nullptr, // no scalar quantizer
ivfFlatConfig_.indicesOptions, ivfFlatConfig_.indicesOptions,
memorySpace_)); config_.memorySpace));
if (reserveMemoryVecs_) { if (reserveMemoryVecs_) {
index_->reserveMemory(reserveMemoryVecs_); index_->reserveMemory(reserveMemoryVecs_);

View File

@ -73,8 +73,9 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
float* distances, float* distances,
Index::idx_t* labels) const override; Index::idx_t* labels) const override;
private: protected:
GpuIndexIVFFlatConfig ivfFlatConfig_; /// Our configuration options
const GpuIndexIVFFlatConfig ivfFlatConfig_;
/// Desired inverted list memory reservation /// Desired inverted list memory reservation
size_t reserveMemoryVecs_; size_t reserveMemoryVecs_;

View File

@ -30,6 +30,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
index->nlist, index->nlist,
config), config),
ivfpqConfig_(config), ivfpqConfig_(config),
usePrecomputedTables_(config.usePrecomputedTables),
subQuantizers_(0), subQuantizers_(0),
bitsPerCode_(0), bitsPerCode_(0),
reserveMemoryVecs_(0) { reserveMemoryVecs_(0) {
@ -50,6 +51,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
nlist, nlist,
config), config),
ivfpqConfig_(config), ivfpqConfig_(config),
usePrecomputedTables_(config.usePrecomputedTables),
subQuantizers_(subQuantizers), subQuantizers_(subQuantizers),
bitsPerCode_(bitsPerCode), bitsPerCode_(bitsPerCode),
reserveMemoryVecs_(0) { reserveMemoryVecs_(0) {
@ -64,7 +66,7 @@ GpuIndexIVFPQ::~GpuIndexIVFPQ() {
void void
GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) { GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
GpuIndexIVF::copyFrom(index); GpuIndexIVF::copyFrom(index);
@ -105,9 +107,9 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
ivfpqConfig_.alternativeLayout, ivfpqConfig_.alternativeLayout,
(float*) index->pq.centroids.data(), (float*) index->pq.centroids.data(),
ivfpqConfig_.indicesOptions, ivfpqConfig_.indicesOptions,
memorySpace_)); config_.memorySpace));
// Doesn't make sense to reserve memory here // Doesn't make sense to reserve memory here
index_->setPrecomputedCodes(ivfpqConfig_.usePrecomputedTables); index_->setPrecomputedCodes(usePrecomputedTables_);
// Copy all of the IVF data // Copy all of the IVF data
index_->copyInvertedListsFrom(index->invlists); index_->copyInvertedListsFrom(index->invlists);
@ -115,7 +117,7 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
void void
GpuIndexIVFPQ::copyTo(faiss::IndexIVFPQ* index) const { GpuIndexIVFPQ::copyTo(faiss::IndexIVFPQ* index) const {
DeviceScope scope(device_); DeviceScope scope(config_.device);
// We must have the indices in order to copy to ourselves // We must have the indices in order to copy to ourselves
FAISS_THROW_IF_NOT_MSG(ivfpqConfig_.indicesOptions != INDICES_IVF, FAISS_THROW_IF_NOT_MSG(ivfpqConfig_.indicesOptions != INDICES_IVF,
@ -153,9 +155,9 @@ GpuIndexIVFPQ::copyTo(faiss::IndexIVFPQ* index) const {
fromDevice<float, 3>(devPQCentroids, fromDevice<float, 3>(devPQCentroids,
index->pq.centroids.data(), index->pq.centroids.data(),
resources_->getDefaultStream(device_)); resources_->getDefaultStream(config_.device));
if (ivfpqConfig_.usePrecomputedTables) { if (usePrecomputedTables_) {
index->precompute_table(); index->precompute_table();
} }
} }
@ -165,16 +167,16 @@ void
GpuIndexIVFPQ::reserveMemory(size_t numVecs) { GpuIndexIVFPQ::reserveMemory(size_t numVecs) {
reserveMemoryVecs_ = numVecs; reserveMemoryVecs_ = numVecs;
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
index_->reserveMemory(numVecs); index_->reserveMemory(numVecs);
} }
} }
void void
GpuIndexIVFPQ::setPrecomputedCodes(bool enable) { GpuIndexIVFPQ::setPrecomputedCodes(bool enable) {
ivfpqConfig_.usePrecomputedTables = enable; usePrecomputedTables_ = enable;
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
index_->setPrecomputedCodes(enable); index_->setPrecomputedCodes(enable);
} }
@ -183,7 +185,7 @@ GpuIndexIVFPQ::setPrecomputedCodes(bool enable) {
bool bool
GpuIndexIVFPQ::getPrecomputedCodes() const { GpuIndexIVFPQ::getPrecomputedCodes() const {
return ivfpqConfig_.usePrecomputedTables; return usePrecomputedTables_;
} }
int int
@ -204,7 +206,7 @@ GpuIndexIVFPQ::getCentroidsPerSubQuantizer() const {
size_t size_t
GpuIndexIVFPQ::reclaimMemory() { GpuIndexIVFPQ::reclaimMemory() {
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
return index_->reclaimMemory(); return index_->reclaimMemory();
} }
@ -214,7 +216,7 @@ GpuIndexIVFPQ::reclaimMemory() {
void void
GpuIndexIVFPQ::reset() { GpuIndexIVFPQ::reset() {
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
index_->reset(); index_->reset();
this->ntotal = 0; this->ntotal = 0;
@ -264,17 +266,17 @@ GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) {
ivfpqConfig_.alternativeLayout, ivfpqConfig_.alternativeLayout,
pq.centroids.data(), pq.centroids.data(),
ivfpqConfig_.indicesOptions, ivfpqConfig_.indicesOptions,
memorySpace_)); config_.memorySpace));
if (reserveMemoryVecs_) { if (reserveMemoryVecs_) {
index_->reserveMemory(reserveMemoryVecs_); index_->reserveMemory(reserveMemoryVecs_);
} }
index_->setPrecomputedCodes(ivfpqConfig_.usePrecomputedTables); index_->setPrecomputedCodes(usePrecomputedTables_);
} }
void void
GpuIndexIVFPQ::train(Index::idx_t n, const float* x) { GpuIndexIVFPQ::train(Index::idx_t n, const float* x) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
if (this->is_trained) { if (this->is_trained) {
FAISS_ASSERT(quantizer->is_trained); FAISS_ASSERT(quantizer->is_trained);
@ -289,7 +291,7 @@ GpuIndexIVFPQ::train(Index::idx_t n, const float* x) {
// First, make sure that the data is resident on the CPU, if it is not on the // First, make sure that the data is resident on the CPU, if it is not on the
// CPU, as we depend upon parts of the CPU code // CPU, as we depend upon parts of the CPU code
auto hostData = toHost<float, 2>((float*) x, auto hostData = toHost<float, 2>((float*) x,
resources_->getDefaultStream(device_), resources_->getDefaultStream(config_.device),
{(int) n, (int) this->d}); {(int) n, (int) this->d});
trainQuantizer_(n, hostData.data()); trainQuantizer_(n, hostData.data());
@ -351,7 +353,7 @@ GpuIndexIVFPQ::getListLength(int listId) const {
std::vector<unsigned char> std::vector<unsigned char>
GpuIndexIVFPQ::getListCodes(int listId) const { GpuIndexIVFPQ::getListCodes(int listId) const {
FAISS_ASSERT(index_); FAISS_ASSERT(index_);
DeviceScope scope(device_); DeviceScope scope(config_.device);
return index_->getListCodes(listId); return index_->getListCodes(listId);
} }
@ -359,7 +361,7 @@ GpuIndexIVFPQ::getListCodes(int listId) const {
std::vector<long> std::vector<long>
GpuIndexIVFPQ::getListIndices(int listId) const { GpuIndexIVFPQ::getListIndices(int listId) const {
FAISS_ASSERT(index_); FAISS_ASSERT(index_);
DeviceScope scope(device_); DeviceScope scope(config_.device);
return index_->getListIndices(listId); return index_->getListIndices(listId);
} }
@ -398,15 +400,15 @@ GpuIndexIVFPQ::verifySettings_() const {
// codes per subquantizer // codes per subquantizer
size_t requiredSmemSize = size_t requiredSmemSize =
lookupTableSize * subQuantizers_ * utils::pow2(bitsPerCode_); lookupTableSize * subQuantizers_ * utils::pow2(bitsPerCode_);
size_t smemPerBlock = getMaxSharedMemPerBlock(device_); size_t smemPerBlock = getMaxSharedMemPerBlock(config_.device);
FAISS_THROW_IF_NOT_FMT(requiredSmemSize FAISS_THROW_IF_NOT_FMT(requiredSmemSize
<= getMaxSharedMemPerBlock(device_), <= getMaxSharedMemPerBlock(config_.device),
"Device %d has %zu bytes of shared memory, while " "Device %d has %zu bytes of shared memory, while "
"%d bits per code and %d sub-quantizers requires %zu " "%d bits per code and %d sub-quantizers requires %zu "
"bytes. Consider useFloat16LookupTables and/or " "bytes. Consider useFloat16LookupTables and/or "
"reduce parameters", "reduce parameters",
device_, smemPerBlock, bitsPerCode_, subQuantizers_, config_.device, smemPerBlock, bitsPerCode_, subQuantizers_,
requiredSmemSize); requiredSmemSize);
} }

View File

@ -135,13 +135,16 @@ class GpuIndexIVFPQ : public GpuIndexIVF {
float* distances, float* distances,
Index::idx_t* labels) const override; Index::idx_t* labels) const override;
private:
void verifySettings_() const; void verifySettings_() const;
void trainResidualQuantizer_(Index::idx_t n, const float* x); void trainResidualQuantizer_(Index::idx_t n, const float* x);
private: protected:
GpuIndexIVFPQConfig ivfpqConfig_; /// Our configuration options that we were initialized with
const GpuIndexIVFPQConfig ivfpqConfig_;
/// Runtime override: whether or not we use precomputed tables
bool usePrecomputedTables_;
/// Number of sub-quantizers per encoded vector /// Number of sub-quantizers per encoded vector
int subQuantizers_; int subQuantizers_;

View File

@ -66,7 +66,7 @@ void
GpuIndexIVFScalarQuantizer::reserveMemory(size_t numVecs) { GpuIndexIVFScalarQuantizer::reserveMemory(size_t numVecs) {
reserveMemoryVecs_ = numVecs; reserveMemoryVecs_ = numVecs;
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
index_->reserveMemory(numVecs); index_->reserveMemory(numVecs);
} }
} }
@ -74,7 +74,7 @@ GpuIndexIVFScalarQuantizer::reserveMemory(size_t numVecs) {
void void
GpuIndexIVFScalarQuantizer::copyFrom( GpuIndexIVFScalarQuantizer::copyFrom(
const faiss::IndexIVFScalarQuantizer* index) { const faiss::IndexIVFScalarQuantizer* index) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
// Clear out our old data // Clear out our old data
index_.reset(); index_.reset();
@ -101,7 +101,7 @@ GpuIndexIVFScalarQuantizer::copyFrom(
by_residual, by_residual,
&sq, &sq,
ivfSQConfig_.indicesOptions, ivfSQConfig_.indicesOptions,
memorySpace_)); config_.memorySpace));
// Copy all of the IVF data // Copy all of the IVF data
index_->copyInvertedListsFrom(index->invlists); index_->copyInvertedListsFrom(index->invlists);
@ -110,7 +110,7 @@ GpuIndexIVFScalarQuantizer::copyFrom(
void void
GpuIndexIVFScalarQuantizer::copyTo( GpuIndexIVFScalarQuantizer::copyTo(
faiss::IndexIVFScalarQuantizer* index) const { faiss::IndexIVFScalarQuantizer* index) const {
DeviceScope scope(device_); DeviceScope scope(config_.device);
// We must have the indices in order to copy to ourselves // We must have the indices in order to copy to ourselves
FAISS_THROW_IF_NOT_MSG( FAISS_THROW_IF_NOT_MSG(
@ -135,7 +135,7 @@ GpuIndexIVFScalarQuantizer::copyTo(
size_t size_t
GpuIndexIVFScalarQuantizer::reclaimMemory() { GpuIndexIVFScalarQuantizer::reclaimMemory() {
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
return index_->reclaimMemory(); return index_->reclaimMemory();
} }
@ -146,7 +146,7 @@ GpuIndexIVFScalarQuantizer::reclaimMemory() {
void void
GpuIndexIVFScalarQuantizer::reset() { GpuIndexIVFScalarQuantizer::reset() {
if (index_) { if (index_) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
index_->reset(); index_->reset();
this->ntotal = 0; this->ntotal = 0;
@ -163,7 +163,7 @@ GpuIndexIVFScalarQuantizer::trainResiduals_(Index::idx_t n, const float* x) {
void void
GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) { GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
DeviceScope scope(device_); DeviceScope scope(config_.device);
if (this->is_trained) { if (this->is_trained) {
FAISS_ASSERT(quantizer->is_trained); FAISS_ASSERT(quantizer->is_trained);
@ -178,7 +178,7 @@ GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
// First, make sure that the data is resident on the CPU, if it is not on the // First, make sure that the data is resident on the CPU, if it is not on the
// CPU, as we depend upon parts of the CPU code // CPU, as we depend upon parts of the CPU code
auto hostData = toHost<float, 2>((float*) x, auto hostData = toHost<float, 2>((float*) x,
resources_->getDefaultStream(device_), resources_->getDefaultStream(config_.device),
{(int) n, (int) this->d}); {(int) n, (int) this->d});
trainQuantizer_(n, hostData.data()); trainQuantizer_(n, hostData.data());
@ -192,7 +192,7 @@ GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
by_residual, by_residual,
&sq, &sq,
ivfSQConfig_.indicesOptions, ivfSQConfig_.indicesOptions,
memorySpace_)); config_.memorySpace));
if (reserveMemoryVecs_) { if (reserveMemoryVecs_) {
index_->reserveMemory(reserveMemoryVecs_); index_->reserveMemory(reserveMemoryVecs_);

View File

@ -88,8 +88,9 @@ class GpuIndexIVFScalarQuantizer : public GpuIndexIVF {
/// Exposed like the CPU version /// Exposed like the CPU version
bool by_residual; bool by_residual;
private: protected:
GpuIndexIVFScalarQuantizerConfig ivfSQConfig_; /// Our configuration options
const GpuIndexIVFScalarQuantizerConfig ivfSQConfig_;
/// Desired inverted list memory reservation /// Desired inverted list memory reservation
size_t reserveMemoryVecs_; size_t reserveMemoryVecs_;

View File

@ -101,12 +101,8 @@ StandardGpuResourcesImpl::~StandardGpuResourcesImpl() {
for (auto& entry : defaultStreams_) { for (auto& entry : defaultStreams_) {
DeviceScope scope(entry.first); DeviceScope scope(entry.first);
auto it = userDefaultStreams_.find(entry.first); // We created these streams, so are responsible for destroying them
if (it == userDefaultStreams_.end()) { CUDA_VERIFY(cudaStreamDestroy(entry.second));
// The user did not specify this stream, thus we are the ones
// who have created it
CUDA_VERIFY(cudaStreamDestroy(entry.second));
}
} }
for (auto& entry : alternateStreams_) { for (auto& entry : alternateStreams_) {
@ -210,16 +206,14 @@ StandardGpuResourcesImpl::setPinnedMemory(size_t size) {
void void
StandardGpuResourcesImpl::setDefaultStream(int device, cudaStream_t stream) { StandardGpuResourcesImpl::setDefaultStream(int device, cudaStream_t stream) {
auto it = defaultStreams_.find(device);
if (it != defaultStreams_.end()) {
// Replace this stream with the user stream
CUDA_VERIFY(cudaStreamDestroy(it->second));
it->second = stream;
}
userDefaultStreams_[device] = stream; userDefaultStreams_[device] = stream;
} }
void
StandardGpuResourcesImpl::revertDefaultStream(int device) {
userDefaultStreams_.erase(device);
}
void void
StandardGpuResourcesImpl::setDefaultNullStreamAllDevices() { StandardGpuResourcesImpl::setDefaultNullStreamAllDevices() {
for (int dev = 0; dev < getNumDevices(); ++dev) { for (int dev = 0; dev < getNumDevices(); ++dev) {
@ -274,14 +268,8 @@ StandardGpuResourcesImpl::initializeForDevice(int device) {
// Create streams // Create streams
cudaStream_t defaultStream = 0; cudaStream_t defaultStream = 0;
auto it = userDefaultStreams_.find(device); CUDA_VERIFY(cudaStreamCreateWithFlags(&defaultStream,
if (it != userDefaultStreams_.end()) { cudaStreamNonBlocking));
// We already have a stream provided by the user
defaultStream = it->second;
} else {
CUDA_VERIFY(cudaStreamCreateWithFlags(&defaultStream,
cudaStreamNonBlocking));
}
defaultStreams_[device] = defaultStream; defaultStreams_[device] = defaultStream;
@ -341,6 +329,14 @@ StandardGpuResourcesImpl::getBlasHandle(int device) {
cudaStream_t cudaStream_t
StandardGpuResourcesImpl::getDefaultStream(int device) { StandardGpuResourcesImpl::getDefaultStream(int device) {
initializeForDevice(device); initializeForDevice(device);
auto it = userDefaultStreams_.find(device);
if (it != userDefaultStreams_.end()) {
// There is a user override stream set
return it->second;
}
// Otherwise, our base default stream
return defaultStreams_[device]; return defaultStreams_[device];
} }
@ -539,6 +535,11 @@ StandardGpuResources::setDefaultStream(int device, cudaStream_t stream) {
res_->setDefaultStream(device, stream); res_->setDefaultStream(device, stream);
} }
void
StandardGpuResources::revertDefaultStream(int device) {
res_->revertDefaultStream(device);
}
void void
StandardGpuResources::setDefaultNullStreamAllDevices() { StandardGpuResources::setDefaultNullStreamAllDevices() {
res_->setDefaultNullStreamAllDevices(); res_->setDefaultNullStreamAllDevices();

View File

@ -41,9 +41,23 @@ class StandardGpuResourcesImpl : public GpuResources {
/// transfers /// transfers
void setPinnedMemory(size_t size); void setPinnedMemory(size_t size);
/// Called to change the stream for work ordering /// Called to change the stream for work ordering. We do not own `stream`;
/// i.e., it will not be destroyed when the GpuResources object gets cleaned
/// up.
/// We are guaranteed that all Faiss GPU work is ordered with respect to
/// this stream upon exit from an index or other Faiss GPU call.
void setDefaultStream(int device, cudaStream_t stream); void setDefaultStream(int device, cudaStream_t stream);
/// Revert the default stream to the original stream managed by this resources
/// object, in case someone called `setDefaultStream`.
void revertDefaultStream(int device);
/// Returns the stream for the given device on which all Faiss GPU work is
/// ordered.
/// We are guaranteed that all Faiss GPU work is ordered with respect to
/// this stream upon exit from an index or other Faiss GPU call.
cudaStream_t getDefaultStream(int device) override;
/// Called to change the work ordering streams to the null stream /// Called to change the work ordering streams to the null stream
/// for all devices /// for all devices
void setDefaultNullStreamAllDevices(); void setDefaultNullStreamAllDevices();
@ -60,8 +74,6 @@ class StandardGpuResourcesImpl : public GpuResources {
cublasHandle_t getBlasHandle(int device) override; cublasHandle_t getBlasHandle(int device) override;
cudaStream_t getDefaultStream(int device) override;
std::vector<cudaStream_t> getAlternateStreams(int device) override; std::vector<cudaStream_t> getAlternateStreams(int device) override;
/// Allocate non-temporary GPU memory /// Allocate non-temporary GPU memory
@ -128,7 +140,9 @@ class StandardGpuResourcesImpl : public GpuResources {
}; };
/// Default implementation of GpuResources that allocates a cuBLAS /// Default implementation of GpuResources that allocates a cuBLAS
/// stream and 2 streams for use, as well as temporary memory /// stream and 2 streams for use, as well as temporary memory.
/// Internally, the Faiss GPU code uses the instance managed by getResources,
/// but this is the user-facing object that is internally reference counted.
class StandardGpuResources : public GpuResourcesProvider { class StandardGpuResources : public GpuResourcesProvider {
public: public:
StandardGpuResources(); StandardGpuResources();
@ -151,9 +165,17 @@ class StandardGpuResources : public GpuResourcesProvider {
/// transfers /// transfers
void setPinnedMemory(size_t size); void setPinnedMemory(size_t size);
/// Called to change the stream for work ordering /// Called to change the stream for work ordering. We do not own `stream`;
/// i.e., it will not be destroyed when the GpuResources object gets cleaned
/// up.
/// We are guaranteed that all Faiss GPU work is ordered with respect to
/// this stream upon exit from an index or other Faiss GPU call.
void setDefaultStream(int device, cudaStream_t stream); void setDefaultStream(int device, cudaStream_t stream);
/// Revert the default stream to the original stream managed by this resources
/// object, in case someone called `setDefaultStream`.
void revertDefaultStream(int device);
/// Called to change the work ordering streams to the null stream /// Called to change the work ordering streams to the null stream
/// for all devices /// for all devices
void setDefaultNullStreamAllDevices(); void setDefaultNullStreamAllDevices();

View File

@ -10,7 +10,7 @@ import unittest
import faiss import faiss
import torch import torch
from faiss.contrib.pytorch_tensors import search_index_pytorch, search_raw_array_pytorch from faiss.contrib.pytorch_tensors import search_index_pytorch, search_raw_array_pytorch, using_stream
def to_column_major(x): def to_column_major(x):
if hasattr(torch, 'contiguous_format'): if hasattr(torch, 'contiguous_format'):
@ -22,40 +22,38 @@ def to_column_major(x):
class PytorchFaissInterop(unittest.TestCase): class PytorchFaissInterop(unittest.TestCase):
def test_interop(self): def test_interop(self):
d = 128
d = 16 nq = 100
nq = 5 nb = 1000
nb = 20 k = 10
xq = faiss.randn(nq * d, 1234).reshape(nq, d) xq = faiss.randn(nq * d, 1234).reshape(nq, d)
xb = faiss.randn(nb * d, 1235).reshape(nb, d) xb = faiss.randn(nb * d, 1235).reshape(nb, d)
res = faiss.StandardGpuResources() res = faiss.StandardGpuResources()
index = faiss.GpuIndexFlatIP(res, d)
index.add(xb)
# reference CPU result # Let's run on a non-default stream
Dref, Iref = index.search(xq, 5) s = torch.cuda.Stream()
# query is pytorch tensor (CPU) # Torch will run on this stream
xq_torch = torch.FloatTensor(xq) with torch.cuda.stream(s):
# query is pytorch tensor (CPU and GPU)
xq_torch_cpu = torch.FloatTensor(xq)
xq_torch_gpu = xq_torch_cpu.cuda()
D2, I2 = search_index_pytorch(index, xq_torch, 5) index = faiss.GpuIndexFlatIP(res, d)
index.add(xb)
assert np.all(Iref == I2.numpy()) # Query with GPU tensor (this will be done on the current pytorch stream)
D2, I2 = search_index_pytorch(res, index, xq_torch_gpu, k)
Dref, Iref = index.search(xq, k)
# query is pytorch tensor (GPU) assert np.all(Iref == I2.cpu().numpy())
xq_torch = xq_torch.cuda()
# no need for a sync here
D3, I3 = search_index_pytorch(index, xq_torch, 5) # Query with CPU tensor
D3, I3 = search_index_pytorch(res, index, xq_torch_cpu, k)
# D3 and I3 are on torch tensors on GPU as well. assert np.all(Iref == I3.numpy())
# this does a sync, which is useful because faiss and
# pytorch use different Cuda streams.
res.syncDefaultStreamCurrentDevice()
assert np.all(Iref == I3.cpu().numpy())
def test_raw_array_search(self): def test_raw_array_search(self):
d = 32 d = 32
@ -74,55 +72,51 @@ class PytorchFaissInterop(unittest.TestCase):
# resource object, can be re-used over calls # resource object, can be re-used over calls
res = faiss.StandardGpuResources() res = faiss.StandardGpuResources()
# put on same stream as pytorch to avoid synchronizing streams
res.setDefaultNullStreamAllDevices()
for xq_row_major in True, False: # Let's have pytorch use a non-default stream
for xb_row_major in True, False: s = torch.cuda.Stream()
with torch.cuda.stream(s):
for xq_row_major in True, False:
for xb_row_major in True, False:
# move to pytorch & GPU # move to pytorch & GPU
xq_t = torch.from_numpy(xq).cuda() xq_t = torch.from_numpy(xq).cuda()
xb_t = torch.from_numpy(xb).cuda() xb_t = torch.from_numpy(xb).cuda()
if not xq_row_major:
xq_t = to_column_major(xq_t)
assert not xq_t.is_contiguous()
if not xb_row_major:
xb_t = to_column_major(xb_t)
assert not xb_t.is_contiguous()
D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)
# back to CPU for verification
D = D.cpu().numpy()
I = I.cpu().numpy()
assert np.all(I == gt_I)
assert np.all(np.abs(D - gt_D).max() < 1e-4)
# test on subset
try:
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
except TypeError:
if not xq_row_major: if not xq_row_major:
# then it is expected xq_t = to_column_major(xq_t)
continue assert not xq_t.is_contiguous()
# otherwise it is an error
raise
# back to CPU for verification if not xb_row_major:
D = D.cpu().numpy() xb_t = to_column_major(xb_t)
I = I.cpu().numpy() assert not xb_t.is_contiguous()
assert np.all(I == gt_I[60:80]) D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
# back to CPU for verification
D = D.cpu().numpy()
I = I.cpu().numpy()
assert np.all(I == gt_I)
assert np.all(np.abs(D - gt_D).max() < 1e-4)
# test on subset
try:
# This internally uses the current pytorch stream
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
except TypeError:
if not xq_row_major:
# then it is expected
continue
# otherwise it is an error
raise
# back to CPU for verification
D = D.cpu().numpy()
I = I.cpu().numpy()
assert np.all(I == gt_I[60:80])
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -274,7 +274,6 @@ void gpu_sync_all_devices()
%} %}
%template() std::pair<int, unsigned long>; %template() std::pair<int, unsigned long>;
%template() std::map<std::string, std::pair<int, unsigned long> >; %template() std::map<std::string, std::pair<int, unsigned long> >;
%template() std::map<int, std::map<std::string, std::pair<int, unsigned long> > >; %template() std::map<int, std::map<std::string, std::pair<int, unsigned long> > >;
@ -287,6 +286,21 @@ void gpu_sync_all_devices()
%include <faiss/gpu/GpuResources.h> %include <faiss/gpu/GpuResources.h>
%include <faiss/gpu/StandardGpuResources.h> %include <faiss/gpu/StandardGpuResources.h>
typedef CUstream_st* cudaStream_t;
%inline %{
// interop between pytorch exposed cudaStream_t and faiss
cudaStream_t cast_integer_to_cudastream_t(long x) {
return (cudaStream_t) x;
}
long cast_cudastream_t_to_integer(cudaStream_t x) {
return (long) x;
}
%}
#else #else
%{ %{