Faiss + Torch fixes, re-enable k = 2048

Summary:
This diff fixes four separate issues:

- Using the pytorch bridge produces the following deprecation warning. We switch to `_storage()` instead.
```
torch_utils.py:51: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor._storage() instead of tensor.storage()
  x.storage().data_ptr() + x.storage_offset() * 4)
```
- The `storage_offset` for certain types was wrong, but this would only affect torch tensors that were a view into a storage that didn't begin at the beginning.
- The `reconstruct_n` numpy pytorch bridge function allowed passing `-1` for `ni` which indicated that all vectors should be reconstructed. The torch bridge didn't follow this and throw an error:
```
TypeError: torch_replacement_reconstruct_n() missing 2 required positional arguments: 'n0' and 'ni'
```
- Choosing values in the range (1024, 2048] for `k` or `nprobe` were broken in D37777979; this is now fixed again.

Reviewed By: alexanderguzhva

Differential Revision: D42041239

fbshipit-source-id: c7d9b4aba63db8ac73e271c8ef34e231002963d9
pull/2623/head
Jeff Johnson 2022-12-14 16:21:22 -08:00 committed by Facebook GitHub Bot
parent eee58b3319
commit 4bb7aa4b77
3 changed files with 25 additions and 8 deletions

View File

@ -33,7 +33,7 @@ def swig_ptr_from_UInt8Tensor(x):
assert x.is_contiguous()
assert x.dtype == torch.uint8
return faiss.cast_integer_to_uint8_ptr(
x.storage().data_ptr() + x.storage_offset())
x._storage().data_ptr() + x.storage_offset())
def swig_ptr_from_HalfTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
@ -41,28 +41,28 @@ def swig_ptr_from_HalfTensor(x):
assert x.dtype == torch.float16
# no canonical half type in C/C++
return faiss.cast_integer_to_void_ptr(
x.storage().data_ptr() + x.storage_offset() * 4)
x._storage().data_ptr() + x.storage_offset() * 2)
def swig_ptr_from_FloatTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.float32
return faiss.cast_integer_to_float_ptr(
x.storage().data_ptr() + x.storage_offset() * 4)
x._storage().data_ptr() + x.storage_offset() * 4)
def swig_ptr_from_IntTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.int32, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_int_ptr(
x.storage().data_ptr() + x.storage_offset() * 8)
x._storage().data_ptr() + x.storage_offset() * 4)
def swig_ptr_from_IndicesTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_idx_t_ptr(
x.storage().data_ptr() + x.storage_offset() * 8)
x._storage().data_ptr() + x.storage_offset() * 8)
@contextlib.contextmanager
def using_stream(res, pytorch_stream=None):
@ -319,7 +319,10 @@ def handle_torch_Index(the_class):
return x
def torch_replacement_reconstruct_n(self, n0, ni, x=None):
def torch_replacement_reconstruct_n(self, n0=0, ni=-1, x=None):
if ni == -1:
ni = self.ntotal
# No tensor inputs are required, but with importing this module, we
# assume that the default should be torch tensors. If we are passed a
# numpy array, however, assume that the user is overriding this default

View File

@ -24,7 +24,7 @@ int getMaxKSelection() {
void validateKSelect(idx_t k) {
FAISS_THROW_IF_NOT_FMT(
k > 0 && k < (idx_t)getMaxKSelection(),
k > 0 && k <= (idx_t)getMaxKSelection(),
"GPU index only supports min/max-K selection up to %d (requested %zu)",
getMaxKSelection(),
k);
@ -32,7 +32,7 @@ void validateKSelect(idx_t k) {
void validateNProbe(idx_t nprobe) {
FAISS_THROW_IF_NOT_FMT(
nprobe > 0 && nprobe < (idx_t)getMaxKSelection(),
nprobe > 0 && nprobe <= (idx_t)getMaxKSelection(),
"GPU IVF index only supports nprobe selection up to %d (requested %zu)",
getMaxKSelection(),
nprobe);

View File

@ -141,6 +141,20 @@ TEST(TestGpuIndexFlat, L2_Float32) {
}
}
// At least one test for the k > 1024 select
TEST(TestGpuIndexFlat, L2_k_2048) {
if (faiss::gpu::getMaxKSelection() >= 2048) {
TestFlatOptions opt;
opt.metric = faiss::MetricType::METRIC_L2;
opt.useFloat16 = false;
opt.kOverride = 2048;
opt.dimOverride = 128;
opt.numVecsOverride = 10000;
testFlat(opt);
}
}
// test specialized k == 1 codepath
TEST(TestGpuIndexFlat, L2_Float32_K1) {
for (int tries = 0; tries < 3; ++tries) {