AVX2 version of faiss::HNSW::MinimaxHeap::pop_min() (#2874)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2874 Reviewed By: mdouze Differential Revision: D46125506 fbshipit-source-id: 4099e5c95bfb168b2097a42f5308c4bea1f72ca8pull/2882/head
parent
6800ebef83
commit
e8b7575e93
|
@ -16,6 +16,15 @@
|
|||
#include <faiss/impl/IDSelector.h>
|
||||
#include <faiss/utils/prefetch.h>
|
||||
|
||||
#include <faiss/impl/platform_macros.h>
|
||||
|
||||
#ifdef __AVX2__
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#endif
|
||||
|
||||
namespace faiss {
|
||||
|
||||
/**************************************************************
|
||||
|
@ -1010,17 +1019,105 @@ void HNSW::MinimaxHeap::clear() {
|
|||
nvalid = k = 0;
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
||||
assert(k > 0);
|
||||
static_assert(
|
||||
std::is_same<storage_idx_t, int32_t>::value,
|
||||
"This code expects storage_idx_t to be int32_t");
|
||||
|
||||
int32_t min_idx = -1;
|
||||
float min_dis = std::numeric_limits<float>::infinity();
|
||||
|
||||
size_t iii = 0;
|
||||
|
||||
__m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
|
||||
__m256 min_distances =
|
||||
_mm256_set1_ps(std::numeric_limits<float>::infinity());
|
||||
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
||||
__m256i offset = _mm256_set1_epi32(8);
|
||||
|
||||
// The baseline version is available in non-AVX2 branch.
|
||||
|
||||
// The following loop tracks the rightmost index with the min distance.
|
||||
// -1 index values are ignored.
|
||||
const int k8 = (k / 8) * 8;
|
||||
for (; iii < k8; iii += 8) {
|
||||
__m256i indices =
|
||||
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
|
||||
__m256 distances = _mm256_loadu_ps(dis.data() + iii);
|
||||
|
||||
// This mask filters out -1 values among indices.
|
||||
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
|
||||
|
||||
__m256i dmask = _mm256_castps_si256(
|
||||
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
|
||||
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
|
||||
|
||||
const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
|
||||
_mm256_castsi256_ps(current_indices),
|
||||
_mm256_castsi256_ps(min_indices),
|
||||
finalmask));
|
||||
|
||||
const __m256 min_distances_new =
|
||||
_mm256_blendv_ps(distances, min_distances, finalmask);
|
||||
|
||||
min_indices = min_indices_new;
|
||||
min_distances = min_distances_new;
|
||||
|
||||
current_indices = _mm256_add_epi32(current_indices, offset);
|
||||
}
|
||||
|
||||
// Vectorizing is doable, but is not practical
|
||||
int32_t vidx8[8];
|
||||
float vdis8[8];
|
||||
_mm256_storeu_ps(vdis8, min_distances);
|
||||
_mm256_storeu_si256((__m256i*)vidx8, min_indices);
|
||||
|
||||
for (size_t j = 0; j < 8; j++) {
|
||||
if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
|
||||
min_idx = vidx8[j];
|
||||
min_dis = vdis8[j];
|
||||
}
|
||||
}
|
||||
|
||||
// process last values. Vectorizing is doable, but is not practical
|
||||
for (; iii < k; iii++) {
|
||||
if (ids[iii] != -1 && dis[iii] <= min_dis) {
|
||||
min_dis = dis[iii];
|
||||
min_idx = iii;
|
||||
}
|
||||
}
|
||||
|
||||
if (min_idx == -1) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (vmin_out) {
|
||||
*vmin_out = min_dis;
|
||||
}
|
||||
int ret = ids[min_idx];
|
||||
ids[min_idx] = -1;
|
||||
--nvalid;
|
||||
return ret;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// baseline non-vectorized version
|
||||
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
||||
assert(k > 0);
|
||||
// returns min. This is an O(n) operation
|
||||
int i = k - 1;
|
||||
while (i >= 0) {
|
||||
if (ids[i] != -1)
|
||||
if (ids[i] != -1) {
|
||||
break;
|
||||
}
|
||||
i--;
|
||||
}
|
||||
if (i == -1)
|
||||
if (i == -1) {
|
||||
return -1;
|
||||
}
|
||||
int imin = i;
|
||||
float vmin = dis[i];
|
||||
i--;
|
||||
|
@ -1031,14 +1128,16 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
|
|||
}
|
||||
i--;
|
||||
}
|
||||
if (vmin_out)
|
||||
if (vmin_out) {
|
||||
*vmin_out = vmin;
|
||||
}
|
||||
int ret = ids[imin];
|
||||
ids[imin] = -1;
|
||||
--nvalid;
|
||||
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
int HNSW::MinimaxHeap::count_below(float thresh) {
|
||||
int n_below = 0;
|
||||
|
|
|
@ -28,6 +28,7 @@ set(FAISS_TEST_SRC
|
|||
test_distances_simd.cpp
|
||||
test_heap.cpp
|
||||
test_code_distance.cpp
|
||||
test_hnsw.cpp
|
||||
)
|
||||
|
||||
add_executable(faiss_test ${FAISS_TEST_SRC})
|
||||
|
|
|
@ -0,0 +1,192 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include <faiss/impl/HNSW.h>
|
||||
|
||||
int reference_pop_min(faiss::HNSW::MinimaxHeap& heap, float* vmin_out) {
|
||||
assert(heap.k > 0);
|
||||
// returns min. This is an O(n) operation
|
||||
int i = heap.k - 1;
|
||||
while (i >= 0) {
|
||||
if (heap.ids[i] != -1)
|
||||
break;
|
||||
i--;
|
||||
}
|
||||
if (i == -1)
|
||||
return -1;
|
||||
int imin = i;
|
||||
float vmin = heap.dis[i];
|
||||
i--;
|
||||
while (i >= 0) {
|
||||
if (heap.ids[i] != -1 && heap.dis[i] < vmin) {
|
||||
vmin = heap.dis[i];
|
||||
imin = i;
|
||||
}
|
||||
i--;
|
||||
}
|
||||
if (vmin_out)
|
||||
*vmin_out = vmin;
|
||||
int ret = heap.ids[imin];
|
||||
heap.ids[imin] = -1;
|
||||
--heap.nvalid;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void test_popmin(int heap_size, int amount_to_put) {
|
||||
// create a heap
|
||||
faiss::HNSW::MinimaxHeap mm_heap(heap_size);
|
||||
|
||||
using storage_idx_t = faiss::HNSW::storage_idx_t;
|
||||
|
||||
std::default_random_engine rng(123 + heap_size * amount_to_put);
|
||||
std::uniform_int_distribution<storage_idx_t> u(0, 65536);
|
||||
std::uniform_real_distribution<float> uf(0, 1);
|
||||
|
||||
// generate random unique indices
|
||||
std::unordered_set<storage_idx_t> indices;
|
||||
while (indices.size() < amount_to_put) {
|
||||
const storage_idx_t index = u(rng);
|
||||
indices.insert(index);
|
||||
}
|
||||
|
||||
// put ones into the heap
|
||||
for (const auto index : indices) {
|
||||
float distance = uf(rng);
|
||||
if (distance >= 0.7f) {
|
||||
// add infinity values from time to time
|
||||
distance = std::numeric_limits<float>::infinity();
|
||||
}
|
||||
mm_heap.push(index, distance);
|
||||
}
|
||||
|
||||
// clone the heap
|
||||
faiss::HNSW::MinimaxHeap cloned_mm_heap = mm_heap;
|
||||
|
||||
// takes ones out one by one
|
||||
while (mm_heap.size() > 0) {
|
||||
// compare heaps
|
||||
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
|
||||
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
|
||||
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
|
||||
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
|
||||
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
|
||||
|
||||
// use the reference pop_min for the cloned heap
|
||||
float cloned_vmin_dis = std::numeric_limits<float>::quiet_NaN();
|
||||
storage_idx_t cloned_vmin_idx =
|
||||
reference_pop_min(cloned_mm_heap, &cloned_vmin_dis);
|
||||
|
||||
float vmin_dis = std::numeric_limits<float>::quiet_NaN();
|
||||
storage_idx_t vmin_idx = mm_heap.pop_min(&vmin_dis);
|
||||
|
||||
// compare returns
|
||||
ASSERT_EQ(vmin_dis, cloned_vmin_dis);
|
||||
ASSERT_EQ(vmin_idx, cloned_vmin_idx);
|
||||
}
|
||||
|
||||
// compare heaps again
|
||||
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
|
||||
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
|
||||
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
|
||||
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
|
||||
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
|
||||
}
|
||||
|
||||
void test_popmin_identical_distances(
|
||||
int heap_size,
|
||||
int amount_to_put,
|
||||
const float distance) {
|
||||
// create a heap
|
||||
faiss::HNSW::MinimaxHeap mm_heap(heap_size);
|
||||
|
||||
using storage_idx_t = faiss::HNSW::storage_idx_t;
|
||||
|
||||
std::default_random_engine rng(123 + heap_size * amount_to_put);
|
||||
std::uniform_int_distribution<storage_idx_t> u(0, 65536);
|
||||
|
||||
// generate random unique indices
|
||||
std::unordered_set<storage_idx_t> indices;
|
||||
while (indices.size() < amount_to_put) {
|
||||
const storage_idx_t index = u(rng);
|
||||
indices.insert(index);
|
||||
}
|
||||
|
||||
// put ones into the heap
|
||||
for (const auto index : indices) {
|
||||
mm_heap.push(index, distance);
|
||||
}
|
||||
|
||||
// clone the heap
|
||||
faiss::HNSW::MinimaxHeap cloned_mm_heap = mm_heap;
|
||||
|
||||
// takes ones out one by one
|
||||
while (mm_heap.size() > 0) {
|
||||
// compare heaps
|
||||
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
|
||||
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
|
||||
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
|
||||
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
|
||||
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
|
||||
|
||||
// use the reference pop_min for the cloned heap
|
||||
float cloned_vmin_dis = std::numeric_limits<float>::quiet_NaN();
|
||||
storage_idx_t cloned_vmin_idx =
|
||||
reference_pop_min(cloned_mm_heap, &cloned_vmin_dis);
|
||||
|
||||
float vmin_dis = std::numeric_limits<float>::quiet_NaN();
|
||||
storage_idx_t vmin_idx = mm_heap.pop_min(&vmin_dis);
|
||||
|
||||
// compare returns
|
||||
ASSERT_EQ(vmin_dis, cloned_vmin_dis);
|
||||
ASSERT_EQ(vmin_idx, cloned_vmin_idx);
|
||||
}
|
||||
|
||||
// compare heaps again
|
||||
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
|
||||
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
|
||||
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
|
||||
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
|
||||
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
|
||||
}
|
||||
|
||||
TEST(HNSW, Test_popmin) {
|
||||
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32, 64, 128};
|
||||
for (const size_t size : sizes) {
|
||||
for (size_t amount = size; amount > 0; amount /= 2) {
|
||||
test_popmin(size, amount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HNSW, Test_popmin_identical_distances) {
|
||||
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32};
|
||||
for (const size_t size : sizes) {
|
||||
for (size_t amount = size; amount > 0; amount /= 2) {
|
||||
test_popmin_identical_distances(size, amount, 1.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HNSW, Test_popmin_infinite_distances) {
|
||||
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32};
|
||||
for (const size_t size : sizes) {
|
||||
for (size_t amount = size; amount > 0; amount /= 2) {
|
||||
test_popmin_identical_distances(
|
||||
size, amount, std::numeric_limits<float>::infinity());
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue