take into account torch offset when getting ptr
parent
5555ae7f88
commit
fa91c13980
|
@ -11,16 +11,18 @@ import unittest
|
|||
import faiss
|
||||
import torch
|
||||
|
||||
|
||||
def swig_ptr_from_FloatTensor(x):
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.float32
|
||||
return faiss.cast_integer_to_float_ptr(x.storage().data_ptr())
|
||||
return faiss.cast_integer_to_float_ptr(
|
||||
x.storage().data_ptr() + x.storage_offset() * 4)
|
||||
|
||||
def swig_ptr_from_LongTensor(x):
|
||||
assert x.is_contiguous()
|
||||
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
|
||||
return faiss.cast_integer_to_long_ptr(x.storage().data_ptr())
|
||||
|
||||
return faiss.cast_integer_to_long_ptr(
|
||||
x.storage().data_ptr() + x.storage_offset() * 8)
|
||||
|
||||
|
||||
def search_index_pytorch(index, x, k, D=None, I=None):
|
||||
|
@ -155,10 +157,15 @@ class PytorchFaissInterop(unittest.TestCase):
|
|||
assert np.all(I == gt_I)
|
||||
assert np.all(np.abs(D - gt_D).max() < 1e-4)
|
||||
|
||||
# test on subset
|
||||
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
|
||||
|
||||
# back to CPU for verification
|
||||
D = D.cpu().numpy()
|
||||
I = I.cpu().numpy()
|
||||
|
||||
|
||||
|
||||
assert np.all(I == gt_I[60:80])
|
||||
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue