Faster versions of fvec_op_ny_Dx for AVX2 (#2811)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2811

Use transpose_AxB kernels to speedup computations.

Reviewed By: mdouze

Differential Revision: D44726814

fbshipit-source-id: a1dd3e289f4ed564a5bece699bee0af88c9925b0
pull/2835/head
Alexandr Guzhva 2023-04-24 15:03:45 -07:00 committed by Facebook GitHub Bot
parent d87888b13e
commit d0ba4c04ca
3 changed files with 458 additions and 49 deletions

View File

@ -247,6 +247,33 @@ static inline __m128 masked_read(int d, const float* x) {
namespace {
/// helper function
inline float horizontal_sum(const __m128 v) {
// say, v is [x0, x1, x2, x3]
// v0 is [x2, x3, ..., ...]
const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
// v1 is [x0 + x2, x1 + x3, ..., ...]
const __m128 v1 = _mm_add_ps(v, v0);
// v2 is [x1 + x3, ..., .... ,...]
__m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
// v3 is [x0 + x1 + x2 + x3, ..., ..., ...]
const __m128 v3 = _mm_add_ps(v1, v2);
// return v3[0]
return _mm_cvtss_f32(v3);
}
#ifdef __AVX2__
/// helper function for AVX2
inline float horizontal_sum(const __m256 v) {
// add high and low parts
const __m128 v0 =
_mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
// perform horizontal sum on v0
return horizontal_sum(v0);
}
#endif
/// Function that does a component-wise operation between x and y
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
/// functions below
@ -260,6 +287,13 @@ struct ElementOpL2 {
__m128 tmp = _mm_sub_ps(x, y);
return _mm_mul_ps(tmp, tmp);
}
#ifdef __AVX2__
static __m256 op(__m256 x, __m256 y) {
__m256 tmp = _mm256_sub_ps(x, y);
return _mm256_mul_ps(tmp, tmp);
}
#endif
};
/// Function that does a component-wise operation between x and y
@ -272,6 +306,12 @@ struct ElementOpIP {
static __m128 op(__m128 x, __m128 y) {
return _mm_mul_ps(x, y);
}
#ifdef __AVX2__
static __m256 op(__m256 x, __m256 y) {
return _mm256_mul_ps(x, y);
}
#endif
};
template <class ElementOp>
@ -314,6 +354,131 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
}
}
#ifdef __AVX2__
template <>
void fvec_op_ny_D2<ElementOpIP>(
float* dis,
const float* x,
const float* y,
size_t ny) {
const size_t ny8 = ny / 8;
size_t i = 0;
if (ny8 > 0) {
// process 8 D2-vectors per loop.
_mm_prefetch(y, _MM_HINT_T0);
_mm_prefetch(y + 16, _MM_HINT_T0);
const __m256 m0 = _mm256_set1_ps(x[0]);
const __m256 m1 = _mm256_set1_ps(x[1]);
for (i = 0; i < ny8 * 8; i += 8) {
_mm_prefetch(y + 32, _MM_HINT_T0);
// load 8x2 matrix and transpose it in registers.
// the typical bottleneck is memory access, so
// let's trade instructions for the bandwidth.
__m256 v0;
__m256 v1;
transpose_8x2(
_mm256_loadu_ps(y + 0 * 8),
_mm256_loadu_ps(y + 1 * 8),
v0,
v1);
// compute distances
__m256 distances = _mm256_mul_ps(m0, v0);
distances = _mm256_fmadd_ps(m1, v1, distances);
// store
_mm256_storeu_ps(dis + i, distances);
y += 16;
}
}
if (i < ny) {
// process leftovers
float x0 = x[0];
float x1 = x[1];
for (; i < ny; i++) {
float distance = x0 * y[0] + x1 * y[1];
y += 2;
dis[i] = distance;
}
}
}
template <>
void fvec_op_ny_D2<ElementOpL2>(
float* dis,
const float* x,
const float* y,
size_t ny) {
const size_t ny8 = ny / 8;
size_t i = 0;
if (ny8 > 0) {
// process 8 D2-vectors per loop.
_mm_prefetch(y, _MM_HINT_T0);
_mm_prefetch(y + 16, _MM_HINT_T0);
const __m256 m0 = _mm256_set1_ps(x[0]);
const __m256 m1 = _mm256_set1_ps(x[1]);
for (i = 0; i < ny8 * 8; i += 8) {
_mm_prefetch(y + 32, _MM_HINT_T0);
// load 8x2 matrix and transpose it in registers.
// the typical bottleneck is memory access, so
// let's trade instructions for the bandwidth.
__m256 v0;
__m256 v1;
transpose_8x2(
_mm256_loadu_ps(y + 0 * 8),
_mm256_loadu_ps(y + 1 * 8),
v0,
v1);
// compute differences
const __m256 d0 = _mm256_sub_ps(m0, v0);
const __m256 d1 = _mm256_sub_ps(m1, v1);
// compute squares of differences
__m256 distances = _mm256_mul_ps(d0, d0);
distances = _mm256_fmadd_ps(d1, d1, distances);
// store
_mm256_storeu_ps(dis + i, distances);
y += 16;
}
}
if (i < ny) {
// process leftovers
float x0 = x[0];
float x1 = x[1];
for (; i < ny; i++) {
float sub0 = x0 - y[0];
float sub1 = x1 - y[1];
float distance = sub0 * sub0 + sub1 * sub1;
y += 2;
dis[i] = distance;
}
}
}
#endif
template <class ElementOp>
void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
__m128 x0 = _mm_loadu_ps(x);
@ -321,17 +486,12 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
for (size_t i = 0; i < ny; i++) {
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
y += 4;
accu = _mm_hadd_ps(accu, accu);
accu = _mm_hadd_ps(accu, accu);
dis[i] = _mm_cvtss_f32(accu);
dis[i] = horizontal_sum(accu);
}
}
#ifdef __AVX2__
// Specialized versions for AVX2 for any CPUs that support gather/scatter.
// Todo: implement fvec_op_ny_Dxxx in the same way.
template <>
void fvec_op_ny_D4<ElementOpIP>(
float* dis,
@ -343,16 +503,9 @@ void fvec_op_ny_D4<ElementOpIP>(
if (ny8 > 0) {
// process 8 D4-vectors per loop.
_mm_prefetch(y, _MM_HINT_NTA);
_mm_prefetch(y + 16, _MM_HINT_NTA);
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
const __m256 m0 = _mm256_set1_ps(x[0]);
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
const __m256 m1 = _mm256_set1_ps(x[1]);
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
const __m256 m2 = _mm256_set1_ps(x[2]);
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
const __m256 m3 = _mm256_set1_ps(x[3]);
for (i = 0; i < ny8 * 8; i += 8) {
@ -395,9 +548,7 @@ void fvec_op_ny_D4<ElementOpIP>(
for (; i < ny; i++) {
__m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
y += 4;
accu = _mm_hadd_ps(accu, accu);
accu = _mm_hadd_ps(accu, accu);
dis[i] = _mm_cvtss_f32(accu);
dis[i] = horizontal_sum(accu);
}
}
}
@ -413,16 +564,9 @@ void fvec_op_ny_D4<ElementOpL2>(
if (ny8 > 0) {
// process 8 D4-vectors per loop.
_mm_prefetch(y, _MM_HINT_NTA);
_mm_prefetch(y + 16, _MM_HINT_NTA);
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
const __m256 m0 = _mm256_set1_ps(x[0]);
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
const __m256 m1 = _mm256_set1_ps(x[1]);
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
const __m256 m2 = _mm256_set1_ps(x[2]);
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
const __m256 m3 = _mm256_set1_ps(x[3]);
for (i = 0; i < ny8 * 8; i += 8) {
@ -471,9 +615,7 @@ void fvec_op_ny_D4<ElementOpL2>(
for (; i < ny; i++) {
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
y += 4;
accu = _mm_hadd_ps(accu, accu);
accu = _mm_hadd_ps(accu, accu);
dis[i] = _mm_cvtss_f32(accu);
dis[i] = horizontal_sum(accu);
}
}
}
@ -496,6 +638,182 @@ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
}
}
#ifdef __AVX2__
template <>
void fvec_op_ny_D8<ElementOpIP>(
float* dis,
const float* x,
const float* y,
size_t ny) {
const size_t ny8 = ny / 8;
size_t i = 0;
if (ny8 > 0) {
// process 8 D8-vectors per loop.
const __m256 m0 = _mm256_set1_ps(x[0]);
const __m256 m1 = _mm256_set1_ps(x[1]);
const __m256 m2 = _mm256_set1_ps(x[2]);
const __m256 m3 = _mm256_set1_ps(x[3]);
const __m256 m4 = _mm256_set1_ps(x[4]);
const __m256 m5 = _mm256_set1_ps(x[5]);
const __m256 m6 = _mm256_set1_ps(x[6]);
const __m256 m7 = _mm256_set1_ps(x[7]);
for (i = 0; i < ny8 * 8; i += 8) {
// load 8x8 matrix and transpose it in registers.
// the typical bottleneck is memory access, so
// let's trade instructions for the bandwidth.
__m256 v0;
__m256 v1;
__m256 v2;
__m256 v3;
__m256 v4;
__m256 v5;
__m256 v6;
__m256 v7;
transpose_8x8(
_mm256_loadu_ps(y + 0 * 8),
_mm256_loadu_ps(y + 1 * 8),
_mm256_loadu_ps(y + 2 * 8),
_mm256_loadu_ps(y + 3 * 8),
_mm256_loadu_ps(y + 4 * 8),
_mm256_loadu_ps(y + 5 * 8),
_mm256_loadu_ps(y + 6 * 8),
_mm256_loadu_ps(y + 7 * 8),
v0,
v1,
v2,
v3,
v4,
v5,
v6,
v7);
// compute distances
__m256 distances = _mm256_mul_ps(m0, v0);
distances = _mm256_fmadd_ps(m1, v1, distances);
distances = _mm256_fmadd_ps(m2, v2, distances);
distances = _mm256_fmadd_ps(m3, v3, distances);
distances = _mm256_fmadd_ps(m4, v4, distances);
distances = _mm256_fmadd_ps(m5, v5, distances);
distances = _mm256_fmadd_ps(m6, v6, distances);
distances = _mm256_fmadd_ps(m7, v7, distances);
// store
_mm256_storeu_ps(dis + i, distances);
y += 64;
}
}
if (i < ny) {
// process leftovers
__m256 x0 = _mm256_loadu_ps(x);
for (; i < ny; i++) {
__m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
y += 8;
dis[i] = horizontal_sum(accu);
}
}
}
template <>
void fvec_op_ny_D8<ElementOpL2>(
float* dis,
const float* x,
const float* y,
size_t ny) {
const size_t ny8 = ny / 8;
size_t i = 0;
if (ny8 > 0) {
// process 8 D8-vectors per loop.
const __m256 m0 = _mm256_set1_ps(x[0]);
const __m256 m1 = _mm256_set1_ps(x[1]);
const __m256 m2 = _mm256_set1_ps(x[2]);
const __m256 m3 = _mm256_set1_ps(x[3]);
const __m256 m4 = _mm256_set1_ps(x[4]);
const __m256 m5 = _mm256_set1_ps(x[5]);
const __m256 m6 = _mm256_set1_ps(x[6]);
const __m256 m7 = _mm256_set1_ps(x[7]);
for (i = 0; i < ny8 * 8; i += 8) {
// load 8x8 matrix and transpose it in registers.
// the typical bottleneck is memory access, so
// let's trade instructions for the bandwidth.
__m256 v0;
__m256 v1;
__m256 v2;
__m256 v3;
__m256 v4;
__m256 v5;
__m256 v6;
__m256 v7;
transpose_8x8(
_mm256_loadu_ps(y + 0 * 8),
_mm256_loadu_ps(y + 1 * 8),
_mm256_loadu_ps(y + 2 * 8),
_mm256_loadu_ps(y + 3 * 8),
_mm256_loadu_ps(y + 4 * 8),
_mm256_loadu_ps(y + 5 * 8),
_mm256_loadu_ps(y + 6 * 8),
_mm256_loadu_ps(y + 7 * 8),
v0,
v1,
v2,
v3,
v4,
v5,
v6,
v7);
// compute differences
const __m256 d0 = _mm256_sub_ps(m0, v0);
const __m256 d1 = _mm256_sub_ps(m1, v1);
const __m256 d2 = _mm256_sub_ps(m2, v2);
const __m256 d3 = _mm256_sub_ps(m3, v3);
const __m256 d4 = _mm256_sub_ps(m4, v4);
const __m256 d5 = _mm256_sub_ps(m5, v5);
const __m256 d6 = _mm256_sub_ps(m6, v6);
const __m256 d7 = _mm256_sub_ps(m7, v7);
// compute squares of differences
__m256 distances = _mm256_mul_ps(d0, d0);
distances = _mm256_fmadd_ps(d1, d1, distances);
distances = _mm256_fmadd_ps(d2, d2, distances);
distances = _mm256_fmadd_ps(d3, d3, distances);
distances = _mm256_fmadd_ps(d4, d4, distances);
distances = _mm256_fmadd_ps(d5, d5, distances);
distances = _mm256_fmadd_ps(d6, d6, distances);
distances = _mm256_fmadd_ps(d7, d7, distances);
// store
_mm256_storeu_ps(dis + i, distances);
y += 64;
}
}
if (i < ny) {
// process leftovers
__m256 x0 = _mm256_loadu_ps(x);
for (; i < ny; i++) {
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
y += 8;
dis[i] = horizontal_sum(accu);
}
}
}
#endif
template <class ElementOp>
void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
__m128 x0 = _mm_loadu_ps(x);
@ -509,9 +827,7 @@ void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
y += 4;
accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
y += 4;
accu = _mm_hadd_ps(accu, accu);
accu = _mm_hadd_ps(accu, accu);
dis[i] = _mm_cvtss_f32(accu);
dis[i] = horizontal_sum(accu);
}
}
@ -892,10 +1208,7 @@ size_t fvec_L2sqr_ny_nearest_D4(
for (; i < ny; i++) {
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
y += 4;
accu = _mm_hadd_ps(accu, accu);
accu = _mm_hadd_ps(accu, accu);
const auto distance = _mm_cvtss_f32(accu);
const float distance = horizontal_sum(accu);
if (current_min_distance > distance) {
current_min_distance = distance;
@ -1031,23 +1344,9 @@ size_t fvec_L2sqr_ny_nearest_D8(
__m256 x0 = _mm256_loadu_ps(x);
for (; i < ny; i++) {
__m256 sub = _mm256_sub_ps(x0, _mm256_loadu_ps(y));
__m256 accu = _mm256_mul_ps(sub, sub);
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
y += 8;
// horitontal sum
const __m256 h0 = _mm256_hadd_ps(accu, accu);
const __m256 h1 = _mm256_hadd_ps(h0, h0);
// extract high and low __m128 regs from __m256
const __m128 h2 = _mm256_extractf128_ps(h1, 1);
const __m128 h3 = _mm256_castps256_ps128(h1);
// get a final hsum into all 4 regs
const __m128 h4 = _mm_add_ss(h2, h3);
// extract f[0] from __m128
const float distance = _mm_cvtss_f32(h4);
const float distance = horizontal_sum(accu);
if (current_min_distance > distance) {
current_min_distance = distance;

View File

@ -25,6 +25,7 @@ set(FAISS_TEST_SRC
test_simdlib.cpp
test_approx_topk.cpp
test_RCQ_cropping.cpp
test_distances_simd.cpp
)
add_executable(faiss_test ${FAISS_TEST_SRC})

View File

@ -0,0 +1,109 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <gtest/gtest.h>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <random>
#include <vector>
#include <faiss/utils/distances.h>
// reference implementations
void fvec_inner_products_ny_ref(
float* ip,
const float* x,
const float* y,
size_t d,
size_t ny) {
for (size_t i = 0; i < ny; i++) {
ip[i] = faiss::fvec_inner_product(x, y, d);
y += d;
}
}
void fvec_L2sqr_ny_ref(
float* dis,
const float* x,
const float* y,
size_t d,
size_t ny) {
for (size_t i = 0; i < ny; i++) {
dis[i] = faiss::fvec_L2sqr(x, y, d);
y += d;
}
}
// test templated versions of fvec_L2sqr_ny
TEST(TEST_FVEC_L2SQR_NY, D2) {
// we're using int values in order to get 100% accurate
// results with floats.
std::default_random_engine rng(123);
std::uniform_int_distribution<int32_t> u(0, 32);
for (const auto dim : {2, 4, 8, 12}) {
std::vector<float> x(dim, 0);
for (size_t i = 0; i < x.size(); i++) {
x[i] = u(rng);
}
for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) {
std::vector<float> y(nrows * dim);
for (size_t i = 0; i < y.size(); i++) {
y[i] = u(rng);
}
std::vector<float> distances(nrows, 0);
faiss::fvec_L2sqr_ny(
distances.data(), x.data(), y.data(), dim, nrows);
std::vector<float> distances_ref(nrows, 0);
fvec_L2sqr_ny_ref(
distances_ref.data(), x.data(), y.data(), dim, nrows);
ASSERT_EQ(distances, distances_ref)
<< "Mismatching results for dim = " << dim
<< ", nrows = " << nrows;
}
}
}
// fvec_inner_products_ny
TEST(TEST_FVEC_INNER_PRODUCTS_NY, D2) {
// we're using int values in order to get 100% accurate
// results with floats.
std::default_random_engine rng(123);
std::uniform_int_distribution<int32_t> u(0, 32);
for (const auto dim : {2, 4, 8, 12}) {
std::vector<float> x(dim, 0);
for (size_t i = 0; i < x.size(); i++) {
x[i] = u(rng);
}
for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) {
std::vector<float> y(nrows * dim);
for (size_t i = 0; i < y.size(); i++) {
y[i] = u(rng);
}
std::vector<float> distances(nrows, 0);
faiss::fvec_inner_products_ny(
distances.data(), x.data(), y.data(), dim, nrows);
std::vector<float> distances_ref(nrows, 0);
fvec_inner_products_ny_ref(
distances_ref.data(), x.data(), y.data(), dim, nrows);
ASSERT_EQ(distances, distances_ref)
<< "Mismatching results for dim = " << dim
<< ", nrows = " << nrows;
}
}
}