From b32abc95c21de4b737e219b51223b8c44922988d Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Tue, 29 Mar 2022 10:21:23 -0700 Subject: [PATCH] ProductQuantizer::compute_code tracks the nearest vector index in a register rather than stores the distances in a buffer. (#2280) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2280 Add a new function call fvec_L2sqr_ny_nearest and a demonstration of its implementation for 4 bits Reviewed By: mdouze Differential Revision: D35189945 fbshipit-source-id: d1b2ba42851df195123c7e318a8dcf26f775eaba --- faiss/impl/ProductQuantizer.cpp | 50 ++++++-- faiss/utils/distances.h | 10 ++ faiss/utils/distances_simd.cpp | 208 ++++++++++++++++++++++++++++++++ 3 files changed, 256 insertions(+), 12 deletions(-) diff --git a/faiss/impl/ProductQuantizer.cpp b/faiss/impl/ProductQuantizer.cpp index 81cc6dcf1..169332ce9 100644 --- a/faiss/impl/ProductQuantizer.cpp +++ b/faiss/impl/ProductQuantizer.cpp @@ -321,28 +321,54 @@ void ProductQuantizer::train(int n, const float* x) { template void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) { std::vector distances(pq.ksub); + + // It seems to be meaningless to allocate std::vector distances. + // But it is done in order to cope the ineffectiveness of the way + // the compiler generates the code. Basically, doing something like + // + // size_t min_distance = HUGE_VALF; + // size_t idxm = 0; + // for (size_t i = 0; i < N; i++) { + // const float distance = compute_distance(x, y + i * d, d); + // if (distance < min_distance) { + // min_distance = distance; + // idxm = i; + // } + // } + // + // generates significantly more CPU instructions than the baseline + // + // std::vector distances_cached(N); + // for (size_t i = 0; i < N; i++) { + // distances_cached[i] = compute_distance(x, y + i * d, d); + // } + // size_t min_distance = HUGE_VALF; + // size_t idxm = 0; + // for (size_t i = 0; i < N; i++) { + // const float distance = distances_cached[i]; + // if (distance < min_distance) { + // min_distance = distance; + // idxm = i; + // } + // } + // + // So, the baseline is faster. This is because of the vectorization. + // I suppose that the branch predictor might affect the performance as well. + // So, the buffer is allocated, but it might be unused in + // manually optimized code. Let's hope that the compiler is smart enough to + // get rid of std::vector allocation in such a case. + PQEncoder encoder(code, pq.nbits); for (size_t m = 0; m < pq.M; m++) { - float mindis = 1e20; - uint64_t idxm = 0; const float* xsub = x + m * pq.dsub; - fvec_L2sqr_ny( + uint64_t idxm = fvec_L2sqr_ny_nearest( distances.data(), xsub, pq.get_centroids(m, 0), pq.dsub, pq.ksub); - /* Find best centroid */ - for (size_t i = 0; i < pq.ksub; i++) { - float dis = distances[i]; - if (dis < mindis) { - mindis = dis; - idxm = i; - } - } - encoder.encode(idxm); } } diff --git a/faiss/utils/distances.h b/faiss/utils/distances.h index f612ea06f..91ad53f18 100644 --- a/faiss/utils/distances.h +++ b/faiss/utils/distances.h @@ -71,6 +71,16 @@ void fvec_L2sqr_ny( size_t d, size_t ny); +/* compute ny square L2 distance between x and a set of contiguous y vectors + and return the index of the nearest vector. + return 0 if ny == 0. */ +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny); + /** squared norm of a vector */ float fvec_norm_L2sqr(const float* x, size_t d); diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index 41429bcd2..c360ed30e 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -113,6 +113,27 @@ void fvec_L2sqr_ny_ref( } } +size_t fvec_L2sqr_ny_nearest_ref( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny); + + size_t nearest_idx = 0; + float min_dis = HUGE_VALF; + + for (size_t i = 0; i < ny; i++) { + if (distances_tmp_buffer[i] < min_dis) { + min_dis = distances_tmp_buffer[i]; + nearest_idx = i; + } + } + + return nearest_idx; +} + void fvec_inner_products_ny_ref( float* ip, const float* x, @@ -514,6 +535,175 @@ void fvec_inner_products_ny( #undef DISPATCH } +#ifdef __AVX2__ +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D4-vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // + _mm_prefetch(y, _MM_HINT_NTA); + _mm_prefetch(y + 16, _MM_HINT_NTA); + + // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0]) + const __m256 m0 = _mm256_set1_ps(x[0]); + // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1]) + const __m256 m1 = _mm256_set1_ps(x[1]); + // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2]) + const __m256 m2 = _mm256_set1_ps(x[2]); + // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3]) + const __m256 m3 = _mm256_set1_ps(x[3]); + + const __m256i indices0 = + _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112); + + for (; i < ny8 * 8; i += 8) { + _mm_prefetch(y + 32, _MM_HINT_NTA); + _mm_prefetch(y + 48, _MM_HINT_NTA); + + // collect dim 0 for 8 D4-vectors. + // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0]) + const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1); + // collect dim 1 for 8 D4-vectors. + // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1]) + const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1); + // collect dim 2 for 8 D4-vectors. + // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2]) + const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1); + // collect dim 3 for 8 D4-vectors. + // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3]) + const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + + // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 + + // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 + + // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 + + // (x[3] - y[(i * 8 + 0) * 4 + 3]) + // ... + // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 + + // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 + + // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 + + // (x[3] - y[(i * 8 + 7) * 4 + 3]) + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = + _mm256_blendv_ps(distances, min_distances, comparison); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 4 DIM each). + y += 32; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_hadd_ps(accu, accu); + accu = _mm_hadd_ps(accu, accu); + + const auto distance = _mm_cvtss_f32(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} +#else +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny); +} +#endif + +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + // optimized for a few special cases + +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny); + + switch (d) { + DISPATCH(4) + default: + return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); + } +#undef DISPATCH +} + #endif #ifdef USE_AVX @@ -816,6 +1006,15 @@ void fvec_L2sqr_ny( fvec_L2sqr_ny_ref(dis, x, y, d, ny); } +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); +} + float fvec_L1(const float* x, const float* y, size_t d) { return fvec_L1_ref(x, y, d); } @@ -865,6 +1064,15 @@ void fvec_L2sqr_ny( fvec_L2sqr_ny_ref(dis, x, y, d, ny); } +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); +} + void fvec_inner_products_ny( float* dis, const float* x,