Remove useless function
Summary: Removed an unused function that caused compile errors in some configurations. Added contrib function (exhaustive_search.knn) to compute the k nearest neighbors without constructing an index. Renamed the equivalent GPU function as exhaustive_search.knn_gpu (it does not make much sense to mention numpy in the name as all functions take numpy arguments by default). Reviewed By: beauby Differential Revision: D24215427 fbshipit-source-id: 6d8e1eafa7c57593304b7b76f83b3015e4d2a2bbpull/1449/head
parent
0412d761e5
commit
8b05434a50
|
@ -39,8 +39,42 @@ def knn_ground_truth(xq, db_iterator, k):
|
|||
|
||||
return rh.D, rh.I
|
||||
|
||||
# Brute-force k-nearest neighbor on the GPU using CPU-resident numpy arrays
|
||||
def knn_numpy_gpu(res, xb, xq, k, D=None, I=None, metric=faiss.METRIC_L2):
|
||||
def knn(xq, xb, k, distance_type=faiss.METRIC_L2):
|
||||
""" wrapper around the faiss knn functions without index """
|
||||
nq, d = xq.shape
|
||||
nb, d2 = xb.shape
|
||||
assert d == d2
|
||||
|
||||
I = np.empty((nq, k), dtype='int64')
|
||||
D = np.empty((nq, k), dtype='float32')
|
||||
|
||||
if distance_type == faiss.METRIC_L2:
|
||||
heaps = faiss.float_maxheap_array_t()
|
||||
heaps.k = k
|
||||
heaps.nh = nq
|
||||
heaps.val = faiss.swig_ptr(D)
|
||||
heaps.ids = faiss.swig_ptr(I)
|
||||
faiss.knn_L2sqr(
|
||||
faiss.swig_ptr(xq), faiss.swig_ptr(xb),
|
||||
d, nq, nb, heaps
|
||||
)
|
||||
elif distance_type == faiss.METRIC_INNER_PRODUCT:
|
||||
heaps = faiss.float_minheap_array_t()
|
||||
heaps.k = k
|
||||
heaps.nh = nq
|
||||
heaps.val = faiss.swig_ptr(D)
|
||||
heaps.ids = faiss.swig_ptr(I)
|
||||
faiss.knn_inner_product(
|
||||
faiss.swig_ptr(xq), faiss.swig_ptr(xb),
|
||||
d, nq, nb, heaps
|
||||
)
|
||||
return D, I
|
||||
|
||||
|
||||
def knn_gpu(res, xb, xq, k, D=None, I=None, metric=faiss.METRIC_L2):
|
||||
"""Brute-force k-nearest neighbor on the GPU using CPU-resident numpy arrays
|
||||
Supports float16 arrays and Fortran-order arrays.
|
||||
"""
|
||||
if xb.ndim != 2 or xq.ndim != 2:
|
||||
raise TypeError('xb and xq must be matrices')
|
||||
|
||||
|
|
|
@ -351,12 +351,6 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|||
};
|
||||
|
||||
|
||||
template <bool store_pairs>
|
||||
BinaryInvertedListScanner *select_IVFBinaryScannerL2 (size_t code_size) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
|
||||
size_t n,
|
||||
const uint8_t *x,
|
||||
|
|
|
@ -3,8 +3,7 @@ import unittest
|
|||
import numpy as np
|
||||
|
||||
from faiss.contrib import datasets
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth
|
||||
from faiss.contrib.exhaustive_search import knn_numpy_gpu
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth, knn_gpu
|
||||
|
||||
from common import get_dataset_2
|
||||
|
||||
|
@ -31,7 +30,8 @@ class TestComputeGT(unittest.TestCase):
|
|||
np.testing.assert_almost_equal(Dref, Dnew, decimal=4)
|
||||
|
||||
class TestBfKnnNumpy(unittest.TestCase):
|
||||
def test_bf_knn_numpy(self):
|
||||
|
||||
def test_bf_knn(self):
|
||||
d = 64
|
||||
k = 10
|
||||
xt, xb, xq = get_dataset_2(d, 0, 10000, 100)
|
||||
|
@ -42,7 +42,7 @@ class TestBfKnnNumpy(unittest.TestCase):
|
|||
|
||||
res = faiss.StandardGpuResources()
|
||||
|
||||
D, I = knn_numpy_gpu(res, xb, xq, k)
|
||||
D, I = knn_gpu(res, xb, xq, k)
|
||||
|
||||
np.testing.assert_array_equal(Iref, I)
|
||||
np.testing.assert_almost_equal(Dref, D, decimal=4)
|
||||
|
@ -50,19 +50,19 @@ class TestBfKnnNumpy(unittest.TestCase):
|
|||
# Test transpositions
|
||||
xbt = np.ascontiguousarray(xb.T)
|
||||
|
||||
D, I = knn_numpy_gpu(res, xbt.T, xq, k)
|
||||
D, I = knn_gpu(res, xbt.T, xq, k)
|
||||
|
||||
np.testing.assert_array_equal(Iref, I)
|
||||
np.testing.assert_almost_equal(Dref, D, decimal=4)
|
||||
|
||||
xqt = np.ascontiguousarray(xq.T)
|
||||
|
||||
D, I = knn_numpy_gpu(res, xb, xqt.T, k)
|
||||
D, I = knn_gpu(res, xb, xqt.T, k)
|
||||
|
||||
np.testing.assert_array_equal(Iref, I)
|
||||
np.testing.assert_almost_equal(Dref, D, decimal=4)
|
||||
|
||||
D, I = knn_numpy_gpu(res, xbt.T, xqt.T, k)
|
||||
D, I = knn_gpu(res, xbt.T, xqt.T, k)
|
||||
|
||||
np.testing.assert_array_equal(Iref, I)
|
||||
np.testing.assert_almost_equal(Dref, D, decimal=4)
|
||||
|
@ -71,7 +71,7 @@ class TestBfKnnNumpy(unittest.TestCase):
|
|||
xb16 = xb.astype(np.float16)
|
||||
xq16 = xq.astype(np.float16)
|
||||
|
||||
D, I = knn_numpy_gpu(res, xb, xq, k)
|
||||
D, I = knn_gpu(res, xb, xq, k)
|
||||
|
||||
np.testing.assert_array_equal(Iref, I)
|
||||
np.testing.assert_almost_equal(Dref, D, decimal=4)
|
||||
|
@ -79,7 +79,7 @@ class TestBfKnnNumpy(unittest.TestCase):
|
|||
# Test i32 indices
|
||||
I32 = np.empty((xq.shape[0], k), dtype=np.int32)
|
||||
|
||||
D, _ = knn_numpy_gpu(res, xb, xq, k, I=I32)
|
||||
D, _ = knn_gpu(res, xb, xq, k, I=I32)
|
||||
|
||||
np.testing.assert_array_equal(Iref, I32)
|
||||
np.testing.assert_almost_equal(Dref, D, decimal=4)
|
||||
|
|
|
@ -7,7 +7,8 @@ from faiss.contrib import datasets
|
|||
|
||||
from common import get_dataset_2
|
||||
try:
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth
|
||||
from faiss.contrib.exhaustive_search import knn_ground_truth, knn
|
||||
|
||||
except:
|
||||
pass # Submodule import broken in python 2.
|
||||
|
||||
|
@ -58,3 +59,31 @@ class TestDatasets(unittest.TestCase):
|
|||
xb2.append(xbi)
|
||||
xb2 = np.vstack(xb2)
|
||||
np.testing.assert_array_equal(xb, xb2)
|
||||
|
||||
|
||||
class TestExhaustiveSearch(unittest.TestCase):
|
||||
|
||||
def test_knn_cpu(self):
|
||||
|
||||
xb = np.random.rand(200, 32).astype('float32')
|
||||
xq = np.random.rand(100, 32).astype('float32')
|
||||
|
||||
|
||||
index = faiss.IndexFlatL2(32)
|
||||
index.add(xb)
|
||||
Dref, Iref = index.search(xq, 10)
|
||||
|
||||
Dnew, Inew = knn(xq, xb, 10)
|
||||
|
||||
assert np.all(Inew == Iref)
|
||||
assert np.allclose(Dref, Dnew)
|
||||
|
||||
|
||||
index = faiss.IndexFlatIP(32)
|
||||
index.add(xb)
|
||||
Dref, Iref = index.search(xq, 10)
|
||||
|
||||
Dnew, Inew = knn(xq, xb, 10, distance_type=faiss.METRIC_INNER_PRODUCT)
|
||||
|
||||
assert np.all(Inew == Iref)
|
||||
assert np.allclose(Dref, Dnew)
|
||||
|
|
Loading…
Reference in New Issue