# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from __future__ import absolute_import, division, print_function, unicode_literals import numpy as np import unittest import faiss import torch def swig_ptr_from_FloatTensor(x): assert x.is_contiguous() assert x.dtype == torch.float32 return faiss.cast_integer_to_float_ptr( x.storage().data_ptr() + x.storage_offset() * 4) def swig_ptr_from_LongTensor(x): assert x.is_contiguous() assert x.dtype == torch.int64, 'dtype=%s' % x.dtype return faiss.cast_integer_to_long_ptr( x.storage().data_ptr() + x.storage_offset() * 8) def search_index_pytorch(index, x, k, D=None, I=None): """call the search function of an index with pytorch tensor I/O (CPU and GPU supported)""" assert x.is_contiguous() n, d = x.size() assert d == index.d if D is None: D = torch.empty((n, k), dtype=torch.float32, device=x.device) else: assert D.size() == (n, k) if I is None: I = torch.empty((n, k), dtype=torch.int64, device=x.device) else: assert I.size() == (n, k) torch.cuda.synchronize() xptr = swig_ptr_from_FloatTensor(x) Iptr = swig_ptr_from_LongTensor(I) Dptr = swig_ptr_from_FloatTensor(D) index.search_c(n, xptr, k, Dptr, Iptr) torch.cuda.synchronize() return D, I def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, metric=faiss.METRIC_L2): assert xb.device == xq.device nq, d = xq.size() if xq.is_contiguous(): xq_row_major = True elif xq.t().is_contiguous(): xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) xq_row_major = False else: raise TypeError('matrix should be row or column-major') xq_ptr = swig_ptr_from_FloatTensor(xq) nb, d2 = xb.size() assert d2 == d if xb.is_contiguous(): xb_row_major = True elif xb.t().is_contiguous(): xb = xb.t() xb_row_major = False else: raise TypeError('matrix should be row or column-major') xb_ptr = swig_ptr_from_FloatTensor(xb) if D is None: D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) else: assert D.shape == (nq, k) assert D.device == xb.device if I is None: I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) else: assert I.shape == (nq, k) assert I.device == xb.device D_ptr = swig_ptr_from_FloatTensor(D) I_ptr = swig_ptr_from_LongTensor(I) faiss.bruteForceKnn(res, metric, xb_ptr, xb_row_major, nb, xq_ptr, xq_row_major, nq, d, k, D_ptr, I_ptr) return D, I def to_column_major(x): if hasattr(torch, 'contiguous_format'): return x.t().clone(memory_format=torch.contiguous_format).t() else: # was default setting before memory_format was introduced return x.t().clone().t() class PytorchFaissInterop(unittest.TestCase): def test_interop(self): d = 16 nq = 5 nb = 20 xq = faiss.randn(nq * d, 1234).reshape(nq, d) xb = faiss.randn(nb * d, 1235).reshape(nb, d) res = faiss.StandardGpuResources() index = faiss.GpuIndexFlatIP(res, d) index.add(xb) # reference CPU result Dref, Iref = index.search(xq, 5) # query is pytorch tensor (CPU) xq_torch = torch.FloatTensor(xq) D2, I2 = search_index_pytorch(index, xq_torch, 5) assert np.all(Iref == I2.numpy()) # query is pytorch tensor (GPU) xq_torch = xq_torch.cuda() # no need for a sync here D3, I3 = search_index_pytorch(index, xq_torch, 5) # D3 and I3 are on torch tensors on GPU as well. # this does a sync, which is useful because faiss and # pytorch use different Cuda streams. res.syncDefaultStreamCurrentDevice() assert np.all(Iref == I3.cpu().numpy()) def test_raw_array_search(self): d = 32 nb = 1024 nq = 128 k = 10 # make GT on Faiss CPU xq = faiss.randn(nq * d, 1234).reshape(nq, d) xb = faiss.randn(nb * d, 1235).reshape(nb, d) index = faiss.IndexFlatL2(d) index.add(xb) gt_D, gt_I = index.search(xq, k) # resource object, can be re-used over calls res = faiss.StandardGpuResources() # put on same stream as pytorch to avoid synchronizing streams res.setDefaultNullStreamAllDevices() for xq_row_major in True, False: for xb_row_major in True, False: # move to pytorch & GPU xq_t = torch.from_numpy(xq).cuda() xb_t = torch.from_numpy(xb).cuda() if not xq_row_major: xq_t = to_column_major(xq_t) assert not xq_t.is_contiguous() if not xb_row_major: xb_t = to_column_major(xb_t) assert not xb_t.is_contiguous() D, I = search_raw_array_pytorch(res, xb_t, xq_t, k) # back to CPU for verification D = D.cpu().numpy() I = I.cpu().numpy() assert np.all(I == gt_I) assert np.all(np.abs(D - gt_D).max() < 1e-4) # test on subset try: D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k) except TypeError: if not xq_row_major: # then it is expected continue # otherwise it is an error raise # back to CPU for verification D = D.cpu().numpy() I = I.cpu().numpy() assert np.all(I == gt_I[60:80]) assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4) if __name__ == '__main__': unittest.main()