PyTorch tensor / Faiss index interoperability (#1484)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1484
This diff allows for native usage of PyTorch tensors for Faiss indexes on both CPU and GPU. It is currently only implemented in this diff for things that inherit from `faiss.Index`, which covers the non-binary indices, and it patches the same functions on `faiss.Index` that were also covered by `__init__.py` for numpy interoperability.
There must be uniformity among the inputs: if any array input is a Torch tensor, then all array inputs must be Torch tensors. Similarly, if any array input is a numpy ndarray, then all array inputs must be numpy ndarrays.
If `faiss.contrib.torch_utils` is imported, it ensures that `import faiss` has already been performed to patch all of the functions using the base `__init__.py` numpy wrappers, and then patches the following functions again:
```
add
add_with_ids
assign
train
search
remove_ids
reconstruct
reconstruct_n
range_search
update_vectors
search_and_reconstruct
sa_encode
sa_decode
```
to allow usage of PyTorch CPU tensors, and additionally PyTorch GPU tensors if the index being used is on the GPU.
numpy functionality is still available when `faiss.contrib.torch_utils` is imported; we pass through to the original patched numpy function when we detect numpy inputs.
In addition, to allow for better (asynchronous) GPU usage without requiring the CPU to be involved, all of these functions which construct tensors/arrays for output now take optional arguments for storage (numpy or torch.Tensor) to be provided that will contain the output data. `range_search` is the only exception to this, as the size of the output data is indeterminate. The eventual GPU implementation will likely require the user to provide a maximum cap on the output size, and allow that to be passed instead. If the optional pre-allocated output values are presented by the user, they are used; otherwise, new return ndarray / Tensors are constructed as before and used for the return. If this feature were not provided on the GPU, then every execution would be completely serial as we would depend upon the CPU to allocate GPU memory before every operation. Instead, now this can function much like NN graph execution on the GPU, assuming that all of the data requirements are pre-allocated, so the execution will run at the full speed of the GPU and not be stalled sequentially launching kernels.
This diff also exposes the `GpuResources` shared_ptr object owned by a GPU index. This is required for pytorch GPU so that we can perform proper stream ordering in Faiss with respect to the current pytorch stream. So, Faiss indices now perform more or less as any NN operation in Torch does.
Note, however, that a Faiss index has its own setting on current device, and if the pytorch GPU tensor inputs are resident on a different device than what the Faiss index expects, a cross-device copy will be initiated. I may choose to make this an error in the future and require matching device to device.
This diff also found a bug when passing GPU data directly to `train()` for `GpuIndexIVFFlat` and `GpuIndexIVFScalarQuantizer`, as I guess we never tested passing GPU data directly to these functions before. `GpuIndexIVFPQ` was doing the right thing however.
The assign function is now also implemented on the GPU as well, and is now marked `const` to be in line with the `search` function.
Also added better checking of non-contiguous inputs for both Torch tensors and numpy ndarrays.
Updated the `knn_gpu` function with a base implementation always present that allows for usage of numpy arrays, which is overridden when `torch_utils` is imported to allow torch usage. This supports row/column major layout, float32/float16 data and int64/int32 indices for both numpy and torch.
Reviewed By: mdouze
Differential Revision: D24299400
fbshipit-source-id: b4f117b9c120bd1ad83e7702087051ab7b303b29
2020-10-24 13:22:51 +08:00
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the MIT license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
import faiss
|
|
|
|
import torch
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
import faiss.contrib.torch_utils
|
|
|
|
|
|
|
|
class TestTorchUtilsCPU(unittest.TestCase):
|
|
|
|
# tests add, search
|
|
|
|
def test_lookup(self):
|
|
|
|
d = 128
|
|
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
|
|
|
|
# Add to CPU index with torch CPU
|
|
|
|
xb_torch = torch.rand(10000, d)
|
|
|
|
index.add(xb_torch)
|
|
|
|
|
|
|
|
# Test reconstruct
|
|
|
|
y_torch = index.reconstruct(10)
|
|
|
|
self.assertTrue(torch.equal(y_torch, xb_torch[10]))
|
|
|
|
|
|
|
|
# Add to CPU index with numpy CPU
|
|
|
|
xb_np = torch.rand(500, d).numpy()
|
|
|
|
index.add(xb_np)
|
|
|
|
self.assertEqual(index.ntotal, 10500)
|
|
|
|
|
|
|
|
y_np = np.zeros(d, dtype=np.float32)
|
|
|
|
index.reconstruct(10100, y_np)
|
|
|
|
self.assertTrue(np.array_equal(y_np, xb_np[100]))
|
|
|
|
|
|
|
|
# Search with np cpu
|
|
|
|
xq_torch = torch.rand(10, d, dtype=torch.float32)
|
|
|
|
d_np, I_np = index.search(xq_torch.numpy(), 5)
|
|
|
|
|
|
|
|
# Search with torch cpu
|
|
|
|
d_torch, I_torch = index.search(xq_torch, 5)
|
|
|
|
|
|
|
|
# The two should be equivalent
|
|
|
|
self.assertTrue(np.array_equal(d_np, d_torch.numpy()))
|
|
|
|
self.assertTrue(np.array_equal(I_np, I_torch.numpy()))
|
|
|
|
|
|
|
|
# Search with np cpu using pre-allocated arrays
|
|
|
|
d_np_input = np.zeros((10, 5), dtype=np.float32)
|
|
|
|
I_np_input = np.zeros((10, 5), dtype=np.int64)
|
|
|
|
index.search(xq_torch.numpy(), 5, d_np_input, I_np_input)
|
|
|
|
|
|
|
|
self.assertTrue(np.array_equal(d_np, d_np_input))
|
|
|
|
self.assertTrue(np.array_equal(I_np, I_np_input))
|
|
|
|
|
|
|
|
# Search with torch cpu using pre-allocated arrays
|
|
|
|
d_torch_input = torch.zeros(10, 5, dtype=torch.float32)
|
|
|
|
I_torch_input = torch.zeros(10, 5, dtype=torch.int64)
|
|
|
|
index.search(xq_torch, 5, d_torch_input, I_torch_input)
|
|
|
|
|
|
|
|
self.assertTrue(np.array_equal(d_torch_input.numpy(), d_np))
|
|
|
|
self.assertTrue(np.array_equal(I_torch_input.numpy(), I_np))
|
|
|
|
|
|
|
|
# tests train, add_with_ids
|
|
|
|
def test_train_add_with_ids(self):
|
|
|
|
d = 32
|
|
|
|
nlist = 5
|
|
|
|
|
|
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
|
|
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
|
|
index.train(xb)
|
|
|
|
|
|
|
|
# Test add_with_ids with torch cpu
|
|
|
|
ids = torch.arange(1000, 1000 + xb.shape[0], dtype=torch.int64)
|
|
|
|
index.add_with_ids(xb, ids)
|
|
|
|
_, I = index.search(xb[10:20], 1)
|
|
|
|
self.assertTrue(torch.equal(I.view(10), ids[10:20]))
|
|
|
|
|
|
|
|
# Test add_with_ids with numpy
|
|
|
|
index.reset()
|
|
|
|
index.train(xb.numpy())
|
|
|
|
index.add_with_ids(xb.numpy(), ids.numpy())
|
|
|
|
_, I = index.search(xb.numpy()[10:20], 1)
|
|
|
|
self.assertTrue(np.array_equal(I.reshape(10), ids.numpy()[10:20]))
|
|
|
|
|
|
|
|
# tests reconstruct, reconstruct_n
|
|
|
|
def test_reconstruct(self):
|
|
|
|
d = 32
|
|
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
|
|
|
|
xb = torch.rand(100, d, dtype=torch.float32)
|
|
|
|
index.add(xb)
|
|
|
|
|
|
|
|
# Test reconstruct with torch cpu (native return)
|
|
|
|
y = index.reconstruct(7)
|
|
|
|
self.assertTrue(torch.equal(xb[7], y))
|
|
|
|
|
|
|
|
# Test reconstruct with numpy output provided
|
|
|
|
y = np.empty(d, dtype=np.float32)
|
|
|
|
index.reconstruct(11, y)
|
|
|
|
self.assertTrue(np.array_equal(xb.numpy()[11], y))
|
|
|
|
|
|
|
|
# Test reconstruct with torch cpu output providesd
|
|
|
|
y = torch.empty(d, dtype=torch.float32)
|
|
|
|
index.reconstruct(12, y)
|
|
|
|
self.assertTrue(torch.equal(xb[12], y))
|
|
|
|
|
|
|
|
# Test reconstruct_n with torch cpu (native return)
|
|
|
|
y = index.reconstruct_n(10, 10)
|
|
|
|
self.assertTrue(torch.equal(xb[10:20], y))
|
|
|
|
|
|
|
|
# Test reconstruct with numpy output provided
|
|
|
|
y = np.empty((10, d), dtype=np.float32)
|
|
|
|
index.reconstruct_n(20, 10, y)
|
|
|
|
self.assertTrue(np.array_equal(xb.cpu().numpy()[20:30], y))
|
|
|
|
|
|
|
|
# Test reconstruct_n with torch cpu output provided
|
|
|
|
y = torch.empty(10, d, dtype=torch.float32)
|
|
|
|
index.reconstruct_n(40, 10, y)
|
|
|
|
self.assertTrue(torch.equal(xb[40:50].cpu(), y))
|
|
|
|
|
|
|
|
# tests assign
|
|
|
|
def test_assign(self):
|
|
|
|
d = 32
|
|
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
|
|
index.add(xb)
|
|
|
|
|
|
|
|
index_ref = faiss.IndexFlatL2(d)
|
|
|
|
index_ref.add(xb.numpy())
|
|
|
|
|
|
|
|
# Test assign with native cpu output
|
|
|
|
xq = torch.rand(10, d, dtype=torch.float32)
|
|
|
|
labels = index.assign(xq, 5)
|
|
|
|
labels_ref = index_ref.assign(xq.cpu(), 5)
|
|
|
|
|
|
|
|
self.assertTrue(torch.equal(labels, labels_ref))
|
|
|
|
|
|
|
|
# Test assign with np input
|
|
|
|
labels = index.assign(xq.numpy(), 5)
|
|
|
|
labels_ref = index_ref.assign(xq.numpy(), 5)
|
|
|
|
self.assertTrue(np.array_equal(labels, labels_ref))
|
|
|
|
|
|
|
|
# Test assign with numpy output provided
|
|
|
|
labels = np.empty((xq.shape[0], 5), dtype='int64')
|
|
|
|
index.assign(xq.numpy(), 5, labels)
|
|
|
|
self.assertTrue(np.array_equal(labels, labels_ref))
|
|
|
|
|
|
|
|
# Test assign with torch cpu output provided
|
|
|
|
labels = torch.empty(xq.shape[0], 5, dtype=torch.int64)
|
|
|
|
index.assign(xq, 5, labels)
|
|
|
|
labels_ref = index_ref.assign(xq, 5)
|
|
|
|
self.assertTrue(torch.equal(labels, labels_ref))
|
|
|
|
|
|
|
|
# tests remove_ids
|
|
|
|
def test_remove_ids(self):
|
|
|
|
# only implemented for cpu index + numpy at the moment
|
|
|
|
d = 32
|
|
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
|
|
index = faiss.IndexIVFFlat(quantizer, d, 5)
|
|
|
|
index.make_direct_map()
|
|
|
|
index.set_direct_map_type(faiss.DirectMap.Hashtable)
|
|
|
|
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
|
|
ids = torch.arange(1000, 1000 + xb.shape[0], dtype=torch.int64)
|
|
|
|
index.train(xb)
|
|
|
|
index.add_with_ids(xb, ids)
|
|
|
|
|
|
|
|
ids_remove = np.array([1010], dtype=np.int64)
|
|
|
|
index.remove_ids(ids_remove)
|
|
|
|
|
|
|
|
# We should find this
|
|
|
|
y = index.reconstruct(1011)
|
|
|
|
self.assertTrue(np.array_equal(xb[11].numpy(), y))
|
|
|
|
|
|
|
|
# We should not find this
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
|
|
y = index.reconstruct(1010)
|
|
|
|
|
|
|
|
# Torch not yet supported
|
|
|
|
ids_remove = torch.tensor([1012], dtype=torch.int64)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
index.remove_ids(ids_remove)
|
|
|
|
|
|
|
|
# tests update_vectors
|
|
|
|
def test_update_vectors(self):
|
|
|
|
d = 32
|
|
|
|
quantizer_np = faiss.IndexFlatL2(d)
|
|
|
|
index_np = faiss.IndexIVFFlat(quantizer_np, d, 5)
|
|
|
|
index_np.make_direct_map()
|
|
|
|
index_np.set_direct_map_type(faiss.DirectMap.Hashtable)
|
|
|
|
|
|
|
|
quantizer_torch = faiss.IndexFlatL2(d)
|
|
|
|
index_torch = faiss.IndexIVFFlat(quantizer_torch, d, 5)
|
|
|
|
index_torch.make_direct_map()
|
|
|
|
index_torch.set_direct_map_type(faiss.DirectMap.Hashtable)
|
|
|
|
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
|
|
ids = torch.arange(1000, 1000 + xb.shape[0], dtype=torch.int64)
|
|
|
|
|
|
|
|
index_np.train(xb.numpy())
|
|
|
|
index_np.add_with_ids(xb.numpy(), ids.numpy())
|
|
|
|
|
|
|
|
index_torch.train(xb)
|
|
|
|
index_torch.add_with_ids(xb, ids)
|
|
|
|
|
|
|
|
xb_up = torch.rand(10, d, dtype=torch.float32)
|
|
|
|
ids_up = ids[0:10]
|
|
|
|
|
|
|
|
index_np.update_vectors(ids_up.numpy(), xb_up.numpy())
|
|
|
|
index_torch.update_vectors(ids_up, xb_up)
|
|
|
|
|
|
|
|
xq = torch.rand(10, d, dtype=torch.float32)
|
|
|
|
|
|
|
|
D_np, I_np = index_np.search(xq.numpy(), 5)
|
|
|
|
D_torch, I_torch = index_torch.search(xq, 5)
|
|
|
|
|
|
|
|
self.assertTrue(np.array_equal(D_np, D_torch.numpy()))
|
|
|
|
self.assertTrue(np.array_equal(I_np, I_torch.numpy()))
|
|
|
|
|
|
|
|
# tests range_search
|
|
|
|
def test_range_search(self):
|
|
|
|
torch.manual_seed(10)
|
|
|
|
d = 32
|
|
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
xb = torch.rand(100, d, dtype=torch.float32)
|
|
|
|
index.add(xb)
|
|
|
|
|
|
|
|
# torch cpu as ground truth
|
|
|
|
thresh = 2.9
|
|
|
|
xq = torch.rand(10, d, dtype=torch.float32)
|
|
|
|
lims, D, I = index.range_search(xq, thresh)
|
|
|
|
|
|
|
|
# compare against np
|
|
|
|
lims_np, D_np, I_np = index.range_search(xq.numpy(), thresh)
|
|
|
|
|
|
|
|
self.assertTrue(np.array_equal(lims.numpy(), lims_np))
|
|
|
|
self.assertTrue(np.array_equal(D.numpy(), D_np))
|
|
|
|
self.assertTrue(np.array_equal(I.numpy(), I_np))
|
|
|
|
|
|
|
|
# tests search_and_reconstruct
|
|
|
|
def test_search_and_reconstruct(self):
|
|
|
|
d = 32
|
|
|
|
nlist = 10
|
|
|
|
M = 4
|
|
|
|
k = 5
|
|
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
|
|
index = faiss.IndexIVFPQ(quantizer, d, nlist, M, 4)
|
|
|
|
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
|
|
index.train(xb)
|
|
|
|
|
|
|
|
# different set
|
|
|
|
xb = torch.rand(500, d, dtype=torch.float32)
|
|
|
|
index.add(xb)
|
|
|
|
|
|
|
|
# torch cpu as ground truth
|
|
|
|
xq = torch.rand(10, d, dtype=torch.float32)
|
|
|
|
D, I, R = index.search_and_reconstruct(xq, k)
|
|
|
|
|
|
|
|
# compare against numpy
|
|
|
|
D_np, I_np, R_np = index.search_and_reconstruct(xq.numpy(), k)
|
|
|
|
|
|
|
|
self.assertTrue(np.array_equal(D.numpy(), D_np))
|
|
|
|
self.assertTrue(np.array_equal(I.numpy(), I_np))
|
|
|
|
self.assertTrue(np.array_equal(R.numpy(), R_np))
|
|
|
|
|
|
|
|
# numpy input values
|
|
|
|
D_input = np.zeros((xq.shape[0], k), dtype=np.float32)
|
|
|
|
I_input = np.zeros((xq.shape[0], k), dtype=np.int64)
|
|
|
|
R_input = np.zeros((xq.shape[0], k, d), dtype=np.float32)
|
|
|
|
|
|
|
|
index.search_and_reconstruct(xq.numpy(), k, D_input, I_input, R_input)
|
|
|
|
|
|
|
|
self.assertTrue(np.array_equal(D.numpy(), D_input))
|
|
|
|
self.assertTrue(np.array_equal(I.numpy(), I_input))
|
|
|
|
self.assertTrue(np.array_equal(R.numpy(), R_input))
|
|
|
|
|
|
|
|
# torch input values
|
|
|
|
D_input = torch.zeros(xq.shape[0], k, dtype=torch.float32)
|
|
|
|
I_input = torch.zeros(xq.shape[0], k, dtype=torch.int64)
|
|
|
|
R_input = torch.zeros(xq.shape[0], k, d, dtype=torch.float32)
|
|
|
|
|
|
|
|
index.search_and_reconstruct(xq, k, D_input, I_input, R_input)
|
|
|
|
|
|
|
|
self.assertTrue(torch.equal(D, D_input))
|
|
|
|
self.assertTrue(torch.equal(I, I_input))
|
|
|
|
self.assertTrue(torch.equal(R, R_input))
|
|
|
|
|
|
|
|
# tests sa_encode, sa_decode
|
|
|
|
def test_sa_encode_decode(self):
|
|
|
|
d = 16
|
|
|
|
index = faiss.IndexScalarQuantizer(d, faiss.ScalarQuantizer.QT_8bit)
|
|
|
|
|
|
|
|
xb = torch.rand(1000, d, dtype=torch.float32)
|
|
|
|
index.train(xb)
|
|
|
|
|
|
|
|
# torch cpu as ground truth
|
|
|
|
nq = 10
|
|
|
|
xq = torch.rand(nq, d, dtype=torch.float32)
|
|
|
|
encoded_torch = index.sa_encode(xq)
|
|
|
|
|
|
|
|
# numpy cpu
|
|
|
|
encoded_np = index.sa_encode(xq.numpy())
|
|
|
|
|
|
|
|
self.assertTrue(np.array_equal(encoded_torch.numpy(), encoded_np))
|
|
|
|
|
|
|
|
decoded_torch = index.sa_decode(encoded_torch)
|
|
|
|
decoded_np = index.sa_decode(encoded_np)
|
|
|
|
|
|
|
|
self.assertTrue(torch.equal(decoded_torch, torch.from_numpy(decoded_np)))
|
|
|
|
|
|
|
|
# torch cpu as output parameter
|
|
|
|
encoded_torch_param = torch.zeros(nq, d, dtype=torch.uint8)
|
|
|
|
index.sa_encode(xq, encoded_torch_param)
|
|
|
|
|
|
|
|
self.assertTrue(torch.equal(encoded_torch, encoded_torch))
|
|
|
|
|
|
|
|
decoded_torch_param = torch.zeros(nq, d, dtype=torch.float32)
|
|
|
|
index.sa_decode(encoded_torch, decoded_torch_param)
|
|
|
|
|
|
|
|
self.assertTrue(torch.equal(decoded_torch, decoded_torch_param))
|
|
|
|
|
|
|
|
# np as output parameter
|
|
|
|
encoded_np_param = np.zeros((nq, d), dtype=np.uint8)
|
|
|
|
index.sa_encode(xq.numpy(), encoded_np_param)
|
|
|
|
|
|
|
|
self.assertTrue(np.array_equal(encoded_torch.numpy(), encoded_np_param))
|
|
|
|
|
|
|
|
decoded_np_param = np.zeros((nq, d), dtype=np.float32)
|
|
|
|
index.sa_decode(encoded_np_param, decoded_np_param)
|
|
|
|
|
|
|
|
self.assertTrue(np.array_equal(decoded_np, decoded_np_param))
|
|
|
|
|
|
|
|
def test_non_contiguous(self):
|
|
|
|
d = 128
|
|
|
|
index = faiss.IndexFlatL2(d)
|
|
|
|
|
|
|
|
xb = torch.rand(d, 100).transpose(0, 1)
|
|
|
|
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
index.add(xb)
|
|
|
|
|
2022-03-30 20:42:08 +08:00
|
|
|
# disabled since we now accept non-contiguous arrays
|
|
|
|
# with self.assertRaises(ValueError):
|
|
|
|
# index.add(xb.numpy())
|