216 lines
6.0 KiB
Python
216 lines
6.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import numpy as np
|
|
import unittest
|
|
import faiss
|
|
import torch
|
|
|
|
def swig_ptr_from_FloatTensor(x):
|
|
assert x.is_contiguous()
|
|
assert x.dtype == torch.float32
|
|
return faiss.cast_integer_to_float_ptr(
|
|
x.storage().data_ptr() + x.storage_offset() * 4)
|
|
|
|
def swig_ptr_from_LongTensor(x):
|
|
assert x.is_contiguous()
|
|
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
|
|
return faiss.cast_integer_to_long_ptr(
|
|
x.storage().data_ptr() + x.storage_offset() * 8)
|
|
|
|
|
|
|
|
def search_index_pytorch(index, x, k, D=None, I=None):
|
|
"""call the search function of an index with pytorch tensor I/O (CPU
|
|
and GPU supported)"""
|
|
assert x.is_contiguous()
|
|
n, d = x.size()
|
|
assert d == index.d
|
|
|
|
if D is None:
|
|
D = torch.empty((n, k), dtype=torch.float32, device=x.device)
|
|
else:
|
|
assert D.size() == (n, k)
|
|
|
|
if I is None:
|
|
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
|
|
else:
|
|
assert I.size() == (n, k)
|
|
torch.cuda.synchronize()
|
|
xptr = swig_ptr_from_FloatTensor(x)
|
|
Iptr = swig_ptr_from_LongTensor(I)
|
|
Dptr = swig_ptr_from_FloatTensor(D)
|
|
index.search_c(n, xptr,
|
|
k, Dptr, Iptr)
|
|
torch.cuda.synchronize()
|
|
return D, I
|
|
|
|
|
|
def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,
|
|
metric=faiss.METRIC_L2):
|
|
assert xb.device == xq.device
|
|
|
|
nq, d = xq.size()
|
|
if xq.is_contiguous():
|
|
xq_row_major = True
|
|
elif xq.t().is_contiguous():
|
|
xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-)
|
|
xq_row_major = False
|
|
else:
|
|
raise TypeError('matrix should be row or column-major')
|
|
|
|
xq_ptr = swig_ptr_from_FloatTensor(xq)
|
|
|
|
nb, d2 = xb.size()
|
|
assert d2 == d
|
|
if xb.is_contiguous():
|
|
xb_row_major = True
|
|
elif xb.t().is_contiguous():
|
|
xb = xb.t()
|
|
xb_row_major = False
|
|
else:
|
|
raise TypeError('matrix should be row or column-major')
|
|
xb_ptr = swig_ptr_from_FloatTensor(xb)
|
|
|
|
if D is None:
|
|
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
|
|
else:
|
|
assert D.shape == (nq, k)
|
|
assert D.device == xb.device
|
|
|
|
if I is None:
|
|
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
|
|
else:
|
|
assert I.shape == (nq, k)
|
|
assert I.device == xb.device
|
|
|
|
D_ptr = swig_ptr_from_FloatTensor(D)
|
|
I_ptr = swig_ptr_from_LongTensor(I)
|
|
|
|
faiss.bruteForceKnn(res, metric,
|
|
xb_ptr, xb_row_major, nb,
|
|
xq_ptr, xq_row_major, nq,
|
|
d, k, D_ptr, I_ptr)
|
|
|
|
return D, I
|
|
|
|
def to_column_major(x):
|
|
if hasattr(torch, 'contiguous_format'):
|
|
return x.t().clone(memory_format=torch.contiguous_format).t()
|
|
else:
|
|
# was default setting before memory_format was introduced
|
|
return x.t().clone().t()
|
|
|
|
class PytorchFaissInterop(unittest.TestCase):
|
|
|
|
def test_interop(self):
|
|
|
|
d = 16
|
|
nq = 5
|
|
nb = 20
|
|
|
|
xq = faiss.randn(nq * d, 1234).reshape(nq, d)
|
|
xb = faiss.randn(nb * d, 1235).reshape(nb, d)
|
|
|
|
res = faiss.StandardGpuResources()
|
|
index = faiss.GpuIndexFlatIP(res, d)
|
|
index.add(xb)
|
|
|
|
# reference CPU result
|
|
Dref, Iref = index.search(xq, 5)
|
|
|
|
# query is pytorch tensor (CPU)
|
|
xq_torch = torch.FloatTensor(xq)
|
|
|
|
D2, I2 = search_index_pytorch(index, xq_torch, 5)
|
|
|
|
assert np.all(Iref == I2.numpy())
|
|
|
|
# query is pytorch tensor (GPU)
|
|
xq_torch = xq_torch.cuda()
|
|
# no need for a sync here
|
|
|
|
D3, I3 = search_index_pytorch(index, xq_torch, 5)
|
|
|
|
# D3 and I3 are on torch tensors on GPU as well.
|
|
# this does a sync, which is useful because faiss and
|
|
# pytorch use different Cuda streams.
|
|
res.syncDefaultStreamCurrentDevice()
|
|
|
|
assert np.all(Iref == I3.cpu().numpy())
|
|
|
|
def test_raw_array_search(self):
|
|
d = 32
|
|
nb = 1024
|
|
nq = 128
|
|
k = 10
|
|
|
|
# make GT on Faiss CPU
|
|
|
|
xq = faiss.randn(nq * d, 1234).reshape(nq, d)
|
|
xb = faiss.randn(nb * d, 1235).reshape(nb, d)
|
|
|
|
index = faiss.IndexFlatL2(d)
|
|
index.add(xb)
|
|
gt_D, gt_I = index.search(xq, k)
|
|
|
|
# resource object, can be re-used over calls
|
|
res = faiss.StandardGpuResources()
|
|
# put on same stream as pytorch to avoid synchronizing streams
|
|
res.setDefaultNullStreamAllDevices()
|
|
|
|
for xq_row_major in True, False:
|
|
for xb_row_major in True, False:
|
|
|
|
# move to pytorch & GPU
|
|
xq_t = torch.from_numpy(xq).cuda()
|
|
xb_t = torch.from_numpy(xb).cuda()
|
|
|
|
if not xq_row_major:
|
|
xq_t = to_column_major(xq_t)
|
|
assert not xq_t.is_contiguous()
|
|
|
|
if not xb_row_major:
|
|
xb_t = to_column_major(xb_t)
|
|
assert not xb_t.is_contiguous()
|
|
|
|
D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)
|
|
|
|
# back to CPU for verification
|
|
D = D.cpu().numpy()
|
|
I = I.cpu().numpy()
|
|
|
|
assert np.all(I == gt_I)
|
|
assert np.all(np.abs(D - gt_D).max() < 1e-4)
|
|
|
|
|
|
|
|
# test on subset
|
|
try:
|
|
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
|
|
except TypeError:
|
|
if not xq_row_major:
|
|
# then it is expected
|
|
continue
|
|
# otherwise it is an error
|
|
raise
|
|
|
|
# back to CPU for verification
|
|
D = D.cpu().numpy()
|
|
I = I.cpu().numpy()
|
|
|
|
assert np.all(I == gt_I[60:80])
|
|
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|