fast C++ templates for sa_decode (#2354)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2354

A specialized code that provides 2x-3x faster Index::sa_decode for
* IVF256,PQ[1]x8np
* Residual[1]x8,PQ[2]x8

Reviewed By: mdouze

Differential Revision: D37092134

fbshipit-source-id: d848b6cf1aefa826a5ca01e41935aa5d46f5dcc7
pull/2362/head
Alexandr Guzhva 2022-06-16 09:20:19 -07:00 committed by Facebook GitHub Bot
parent 578fbc9a8e
commit 3986ebffca
5 changed files with 1187 additions and 0 deletions

View File

@ -0,0 +1,573 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
#pragma once
#include <immintrin.h>
#include <cstddef>
#include <cstdint>
namespace faiss {
namespace cppcontrib {
namespace {
// Processes 4 float values.
// Returns {
// [0..3] = *coarse[0..3] + *fine[0..3];
// }
inline __m128 elementaryBlock4x1b(
const float* const __restrict coarse,
const float* const __restrict fine) {
// load fine
const __m128 fineValue = _mm_loadu_ps(fine);
// load coarse
const __m128 coarseValue = _mm_loadu_ps(coarse);
// add coarse and fine
return _mm_add_ps(fineValue, coarseValue);
}
// Processes 4 float values.
// Returns {
// [0..3] = existingValue[0..3] + weight * (*coarse[0..3] + *fine[0..3]);
// }
inline __m128 elementaryBlock4x1bAccum(
const float* const __restrict coarse,
const float* const __restrict fine,
const float weight,
const __m128 existingValue) {
// add coarse and fine
const __m128 combinedValue = elementaryBlock4x1b(coarse, fine);
// this operation is expected to be optimized by a compiler
const __m128 weightAvx = _mm_set1_ps(weight);
// do fma
return _mm_fmadd_ps(combinedValue, weightAvx, existingValue);
}
// Processes 8 float values.
// Returns {
// [0..3] = *coarse[0..3] + *fine0[0..3];
// [4..7] = *coarse[4..7] + *fine1[0..3];
// }
inline __m256 elementaryBlock4x2b(
const float* const __restrict coarse,
const float* const __restrict fine0,
const float* const __restrict fine1) {
// load fine
const __m128 fineValue0 = _mm_loadu_ps(fine0);
const __m128 fineValue1 = _mm_loadu_ps(fine1);
// load coarse
const __m256 coarseValue = _mm256_loadu_ps(coarse);
// combine two 4b into a single 8b
const __m256 combinedFineValue = _mm256_set_m128(fineValue1, fineValue0);
// add coarse and fine
return _mm256_add_ps(combinedFineValue, coarseValue);
}
// Processes 8 float values.
// Returns {
// [0..3] = existingValue[0..3] + weight * (*coarse[0..3] + *fine0[0..3]);
// [4..7] = existingValue[4..7] + weight * (*coarse[4..7] + *fine1[0..3]);
// }
inline __m256 elementaryBlock4x2bAccum(
const float* const __restrict coarse,
const float* const __restrict fine0,
const float* const __restrict fine1,
const float weight,
const __m256 existingValue) {
// add coarse and fine
const __m256 combinedValue = elementaryBlock4x2b(coarse, fine0, fine1);
// this operation is expected to be optimized by a compiler
const __m256 weightAvx2 = _mm256_set1_ps(weight);
// do fma
return _mm256_fmadd_ps(combinedValue, weightAvx2, existingValue);
}
// Processes 8 float values.
// Returns {
// [0..7] = *coarse[0..7] + *fine[0..7];
// }
inline __m256 elementaryBlock8x1b(
const float* const __restrict coarse,
const float* const __restrict fine) {
// load fine
const __m256 fineValue = _mm256_loadu_ps(fine);
// load coarse
const __m256 coarseValue = _mm256_loadu_ps(coarse);
// add coarse and fine
return _mm256_add_ps(fineValue, coarseValue);
}
// Processes 8 float values.
// Returns {
// [0..7] = existingValue[0..7] + weight * (*coarse[0..7] + *fine[0..7]);
// }
inline __m256 elementaryBlock8x1bAccum(
const float* const __restrict coarse,
const float* const __restrict fine,
const float weight,
const __m256 existingValue) {
// add coarse and fine
const __m256 combinedValue = elementaryBlock8x1b(coarse, fine);
// this operation is expected to be optimized by a compiler
const __m256 weightAvx2 = _mm256_set1_ps(weight);
// do fma
return _mm256_fmadd_ps(combinedValue, weightAvx2, existingValue);
}
// reduces the number of read operations from RAM
template <intptr_t DIM, intptr_t CODE_SIZE, intptr_t CPOS>
struct Uint8Reader {
static intptr_t get(const uint8_t* const __restrict codes) {
constexpr intptr_t nCodeWords = DIM / CODE_SIZE;
if constexpr (nCodeWords <= 3) {
// Read 1 byte (movzx).
return codes[CPOS];
} else {
// Read using 4-bytes.
// Reading using 8-byte takes too many registers somewhy.
const uint32_t* __restrict codes32 =
reinterpret_cast<const uint32_t*>(codes);
constexpr intptr_t ELEMENT_TO_READ = CPOS / 4;
constexpr intptr_t SUB_ELEMENT = CPOS % 4;
const uint32_t code32 = codes32[ELEMENT_TO_READ];
switch (SUB_ELEMENT) {
case 0:
return (code32 & 0x000000FF);
case 1:
return (code32 & 0x0000FF00) >> 8;
case 2:
return (code32 & 0x00FF0000) >> 16;
case 3:
return (code32) >> 24;
}
}
}
};
// The following code uses template-based for-loop unrolling,
// because the compiler does not do that on its own as needed.
// The idea is the following:
// template<int I, int MAX>
// struct Foo {
// static void bar() {
// doSomething(I);
// Foo<I + 1, MAX>::bar();
// }
// };
//
// template<int MAX>
// struct Foo<MAX, MAX> {
// static void bar() {}
// };
//
// Initiate the loop:
// Foo<0, MAX>::bar();
// Suitable for IVF256,PQ[1]x8
// Suitable for Residual[1]x8,PQ[2]x8
template <intptr_t DIM, intptr_t COARSE_SIZE, intptr_t FINE_SIZE, intptr_t CPOS>
struct Index2LevelDecoderImpl {
static constexpr intptr_t coarseCentroidIdx = CPOS / COARSE_SIZE;
static constexpr intptr_t coarseCentroidOffset = CPOS % COARSE_SIZE;
static constexpr intptr_t fineCentroidIdx = CPOS / FINE_SIZE;
static constexpr intptr_t fineCentroidOffset = CPOS % FINE_SIZE;
static constexpr intptr_t QPOS_LEFT = FINE_SIZE - fineCentroidOffset;
// process 1 sample
static void store(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
float* const __restrict outputStore) {
// coarse quantizer
const uint8_t* const __restrict coarse0 = code0;
// fine quantizer
const uint8_t* const __restrict fine0 = code0 + (DIM / COARSE_SIZE);
if constexpr (FINE_SIZE == 4) {
// clang-format off
// process chunks, 4 float
// but 8 floats per loop
const intptr_t coarseCode0 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0a = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx + 0>::get(fine0);
const intptr_t fineCode0b = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx + 1>::get(fine0);
const __m256 storeValue = elementaryBlock4x2b(
pqCoarseCentroids0 + (coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + ((fineCentroidIdx + 0) * 256 + fineCode0a) * FINE_SIZE + fineCentroidOffset,
pqFineCentroids0 + ((fineCentroidIdx + 1) * 256 + fineCode0b) * FINE_SIZE + fineCentroidOffset);
_mm256_storeu_ps(outputStore + CPOS, storeValue);
// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, CPOS + 8>::store(
pqCoarseCentroids0, pqFineCentroids0, code0,
outputStore);
// clang-format on
} else if constexpr (QPOS_LEFT >= 8) {
// clang-format off
// process chunks, 8 float
const intptr_t coarseCode0 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0 = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx>::get(fine0);
const __m256 storeValue = elementaryBlock8x1b(
pqCoarseCentroids0 + (coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + (fineCentroidIdx * 256 + fineCode0) * FINE_SIZE + fineCentroidOffset);
_mm256_storeu_ps(outputStore + CPOS, storeValue);
// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, CPOS + 8>::store(
pqCoarseCentroids0, pqFineCentroids0, code0,
outputStore);
// clang-format on
} else if constexpr (QPOS_LEFT >= 4) {
// clang-format off
// process chunks, 4 float
const intptr_t coarseCode0 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0 = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx>::get(fine0);
const __m128 storeValue = elementaryBlock4x1b(
pqCoarseCentroids0 + (coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + (fineCentroidIdx * 256 + fineCode0) * FINE_SIZE + fineCentroidOffset);
_mm_storeu_ps(outputStore + CPOS, storeValue);
// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, CPOS + 4>::store(
pqCoarseCentroids0, pqFineCentroids0, code0,
outputStore);
// clang-format on
}
}
// process 1 sample
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
float* const __restrict outputAccum) {
// coarse quantizer
const uint8_t* const __restrict coarse0 = code0;
// fine quantizer
const uint8_t* const __restrict fine0 = code0 + (DIM / COARSE_SIZE);
if constexpr (FINE_SIZE == 4) {
// clang-format off
// process chunks, 4 float
// but 8 floats per loop
const intptr_t coarseCode0 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0a = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx + 0>::get(fine0);
const intptr_t fineCode0b = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx + 1>::get(fine0);
__m256 existingValue = _mm256_loadu_ps(outputAccum + CPOS);
existingValue = elementaryBlock4x2bAccum(
pqCoarseCentroids0 + (coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + ((fineCentroidIdx + 0) * 256 + fineCode0a) * FINE_SIZE + fineCentroidOffset,
pqFineCentroids0 + ((fineCentroidIdx + 1) * 256 + fineCode0b) * FINE_SIZE + fineCentroidOffset,
weight0,
existingValue);
_mm256_storeu_ps(outputAccum + CPOS, existingValue);
// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, CPOS + 8>::accum(
pqCoarseCentroids0, pqFineCentroids0, code0, weight0,
outputAccum);
// clang-format on
} else if constexpr (QPOS_LEFT >= 8) {
// clang-format off
// process chunks, 8 float
const intptr_t coarseCode0 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0 = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx>::get(fine0);
__m256 existingValue = _mm256_loadu_ps(outputAccum + CPOS);
existingValue = elementaryBlock8x1bAccum(
pqCoarseCentroids0 + (coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + (fineCentroidIdx * 256 + fineCode0) * FINE_SIZE + fineCentroidOffset,
weight0,
existingValue);
_mm256_storeu_ps(outputAccum + CPOS, existingValue);
// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, CPOS + 8>::accum(
pqCoarseCentroids0, pqFineCentroids0, code0, weight0,
outputAccum);
// clang-format on
} else if constexpr (QPOS_LEFT >= 4) {
// clang-format off
// process chunks, 4 float
const intptr_t coarseCode0 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0 = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx>::get(fine0);
__m128 existingValue = _mm_loadu_ps(outputAccum + CPOS);
existingValue = elementaryBlock4x1bAccum(
pqCoarseCentroids0 + (coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + (fineCentroidIdx * 256 + fineCode0) * FINE_SIZE + fineCentroidOffset,
weight0,
existingValue);
_mm_storeu_ps(outputAccum + CPOS, existingValue);
// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, CPOS + 4>::accum(
pqCoarseCentroids0, pqFineCentroids0, code0, weight0,
outputAccum);
// clang-format on
}
}
// process 2 samples
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
const float* const __restrict pqCoarseCentroids1,
const float* const __restrict pqFineCentroids1,
const uint8_t* const __restrict code1,
const float weight1,
float* const __restrict outputAccum) {
// coarse quantizer
const uint8_t* const __restrict coarse0 = code0;
const uint8_t* const __restrict coarse1 = code1;
// fine quantizer
const uint8_t* const __restrict fine0 = code0 + (DIM / COARSE_SIZE);
const uint8_t* const __restrict fine1 = code1 + (DIM / COARSE_SIZE);
if constexpr (FINE_SIZE == 4) {
// clang-format off
// process chunks, 4 float
// but 8 floats per loop
const intptr_t coarseCode0 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0a = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx + 0>::get(fine0);
const intptr_t fineCode0b = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx + 1>::get(fine0);
const intptr_t coarseCode1 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse1);
const intptr_t fineCode1a = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx + 0>::get(fine1);
const intptr_t fineCode1b = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx + 1>::get(fine1);
__m256 existingValue = _mm256_loadu_ps(outputAccum + CPOS);
existingValue = elementaryBlock4x2bAccum(
pqCoarseCentroids0 + (coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + ((fineCentroidIdx + 0) * 256 + fineCode0a) * FINE_SIZE + fineCentroidOffset,
pqFineCentroids0 + ((fineCentroidIdx + 1) * 256 + fineCode0b) * FINE_SIZE + fineCentroidOffset,
weight0,
existingValue);
existingValue = elementaryBlock4x2bAccum(
pqCoarseCentroids1 + (coarseCentroidIdx * 256 + coarseCode1) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids1 + ((fineCentroidIdx + 0) * 256 + fineCode1a) * FINE_SIZE + fineCentroidOffset,
pqFineCentroids1 + ((fineCentroidIdx + 1) * 256 + fineCode1b) * FINE_SIZE + fineCentroidOffset,
weight1,
existingValue);
_mm256_storeu_ps(outputAccum + CPOS, existingValue);
// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, CPOS + 8>::accum(
pqCoarseCentroids0, pqFineCentroids0, code0, weight0,
pqCoarseCentroids1, pqFineCentroids1, code1, weight1,
outputAccum);
// clang-format on
} else if constexpr (QPOS_LEFT >= 8) {
// clang-format off
// process chunks, 8 float
const intptr_t coarseCode0 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0 = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx>::get(fine0);
const intptr_t coarseCode1 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse1);
const intptr_t fineCode1 = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx>::get(fine1);
__m256 existingValue = _mm256_loadu_ps(outputAccum + CPOS);
existingValue = elementaryBlock8x1bAccum(
pqCoarseCentroids0 + (coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + (fineCentroidIdx * 256 + fineCode0) * FINE_SIZE + fineCentroidOffset,
weight0,
existingValue);
existingValue = elementaryBlock8x1bAccum(
pqCoarseCentroids1 + (coarseCentroidIdx * 256 + coarseCode1) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids1 + (fineCentroidIdx * 256 + fineCode1) * FINE_SIZE + fineCentroidOffset,
weight1,
existingValue);
_mm256_storeu_ps(outputAccum + CPOS, existingValue);
// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, CPOS + 8>::accum(
pqCoarseCentroids0, pqFineCentroids0, code0, weight0,
pqCoarseCentroids1, pqFineCentroids1, code1, weight1,
outputAccum);
// clang-format on
} else if constexpr (QPOS_LEFT >= 4) {
// clang-format off
// process chunks, 4 float
const intptr_t coarseCode0 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse0);
const intptr_t fineCode0 = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx>::get(fine0);
const intptr_t coarseCode1 = Uint8Reader<DIM, COARSE_SIZE, coarseCentroidIdx>::get(coarse1);
const intptr_t fineCode1 = Uint8Reader<DIM, FINE_SIZE, fineCentroidIdx>::get(fine1);
__m128 existingValue = _mm_loadu_ps(outputAccum + CPOS);
existingValue = elementaryBlock4x1bAccum(
pqCoarseCentroids0 + (coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids0 + (fineCentroidIdx * 256 + fineCode0) * FINE_SIZE + fineCentroidOffset,
weight0,
existingValue);
existingValue = elementaryBlock4x1bAccum(
pqCoarseCentroids1 + (coarseCentroidIdx * 256 + coarseCode1) * COARSE_SIZE + coarseCentroidOffset,
pqFineCentroids1 + (fineCentroidIdx * 256 + fineCode1) * FINE_SIZE + fineCentroidOffset,
weight1,
existingValue);
_mm_storeu_ps(outputAccum + CPOS, existingValue);
// next
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, CPOS + 4>::accum(
pqCoarseCentroids0, pqFineCentroids0, code0, weight0,
pqCoarseCentroids1, pqFineCentroids1, code1, weight1,
outputAccum);
// clang-format on
}
}
};
// Suitable for IVF256,PQ[1]x8
// Suitable for Residual[1]x8,PQ[2]x8
// This partial specialization is expected to do nothing.
template <intptr_t DIM, intptr_t COARSE_SIZE, intptr_t FINE_SIZE>
struct Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, DIM> {
// clang-format off
// process 1 sample
static void store(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
float* const __restrict outputStore) {}
// process 1 sample
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
float* const __restrict outputAccum) {}
// process 2 samples
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
const float* const __restrict pqCoarseCentroids1,
const float* const __restrict pqFineCentroids1,
const uint8_t* const __restrict code1,
const float weight1,
float* const __restrict outputAccum) {}
// clang-format on
};
} // namespace
// Suitable for IVF256,PQ[1]x8
// Suitable for Residual[1]x8,PQ[2]x8
template <intptr_t DIM, intptr_t COARSE_SIZE, intptr_t FINE_SIZE>
struct Index2LevelDecoder {
// Process 1 sample.
static void store(
const float* const __restrict pqCoarseCentroids,
const float* const __restrict pqFineCentroids,
const uint8_t* const __restrict code,
float* const __restrict outputStore) {
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, 0>::store(
pqCoarseCentroids, pqFineCentroids, code, outputStore);
}
// Process 1 sample.
// Performs outputAccum += weight * decoded(code)
static void accum(
const float* const __restrict pqCoarseCentroids,
const float* const __restrict pqFineCentroids,
const uint8_t* const __restrict code,
const float weight,
float* const __restrict outputAccum) {
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, 0>::accum(
pqCoarseCentroids, pqFineCentroids, code, weight, outputAccum);
}
// process 2 samples
// Performs outputAccum += weight0 * decoded(code0) + weight1 *
// decoded(code1)
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
const float* const __restrict pqCoarseCentroids1,
const float* const __restrict pqFineCentroids1,
const uint8_t* const __restrict code1,
const float weight1,
float* const __restrict outputAccum) {
Index2LevelDecoderImpl<DIM, COARSE_SIZE, FINE_SIZE, 0>::accum(
pqCoarseCentroids0,
pqFineCentroids0,
code0,
weight0,
pqCoarseCentroids1,
pqFineCentroids1,
code1,
weight1,
outputAccum);
}
};
} // namespace cppcontrib
} // namespace faiss

