diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index f0d901a60..a8f2b84cf 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -181,6 +181,16 @@ static inline uint32_t cmp_xe32( return d0_mask | static_cast(d1_mask) << 16; } +template +static inline uint16x8_t vshlq(uint16x8_t vec) { + return vshlq_n_u16(vec, Shift); +} + +template +static inline uint16x8_t vshrq(uint16x8_t vec) { + return vshrq_n_u16(vec, Shift); +} + } // namespace simdlib } // namespace detail @@ -252,14 +262,112 @@ struct simd16uint16 { // shift must be known at compile time simd16uint16 operator>>(const int shift) const { - return simd16uint16{detail::simdlib::unary_func( - data, [shift](uint16x8_t a) { return vshrq_n_u16(a, shift); })}; + switch (shift) { + case 0: + return *this; + case 1: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<1>)}; + case 2: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<2>)}; + case 3: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<3>)}; + case 4: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<4>)}; + case 5: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<5>)}; + case 6: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<6>)}; + case 7: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<7>)}; + case 8: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<8>)}; + case 9: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<9>)}; + case 10: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<10>)}; + case 11: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<11>)}; + case 12: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<12>)}; + case 13: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<13>)}; + case 14: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<14>)}; + case 15: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshrq<15>)}; + default: + FAISS_THROW_FMT("Invalid shift %d", shift); + } } // shift must be known at compile time simd16uint16 operator<<(const int shift) const { - return simd16uint16{detail::simdlib::unary_func( - data, [shift](uint16x8_t a) { return vshlq_n_u16(a, shift); })}; + switch (shift) { + case 0: + return *this; + case 1: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<1>)}; + case 2: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<2>)}; + case 3: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<3>)}; + case 4: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<4>)}; + case 5: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<5>)}; + case 6: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<6>)}; + case 7: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<7>)}; + case 8: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<8>)}; + case 9: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<9>)}; + case 10: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<10>)}; + case 11: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<11>)}; + case 12: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<12>)}; + case 13: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<13>)}; + case 14: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<14>)}; + case 15: + return simd16uint16{detail::simdlib::unary_func( + data, detail::simdlib::vshlq<15>)}; + default: + FAISS_THROW_FMT("Invalid shift %d", shift); + } } simd16uint16 operator+=(const simd16uint16& other) {