AVX2 optimized IVFPQ scanning code (#2253)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2253 add a specialized AVX2 version for IVFPQScannerT::scan_list_with_table Reviewed By: mdouze Differential Revision: D34733503 fbshipit-source-id: a428de04548426b39bc5a092b9f6802eadbd184dpull/2255/head
parent
80bf6a2bc6
commit
8f2a72a8e6
|
@ -30,6 +30,12 @@
|
|||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
#include <faiss/impl/ProductQuantizer.h>
|
||||
|
||||
#ifdef __AVX2__
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
namespace faiss {
|
||||
|
||||
/*****************************************
|
||||
|
@ -864,6 +870,136 @@ struct IVFPQScannerT : QueryTables {
|
|||
* Scaning the codes: simple PQ scan.
|
||||
*****************************************************/
|
||||
|
||||
#ifdef __AVX2__
|
||||
/// version of the scan where we use precomputed tables.
|
||||
/// non PQDecoder8 version.
|
||||
template <class SearchResultType, typename T = PQDecoder>
|
||||
typename std::enable_if<!(std::is_same<T, PQDecoder8>::value), void>::type
|
||||
scan_list_with_table(
|
||||
size_t ncode,
|
||||
const uint8_t* codes,
|
||||
SearchResultType& res) const {
|
||||
for (size_t j = 0; j < ncode; j++) {
|
||||
PQDecoder decoder(codes, pq.nbits);
|
||||
codes += pq.code_size;
|
||||
float dis = dis0;
|
||||
const float* tab = sim_table;
|
||||
|
||||
for (size_t m = 0; m < pq.M; m++) {
|
||||
dis += tab[decoder.decode()];
|
||||
tab += pq.ksub;
|
||||
}
|
||||
|
||||
res.add(j, dis);
|
||||
}
|
||||
}
|
||||
|
||||
/// version of the scan where we use precomputed tables.
|
||||
/// AVX2 PQDecoder8 version.
|
||||
template <class SearchResultType, typename T = PQDecoder>
|
||||
typename std::enable_if<std::is_same<T, PQDecoder8>::value, void>::type
|
||||
scan_list_with_table(
|
||||
size_t ncode,
|
||||
const uint8_t* codes,
|
||||
SearchResultType& res) const {
|
||||
for (size_t j = 0; j < ncode; j++) {
|
||||
float dis = dis0;
|
||||
|
||||
//
|
||||
size_t m = 0;
|
||||
const size_t pqM16 = pq.M / 16;
|
||||
|
||||
const float* tab = sim_table;
|
||||
|
||||
if (pqM16 > 0) {
|
||||
// process 16 values per loop
|
||||
|
||||
const __m256i ksub = _mm256_set1_epi32(pq.ksub);
|
||||
__m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
||||
offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);
|
||||
|
||||
// accumulators of partial sums
|
||||
__m256 partialSum = _mm256_setzero_ps();
|
||||
|
||||
// loop
|
||||
for (m = 0; m < pqM16 * 16; m += 16) {
|
||||
// load 16 uint8 values
|
||||
const __m128i mm1 =
|
||||
_mm_loadu_si128((const __m128i_u*)(codes + m));
|
||||
{
|
||||
// convert uint8 values (low part of __m128i) to int32
|
||||
// values
|
||||
const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
|
||||
|
||||
// add offsets
|
||||
const __m256i indices_to_read_from =
|
||||
_mm256_add_epi32(idx1, offsets_0);
|
||||
|
||||
// gather 8 values, similar to 8 operations of tab[idx]
|
||||
__m256 collected = _mm256_i32gather_ps(
|
||||
tab, indices_to_read_from, sizeof(float));
|
||||
tab += pq.ksub * 8;
|
||||
|
||||
// collect partial sums
|
||||
partialSum = _mm256_add_ps(partialSum, collected);
|
||||
}
|
||||
|
||||
// move high 8 uint8 to low ones
|
||||
const __m128i mm2 =
|
||||
_mm_unpackhi_epi64(mm1, _mm_setzero_si128());
|
||||
{
|
||||
// convert uint8 values (low part of __m128i) to int32
|
||||
// values
|
||||
const __m256i idx1 = _mm256_cvtepu8_epi32(mm2);
|
||||
|
||||
// add offsets
|
||||
const __m256i indices_to_read_from =
|
||||
_mm256_add_epi32(idx1, offsets_0);
|
||||
|
||||
// gather 8 values, similar to 8 operations of tab[idx]
|
||||
__m256 collected = _mm256_i32gather_ps(
|
||||
tab, indices_to_read_from, sizeof(float));
|
||||
tab += pq.ksub * 8;
|
||||
|
||||
// collect partial sums
|
||||
partialSum = _mm256_add_ps(partialSum, collected);
|
||||
}
|
||||
}
|
||||
|
||||
// horizontal sum for partialSum
|
||||
const __m256 h0 = _mm256_hadd_ps(partialSum, partialSum);
|
||||
const __m256 h1 = _mm256_hadd_ps(h0, h0);
|
||||
|
||||
// extract high and low __m128 regs from __m256
|
||||
const __m128 h2 = _mm256_extractf128_ps(h1, 1);
|
||||
const __m128 h3 = _mm256_castps256_ps128(h1);
|
||||
|
||||
// get a final hsum into all 4 regs
|
||||
const __m128 h4 = _mm_add_ss(h2, h3);
|
||||
|
||||
// extract f[0] from __m128
|
||||
const float hsum = _mm_cvtss_f32(h4);
|
||||
dis += hsum;
|
||||
}
|
||||
|
||||
//
|
||||
if (m < pq.M) {
|
||||
// process leftovers
|
||||
PQDecoder decoder(codes + m, pq.nbits);
|
||||
|
||||
for (; m < pq.M; m++) {
|
||||
dis += tab[decoder.decode()];
|
||||
tab += pq.ksub;
|
||||
}
|
||||
}
|
||||
|
||||
codes += pq.code_size;
|
||||
|
||||
// done
|
||||
res.add(j, dis);
|
||||
}
|
||||
}
|
||||
#else
|
||||
/// version of the scan where we use precomputed tables
|
||||
template <class SearchResultType>
|
||||
void scan_list_with_table(
|
||||
|
@ -884,6 +1020,7 @@ struct IVFPQScannerT : QueryTables {
|
|||
res.add(j, dis);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/// tables are not precomputed, but pointers are provided to the
|
||||
/// relevant X_c|x_r tables
|
||||
|
|
Loading…
Reference in New Issue