diff --git a/gpu/test/test_pytorch_faiss.py b/gpu/test/test_pytorch_faiss.py index 3a47a0188..80468a761 100644 --- a/gpu/test/test_pytorch_faiss.py +++ b/gpu/test/test_pytorch_faiss.py @@ -11,16 +11,18 @@ import unittest import faiss import torch + def swig_ptr_from_FloatTensor(x): assert x.is_contiguous() assert x.dtype == torch.float32 - return faiss.cast_integer_to_float_ptr(x.storage().data_ptr()) + return faiss.cast_integer_to_float_ptr( + x.storage().data_ptr() + x.storage_offset() * 4) def swig_ptr_from_LongTensor(x): assert x.is_contiguous() assert x.dtype == torch.int64, 'dtype=%s' % x.dtype - return faiss.cast_integer_to_long_ptr(x.storage().data_ptr()) - + return faiss.cast_integer_to_long_ptr( + x.storage().data_ptr() + x.storage_offset() * 8) def search_index_pytorch(index, x, k, D=None, I=None): @@ -155,10 +157,15 @@ class PytorchFaissInterop(unittest.TestCase): assert np.all(I == gt_I) assert np.all(np.abs(D - gt_D).max() < 1e-4) + # test on subset + D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k) + # back to CPU for verification + D = D.cpu().numpy() + I = I.cpu().numpy() - - + assert np.all(I == gt_I[60:80]) + assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4) if __name__ == '__main__':