View File

@ -0,0 +1,137 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
#pragma once
#include <cstddef>
#include <cstdint>
namespace faiss {
namespace cppcontrib {
// Suitable for IVF256,PQ[1]x8
// Suitable for Residual[1]x8,PQ[2]x8
template <intptr_t DIM, intptr_t COARSE_SIZE, intptr_t FINE_SIZE>
struct Index2LevelDecoder {
// Process 1 sample.
// Performs outputStore = decoded(code)
static void store(
const float* const __restrict pqCoarseCentroids,
const float* const __restrict pqFineCentroids,
const uint8_t* const __restrict code,
float* const __restrict outputStore) {
// coarse quantizer
const uint8_t* const __restrict coarse = code;
// fine quantizer
const uint8_t* const __restrict fine = code + (DIM / COARSE_SIZE);
#pragma unroll
for (intptr_t i = 0; i < DIM; i++) {
const intptr_t coarseCentroidIdx = i / COARSE_SIZE;
const intptr_t coarseCentroidOffset = i % COARSE_SIZE;
const intptr_t fineCentroidIdx = i / FINE_SIZE;
const intptr_t fineCentroidOffset = i % FINE_SIZE;
const intptr_t coarseCode = coarse[coarseCentroidIdx];
const intptr_t fineCode = fine[fineCentroidIdx];
const float* const __restrict coarsePtr = pqCoarseCentroids +
(coarseCentroidIdx * 256 + coarseCode) * COARSE_SIZE +
coarseCentroidOffset;
const float* const __restrict finePtr = pqFineCentroids +
(fineCentroidIdx * 256 + fineCode) * FINE_SIZE +
fineCentroidOffset;
outputStore[i] = *coarsePtr + *finePtr;
}
}
// Process 1 sample.
// Performs outputAccum += weight * decoded(code)
static void accum(
const float* const __restrict pqCoarseCentroids,
const float* const __restrict pqFineCentroids,
const uint8_t* const __restrict code,
const float weight,
float* const __restrict outputAccum) {
// coarse quantizer
const uint8_t* const __restrict coarse = code;
// fine quantizer
const uint8_t* const __restrict fine = code + (DIM / COARSE_SIZE);
#pragma unroll
for (intptr_t i = 0; i < DIM; i++) {
const intptr_t coarseCentroidIdx = i / COARSE_SIZE;
const intptr_t coarseCentroidOffset = i % COARSE_SIZE;
const intptr_t fineCentroidIdx = i / FINE_SIZE;
const intptr_t fineCentroidOffset = i % FINE_SIZE;
const intptr_t coarseCode = coarse[coarseCentroidIdx];
const intptr_t fineCode = fine[fineCentroidIdx];
const float* const __restrict coarsePtr = pqCoarseCentroids +
(coarseCentroidIdx * 256 + coarseCode) * COARSE_SIZE +
coarseCentroidOffset;
const float* const __restrict finePtr = pqFineCentroids +
(fineCentroidIdx * 256 + fineCode) * FINE_SIZE +
fineCentroidOffset;
outputAccum[i] += weight * (*coarsePtr + *finePtr);
}
}
// process 2 samples
// Performs
// outputAccum += weight0 * decoded(code0) + weight1 * decoded(code1)
static void accum(
const float* const __restrict pqCoarseCentroids0,
const float* const __restrict pqFineCentroids0,
const uint8_t* const __restrict code0,
const float weight0,
const float* const __restrict pqCoarseCentroids1,
const float* const __restrict pqFineCentroids1,
const uint8_t* const __restrict code1,
const float weight1,
float* const __restrict outputAccum) {
// coarse quantizer
const uint8_t* const __restrict coarse0 = code0;
const uint8_t* const __restrict coarse1 = code1;
// fine quantizer
const uint8_t* const __restrict fine0 = code0 + (DIM / COARSE_SIZE);
const uint8_t* const __restrict fine1 = code1 + (DIM / COARSE_SIZE);
#pragma unroll
for (intptr_t i = 0; i < DIM; i++) {
const intptr_t coarseCentroidIdx = i / COARSE_SIZE;
const intptr_t coarseCentroidOffset = i % COARSE_SIZE;
const intptr_t fineCentroidIdx = i / FINE_SIZE;
const intptr_t fineCentroidOffset = i % FINE_SIZE;
const intptr_t coarseCode0 = coarse0[coarseCentroidIdx];
const intptr_t fineCode0 = fine0[fineCentroidIdx];
const intptr_t coarseCode1 = coarse1[coarseCentroidIdx];
const intptr_t fineCode1 = fine1[fineCentroidIdx];
const float* const __restrict coarsePtr0 = pqCoarseCentroids0 +
(coarseCentroidIdx * 256 + coarseCode0) * COARSE_SIZE +
coarseCentroidOffset;
const float* const __restrict finePtr0 = pqFineCentroids0 +
(fineCentroidIdx * 256 + fineCode0) * FINE_SIZE +
fineCentroidOffset;
const float* const __restrict coarsePtr1 = pqCoarseCentroids1 +
(coarseCentroidIdx * 256 + coarseCode1) * COARSE_SIZE +
coarseCentroidOffset;
const float* const __restrict finePtr1 = pqFineCentroids1 +
(fineCentroidIdx * 256 + fineCode1) * FINE_SIZE +
fineCentroidOffset;
outputAccum[i] += weight0 * (*coarsePtr0 + *finePtr0) +
weight1 * (*coarsePtr1 + *finePtr1);
}
}
};
} // namespace cppcontrib
} // namespace faiss

