mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
Make simdlib_emulated.h faster (#1814)
Summary: related: https://github.com/facebookresearch/faiss/issues/1812 This PR improves the performance of contents in `simdlib_emulated.h` . `IndexPQFastScan` and `IndexIVFPQFastScan` will become faster on non-AVX2 environments, e.g., 4x faster on SIFT1M. This PR contains below changes: - Use `template` instead of `std::function` on argument of `unary_func` and `binary_func` - Because `std::function` hinders some optimizations like function inlining - Use `const T&` instead of `T` for vector classes like `simd16uint16` on argument of functions - Vector classes on `simdlib_emulated.h` has the data member as array, so the runtime cost for copying is not so low. - Passing by const lvalue-ref prevents copy. Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1814 Reviewed By: beauby Differential Revision: D27760072 Pulled By: mdouze fbshipit-source-id: cbc5a14658d1960b24ce55a395e71c80998742dc
This commit is contained in:
parent
c62ab3a696
commit
b85b4308f2
@ -10,7 +10,6 @@
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
namespace faiss {
|
||||
@ -72,7 +71,7 @@ struct simd16uint16 : simd256bit {
|
||||
set1(x);
|
||||
}
|
||||
|
||||
explicit simd16uint16(simd256bit x) : simd256bit(x) {}
|
||||
explicit simd16uint16(const simd256bit& x) : simd256bit(x) {}
|
||||
|
||||
explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
|
||||
|
||||
@ -94,9 +93,8 @@ struct simd16uint16 : simd256bit {
|
||||
return elements_to_string("%3d,");
|
||||
}
|
||||
|
||||
static simd16uint16 unary_func(
|
||||
simd16uint16 a,
|
||||
std::function<uint16_t(uint16_t)> f) {
|
||||
template <typename F>
|
||||
static simd16uint16 unary_func(const simd16uint16& a, F&& f) {
|
||||
simd16uint16 c;
|
||||
for (int j = 0; j < 16; j++) {
|
||||
c.u16[j] = f(a.u16[j]);
|
||||
@ -104,10 +102,11 @@ struct simd16uint16 : simd256bit {
|
||||
return c;
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
static simd16uint16 binary_func(
|
||||
simd16uint16 a,
|
||||
simd16uint16 b,
|
||||
std::function<uint16_t(uint16_t, uint16_t)> f) {
|
||||
const simd16uint16& a,
|
||||
const simd16uint16& b,
|
||||
F&& f) {
|
||||
simd16uint16 c;
|
||||
for (int j = 0; j < 16; j++) {
|
||||
c.u16[j] = f(a.u16[j], b.u16[j]);
|
||||
@ -131,34 +130,34 @@ struct simd16uint16 : simd256bit {
|
||||
return unary_func(*this, [shift](uint16_t a) { return a << shift; });
|
||||
}
|
||||
|
||||
simd16uint16 operator+=(simd16uint16 other) {
|
||||
simd16uint16 operator+=(const simd16uint16& other) {
|
||||
*this = *this + other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
simd16uint16 operator-=(simd16uint16 other) {
|
||||
simd16uint16 operator-=(const simd16uint16& other) {
|
||||
*this = *this - other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
simd16uint16 operator+(simd16uint16 other) const {
|
||||
simd16uint16 operator+(const simd16uint16& other) const {
|
||||
return binary_func(
|
||||
*this, other, [](uint16_t a, uint16_t b) { return a + b; });
|
||||
}
|
||||
|
||||
simd16uint16 operator-(simd16uint16 other) const {
|
||||
simd16uint16 operator-(const simd16uint16& other) const {
|
||||
return binary_func(
|
||||
*this, other, [](uint16_t a, uint16_t b) { return a - b; });
|
||||
}
|
||||
|
||||
simd16uint16 operator&(simd256bit other) const {
|
||||
simd16uint16 operator&(const simd256bit& other) const {
|
||||
return binary_func(
|
||||
*this, simd16uint16(other), [](uint16_t a, uint16_t b) {
|
||||
return a & b;
|
||||
});
|
||||
}
|
||||
|
||||
simd16uint16 operator|(simd256bit other) const {
|
||||
simd16uint16 operator|(const simd256bit& other) const {
|
||||
return binary_func(
|
||||
*this, simd16uint16(other), [](uint16_t a, uint16_t b) {
|
||||
return a | b;
|
||||
@ -166,7 +165,7 @@ struct simd16uint16 : simd256bit {
|
||||
}
|
||||
|
||||
// returns binary masks
|
||||
simd16uint16 operator==(simd16uint16 other) const {
|
||||
simd16uint16 operator==(const simd16uint16& other) const {
|
||||
return binary_func(*this, other, [](uint16_t a, uint16_t b) {
|
||||
return a == b ? 0xffff : 0;
|
||||
});
|
||||
@ -183,7 +182,7 @@ struct simd16uint16 : simd256bit {
|
||||
|
||||
// mask of elements where this >= thresh
|
||||
// 2 bit per component: 16 * 2 = 32 bit
|
||||
uint32_t ge_mask(simd16uint16 thresh) const {
|
||||
uint32_t ge_mask(const simd16uint16& thresh) const {
|
||||
uint32_t gem = 0;
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (u16[j] >= thresh.u16[j]) {
|
||||
@ -193,15 +192,15 @@ struct simd16uint16 : simd256bit {
|
||||
return gem;
|
||||
}
|
||||
|
||||
uint32_t le_mask(simd16uint16 thresh) const {
|
||||
uint32_t le_mask(const simd16uint16& thresh) const {
|
||||
return thresh.ge_mask(*this);
|
||||
}
|
||||
|
||||
uint32_t gt_mask(simd16uint16 thresh) const {
|
||||
uint32_t gt_mask(const simd16uint16& thresh) const {
|
||||
return ~le_mask(thresh);
|
||||
}
|
||||
|
||||
bool all_gt(simd16uint16 thresh) const {
|
||||
bool all_gt(const simd16uint16& thresh) const {
|
||||
return le_mask(thresh) == 0;
|
||||
}
|
||||
|
||||
@ -210,7 +209,7 @@ struct simd16uint16 : simd256bit {
|
||||
return u16[i];
|
||||
}
|
||||
|
||||
void accu_min(simd16uint16 incoming) {
|
||||
void accu_min(const simd16uint16& incoming) {
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (incoming.u16[j] < u16[j]) {
|
||||
u16[j] = incoming.u16[j];
|
||||
@ -218,7 +217,7 @@ struct simd16uint16 : simd256bit {
|
||||
}
|
||||
}
|
||||
|
||||
void accu_max(simd16uint16 incoming) {
|
||||
void accu_max(const simd16uint16& incoming) {
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (incoming.u16[j] > u16[j]) {
|
||||
u16[j] = incoming.u16[j];
|
||||
@ -228,12 +227,12 @@ struct simd16uint16 : simd256bit {
|
||||
};
|
||||
|
||||
// not really a std::min because it returns an elementwise min
|
||||
inline simd16uint16 min(simd16uint16 av, simd16uint16 bv) {
|
||||
inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
|
||||
return simd16uint16::binary_func(
|
||||
av, bv, [](uint16_t a, uint16_t b) { return std::min(a, b); });
|
||||
}
|
||||
|
||||
inline simd16uint16 max(simd16uint16 av, simd16uint16 bv) {
|
||||
inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
|
||||
return simd16uint16::binary_func(
|
||||
av, bv, [](uint16_t a, uint16_t b) { return std::max(a, b); });
|
||||
}
|
||||
@ -241,7 +240,7 @@ inline simd16uint16 max(simd16uint16 av, simd16uint16 bv) {
|
||||
// decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
|
||||
// return (a0 + a1, b0 + b1)
|
||||
// TODO find a better name
|
||||
inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
|
||||
inline simd16uint16 combine2x2(const simd16uint16& a, const simd16uint16& b) {
|
||||
simd16uint16 c;
|
||||
for (int j = 0; j < 8; j++) {
|
||||
c.u16[j] = a.u16[j] + a.u16[j + 8];
|
||||
@ -252,7 +251,10 @@ inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
|
||||
|
||||
// compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
|
||||
// of d0 and d1 with thr
|
||||
inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
|
||||
inline uint32_t cmp_ge32(
|
||||
const simd16uint16& d0,
|
||||
const simd16uint16& d1,
|
||||
const simd16uint16& thr) {
|
||||
uint32_t gem = 0;
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (d0.u16[j] >= thr.u16[j]) {
|
||||
@ -265,7 +267,10 @@ inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
|
||||
return gem;
|
||||
}
|
||||
|
||||
inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
|
||||
inline uint32_t cmp_le32(
|
||||
const simd16uint16& d0,
|
||||
const simd16uint16& d1,
|
||||
const simd16uint16& thr) {
|
||||
uint32_t gem = 0;
|
||||
for (int j = 0; j < 16; j++) {
|
||||
if (d0.u16[j] <= thr.u16[j]) {
|
||||
@ -290,7 +295,7 @@ struct simd32uint8 : simd256bit {
|
||||
set1(x);
|
||||
}
|
||||
|
||||
explicit simd32uint8(simd256bit x) : simd256bit(x) {}
|
||||
explicit simd32uint8(const simd256bit& x) : simd256bit(x) {}
|
||||
|
||||
explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
|
||||
|
||||
@ -318,10 +323,11 @@ struct simd32uint8 : simd256bit {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
static simd32uint8 binary_func(
|
||||
simd32uint8 a,
|
||||
simd32uint8 b,
|
||||
std::function<uint8_t(uint8_t, uint8_t)> f) {
|
||||
const simd32uint8& a,
|
||||
const simd32uint8& b,
|
||||
F&& f) {
|
||||
simd32uint8 c;
|
||||
for (int j = 0; j < 32; j++) {
|
||||
c.u8[j] = f(a.u8[j], b.u8[j]);
|
||||
@ -329,19 +335,19 @@ struct simd32uint8 : simd256bit {
|
||||
return c;
|
||||
}
|
||||
|
||||
simd32uint8 operator&(simd256bit other) const {
|
||||
simd32uint8 operator&(const simd256bit& other) const {
|
||||
return binary_func(*this, simd32uint8(other), [](uint8_t a, uint8_t b) {
|
||||
return a & b;
|
||||
});
|
||||
}
|
||||
|
||||
simd32uint8 operator+(simd32uint8 other) const {
|
||||
simd32uint8 operator+(const simd32uint8& other) const {
|
||||
return binary_func(
|
||||
*this, other, [](uint8_t a, uint8_t b) { return a + b; });
|
||||
}
|
||||
|
||||
// The very important operation that everything relies on
|
||||
simd32uint8 lookup_2_lanes(simd32uint8 idx) const {
|
||||
simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
|
||||
simd32uint8 c;
|
||||
for (int j = 0; j < 32; j++) {
|
||||
if (idx.u8[j] & 0x80) {
|
||||
@ -361,7 +367,7 @@ struct simd32uint8 : simd256bit {
|
||||
// extract + 0-extend lane
|
||||
// this operation is slow (3 cycles)
|
||||
|
||||
simd32uint8 operator+=(simd32uint8 other) {
|
||||
simd32uint8 operator+=(const simd32uint8& other) {
|
||||
*this = *this + other;
|
||||
return *this;
|
||||
}
|
||||
@ -374,7 +380,9 @@ struct simd32uint8 : simd256bit {
|
||||
|
||||
// convert with saturation
|
||||
// careful: this does not cross lanes, so the order is weird
|
||||
inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
|
||||
inline simd32uint8 uint16_to_uint8_saturate(
|
||||
const simd16uint16& a,
|
||||
const simd16uint16& b) {
|
||||
simd32uint8 c;
|
||||
|
||||
auto saturate_16_to_8 = [](uint16_t x) { return x >= 256 ? 0xff : x; };
|
||||
@ -389,7 +397,7 @@ inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
|
||||
}
|
||||
|
||||
/// get most significant bit of each byte
|
||||
inline uint32_t get_MSBs(simd32uint8 a) {
|
||||
inline uint32_t get_MSBs(const simd32uint8& a) {
|
||||
uint32_t res = 0;
|
||||
for (int i = 0; i < 32; i++) {
|
||||
if (a.u8[i] & 0x80) {
|
||||
@ -400,7 +408,10 @@ inline uint32_t get_MSBs(simd32uint8 a) {
|
||||
}
|
||||
|
||||
/// use MSB of each byte of mask to select a byte between a and b
|
||||
inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
|
||||
inline simd32uint8 blendv(
|
||||
const simd32uint8& a,
|
||||
const simd32uint8& b,
|
||||
const simd32uint8& mask) {
|
||||
simd32uint8 c;
|
||||
for (int i = 0; i < 32; i++) {
|
||||
if (mask.u8[i] & 0x80) {
|
||||
@ -420,7 +431,7 @@ struct simd8uint32 : simd256bit {
|
||||
set1(x);
|
||||
}
|
||||
|
||||
explicit simd8uint32(simd256bit x) : simd256bit(x) {}
|
||||
explicit simd8uint32(const simd256bit& x) : simd256bit(x) {}
|
||||
|
||||
explicit simd8uint32(const uint8_t* x) : simd256bit((const void*)x) {}
|
||||
|
||||
@ -452,7 +463,7 @@ struct simd8uint32 : simd256bit {
|
||||
struct simd8float32 : simd256bit {
|
||||
simd8float32() {}
|
||||
|
||||
explicit simd8float32(simd256bit x) : simd256bit(x) {}
|
||||
explicit simd8float32(const simd256bit& x) : simd256bit(x) {}
|
||||
|
||||
explicit simd8float32(float x) {
|
||||
set1(x);
|
||||
@ -468,10 +479,11 @@ struct simd8float32 : simd256bit {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
static simd8float32 binary_func(
|
||||
simd8float32 a,
|
||||
simd8float32 b,
|
||||
std::function<float(float, float)> f) {
|
||||
const simd8float32& a,
|
||||
const simd8float32& b,
|
||||
F&& f) {
|
||||
simd8float32 c;
|
||||
for (int j = 0; j < 8; j++) {
|
||||
c.f32[j] = f(a.f32[j], b.f32[j]);
|
||||
@ -479,17 +491,17 @@ struct simd8float32 : simd256bit {
|
||||
return c;
|
||||
}
|
||||
|
||||
simd8float32 operator*(simd8float32 other) const {
|
||||
simd8float32 operator*(const simd8float32& other) const {
|
||||
return binary_func(
|
||||
*this, other, [](float a, float b) { return a * b; });
|
||||
}
|
||||
|
||||
simd8float32 operator+(simd8float32 other) const {
|
||||
simd8float32 operator+(const simd8float32& other) const {
|
||||
return binary_func(
|
||||
*this, other, [](float a, float b) { return a + b; });
|
||||
}
|
||||
|
||||
simd8float32 operator-(simd8float32 other) const {
|
||||
simd8float32 operator-(const simd8float32& other) const {
|
||||
return binary_func(
|
||||
*this, other, [](float a, float b) { return a - b; });
|
||||
}
|
||||
@ -506,7 +518,7 @@ struct simd8float32 : simd256bit {
|
||||
};
|
||||
|
||||
// hadd does not cross lanes
|
||||
inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
|
||||
inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
|
||||
simd8float32 c;
|
||||
c.f32[0] = a.f32[0] + a.f32[1];
|
||||
c.f32[1] = a.f32[2] + a.f32[3];
|
||||
@ -521,7 +533,7 @@ inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
|
||||
inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
|
||||
simd8float32 c;
|
||||
c.f32[0] = a.f32[0];
|
||||
c.f32[1] = b.f32[0];
|
||||
@ -536,7 +548,7 @@ inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
|
||||
return c;
|
||||
}
|
||||
|
||||
inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
|
||||
inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
|
||||
simd8float32 c;
|
||||
c.f32[0] = a.f32[2];
|
||||
c.f32[1] = b.f32[2];
|
||||
@ -552,7 +564,10 @@ inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
|
||||
}
|
||||
|
||||
// compute a * b + c
|
||||
inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
|
||||
inline simd8float32 fmadd(
|
||||
const simd8float32& a,
|
||||
const simd8float32& b,
|
||||
const simd8float32& c) {
|
||||
simd8float32 res;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
|
||||
|
Loading…
x
Reference in New Issue
Block a user