Faiss + Torch fixes, re-enable k = 2048
Summary: This diff fixes four separate issues: - Using the pytorch bridge produces the following deprecation warning. We switch to `_storage()` instead. ``` torch_utils.py:51: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor._storage() instead of tensor.storage() x.storage().data_ptr() + x.storage_offset() * 4) ``` - The `storage_offset` for certain types was wrong, but this would only affect torch tensors that were a view into a storage that didn't begin at the beginning. - The `reconstruct_n` numpy pytorch bridge function allowed passing `-1` for `ni` which indicated that all vectors should be reconstructed. The torch bridge didn't follow this and throw an error: ``` TypeError: torch_replacement_reconstruct_n() missing 2 required positional arguments: 'n0' and 'ni' ``` - Choosing values in the range (1024, 2048] for `k` or `nprobe` were broken in D37777979; this is now fixed again. Reviewed By: alexanderguzhva Differential Revision: D42041239 fbshipit-source-id: c7d9b4aba63db8ac73e271c8ef34e231002963d9pull/2623/head
parent
eee58b3319
commit
4bb7aa4b77
|
@ -33,7 +33,7 @@ def swig_ptr_from_UInt8Tensor(x):
|
|||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.uint8
|
||||
return faiss.cast_integer_to_uint8_ptr(
|
||||
x.storage().data_ptr() + x.storage_offset())
|
||||
x._storage().data_ptr() + x.storage_offset())
|
||||
|
||||
def swig_ptr_from_HalfTensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
|
@ -41,28 +41,28 @@ def swig_ptr_from_HalfTensor(x):
|
|||
assert x.dtype == torch.float16
|
||||
# no canonical half type in C/C++
|
||||
return faiss.cast_integer_to_void_ptr(
|
||||
x.storage().data_ptr() + x.storage_offset() * 4)
|
||||
x._storage().data_ptr() + x.storage_offset() * 2)
|
||||
|
||||
def swig_ptr_from_FloatTensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.float32
|
||||
return faiss.cast_integer_to_float_ptr(
|
||||
x.storage().data_ptr() + x.storage_offset() * 4)
|
||||
x._storage().data_ptr() + x.storage_offset() * 4)
|
||||
|
||||
def swig_ptr_from_IntTensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.int32, 'dtype=%s' % x.dtype
|
||||
return faiss.cast_integer_to_int_ptr(
|
||||
x.storage().data_ptr() + x.storage_offset() * 8)
|
||||
x._storage().data_ptr() + x.storage_offset() * 4)
|
||||
|
||||
def swig_ptr_from_IndicesTensor(x):
|
||||
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
|
||||
return faiss.cast_integer_to_idx_t_ptr(
|
||||
x.storage().data_ptr() + x.storage_offset() * 8)
|
||||
x._storage().data_ptr() + x.storage_offset() * 8)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def using_stream(res, pytorch_stream=None):
|
||||
|
@ -319,7 +319,10 @@ def handle_torch_Index(the_class):
|
|||
|
||||
return x
|
||||
|
||||
def torch_replacement_reconstruct_n(self, n0, ni, x=None):
|
||||
def torch_replacement_reconstruct_n(self, n0=0, ni=-1, x=None):
|
||||
if ni == -1:
|
||||
ni = self.ntotal
|
||||
|
||||
# No tensor inputs are required, but with importing this module, we
|
||||
# assume that the default should be torch tensors. If we are passed a
|
||||
# numpy array, however, assume that the user is overriding this default
|
||||
|
|
|
@ -24,7 +24,7 @@ int getMaxKSelection() {
|
|||
|
||||
void validateKSelect(idx_t k) {
|
||||
FAISS_THROW_IF_NOT_FMT(
|
||||
k > 0 && k < (idx_t)getMaxKSelection(),
|
||||
k > 0 && k <= (idx_t)getMaxKSelection(),
|
||||
"GPU index only supports min/max-K selection up to %d (requested %zu)",
|
||||
getMaxKSelection(),
|
||||
k);
|
||||
|
@ -32,7 +32,7 @@ void validateKSelect(idx_t k) {
|
|||
|
||||
void validateNProbe(idx_t nprobe) {
|
||||
FAISS_THROW_IF_NOT_FMT(
|
||||
nprobe > 0 && nprobe < (idx_t)getMaxKSelection(),
|
||||
nprobe > 0 && nprobe <= (idx_t)getMaxKSelection(),
|
||||
"GPU IVF index only supports nprobe selection up to %d (requested %zu)",
|
||||
getMaxKSelection(),
|
||||
nprobe);
|
||||
|
|
|
@ -141,6 +141,20 @@ TEST(TestGpuIndexFlat, L2_Float32) {
|
|||
}
|
||||
}
|
||||
|
||||
// At least one test for the k > 1024 select
|
||||
TEST(TestGpuIndexFlat, L2_k_2048) {
|
||||
if (faiss::gpu::getMaxKSelection() >= 2048) {
|
||||
TestFlatOptions opt;
|
||||
opt.metric = faiss::MetricType::METRIC_L2;
|
||||
opt.useFloat16 = false;
|
||||
opt.kOverride = 2048;
|
||||
opt.dimOverride = 128;
|
||||
opt.numVecsOverride = 10000;
|
||||
|
||||
testFlat(opt);
|
||||
}
|
||||
}
|
||||
|
||||
// test specialized k == 1 codepath
|
||||
TEST(TestGpuIndexFlat, L2_Float32_K1) {
|
||||
for (int tries = 0; tries < 3; ++tries) {
|
||||
|
|
Loading…
Reference in New Issue