View File

@ -0,0 +1,85 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
#pragma once
// This file contains a custom fast implementation of faiss::Index::sa_decode()
// function for the following index families:
// * IVF256,PQ[1]x8np
// * Residual[1]x8,PQ[2]x8
//
// The goal was to achieve the maximum performance, so the template version it
// is. The provided index families share the same code for sa_decode.
// The front-end code looks the following:
// {
// template <intptr_t DIM, intptr_t COARSE_SIZE, intptr_t FINE_SIZE>
// struct Index2LevelDecoder { /*...*/ };
// }
// * DIM is the dimensionality of data
// * COARSE_SIZE is the dimensionality of the coarse quantizer (IVF, Residual)
// * FINE_SIZE is the dimensionality of the ProductQuantizer dsq
// For example, "IVF256,PQ8np" for 160-dim data translates into
// Index2LevelDecoder<160,160,20>
// For example, "Residual4x8,PQ16" for 256-dim data translates into
// Index2LevelDecoder<256,64,16>
//
// Unlike the general purpose version in faiss::Index::sa_decode(),
// this version provides the following functions:
// * ::store, which is similar to sa_decode(1, input, output),
// The method signature is the following:
// {
// void store(
// const float* const __restrict pqCoarseCentroids,
// const float* const __restrict pqFineCentroids,
// const uint8_t* const __restrict code,
// float* const __restrict outputStore);
// }
// * ::accum, which is used to create a linear combination
// of decoded vectors:
// {
// faiss::Index* index;
// float weight;
//
// std::vector<float> buffer(d, 0);
//
// index->sa_decode(1, input, buffer.data());
// for (size_t iDim = 0; iDim < d; iDim++)
// output[iDim] += weight * input[iDim];
// }
// The method signature is the following:
// {
// static void accum(
// const float* const __restrict pqCoarseCentroids,
// const float* const __restrict pqFineCentroids,
// const uint8_t* const __restrict code,
// const float weight,
// float* const __restrict outputAccum);
// }
// * There is an additional overload for ::accum that decodes two vectors
// per call. This provides an additional speedup because of a CPU
// superscalar architecture. Doing more vectors per call is less attractive
// because of the possible lack of available CPU registers, but it is still
// doable.
// The method signature is the following:
// {
// static void accum(
// const float* const __restrict pqCoarseCentroids0,
// const float* const __restrict pqFineCentroids0,
// const uint8_t* const __restrict code0,
// const float weight0,
// const float* const __restrict pqCoarseCentroids1,
// const float* const __restrict pqFineCentroids1,
// const uint8_t* const __restrict code1,
// const float weight1,
// float* const __restrict outputAccum);
// }
// The provided version is not multithreaded.
//
// Currently, an AVX2+FMA implementation is available. AVX512 version is also
// doable, but it was found to be slower than AVX2 for real world applications
// that I needed.
#ifdef __AVX2__
#include <faiss/cppcontrib/SaDecodeKernels-avx2-inl.h>
#else
#include <faiss/cppcontrib/SaDecodeKernels-inl.h>
#endif

