ProductQuantizer::compute_code tracks the nearest vector index in a register rather than stores the distances in a buffer. (#2280)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2280

Add a new function call fvec_L2sqr_ny_nearest and a demonstration of its implementation for 4 bits

Reviewed By: mdouze

Differential Revision: D35189945

fbshipit-source-id: d1b2ba42851df195123c7e318a8dcf26f775eaba
pull/2274/head
Alexandr Guzhva 2022-03-29 10:21:23 -07:00 committed by Facebook GitHub Bot
parent 438b64cd8b
commit b32abc95c2
3 changed files with 256 additions and 12 deletions

View File

@ -321,28 +321,54 @@ void ProductQuantizer::train(int n, const float* x) {
template <class PQEncoder>
void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
std::vector<float> distances(pq.ksub);
// It seems to be meaningless to allocate std::vector<float> distances.
// But it is done in order to cope the ineffectiveness of the way
// the compiler generates the code. Basically, doing something like
//
// size_t min_distance = HUGE_VALF;
// size_t idxm = 0;
// for (size_t i = 0; i < N; i++) {
// const float distance = compute_distance(x, y + i * d, d);
// if (distance < min_distance) {
// min_distance = distance;
// idxm = i;
// }
// }
//
// generates significantly more CPU instructions than the baseline
//
// std::vector<float> distances_cached(N);
// for (size_t i = 0; i < N; i++) {
// distances_cached[i] = compute_distance(x, y + i * d, d);
// }
// size_t min_distance = HUGE_VALF;
// size_t idxm = 0;
// for (size_t i = 0; i < N; i++) {
// const float distance = distances_cached[i];
// if (distance < min_distance) {
// min_distance = distance;
// idxm = i;
// }
// }
//
// So, the baseline is faster. This is because of the vectorization.
// I suppose that the branch predictor might affect the performance as well.
// So, the buffer is allocated, but it might be unused in
// manually optimized code. Let's hope that the compiler is smart enough to
// get rid of std::vector allocation in such a case.
PQEncoder encoder(code, pq.nbits);
for (size_t m = 0; m < pq.M; m++) {
float mindis = 1e20;
uint64_t idxm = 0;
const float* xsub = x + m * pq.dsub;
fvec_L2sqr_ny(
uint64_t idxm = fvec_L2sqr_ny_nearest(
distances.data(),
xsub,
pq.get_centroids(m, 0),
pq.dsub,
pq.ksub);
/* Find best centroid */
for (size_t i = 0; i < pq.ksub; i++) {
float dis = distances[i];
if (dis < mindis) {
mindis = dis;
idxm = i;
}
}
encoder.encode(idxm);
}
}

View File

@ -71,6 +71,16 @@ void fvec_L2sqr_ny(
size_t d,
size_t ny);
/* compute ny square L2 distance between x and a set of contiguous y vectors
and return the index of the nearest vector.
return 0 if ny == 0. */
size_t fvec_L2sqr_ny_nearest(
float* distances_tmp_buffer,
const float* x,
const float* y,
size_t d,
size_t ny);
/** squared norm of a vector */
float fvec_norm_L2sqr(const float* x, size_t d);

View File

@ -113,6 +113,27 @@ void fvec_L2sqr_ny_ref(
}
}
size_t fvec_L2sqr_ny_nearest_ref(
float* distances_tmp_buffer,
const float* x,
const float* y,
size_t d,
size_t ny) {
fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
size_t nearest_idx = 0;
float min_dis = HUGE_VALF;
for (size_t i = 0; i < ny; i++) {
if (distances_tmp_buffer[i] < min_dis) {
min_dis = distances_tmp_buffer[i];
nearest_idx = i;
}
}
return nearest_idx;
}
void fvec_inner_products_ny_ref(
float* ip,
const float* x,
@ -514,6 +535,175 @@ void fvec_inner_products_ny(
#undef DISPATCH
}
#ifdef __AVX2__
size_t fvec_L2sqr_ny_nearest_D4(
float* distances_tmp_buffer,
const float* x,
const float* y,
size_t ny) {
// this implementation does not use distances_tmp_buffer.
// current index being processed
size_t i = 0;
// min distance and the index of the closest vector so far
float current_min_distance = HUGE_VALF;
size_t current_min_index = 0;
// process 8 D4-vectors per loop.
const size_t ny8 = ny / 8;
if (ny8 > 0) {
// track min distance and the closest vector independently
// for each of 8 AVX2 components.
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
__m256i min_indices = _mm256_set1_epi32(0);
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i indices_increment = _mm256_set1_epi32(8);
//
_mm_prefetch(y, _MM_HINT_NTA);
_mm_prefetch(y + 16, _MM_HINT_NTA);
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
const __m256 m0 = _mm256_set1_ps(x[0]);
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
const __m256 m1 = _mm256_set1_ps(x[1]);
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
const __m256 m2 = _mm256_set1_ps(x[2]);
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
const __m256 m3 = _mm256_set1_ps(x[3]);
const __m256i indices0 =
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
for (; i < ny8 * 8; i += 8) {
_mm_prefetch(y + 32, _MM_HINT_NTA);
_mm_prefetch(y + 48, _MM_HINT_NTA);
// collect dim 0 for 8 D4-vectors.
// v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
// collect dim 1 for 8 D4-vectors.
// v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
// collect dim 2 for 8 D4-vectors.
// v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
// collect dim 3 for 8 D4-vectors.
// v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
// compute differences
const __m256 d0 = _mm256_sub_ps(m0, v0);
const __m256 d1 = _mm256_sub_ps(m1, v1);
const __m256 d2 = _mm256_sub_ps(m2, v2);
const __m256 d3 = _mm256_sub_ps(m3, v3);
// compute squares of differences
__m256 distances = _mm256_mul_ps(d0, d0);
distances = _mm256_fmadd_ps(d1, d1, distances);
distances = _mm256_fmadd_ps(d2, d2, distances);
distances = _mm256_fmadd_ps(d3, d3, distances);
// distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
// (x[3] - y[(i * 8 + 0) * 4 + 3])
// ...
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
// (x[3] - y[(i * 8 + 7) * 4 + 3])
// compare the new distances to the min distances
// for each of 8 AVX2 components.
__m256 comparison =
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
// update min distances and indices with closest vectors if needed.
min_distances =
_mm256_blendv_ps(distances, min_distances, comparison);
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
_mm256_castsi256_ps(current_indices),
_mm256_castsi256_ps(min_indices),
comparison));
// update current indices values. Basically, +8 to each of the
// 8 AVX2 components.
current_indices =
_mm256_add_epi32(current_indices, indices_increment);
// scroll y forward (8 vectors 4 DIM each).
y += 32;
}
// dump values and find the minimum distance / minimum index
float min_distances_scalar[8];
uint32_t min_indices_scalar[8];
_mm256_storeu_ps(min_distances_scalar, min_distances);
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
for (size_t j = 0; j < 8; j++) {
if (current_min_distance > min_distances_scalar[j]) {
current_min_distance = min_distances_scalar[j];
current_min_index = min_indices_scalar[j];
}
}
}
if (i < ny) {
// process leftovers
__m128 x0 = _mm_loadu_ps(x);
for (; i < ny; i++) {
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
y += 4;
accu = _mm_hadd_ps(accu, accu);
accu = _mm_hadd_ps(accu, accu);
const auto distance = _mm_cvtss_f32(accu);
if (current_min_distance > distance) {
current_min_distance = distance;
current_min_index = i;
}
}
}
return current_min_index;
}
#else
size_t fvec_L2sqr_ny_nearest_D4(
float* distances_tmp_buffer,
const float* x,
const float* y,
size_t ny) {
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
}
#endif
size_t fvec_L2sqr_ny_nearest(
float* distances_tmp_buffer,
const float* x,
const float* y,
size_t d,
size_t ny) {
// optimized for a few special cases
#define DISPATCH(dval) \
case dval: \
return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
switch (d) {
DISPATCH(4)
default:
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
}
#undef DISPATCH
}
#endif
#ifdef USE_AVX
@ -816,6 +1006,15 @@ void fvec_L2sqr_ny(
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
}
size_t fvec_L2sqr_ny_nearest(
float* distances_tmp_buffer,
const float* x,
const float* y,
size_t d,
size_t ny) {
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
}
float fvec_L1(const float* x, const float* y, size_t d) {
return fvec_L1_ref(x, y, d);
}
@ -865,6 +1064,15 @@ void fvec_L2sqr_ny(
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
}
size_t fvec_L2sqr_ny_nearest(
float* distances_tmp_buffer,
const float* x,
const float* y,
size_t d,
size_t ny) {
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
}
void fvec_inner_products_ny(
float* dis,
const float* x,