Make simdlib_emulated.h faster (#1814)

Summary:
related: https://github.com/facebookresearch/faiss/issues/1812

This PR improves the performance of contents in `simdlib_emulated.h` .
`IndexPQFastScan` and `IndexIVFPQFastScan` will become faster on non-AVX2 environments, e.g., 4x faster on SIFT1M.
This PR contains below changes:

- Use `template` instead of `std::function` on argument of `unary_func` and `binary_func`
    - Because `std::function` hinders some optimizations like function inlining
- Use `const T&` instead of `T` for vector classes like `simd16uint16` on argument of functions
    - Vector classes on `simdlib_emulated.h` has the data member as array, so the runtime cost for copying is not so low.
    - Passing by const lvalue-ref prevents copy.

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1814

Reviewed By: beauby

Differential Revision: D27760072

Pulled By: mdouze

fbshipit-source-id: cbc5a14658d1960b24ce55a395e71c80998742dc
This commit is contained in:
Y.Imaizumi 2021-04-16 00:23:56 -07:00 committed by Facebook GitHub Bot
parent c62ab3a696
commit b85b4308f2

View File

@ -10,7 +10,6 @@
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <functional>
#include <string>
namespace faiss {
@ -72,7 +71,7 @@ struct simd16uint16 : simd256bit {
set1(x);
}
explicit simd16uint16(simd256bit x) : simd256bit(x) {}
explicit simd16uint16(const simd256bit& x) : simd256bit(x) {}
explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
@ -94,9 +93,8 @@ struct simd16uint16 : simd256bit {
return elements_to_string("%3d,");
}
static simd16uint16 unary_func(
simd16uint16 a,
std::function<uint16_t(uint16_t)> f) {
template <typename F>
static simd16uint16 unary_func(const simd16uint16& a, F&& f) {
simd16uint16 c;
for (int j = 0; j < 16; j++) {
c.u16[j] = f(a.u16[j]);
@ -104,10 +102,11 @@ struct simd16uint16 : simd256bit {
return c;
}
template <typename F>
static simd16uint16 binary_func(
simd16uint16 a,
simd16uint16 b,
std::function<uint16_t(uint16_t, uint16_t)> f) {
const simd16uint16& a,
const simd16uint16& b,
F&& f) {
simd16uint16 c;
for (int j = 0; j < 16; j++) {
c.u16[j] = f(a.u16[j], b.u16[j]);
@ -131,34 +130,34 @@ struct simd16uint16 : simd256bit {
return unary_func(*this, [shift](uint16_t a) { return a << shift; });
}
simd16uint16 operator+=(simd16uint16 other) {
simd16uint16 operator+=(const simd16uint16& other) {
*this = *this + other;
return *this;
}
simd16uint16 operator-=(simd16uint16 other) {
simd16uint16 operator-=(const simd16uint16& other) {
*this = *this - other;
return *this;
}
simd16uint16 operator+(simd16uint16 other) const {
simd16uint16 operator+(const simd16uint16& other) const {
return binary_func(
*this, other, [](uint16_t a, uint16_t b) { return a + b; });
}
simd16uint16 operator-(simd16uint16 other) const {
simd16uint16 operator-(const simd16uint16& other) const {
return binary_func(
*this, other, [](uint16_t a, uint16_t b) { return a - b; });
}
simd16uint16 operator&(simd256bit other) const {
simd16uint16 operator&(const simd256bit& other) const {
return binary_func(
*this, simd16uint16(other), [](uint16_t a, uint16_t b) {
return a & b;
});
}
simd16uint16 operator|(simd256bit other) const {
simd16uint16 operator|(const simd256bit& other) const {
return binary_func(
*this, simd16uint16(other), [](uint16_t a, uint16_t b) {
return a | b;
@ -166,7 +165,7 @@ struct simd16uint16 : simd256bit {
}
// returns binary masks
simd16uint16 operator==(simd16uint16 other) const {
simd16uint16 operator==(const simd16uint16& other) const {
return binary_func(*this, other, [](uint16_t a, uint16_t b) {
return a == b ? 0xffff : 0;
});
@ -183,7 +182,7 @@ struct simd16uint16 : simd256bit {
// mask of elements where this >= thresh
// 2 bit per component: 16 * 2 = 32 bit
uint32_t ge_mask(simd16uint16 thresh) const {
uint32_t ge_mask(const simd16uint16& thresh) const {
uint32_t gem = 0;
for (int j = 0; j < 16; j++) {
if (u16[j] >= thresh.u16[j]) {
@ -193,15 +192,15 @@ struct simd16uint16 : simd256bit {
return gem;
}
uint32_t le_mask(simd16uint16 thresh) const {
uint32_t le_mask(const simd16uint16& thresh) const {
return thresh.ge_mask(*this);
}
uint32_t gt_mask(simd16uint16 thresh) const {
uint32_t gt_mask(const simd16uint16& thresh) const {
return ~le_mask(thresh);
}
bool all_gt(simd16uint16 thresh) const {
bool all_gt(const simd16uint16& thresh) const {
return le_mask(thresh) == 0;
}
@ -210,7 +209,7 @@ struct simd16uint16 : simd256bit {
return u16[i];
}
void accu_min(simd16uint16 incoming) {
void accu_min(const simd16uint16& incoming) {
for (int j = 0; j < 16; j++) {
if (incoming.u16[j] < u16[j]) {
u16[j] = incoming.u16[j];
@ -218,7 +217,7 @@ struct simd16uint16 : simd256bit {
}
}
void accu_max(simd16uint16 incoming) {
void accu_max(const simd16uint16& incoming) {
for (int j = 0; j < 16; j++) {
if (incoming.u16[j] > u16[j]) {
u16[j] = incoming.u16[j];
@ -228,12 +227,12 @@ struct simd16uint16 : simd256bit {
};
// not really a std::min because it returns an elementwise min
inline simd16uint16 min(simd16uint16 av, simd16uint16 bv) {
inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
return simd16uint16::binary_func(
av, bv, [](uint16_t a, uint16_t b) { return std::min(a, b); });
}
inline simd16uint16 max(simd16uint16 av, simd16uint16 bv) {
inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
return simd16uint16::binary_func(
av, bv, [](uint16_t a, uint16_t b) { return std::max(a, b); });
}
@ -241,7 +240,7 @@ inline simd16uint16 max(simd16uint16 av, simd16uint16 bv) {
// decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
// return (a0 + a1, b0 + b1)
// TODO find a better name
inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
inline simd16uint16 combine2x2(const simd16uint16& a, const simd16uint16& b) {
simd16uint16 c;
for (int j = 0; j < 8; j++) {
c.u16[j] = a.u16[j] + a.u16[j + 8];
@ -252,7 +251,10 @@ inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
// compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
// of d0 and d1 with thr
inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
inline uint32_t cmp_ge32(
const simd16uint16& d0,
const simd16uint16& d1,
const simd16uint16& thr) {
uint32_t gem = 0;
for (int j = 0; j < 16; j++) {
if (d0.u16[j] >= thr.u16[j]) {
@ -265,7 +267,10 @@ inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
return gem;
}
inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
inline uint32_t cmp_le32(
const simd16uint16& d0,
const simd16uint16& d1,
const simd16uint16& thr) {
uint32_t gem = 0;
for (int j = 0; j < 16; j++) {
if (d0.u16[j] <= thr.u16[j]) {
@ -290,7 +295,7 @@ struct simd32uint8 : simd256bit {
set1(x);
}
explicit simd32uint8(simd256bit x) : simd256bit(x) {}
explicit simd32uint8(const simd256bit& x) : simd256bit(x) {}
explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
@ -318,10 +323,11 @@ struct simd32uint8 : simd256bit {
}
}
template <typename F>
static simd32uint8 binary_func(
simd32uint8 a,
simd32uint8 b,
std::function<uint8_t(uint8_t, uint8_t)> f) {
const simd32uint8& a,
const simd32uint8& b,
F&& f) {
simd32uint8 c;
for (int j = 0; j < 32; j++) {
c.u8[j] = f(a.u8[j], b.u8[j]);
@ -329,19 +335,19 @@ struct simd32uint8 : simd256bit {
return c;
}
simd32uint8 operator&(simd256bit other) const {
simd32uint8 operator&(const simd256bit& other) const {
return binary_func(*this, simd32uint8(other), [](uint8_t a, uint8_t b) {
return a & b;
});
}
simd32uint8 operator+(simd32uint8 other) const {
simd32uint8 operator+(const simd32uint8& other) const {
return binary_func(
*this, other, [](uint8_t a, uint8_t b) { return a + b; });
}
// The very important operation that everything relies on
simd32uint8 lookup_2_lanes(simd32uint8 idx) const {
simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
simd32uint8 c;
for (int j = 0; j < 32; j++) {
if (idx.u8[j] & 0x80) {
@ -361,7 +367,7 @@ struct simd32uint8 : simd256bit {
// extract + 0-extend lane
// this operation is slow (3 cycles)
simd32uint8 operator+=(simd32uint8 other) {
simd32uint8 operator+=(const simd32uint8& other) {
*this = *this + other;
return *this;
}
@ -374,7 +380,9 @@ struct simd32uint8 : simd256bit {
// convert with saturation
// careful: this does not cross lanes, so the order is weird
inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
inline simd32uint8 uint16_to_uint8_saturate(
const simd16uint16& a,
const simd16uint16& b) {
simd32uint8 c;
auto saturate_16_to_8 = [](uint16_t x) { return x >= 256 ? 0xff : x; };
@ -389,7 +397,7 @@ inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
}
/// get most significant bit of each byte
inline uint32_t get_MSBs(simd32uint8 a) {
inline uint32_t get_MSBs(const simd32uint8& a) {
uint32_t res = 0;
for (int i = 0; i < 32; i++) {
if (a.u8[i] & 0x80) {
@ -400,7 +408,10 @@ inline uint32_t get_MSBs(simd32uint8 a) {
}
/// use MSB of each byte of mask to select a byte between a and b
inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
inline simd32uint8 blendv(
const simd32uint8& a,
const simd32uint8& b,
const simd32uint8& mask) {
simd32uint8 c;
for (int i = 0; i < 32; i++) {
if (mask.u8[i] & 0x80) {
@ -420,7 +431,7 @@ struct simd8uint32 : simd256bit {
set1(x);
}
explicit simd8uint32(simd256bit x) : simd256bit(x) {}
explicit simd8uint32(const simd256bit& x) : simd256bit(x) {}
explicit simd8uint32(const uint8_t* x) : simd256bit((const void*)x) {}
@ -452,7 +463,7 @@ struct simd8uint32 : simd256bit {
struct simd8float32 : simd256bit {
simd8float32() {}
explicit simd8float32(simd256bit x) : simd256bit(x) {}
explicit simd8float32(const simd256bit& x) : simd256bit(x) {}
explicit simd8float32(float x) {
set1(x);
@ -468,10 +479,11 @@ struct simd8float32 : simd256bit {
}
}
template <typename F>
static simd8float32 binary_func(
simd8float32 a,
simd8float32 b,
std::function<float(float, float)> f) {
const simd8float32& a,
const simd8float32& b,
F&& f) {
simd8float32 c;
for (int j = 0; j < 8; j++) {
c.f32[j] = f(a.f32[j], b.f32[j]);
@ -479,17 +491,17 @@ struct simd8float32 : simd256bit {
return c;
}
simd8float32 operator*(simd8float32 other) const {
simd8float32 operator*(const simd8float32& other) const {
return binary_func(
*this, other, [](float a, float b) { return a * b; });
}
simd8float32 operator+(simd8float32 other) const {
simd8float32 operator+(const simd8float32& other) const {
return binary_func(
*this, other, [](float a, float b) { return a + b; });
}
simd8float32 operator-(simd8float32 other) const {
simd8float32 operator-(const simd8float32& other) const {
return binary_func(
*this, other, [](float a, float b) { return a - b; });
}
@ -506,7 +518,7 @@ struct simd8float32 : simd256bit {
};
// hadd does not cross lanes
inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
simd8float32 c;
c.f32[0] = a.f32[0] + a.f32[1];
c.f32[1] = a.f32[2] + a.f32[3];
@ -521,7 +533,7 @@ inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
return c;
}
inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
simd8float32 c;
c.f32[0] = a.f32[0];
c.f32[1] = b.f32[0];
@ -536,7 +548,7 @@ inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
return c;
}
inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
simd8float32 c;
c.f32[0] = a.f32[2];
c.f32[1] = b.f32[2];
@ -552,7 +564,10 @@ inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
}
// compute a * b + c
inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
inline simd8float32 fmadd(
const simd8float32& a,
const simd8float32& b,
const simd8float32& c) {
simd8float32 res;
for (int i = 0; i < 8; i++) {
res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];