ProductQuantizer::compute_code tracks the nearest vector index in a register rather than stores the distances in a buffer. (#2280)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2280 Add a new function call fvec_L2sqr_ny_nearest and a demonstration of its implementation for 4 bits Reviewed By: mdouze Differential Revision: D35189945 fbshipit-source-id: d1b2ba42851df195123c7e318a8dcf26f775eabapull/2274/head
parent
438b64cd8b
commit
b32abc95c2
|
@ -321,28 +321,54 @@ void ProductQuantizer::train(int n, const float* x) {
|
|||
template <class PQEncoder>
|
||||
void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
|
||||
std::vector<float> distances(pq.ksub);
|
||||
|
||||
// It seems to be meaningless to allocate std::vector<float> distances.
|
||||
// But it is done in order to cope the ineffectiveness of the way
|
||||
// the compiler generates the code. Basically, doing something like
|
||||
//
|
||||
// size_t min_distance = HUGE_VALF;
|
||||
// size_t idxm = 0;
|
||||
// for (size_t i = 0; i < N; i++) {
|
||||
// const float distance = compute_distance(x, y + i * d, d);
|
||||
// if (distance < min_distance) {
|
||||
// min_distance = distance;
|
||||
// idxm = i;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// generates significantly more CPU instructions than the baseline
|
||||
//
|
||||
// std::vector<float> distances_cached(N);
|
||||
// for (size_t i = 0; i < N; i++) {
|
||||
// distances_cached[i] = compute_distance(x, y + i * d, d);
|
||||
// }
|
||||
// size_t min_distance = HUGE_VALF;
|
||||
// size_t idxm = 0;
|
||||
// for (size_t i = 0; i < N; i++) {
|
||||
// const float distance = distances_cached[i];
|
||||
// if (distance < min_distance) {
|
||||
// min_distance = distance;
|
||||
// idxm = i;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// So, the baseline is faster. This is because of the vectorization.
|
||||
// I suppose that the branch predictor might affect the performance as well.
|
||||
// So, the buffer is allocated, but it might be unused in
|
||||
// manually optimized code. Let's hope that the compiler is smart enough to
|
||||
// get rid of std::vector allocation in such a case.
|
||||
|
||||
PQEncoder encoder(code, pq.nbits);
|
||||
for (size_t m = 0; m < pq.M; m++) {
|
||||
float mindis = 1e20;
|
||||
uint64_t idxm = 0;
|
||||
const float* xsub = x + m * pq.dsub;
|
||||
|
||||
fvec_L2sqr_ny(
|
||||
uint64_t idxm = fvec_L2sqr_ny_nearest(
|
||||
distances.data(),
|
||||
xsub,
|
||||
pq.get_centroids(m, 0),
|
||||
pq.dsub,
|
||||
pq.ksub);
|
||||
|
||||
/* Find best centroid */
|
||||
for (size_t i = 0; i < pq.ksub; i++) {
|
||||
float dis = distances[i];
|
||||
if (dis < mindis) {
|
||||
mindis = dis;
|
||||
idxm = i;
|
||||
}
|
||||
}
|
||||
|
||||
encoder.encode(idxm);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,6 +71,16 @@ void fvec_L2sqr_ny(
|
|||
size_t d,
|
||||
size_t ny);
|
||||
|
||||
/* compute ny square L2 distance between x and a set of contiguous y vectors
|
||||
and return the index of the nearest vector.
|
||||
return 0 if ny == 0. */
|
||||
size_t fvec_L2sqr_ny_nearest(
|
||||
float* distances_tmp_buffer,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t ny);
|
||||
|
||||
/** squared norm of a vector */
|
||||
float fvec_norm_L2sqr(const float* x, size_t d);
|
||||
|
||||
|
|
|
@ -113,6 +113,27 @@ void fvec_L2sqr_ny_ref(
|
|||
}
|
||||
}
|
||||
|
||||
size_t fvec_L2sqr_ny_nearest_ref(
|
||||
float* distances_tmp_buffer,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t ny) {
|
||||
fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
|
||||
|
||||
size_t nearest_idx = 0;
|
||||
float min_dis = HUGE_VALF;
|
||||
|
||||
for (size_t i = 0; i < ny; i++) {
|
||||
if (distances_tmp_buffer[i] < min_dis) {
|
||||
min_dis = distances_tmp_buffer[i];
|
||||
nearest_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
return nearest_idx;
|
||||
}
|
||||
|
||||
void fvec_inner_products_ny_ref(
|
||||
float* ip,
|
||||
const float* x,
|
||||
|
@ -514,6 +535,175 @@ void fvec_inner_products_ny(
|
|||
#undef DISPATCH
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
size_t fvec_L2sqr_ny_nearest_D4(
|
||||
float* distances_tmp_buffer,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t ny) {
|
||||
// this implementation does not use distances_tmp_buffer.
|
||||
|
||||
// current index being processed
|
||||
size_t i = 0;
|
||||
|
||||
// min distance and the index of the closest vector so far
|
||||
float current_min_distance = HUGE_VALF;
|
||||
size_t current_min_index = 0;
|
||||
|
||||
// process 8 D4-vectors per loop.
|
||||
const size_t ny8 = ny / 8;
|
||||
|
||||
if (ny8 > 0) {
|
||||
// track min distance and the closest vector independently
|
||||
// for each of 8 AVX2 components.
|
||||
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
||||
__m256i min_indices = _mm256_set1_epi32(0);
|
||||
|
||||
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
||||
const __m256i indices_increment = _mm256_set1_epi32(8);
|
||||
|
||||
//
|
||||
_mm_prefetch(y, _MM_HINT_NTA);
|
||||
_mm_prefetch(y + 16, _MM_HINT_NTA);
|
||||
|
||||
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
||||
const __m256 m0 = _mm256_set1_ps(x[0]);
|
||||
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
||||
const __m256 m1 = _mm256_set1_ps(x[1]);
|
||||
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
||||
const __m256 m2 = _mm256_set1_ps(x[2]);
|
||||
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
||||
const __m256 m3 = _mm256_set1_ps(x[3]);
|
||||
|
||||
const __m256i indices0 =
|
||||
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
||||
|
||||
for (; i < ny8 * 8; i += 8) {
|
||||
_mm_prefetch(y + 32, _MM_HINT_NTA);
|
||||
_mm_prefetch(y + 48, _MM_HINT_NTA);
|
||||
|
||||
// collect dim 0 for 8 D4-vectors.
|
||||
// v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
|
||||
const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
|
||||
// collect dim 1 for 8 D4-vectors.
|
||||
// v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
|
||||
const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
|
||||
// collect dim 2 for 8 D4-vectors.
|
||||
// v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
|
||||
const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
|
||||
// collect dim 3 for 8 D4-vectors.
|
||||
// v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
|
||||
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
|
||||
|
||||
// compute differences
|
||||
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
||||
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
||||
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
||||
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
||||
|
||||
// compute squares of differences
|
||||
__m256 distances = _mm256_mul_ps(d0, d0);
|
||||
distances = _mm256_fmadd_ps(d1, d1, distances);
|
||||
distances = _mm256_fmadd_ps(d2, d2, distances);
|
||||
distances = _mm256_fmadd_ps(d3, d3, distances);
|
||||
|
||||
// distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
|
||||
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
|
||||
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
|
||||
// (x[3] - y[(i * 8 + 0) * 4 + 3])
|
||||
// ...
|
||||
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
|
||||
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
|
||||
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
|
||||
// (x[3] - y[(i * 8 + 7) * 4 + 3])
|
||||
|
||||
// compare the new distances to the min distances
|
||||
// for each of 8 AVX2 components.
|
||||
__m256 comparison =
|
||||
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
||||
|
||||
// update min distances and indices with closest vectors if needed.
|
||||
min_distances =
|
||||
_mm256_blendv_ps(distances, min_distances, comparison);
|
||||
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
||||
_mm256_castsi256_ps(current_indices),
|
||||
_mm256_castsi256_ps(min_indices),
|
||||
comparison));
|
||||
|
||||
// update current indices values. Basically, +8 to each of the
|
||||
// 8 AVX2 components.
|
||||
current_indices =
|
||||
_mm256_add_epi32(current_indices, indices_increment);
|
||||
|
||||
// scroll y forward (8 vectors 4 DIM each).
|
||||
y += 32;
|
||||
}
|
||||
|
||||
// dump values and find the minimum distance / minimum index
|
||||
float min_distances_scalar[8];
|
||||
uint32_t min_indices_scalar[8];
|
||||
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
||||
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
||||
|
||||
for (size_t j = 0; j < 8; j++) {
|
||||
if (current_min_distance > min_distances_scalar[j]) {
|
||||
current_min_distance = min_distances_scalar[j];
|
||||
current_min_index = min_indices_scalar[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (i < ny) {
|
||||
// process leftovers
|
||||
__m128 x0 = _mm_loadu_ps(x);
|
||||
|
||||
for (; i < ny; i++) {
|
||||
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
||||
y += 4;
|
||||
accu = _mm_hadd_ps(accu, accu);
|
||||
accu = _mm_hadd_ps(accu, accu);
|
||||
|
||||
const auto distance = _mm_cvtss_f32(accu);
|
||||
|
||||
if (current_min_distance > distance) {
|
||||
current_min_distance = distance;
|
||||
current_min_index = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return current_min_index;
|
||||
}
|
||||
#else
|
||||
size_t fvec_L2sqr_ny_nearest_D4(
|
||||
float* distances_tmp_buffer,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t ny) {
|
||||
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t fvec_L2sqr_ny_nearest(
|
||||
float* distances_tmp_buffer,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t ny) {
|
||||
// optimized for a few special cases
|
||||
|
||||
#define DISPATCH(dval) \
|
||||
case dval: \
|
||||
return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
|
||||
|
||||
switch (d) {
|
||||
DISPATCH(4)
|
||||
default:
|
||||
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
||||
}
|
||||
#undef DISPATCH
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef USE_AVX
|
||||
|
@ -816,6 +1006,15 @@ void fvec_L2sqr_ny(
|
|||
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
||||
}
|
||||
|
||||
size_t fvec_L2sqr_ny_nearest(
|
||||
float* distances_tmp_buffer,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t ny) {
|
||||
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
||||
}
|
||||
|
||||
float fvec_L1(const float* x, const float* y, size_t d) {
|
||||
return fvec_L1_ref(x, y, d);
|
||||
}
|
||||
|
@ -865,6 +1064,15 @@ void fvec_L2sqr_ny(
|
|||
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
||||
}
|
||||
|
||||
size_t fvec_L2sqr_ny_nearest(
|
||||
float* distances_tmp_buffer,
|
||||
const float* x,
|
||||
const float* y,
|
||||
size_t d,
|
||||
size_t ny) {
|
||||
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
||||
}
|
||||
|
||||
void fvec_inner_products_ny(
|
||||
float* dis,
|
||||
const float* x,
|
||||
|
|
Loading…
Reference in New Issue