take into account torch offset when getting ptr

pull/782/head
Matthijs Douze 2019-04-08 20:26:26 -07:00
parent 5555ae7f88
commit fa91c13980
1 changed files with 12 additions and 5 deletions

View File

@ -11,16 +11,18 @@ import unittest
import faiss
import torch
def swig_ptr_from_FloatTensor(x):
assert x.is_contiguous()
assert x.dtype == torch.float32
return faiss.cast_integer_to_float_ptr(x.storage().data_ptr())
return faiss.cast_integer_to_float_ptr(
x.storage().data_ptr() + x.storage_offset() * 4)
def swig_ptr_from_LongTensor(x):
assert x.is_contiguous()
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_long_ptr(x.storage().data_ptr())
return faiss.cast_integer_to_long_ptr(
x.storage().data_ptr() + x.storage_offset() * 8)
def search_index_pytorch(index, x, k, D=None, I=None):
@ -155,10 +157,15 @@ class PytorchFaissInterop(unittest.TestCase):
assert np.all(I == gt_I)
assert np.all(np.abs(D - gt_D).max() < 1e-4)
# test on subset
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
# back to CPU for verification
D = D.cpu().numpy()
I = I.cpu().numpy()
assert np.all(I == gt_I[60:80])
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
if __name__ == '__main__':