AVX2 optimized IVFPQ scanning code (#2253)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2253

add a specialized AVX2 version for IVFPQScannerT::scan_list_with_table

Reviewed By: mdouze

Differential Revision: D34733503

fbshipit-source-id: a428de04548426b39bc5a092b9f6802eadbd184d
pull/2255/head
Alexandr Guzhva 2022-03-15 17:35:11 -07:00 committed by Facebook GitHub Bot
parent 80bf6a2bc6
commit 8f2a72a8e6
1 changed files with 137 additions and 0 deletions

View File

@ -30,6 +30,12 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/ProductQuantizer.h>
#ifdef __AVX2__
#include <immintrin.h>
#endif
namespace faiss {
/*****************************************
@ -864,6 +870,136 @@ struct IVFPQScannerT : QueryTables {
* Scaning the codes: simple PQ scan.
*****************************************************/
#ifdef __AVX2__
/// version of the scan where we use precomputed tables.
/// non PQDecoder8 version.
template <class SearchResultType, typename T = PQDecoder>
typename std::enable_if<!(std::is_same<T, PQDecoder8>::value), void>::type
scan_list_with_table(
size_t ncode,
const uint8_t* codes,
SearchResultType& res) const {
for (size_t j = 0; j < ncode; j++) {
PQDecoder decoder(codes, pq.nbits);
codes += pq.code_size;
float dis = dis0;
const float* tab = sim_table;
for (size_t m = 0; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
}
res.add(j, dis);
}
}
/// version of the scan where we use precomputed tables.
/// AVX2 PQDecoder8 version.
template <class SearchResultType, typename T = PQDecoder>
typename std::enable_if<std::is_same<T, PQDecoder8>::value, void>::type
scan_list_with_table(
size_t ncode,
const uint8_t* codes,
SearchResultType& res) const {
for (size_t j = 0; j < ncode; j++) {
float dis = dis0;
//
size_t m = 0;
const size_t pqM16 = pq.M / 16;
const float* tab = sim_table;
if (pqM16 > 0) {
// process 16 values per loop
const __m256i ksub = _mm256_set1_epi32(pq.ksub);
__m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);
// accumulators of partial sums
__m256 partialSum = _mm256_setzero_ps();
// loop
for (m = 0; m < pqM16 * 16; m += 16) {
// load 16 uint8 values
const __m128i mm1 =
_mm_loadu_si128((const __m128i_u*)(codes + m));
{
// convert uint8 values (low part of __m128i) to int32
// values
const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
// add offsets
const __m256i indices_to_read_from =
_mm256_add_epi32(idx1, offsets_0);
// gather 8 values, similar to 8 operations of tab[idx]
__m256 collected = _mm256_i32gather_ps(
tab, indices_to_read_from, sizeof(float));
tab += pq.ksub * 8;
// collect partial sums
partialSum = _mm256_add_ps(partialSum, collected);
}
// move high 8 uint8 to low ones
const __m128i mm2 =
_mm_unpackhi_epi64(mm1, _mm_setzero_si128());
{
// convert uint8 values (low part of __m128i) to int32
// values
const __m256i idx1 = _mm256_cvtepu8_epi32(mm2);
// add offsets
const __m256i indices_to_read_from =
_mm256_add_epi32(idx1, offsets_0);
// gather 8 values, similar to 8 operations of tab[idx]
__m256 collected = _mm256_i32gather_ps(
tab, indices_to_read_from, sizeof(float));
tab += pq.ksub * 8;
// collect partial sums
partialSum = _mm256_add_ps(partialSum, collected);
}
}
// horizontal sum for partialSum
const __m256 h0 = _mm256_hadd_ps(partialSum, partialSum);
const __m256 h1 = _mm256_hadd_ps(h0, h0);
// extract high and low __m128 regs from __m256
const __m128 h2 = _mm256_extractf128_ps(h1, 1);
const __m128 h3 = _mm256_castps256_ps128(h1);
// get a final hsum into all 4 regs
const __m128 h4 = _mm_add_ss(h2, h3);
// extract f[0] from __m128
const float hsum = _mm_cvtss_f32(h4);
dis += hsum;
}
//
if (m < pq.M) {
// process leftovers
PQDecoder decoder(codes + m, pq.nbits);
for (; m < pq.M; m++) {
dis += tab[decoder.decode()];
tab += pq.ksub;
}
}
codes += pq.code_size;
// done
res.add(j, dis);
}
}
#else
/// version of the scan where we use precomputed tables
template <class SearchResultType>
void scan_list_with_table(
@ -884,6 +1020,7 @@ struct IVFPQScannerT : QueryTables {
res.add(j, dis);
}
}
#endif
/// tables are not precomputed, but pointers are provided to the
/// relevant X_c|x_r tables