From 8f2a72a8e6ddaa92442b7b8f81c431491006d4c2 Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 15 Mar 2022 17:35:11 -0700 Subject: [PATCH] AVX2 optimized IVFPQ scanning code (#2253) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2253 add a specialized AVX2 version for IVFPQScannerT::scan_list_with_table Reviewed By: mdouze Differential Revision: D34733503 fbshipit-source-id: a428de04548426b39bc5a092b9f6802eadbd184d --- faiss/IndexIVFPQ.cpp | 137 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/faiss/IndexIVFPQ.cpp b/faiss/IndexIVFPQ.cpp index c9c61c6d0..5ec1ad3ec 100644 --- a/faiss/IndexIVFPQ.cpp +++ b/faiss/IndexIVFPQ.cpp @@ -30,6 +30,12 @@ #include +#include + +#ifdef __AVX2__ +#include +#endif + namespace faiss { /***************************************** @@ -864,6 +870,136 @@ struct IVFPQScannerT : QueryTables { * Scaning the codes: simple PQ scan. *****************************************************/ +#ifdef __AVX2__ + /// version of the scan where we use precomputed tables. + /// non PQDecoder8 version. + template + typename std::enable_if::value), void>::type + scan_list_with_table( + size_t ncode, + const uint8_t* codes, + SearchResultType& res) const { + for (size_t j = 0; j < ncode; j++) { + PQDecoder decoder(codes, pq.nbits); + codes += pq.code_size; + float dis = dis0; + const float* tab = sim_table; + + for (size_t m = 0; m < pq.M; m++) { + dis += tab[decoder.decode()]; + tab += pq.ksub; + } + + res.add(j, dis); + } + } + + /// version of the scan where we use precomputed tables. + /// AVX2 PQDecoder8 version. + template + typename std::enable_if::value, void>::type + scan_list_with_table( + size_t ncode, + const uint8_t* codes, + SearchResultType& res) const { + for (size_t j = 0; j < ncode; j++) { + float dis = dis0; + + // + size_t m = 0; + const size_t pqM16 = pq.M / 16; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + + const __m256i ksub = _mm256_set1_epi32(pq.ksub); + __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + offsets_0 = _mm256_mullo_epi32(offsets_0, ksub); + + // accumulators of partial sums + __m256 partialSum = _mm256_setzero_ps(); + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + const __m128i mm1 = + _mm_loadu_si128((const __m128i_u*)(codes + m)); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm1); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += pq.ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + + // move high 8 uint8 to low ones + const __m128i mm2 = + _mm_unpackhi_epi64(mm1, _mm_setzero_si128()); + { + // convert uint8 values (low part of __m128i) to int32 + // values + const __m256i idx1 = _mm256_cvtepu8_epi32(mm2); + + // add offsets + const __m256i indices_to_read_from = + _mm256_add_epi32(idx1, offsets_0); + + // gather 8 values, similar to 8 operations of tab[idx] + __m256 collected = _mm256_i32gather_ps( + tab, indices_to_read_from, sizeof(float)); + tab += pq.ksub * 8; + + // collect partial sums + partialSum = _mm256_add_ps(partialSum, collected); + } + } + + // horizontal sum for partialSum + const __m256 h0 = _mm256_hadd_ps(partialSum, partialSum); + const __m256 h1 = _mm256_hadd_ps(h0, h0); + + // extract high and low __m128 regs from __m256 + const __m128 h2 = _mm256_extractf128_ps(h1, 1); + const __m128 h3 = _mm256_castps256_ps128(h1); + + // get a final hsum into all 4 regs + const __m128 h4 = _mm_add_ss(h2, h3); + + // extract f[0] from __m128 + const float hsum = _mm_cvtss_f32(h4); + dis += hsum; + } + + // + if (m < pq.M) { + // process leftovers + PQDecoder decoder(codes + m, pq.nbits); + + for (; m < pq.M; m++) { + dis += tab[decoder.decode()]; + tab += pq.ksub; + } + } + + codes += pq.code_size; + + // done + res.add(j, dis); + } + } +#else /// version of the scan where we use precomputed tables template void scan_list_with_table( @@ -884,6 +1020,7 @@ struct IVFPQScannerT : QueryTables { res.add(j, dis); } } +#endif /// tables are not precomputed, but pointers are provided to the /// relevant X_c|x_r tables