19 #include <immintrin.h>
67 float fvec_L2sqr_ref (
const float * x,
73 for (i = 0; i < d; i++) {
74 const float tmp = x[i] - y[i];
80 float fvec_inner_product_ref (
const float * x,
86 for (i = 0; i < d; i++)
91 float fvec_norm_L2sqr_ref (
const float *x,
size_t d)
95 for (i = 0; i < d; i++)
101 void fvec_L2sqr_ny_ref (
float * dis,
106 for (
size_t i = 0; i < ny; i++) {
122 static inline __m128 masked_read (
int d,
const float *x)
124 assert (0 <= d && d < 4);
125 __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
134 return _mm_load_ps (buf);
142 __m128 msum1 = _mm_setzero_ps();
145 mx = _mm_loadu_ps (x); x += 4;
146 msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
150 mx = masked_read (d, x);
151 msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
153 msum1 = _mm_hadd_ps (msum1, msum1);
154 msum1 = _mm_hadd_ps (msum1, msum1);
155 return _mm_cvtss_f32 (msum1);
160 float sqr (
float x) {
165 void fvec_L2sqr_ny_D1 (
float * dis,
const float * x,
166 const float * y,
size_t ny)
169 __m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
172 for (i = 0; i + 3 < ny; i += 4) {
174 tmp = x0 - _mm_loadu_ps (y); y += 4;
176 dis[i] = _mm_cvtss_f32 (accu);
177 tmp = _mm_shuffle_ps (accu, accu, 1);
178 dis[i + 1] = _mm_cvtss_f32 (tmp);
179 tmp = _mm_shuffle_ps (accu, accu, 2);
180 dis[i + 2] = _mm_cvtss_f32 (tmp);
181 tmp = _mm_shuffle_ps (accu, accu, 3);
182 dis[i + 3] = _mm_cvtss_f32 (tmp);
185 dis[i++] = sqr(x0s - *y++);
190 void fvec_L2sqr_ny_D2 (
float * dis,
const float * x,
191 const float * y,
size_t ny)
193 __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
196 for (i = 0; i + 1 < ny; i += 2) {
198 tmp = x0 - _mm_loadu_ps (y); y += 4;
200 accu = _mm_hadd_ps (accu, accu);
201 dis[i] = _mm_cvtss_f32 (accu);
202 accu = _mm_shuffle_ps (accu, accu, 3);
203 dis[i + 1] = _mm_cvtss_f32 (accu);
206 dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]);
212 void fvec_L2sqr_ny_D4 (
float * dis,
const float * x,
213 const float * y,
size_t ny)
215 __m128 x0 = _mm_loadu_ps(x);
217 for (
size_t i = 0; i < ny; i++) {
219 tmp = x0 - _mm_loadu_ps (y); y += 4;
221 accu = _mm_hadd_ps (accu, accu);
222 accu = _mm_hadd_ps (accu, accu);
223 dis[i] = _mm_cvtss_f32 (accu);
228 void fvec_L2sqr_ny_D8 (
float * dis,
const float * x,
229 const float * y,
size_t ny)
231 __m128 x0 = _mm_loadu_ps(x);
232 __m128 x1 = _mm_loadu_ps(x + 4);
234 for (
size_t i = 0; i < ny; i++) {
236 tmp = x0 - _mm_loadu_ps (y); y += 4;
238 tmp = x1 - _mm_loadu_ps (y); y += 4;
240 accu = _mm_hadd_ps (accu, accu);
241 accu = _mm_hadd_ps (accu, accu);
242 dis[i] = _mm_cvtss_f32 (accu);
247 void fvec_L2sqr_ny_D12 (
float * dis,
const float * x,
248 const float * y,
size_t ny)
250 __m128 x0 = _mm_loadu_ps(x);
251 __m128 x1 = _mm_loadu_ps(x + 4);
252 __m128 x2 = _mm_loadu_ps(x + 8);
254 for (
size_t i = 0; i < ny; i++) {
256 tmp = x0 - _mm_loadu_ps (y); y += 4;
258 tmp = x1 - _mm_loadu_ps (y); y += 4;
260 tmp = x2 - _mm_loadu_ps (y); y += 4;
262 accu = _mm_hadd_ps (accu, accu);
263 accu = _mm_hadd_ps (accu, accu);
264 dis[i] = _mm_cvtss_f32 (accu);
271 void fvec_L2sqr_ny (
float * dis,
const float * x,
272 const float * y,
size_t d,
size_t ny) {
276 fvec_L2sqr_ny_D1 (dis, x, y, ny);
279 fvec_L2sqr_ny_D2 (dis, x, y, ny);
282 fvec_L2sqr_ny_D4 (dis, x, y, ny);
285 fvec_L2sqr_ny_D8 (dis, x, y, ny);
288 fvec_L2sqr_ny_D12 (dis, x, y, ny);
291 fvec_L2sqr_ny_ref (dis, x, y, d, ny);
303 static inline __m256 masked_read_8 (
int d,
const float *x)
305 assert (0 <= d && d < 8);
307 __m256 res = _mm256_setzero_ps ();
308 res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
311 __m256 res = _mm256_setzero_ps ();
312 res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
313 res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
318 float fvec_inner_product (
const float * x,
322 __m256 msum1 = _mm256_setzero_ps();
325 __m256 mx = _mm256_loadu_ps (x); x += 8;
326 __m256 my = _mm256_loadu_ps (y); y += 8;
327 msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
331 __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
332 msum2 += _mm256_extractf128_ps(msum1, 0);
335 __m128 mx = _mm_loadu_ps (x); x += 4;
336 __m128 my = _mm_loadu_ps (y); y += 4;
337 msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
342 __m128 mx = masked_read (d, x);
343 __m128 my = masked_read (d, y);
344 msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
347 msum2 = _mm_hadd_ps (msum2, msum2);
348 msum2 = _mm_hadd_ps (msum2, msum2);
349 return _mm_cvtss_f32 (msum2);
356 __m256 msum1 = _mm256_setzero_ps();
359 __m256 mx = _mm256_loadu_ps (x); x += 8;
360 __m256 my = _mm256_loadu_ps (y); y += 8;
361 const __m256 a_m_b1 = mx - my;
362 msum1 += a_m_b1 * a_m_b1;
366 __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
367 msum2 += _mm256_extractf128_ps(msum1, 0);
370 __m128 mx = _mm_loadu_ps (x); x += 4;
371 __m128 my = _mm_loadu_ps (y); y += 4;
372 const __m128 a_m_b1 = mx - my;
373 msum2 += a_m_b1 * a_m_b1;
378 __m128 mx = masked_read (d, x);
379 __m128 my = masked_read (d, y);
380 __m128 a_m_b1 = mx - my;
381 msum2 += a_m_b1 * a_m_b1;
384 msum2 = _mm_hadd_ps (msum2, msum2);
385 msum2 = _mm_hadd_ps (msum2, msum2);
386 return _mm_cvtss_f32 (msum2);
389 #elif defined(__SSE__)
396 __m128 msum1 = _mm_setzero_ps();
399 __m128 mx = _mm_loadu_ps (x); x += 4;
400 __m128 my = _mm_loadu_ps (y); y += 4;
401 const __m128 a_m_b1 = mx - my;
402 msum1 += a_m_b1 * a_m_b1;
408 __m128 mx = masked_read (d, x);
409 __m128 my = masked_read (d, y);
410 __m128 a_m_b1 = mx - my;
411 msum1 += a_m_b1 * a_m_b1;
414 msum1 = _mm_hadd_ps (msum1, msum1);
415 msum1 = _mm_hadd_ps (msum1, msum1);
416 return _mm_cvtss_f32 (msum1);
420 float fvec_inner_product (
const float * x,
425 __m128 msum1 = _mm_setzero_ps();
428 mx = _mm_loadu_ps (x); x += 4;
429 my = _mm_loadu_ps (y); y += 4;
430 msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
435 mx = masked_read (d, x);
436 my = masked_read (d, y);
437 __m128 prod = _mm_mul_ps (mx, my);
439 msum1 = _mm_add_ps (msum1, prod);
441 msum1 = _mm_hadd_ps (msum1, msum1);
442 msum1 = _mm_hadd_ps (msum1, msum1);
443 return _mm_cvtss_f32 (msum1);
446 #elif defined(__aarch64__)
453 if (d & 3)
return fvec_L2sqr_ref (x, y, d);
454 float32x4_t accu = vdupq_n_f32 (0);
455 for (
size_t i = 0; i < d; i += 4) {
456 float32x4_t xi = vld1q_f32 (x + i);
457 float32x4_t yi = vld1q_f32 (y + i);
458 float32x4_t sq = vsubq_f32 (xi, yi);
459 accu = vfmaq_f32 (accu, sq, sq);
461 float32x4_t a2 = vpaddq_f32 (accu, accu);
462 return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
465 float fvec_inner_product (
const float * x,
469 if (d & 3)
return fvec_inner_product_ref (x, y, d);
470 float32x4_t accu = vdupq_n_f32 (0);
471 for (
size_t i = 0; i < d; i += 4) {
472 float32x4_t xi = vld1q_f32 (x + i);
473 float32x4_t yi = vld1q_f32 (y + i);
474 accu = vfmaq_f32 (accu, xi, yi);
476 float32x4_t a2 = vpaddq_f32 (accu, accu);
477 return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
482 if (d & 3)
return fvec_norm_L2sqr_ref (x, d);
483 float32x4_t accu = vdupq_n_f32 (0);
484 for (
size_t i = 0; i < d; i += 4) {
485 float32x4_t xi = vld1q_f32 (x + i);
486 accu = vfmaq_f32 (accu, xi, xi);
488 float32x4_t a2 = vpaddq_f32 (accu, accu);
489 return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
493 void fvec_L2sqr_ny (
float * dis,
const float * x,
494 const float * y,
size_t d,
size_t ny) {
495 fvec_L2sqr_ny_ref (dis, x, y, d, ny);
506 return fvec_L2sqr_ref (x, y, d);
509 float fvec_inner_product (
const float * x,
513 return fvec_inner_product_ref (x, y, d);
518 return fvec_norm_L2sqr_ref (x, d);
521 void fvec_L2sqr_ny (
float * dis,
const float * x,
522 const float * y,
size_t d,
size_t ny) {
523 fvec_L2sqr_ny_ref (dis, x, y, d, ny);
553 static inline void fvec_madd_ref (
size_t n,
const float *a,
554 float bf,
const float *b,
float *c) {
555 for (
size_t i = 0; i < n; i++)
556 c[i] = a[i] + bf * b[i];
561 static inline void fvec_madd_sse (
size_t n,
const float *a,
562 float bf,
const float *b,
float *c) {
564 __m128 bf4 = _mm_set_ps1 (bf);
565 __m128 * a4 = (__m128*)a;
566 __m128 * b4 = (__m128*)b;
567 __m128 * c4 = (__m128*)c;
570 *c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
577 void fvec_madd (
size_t n,
const float *a,
578 float bf,
const float *b,
float *c)
581 ((((
long)a) | ((
long)b) | ((
long)c)) & 15) == 0)
582 fvec_madd_sse (n, a, bf, b, c);
584 fvec_madd_ref (n, a, bf, b, c);
590 float bf,
const float *b,
float *c)
592 fvec_madd_ref (n, a, bf, b, c);
597 static inline int fvec_madd_and_argmin_ref (
size_t n,
const float *a,
598 float bf,
const float *b,
float *c) {
602 for (
size_t i = 0; i < n; i++) {
603 c[i] = a[i] + bf * b[i];
614 static inline int fvec_madd_and_argmin_sse (
615 size_t n,
const float *a,
616 float bf,
const float *b,
float *c) {
618 __m128 bf4 = _mm_set_ps1 (bf);
619 __m128 vmin4 = _mm_set_ps1 (1e20);
620 __m128i imin4 = _mm_set1_epi32 (-1);
621 __m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
622 __m128i inc4 = _mm_set1_epi32 (4);
623 __m128 * a4 = (__m128*)a;
624 __m128 * b4 = (__m128*)b;
625 __m128 * c4 = (__m128*)c;
628 __m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
630 __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
633 imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
634 _mm_andnot_si128 (mask, imin4));
635 vmin4 = _mm_min_ps (vmin4, vc4);
639 idx4 = _mm_add_epi32 (idx4, inc4);
644 idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
645 __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
646 __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
647 imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
648 _mm_andnot_si128 (mask, imin4));
649 vmin4 = _mm_min_ps (vmin4, vc4);
653 idx4 = _mm_shuffle_epi32 (imin4, 1);
654 __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
655 __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
656 imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
657 _mm_andnot_si128 (mask, imin4));
660 return _mm_cvtsi128_si32 (imin4);
665 float bf,
const float *b,
float *c)
668 ((((
long)a) | ((
long)b) | ((
long)c)) & 15) == 0)
669 return fvec_madd_and_argmin_sse (n, a, bf, b, c);
671 return fvec_madd_and_argmin_ref (n, a, bf, b, c);
677 float bf,
const float *b,
float *c)
679 return fvec_madd_and_argmin_ref (n, a, bf, b, c);
float fvec_L2sqr(const float *x, const float *y, size_t d)
Squared L2 distance between two vectors.
void fvec_madd(size_t n, const float *a, float bf, const float *b, float *c)
float fvec_norm_L2sqr(const float *x, size_t d)
int fvec_madd_and_argmin(size_t n, const float *a, float bf, const float *b, float *c)