forgotten
parent
ca2ea5097b
commit
2882e89397
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the CC-by-NC license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
#! /usr/bin/env python2
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import faiss
|
||||
import torch
|
||||
|
||||
|
||||
def search_index_pytorch(index, x, k, D=None, I=None):
|
||||
"""call the search function of an index with pytorch tensor I/O (CPU
|
||||
and GPU supported)"""
|
||||
assert x.is_contiguous()
|
||||
n, d = x.size()
|
||||
assert d == index.d
|
||||
|
||||
if D is None:
|
||||
if x.is_cuda:
|
||||
D = torch.cuda.FloatTensor(n, k)
|
||||
else:
|
||||
D = torch.FloatTensor(n, k)
|
||||
else:
|
||||
assert D.__class__ in (torch.FloatTensor, torch.cuda.FloatTensor)
|
||||
assert D.size() == (n, k)
|
||||
assert D.is_contiguous()
|
||||
|
||||
if I is None:
|
||||
if x.is_cuda:
|
||||
I = torch.cuda.LongTensor(n, k)
|
||||
else:
|
||||
I = torch.LongTensor(n, k)
|
||||
else:
|
||||
assert I.__class__ in (torch.LongTensor, torch.cuda.LongTensor)
|
||||
assert I.size() == (n, k)
|
||||
assert I.is_contiguous()
|
||||
|
||||
xptr = x.storage().data_ptr()
|
||||
Iptr = I.storage().data_ptr()
|
||||
Dptr = D.storage().data_ptr()
|
||||
index.search_c(n, faiss.cast_integer_to_float_ptr(xptr),
|
||||
k, faiss.cast_integer_to_float_ptr(Dptr),
|
||||
faiss.cast_integer_to_long_ptr(Iptr))
|
||||
|
||||
return D, I
|
||||
|
||||
|
||||
class PytorchFaissInterop(unittest.TestCase):
|
||||
|
||||
def test_interop(self):
|
||||
|
||||
d = 16
|
||||
nq = 5
|
||||
nb = 20
|
||||
|
||||
xq = faiss.randn(nq * d, 1234).reshape(nq, d)
|
||||
xb = faiss.randn(nb * d, 1235).reshape(nb, d)
|
||||
|
||||
res = faiss.StandardGpuResources()
|
||||
index = faiss.GpuIndexFlatIP(res, d)
|
||||
index.add(xb)
|
||||
|
||||
# reference CPU result
|
||||
Dref, Iref = index.search(xq, 5)
|
||||
|
||||
# query is pytorch tensor (CPU)
|
||||
xq_torch = torch.FloatTensor(xq)
|
||||
|
||||
D2, I2 = search_index_pytorch(index, xq_torch, 5)
|
||||
|
||||
assert np.all(Iref == I2.numpy())
|
||||
|
||||
# query is pytorch tensor (GPU)
|
||||
xq_torch = xq_torch.cuda()
|
||||
# no need for a sync here
|
||||
|
||||
D3, I3 = search_index_pytorch(index, xq_torch, 5)
|
||||
|
||||
# D3 and I3 are on torch tensors on GPU as well.
|
||||
# this does a sync, which is useful because faiss and
|
||||
# pytorch use different Cuda streams.
|
||||
res.syncDefaultStreamCurrentDevice()
|
||||
|
||||
assert np.all(Iref == I3.cpu().numpy())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue