diff --git a/contrib/torch_utils.py b/contrib/torch_utils.py index 83bd2ac8e..027739958 100644 --- a/contrib/torch_utils.py +++ b/contrib/torch_utils.py @@ -33,7 +33,7 @@ def swig_ptr_from_UInt8Tensor(x): assert x.is_contiguous() assert x.dtype == torch.uint8 return faiss.cast_integer_to_uint8_ptr( - x.storage().data_ptr() + x.storage_offset()) + x._storage().data_ptr() + x.storage_offset()) def swig_ptr_from_HalfTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ @@ -41,28 +41,28 @@ def swig_ptr_from_HalfTensor(x): assert x.dtype == torch.float16 # no canonical half type in C/C++ return faiss.cast_integer_to_void_ptr( - x.storage().data_ptr() + x.storage_offset() * 4) + x._storage().data_ptr() + x.storage_offset() * 2) def swig_ptr_from_FloatTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.float32 return faiss.cast_integer_to_float_ptr( - x.storage().data_ptr() + x.storage_offset() * 4) + x._storage().data_ptr() + x.storage_offset() * 4) def swig_ptr_from_IntTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.int32, 'dtype=%s' % x.dtype return faiss.cast_integer_to_int_ptr( - x.storage().data_ptr() + x.storage_offset() * 8) + x._storage().data_ptr() + x.storage_offset() * 4) def swig_ptr_from_IndicesTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() assert x.dtype == torch.int64, 'dtype=%s' % x.dtype return faiss.cast_integer_to_idx_t_ptr( - x.storage().data_ptr() + x.storage_offset() * 8) + x._storage().data_ptr() + x.storage_offset() * 8) @contextlib.contextmanager def using_stream(res, pytorch_stream=None): @@ -319,7 +319,10 @@ def handle_torch_Index(the_class): return x - def torch_replacement_reconstruct_n(self, n0, ni, x=None): + def torch_replacement_reconstruct_n(self, n0=0, ni=-1, x=None): + if ni == -1: + ni = self.ntotal + # No tensor inputs are required, but with importing this module, we # assume that the default should be torch tensors. If we are passed a # numpy array, however, assume that the user is overriding this default diff --git a/faiss/gpu/impl/IndexUtils.cu b/faiss/gpu/impl/IndexUtils.cu index 6e4f23448..2389e1ad2 100644 --- a/faiss/gpu/impl/IndexUtils.cu +++ b/faiss/gpu/impl/IndexUtils.cu @@ -24,7 +24,7 @@ int getMaxKSelection() { void validateKSelect(idx_t k) { FAISS_THROW_IF_NOT_FMT( - k > 0 && k < (idx_t)getMaxKSelection(), + k > 0 && k <= (idx_t)getMaxKSelection(), "GPU index only supports min/max-K selection up to %d (requested %zu)", getMaxKSelection(), k); @@ -32,7 +32,7 @@ void validateKSelect(idx_t k) { void validateNProbe(idx_t nprobe) { FAISS_THROW_IF_NOT_FMT( - nprobe > 0 && nprobe < (idx_t)getMaxKSelection(), + nprobe > 0 && nprobe <= (idx_t)getMaxKSelection(), "GPU IVF index only supports nprobe selection up to %d (requested %zu)", getMaxKSelection(), nprobe); diff --git a/faiss/gpu/test/TestGpuIndexFlat.cpp b/faiss/gpu/test/TestGpuIndexFlat.cpp index 617db62d9..bcd49c68d 100644 --- a/faiss/gpu/test/TestGpuIndexFlat.cpp +++ b/faiss/gpu/test/TestGpuIndexFlat.cpp @@ -141,6 +141,20 @@ TEST(TestGpuIndexFlat, L2_Float32) { } } +// At least one test for the k > 1024 select +TEST(TestGpuIndexFlat, L2_k_2048) { + if (faiss::gpu::getMaxKSelection() >= 2048) { + TestFlatOptions opt; + opt.metric = faiss::MetricType::METRIC_L2; + opt.useFloat16 = false; + opt.kOverride = 2048; + opt.dimOverride = 128; + opt.numVecsOverride = 10000; + + testFlat(opt); + } +} + // test specialized k == 1 codepath TEST(TestGpuIndexFlat, L2_Float32_K1) { for (int tries = 0; tries < 3; ++tries) {