mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Summary: PyTorch GPU in general is free to use whatever stream it currently wants, based on `torch.cuda.current_stream()`. Due to C++/python language barrier issues, we couldn't previously pass the actual `cudaStream_t` that is currently in use on a given device from PyTorch C++ to Faiss C++ via python. This diff adds conversion functions to convert a Python integer representing a pointer to a `cudaStream_t` (which is itself a `CUstream_st*`), so we can pass the stream specified in `torch.cuda.current_stream()` to `StandardGpuResources::setDefaultStream`. We thus guarantee that all Faiss work is ordered on the same stream that is in use in PyTorch. For use in Python, there is now the `faiss.contrib.pytorch_tensors.using_stream` context object which automatically sets and unsets the current PyTorch stream within Faiss. This takes a `StandardGpuResources` object in Python, and an optional `torch.cuda.Stream` if one wants to use a different stream, otherwise it uses the current one. This is how it is used: ``` # Create a non-default stream s = torch.cuda.Stream() # Have Torch use it with torch.cuda.stream(s): # Have Faiss use the same stream as the above with faiss.contrib.pytorch_tensors.using_stream(res): # Do some work on the GPU faiss.bfKnn(res, args) ``` `using_stream` uses the same pattern as the Pytorch `torch.cuda.stream` object. This replaces any brute-force GPU/CPU synchronization work that was necessary before. Other changes in this diff: - cleans up the config objects in the GpuIndex subclasses, to distinguish between read-only parameters that can only be set upon index construction, versus those that can be changed at runtime. - StandardGpuResources now more properly distinguishes between user-supplied streams (like the PyTorch one) which will not be destroyed upon resources destruction, versus internal streams. - `search_index_pytorch` now needs to take a `StandardGpuResources` object as well, there is no way to get this from an index instance otherwise (or at least, I would have to return a `shared_ptr`, in which case we should just update the Python SWIG stuff to use `shared_ptr` for GpuResources or something. Reviewed By: mdouze Differential Revision: D24260026 fbshipit-source-id: b18bb0eb34eb012584b1c923088228776c10b720
130 lines
3.8 KiB
Python
130 lines
3.8 KiB
Python
import faiss
|
|
import torch
|
|
import contextlib
|
|
|
|
def swig_ptr_from_FloatTensor(x):
|
|
""" gets a Faiss SWIG pointer from a pytorch trensor (on CPU or GPU) """
|
|
assert x.is_contiguous()
|
|
assert x.dtype == torch.float32
|
|
return faiss.cast_integer_to_float_ptr(
|
|
x.storage().data_ptr() + x.storage_offset() * 4)
|
|
|
|
def swig_ptr_from_LongTensor(x):
|
|
""" gets a Faiss SWIG pointer from a pytorch trensor (on CPU or GPU) """
|
|
assert x.is_contiguous()
|
|
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
|
|
return faiss.cast_integer_to_long_ptr(
|
|
x.storage().data_ptr() + x.storage_offset() * 8)
|
|
|
|
@contextlib.contextmanager
|
|
def using_stream(res, pytorch_stream=None):
|
|
r""" Creates a scoping object to make Faiss GPU use the same stream
|
|
as pytorch, based on torch.cuda.current_stream().
|
|
Or, a specific pytorch stream can be passed in as a second
|
|
argument, in which case we will use that stream.
|
|
"""
|
|
|
|
if pytorch_stream is None:
|
|
pytorch_stream = torch.cuda.current_stream()
|
|
|
|
# This is the cudaStream_t that we wish to use
|
|
cuda_stream_s = faiss.cast_integer_to_cudastream_t(pytorch_stream.cuda_stream)
|
|
|
|
# So we can revert GpuResources stream state upon exit
|
|
prior_dev = torch.cuda.current_device()
|
|
prior_stream = res.getDefaultStream(torch.cuda.current_device())
|
|
|
|
res.setDefaultStream(torch.cuda.current_device(), cuda_stream_s)
|
|
|
|
# Do the user work
|
|
try:
|
|
yield
|
|
finally:
|
|
res.setDefaultStream(prior_dev, prior_stream)
|
|
|
|
def search_index_pytorch(res, index, x, k, D=None, I=None):
|
|
"""call the search function of an index with pytorch tensor I/O (CPU
|
|
and GPU supported)"""
|
|
assert x.is_contiguous()
|
|
n, d = x.size()
|
|
assert d == index.d
|
|
|
|
if D is None:
|
|
D = torch.empty((n, k), dtype=torch.float32, device=x.device)
|
|
else:
|
|
assert D.size() == (n, k)
|
|
|
|
if I is None:
|
|
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
|
|
else:
|
|
assert I.size() == (n, k)
|
|
|
|
with using_stream(res):
|
|
xptr = swig_ptr_from_FloatTensor(x)
|
|
Iptr = swig_ptr_from_LongTensor(I)
|
|
Dptr = swig_ptr_from_FloatTensor(D)
|
|
index.search_c(n, xptr, k, Dptr, Iptr)
|
|
|
|
return D, I
|
|
|
|
|
|
def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,
|
|
metric=faiss.METRIC_L2):
|
|
"""search xq in xb, without building an index"""
|
|
assert xb.device == xq.device
|
|
|
|
nq, d = xq.size()
|
|
if xq.is_contiguous():
|
|
xq_row_major = True
|
|
elif xq.t().is_contiguous():
|
|
xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-)
|
|
xq_row_major = False
|
|
else:
|
|
raise TypeError('matrix should be row or column-major')
|
|
|
|
xq_ptr = swig_ptr_from_FloatTensor(xq)
|
|
|
|
nb, d2 = xb.size()
|
|
assert d2 == d
|
|
if xb.is_contiguous():
|
|
xb_row_major = True
|
|
elif xb.t().is_contiguous():
|
|
xb = xb.t()
|
|
xb_row_major = False
|
|
else:
|
|
raise TypeError('matrix should be row or column-major')
|
|
xb_ptr = swig_ptr_from_FloatTensor(xb)
|
|
|
|
if D is None:
|
|
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
|
|
else:
|
|
assert D.shape == (nq, k)
|
|
assert D.device == xb.device
|
|
|
|
if I is None:
|
|
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
|
|
else:
|
|
assert I.shape == (nq, k)
|
|
assert I.device == xb.device
|
|
|
|
D_ptr = swig_ptr_from_FloatTensor(D)
|
|
I_ptr = swig_ptr_from_LongTensor(I)
|
|
|
|
args = faiss.GpuDistanceParams()
|
|
args.metric = metric
|
|
args.k = k
|
|
args.dims = d
|
|
args.vectors = xb_ptr
|
|
args.vectorsRowMajor = xb_row_major
|
|
args.numVectors = nb
|
|
args.queries = xq_ptr
|
|
args.queriesRowMajor = xq_row_major
|
|
args.numQueries = nq
|
|
args.outDistances = D_ptr
|
|
args.outIndices = I_ptr
|
|
|
|
with using_stream(res):
|
|
faiss.bfKnn(res, args)
|
|
|
|
return D, I
|