View File

@ -20,6 +20,7 @@ set(FAISS_TEST_SRC
test_threaded_index.cpp
test_transfer_invlists.cpp
test_mem_leak.cpp
test_cppcontrib_sa_decode.cpp
)
add_executable(faiss_test ${FAISS_TEST_SRC})

View File

@ -0,0 +1,391 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include <tuple>
#include <vector>
#include <faiss/Index.h>
#include <faiss/Index2Layer.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/IndexPQ.h>
#include <faiss/impl/io.h>
#include <faiss/index_factory.h>
#include <faiss/index_io.h>
#include <faiss/cppcontrib/SaDecodeKernels.h>
using namespace ::testing;
using ::testing::TestWithParam;
using ::testing::Values;
std::tuple<std::shared_ptr<faiss::Index>, std::vector<uint8_t>> trainDataset(
const std::vector<float>& input,
const uint64_t n,
const uint64_t d,
const std::string& description) {
// train an index
auto index = std::shared_ptr<faiss::Index>(
faiss::index_factory((int)d, description.c_str()));
index->train((int)n, input.data());
// encode
const size_t codeSize = index->sa_code_size();
std::vector<uint8_t> encodedData(n * codeSize);
index->sa_encode(n, input.data(), encodedData.data());
return std::make_tuple(std::move(index), std::move(encodedData));
}
bool testIfIVFPQ(
const std::shared_ptr<faiss::Index>& index,
float** pqCoarseCentroidsQ,
float** pqFineCentroidsQ) {
if (pqFineCentroidsQ == nullptr || pqCoarseCentroidsQ == nullptr) {
return false;
}
faiss::IndexIVFPQ* const indexQ =
dynamic_cast<faiss::IndexIVFPQ*>(index.get());
if (indexQ == nullptr) {
return false;
}
auto const coarseIndexQ =
dynamic_cast<faiss::IndexFlatCodes*>(indexQ->quantizer);
if (coarseIndexQ == nullptr) {
return false;
}
*pqFineCentroidsQ = indexQ->pq.centroids.data();
*pqCoarseCentroidsQ = reinterpret_cast<float*>(coarseIndexQ->codes.data());
return true;
}
bool testIfResidualPQ(
const std::shared_ptr<faiss::Index>& index,
float** pqCoarseCentroidsQ,
float** pqFineCentroidsQ) {
if (pqFineCentroidsQ == nullptr || pqCoarseCentroidsQ == nullptr) {
return false;
}
faiss::Index2Layer* const indexQ =
dynamic_cast<faiss::Index2Layer*>(index.get());
if (indexQ == nullptr) {
return false;
}
auto const coarseIndexQ =
dynamic_cast<faiss::MultiIndexQuantizer*>(indexQ->q1.quantizer);
if (coarseIndexQ == nullptr) {
return false;
}
*pqFineCentroidsQ = indexQ->pq.centroids.data();
*pqCoarseCentroidsQ = coarseIndexQ->pq.centroids.data();
return true;
}
template <typename T>
void verify(
const uint64_t n,
const uint64_t d,
const std::shared_ptr<faiss::Index>& index,
const std::vector<uint8_t>& encodedData) {
//
float* pqFineCentroidsQ = nullptr;
float* pqCoarseCentroidsQ = nullptr;
//
testIfIVFPQ(index, &pqCoarseCentroidsQ, &pqFineCentroidsQ);
testIfResidualPQ(index, &pqCoarseCentroidsQ, &pqFineCentroidsQ);
//
const size_t codeSize = index->sa_code_size();
//
std::default_random_engine rng(123);
std::uniform_real_distribution<float> u(0, 1);
// test general purpose version vs contrib::store
std::vector<float> outputFaiss(d, 0);
std::vector<float> tmpFaiss(d, 0);
std::vector<float> tmpContrib(d, 0);
for (size_t i = 0; i < n; i++) {
// compute using faiss
index->sa_decode(1, encodedData.data() + i * codeSize, tmpFaiss.data());
// compute using contrib
T::store(
pqCoarseCentroidsQ,
pqFineCentroidsQ,
encodedData.data() + i * codeSize,
tmpContrib.data());
// compare
for (size_t j = 0; j < d; j++)
ASSERT_FLOAT_EQ(tmpFaiss[j], tmpContrib[j]);
// save for the further comparison
const float weight = u(rng);
for (size_t j = 0; j < d; j++)
outputFaiss[j] += weight * tmpFaiss[j];
}
// test contrib::accum, 1 sample per iteration
rng.seed(123);
std::vector<float> outputContrib1s(d, 0);
for (size_t i = 0; i < n; i++) {
const float weight0 = u(rng);
T::accum(
pqCoarseCentroidsQ,
pqFineCentroidsQ,
encodedData.data() + (i + 0) * codeSize,
weight0,
outputContrib1s.data());
}
// verify
for (size_t j = 0; j < d; j++) {
ASSERT_FLOAT_EQ(outputFaiss[j], outputContrib1s[j]);
}
// test contrib::accum, 2 samples per iteration
rng.seed(123);
std::vector<float> outputContrib2s(d, 0);
for (size_t i = 0; i < n; i += 2) {
const float weight0 = u(rng);
const float weight1 = u(rng);
T::accum(
pqCoarseCentroidsQ,
pqFineCentroidsQ,
encodedData.data() + (i + 0) * codeSize,
weight0,
pqCoarseCentroidsQ,
pqFineCentroidsQ,
encodedData.data() + (i + 1) * codeSize,
weight1,
outputContrib2s.data());
}
// verify
for (size_t j = 0; j < d; j++) {
ASSERT_NEAR(outputFaiss[j], outputContrib2s[j], 1e-2);
}
}
std::vector<float> generate(const size_t n, const size_t d) {
std::vector<float> data(n * d);
std::minstd_rand rng(345);
std::uniform_real_distribution<float> ux(0, 1);
//
for (size_t k = 0; k < n; k++) {
for (size_t j = 0; j < d; j++) {
data[k * d + j] = ux(rng);
}
}
return data;
}
template <typename T>
void test(const uint64_t n, const uint64_t d, const std::string& description) {
auto data = generate(n, d);
auto [index, encodedData] = trainDataset(data, n, d, description);
verify<T>(n, d, index, encodedData);
}
constexpr size_t NSAMPLES = 4096;
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D256_IVF256_PQ64) {
using T = faiss::cppcontrib::Index2LevelDecoder<256, 256, 4>;
test<T>(NSAMPLES, 256, "IVF256,PQ64np");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D256_IVF256_PQ32) {
using T = faiss::cppcontrib::Index2LevelDecoder<256, 256, 8>;
test<T>(NSAMPLES, 256, "IVF256,PQ32np");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D256_IVF256_PQ16) {
using T = faiss::cppcontrib::Index2LevelDecoder<256, 256, 16>;
test<T>(NSAMPLES, 256, "IVF256,PQ16np");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D256_IVF256_PQ8) {
using T = faiss::cppcontrib::Index2LevelDecoder<256, 256, 32>;
test<T>(NSAMPLES, 256, "IVF256,PQ8np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D192_IVF256_PQ48) {
using T = faiss::cppcontrib::Index2LevelDecoder<192, 192, 4>;
test<T>(NSAMPLES, 192, "IVF256,PQ48np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D192_IVF256_PQ24) {
using T = faiss::cppcontrib::Index2LevelDecoder<192, 192, 8>;
test<T>(NSAMPLES, 192, "IVF256,PQ24np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D192_IVF256_PQ16) {
using T = faiss::cppcontrib::Index2LevelDecoder<192, 192, 12>;
test<T>(NSAMPLES, 192, "IVF256,PQ16np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D192_IVF256_PQ12) {
using T = faiss::cppcontrib::Index2LevelDecoder<192, 192, 16>;
test<T>(NSAMPLES, 192, "IVF256,PQ12np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D160_IVF256_PQ40) {
using T = faiss::cppcontrib::Index2LevelDecoder<160, 160, 4>;
test<T>(NSAMPLES, 160, "IVF256,PQ40np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D160_IVF256_PQ20) {
using T = faiss::cppcontrib::Index2LevelDecoder<160, 160, 8>;
test<T>(NSAMPLES, 160, "IVF256,PQ20np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D160_IVF256_PQ10) {
using T = faiss::cppcontrib::Index2LevelDecoder<160, 160, 16>;
test<T>(NSAMPLES, 160, "IVF256,PQ10np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D160_IVF256_PQ8) {
using T = faiss::cppcontrib::Index2LevelDecoder<160, 160, 20>;
test<T>(NSAMPLES, 160, "IVF256,PQ8np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D128_IVF256_PQ32) {
using T = faiss::cppcontrib::Index2LevelDecoder<128, 128, 4>;
test<T>(NSAMPLES, 128, "IVF256,PQ32np");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D128_IVF256_PQ16) {
using T = faiss::cppcontrib::Index2LevelDecoder<128, 128, 8>;
test<T>(NSAMPLES, 128, "IVF256,PQ16np");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D128_IVF256_PQ8) {
using T = faiss::cppcontrib::Index2LevelDecoder<128, 128, 16>;
test<T>(NSAMPLES, 128, "IVF256,PQ8np");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D128_IVF256_PQ4) {
using T = faiss::cppcontrib::Index2LevelDecoder<128, 128, 32>;
test<T>(NSAMPLES, 128, "IVF256,PQ4np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D64_IVF256_PQ32) {
using T = faiss::cppcontrib::Index2LevelDecoder<64, 64, 4>;
test<T>(NSAMPLES, 64, "IVF256,PQ16np");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D64_IVF256_PQ16) {
using T = faiss::cppcontrib::Index2LevelDecoder<64, 64, 8>;
test<T>(NSAMPLES, 64, "IVF256,PQ8np");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D64_IVF256_PQ8) {
using T = faiss::cppcontrib::Index2LevelDecoder<64, 64, 16>;
test<T>(NSAMPLES, 64, "IVF256,PQ4np");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D256_Residual4x8_PQ64) {
using T = faiss::cppcontrib::Index2LevelDecoder<256, 64, 4>;
test<T>(NSAMPLES, 256, "Residual4x8,PQ64");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D256_Residual4x8_PQ32) {
using T = faiss::cppcontrib::Index2LevelDecoder<256, 64, 8>;
test<T>(NSAMPLES, 256, "Residual4x8,PQ32");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D256_Residual4x8_PQ16) {
using T = faiss::cppcontrib::Index2LevelDecoder<256, 64, 16>;
test<T>(NSAMPLES, 256, "Residual4x8,PQ16");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D256_Residual4x8_PQ8) {
using T = faiss::cppcontrib::Index2LevelDecoder<256, 64, 32>;
test<T>(NSAMPLES, 256, "Residual4x8,PQ8");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D160_Residual4x8_PQ10) {
using T = faiss::cppcontrib::Index2LevelDecoder<160, 40, 16>;
test<T>(NSAMPLES, 160, "Residual4x8,PQ10");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D160_Residual2x8_PQ10) {
using T = faiss::cppcontrib::Index2LevelDecoder<160, 80, 16>;
test<T>(NSAMPLES, 160, "Residual2x8,PQ10");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D160_Residual1x8_PQ10) {
using T = faiss::cppcontrib::Index2LevelDecoder<160, 160, 16>;
test<T>(NSAMPLES, 160, "Residual1x8,PQ10");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D128_Residual4x8_PQ32) {
using T = faiss::cppcontrib::Index2LevelDecoder<128, 32, 4>;
test<T>(NSAMPLES, 128, "Residual4x8,PQ32");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D128_Residual4x8_PQ16) {
using T = faiss::cppcontrib::Index2LevelDecoder<128, 32, 8>;
test<T>(NSAMPLES, 128, "Residual4x8,PQ16");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D128_Residual4x8_PQ8) {
using T = faiss::cppcontrib::Index2LevelDecoder<128, 32, 16>;
test<T>(NSAMPLES, 128, "Residual4x8,PQ8");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D128_Residual4x8_PQ4) {
using T = faiss::cppcontrib::Index2LevelDecoder<128, 32, 32>;
test<T>(NSAMPLES, 128, "Residual4x8,PQ4");
}
//
TEST(TEST_CPPCONTRIB_SA_DECODE, D64_Residual4x8_PQ16) {
using T = faiss::cppcontrib::Index2LevelDecoder<64, 16, 4>;
test<T>(NSAMPLES, 64, "Residual4x8,PQ16");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D64_Residual4x8_PQ8) {
using T = faiss::cppcontrib::Index2LevelDecoder<64, 16, 8>;
test<T>(NSAMPLES, 64, "Residual4x8,PQ8");
}
TEST(TEST_CPPCONTRIB_SA_DECODE, D64_Residual4x8_PQ4) {
using T = faiss::cppcontrib::Index2LevelDecoder<64, 16, 16>;
test<T>(NSAMPLES, 64, "Residual4x8,PQ4");
}