18 #include <immintrin.h>
66 float fvec_L2sqr_ref (
const float * x,
72 for (i = 0; i < d; i++) {
73 const float tmp = x[i] - y[i];
79 float fvec_inner_product_ref (
const float * x,
85 for (i = 0; i < d; i++)
90 float fvec_norm_L2sqr_ref (
const float *x,
size_t d)
94 for (i = 0; i < d; i++)
100 void fvec_L2sqr_ny_ref (
float * dis,
105 for (
size_t i = 0; i < ny; i++) {
121 static inline __m128 masked_read (
int d,
const float *x)
123 assert (0 <= d && d < 4);
124 __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
133 return _mm_load_ps (buf);
141 __m128 msum1 = _mm_setzero_ps();
144 mx = _mm_loadu_ps (x); x += 4;
145 msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
149 mx = masked_read (d, x);
150 msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
152 msum1 = _mm_hadd_ps (msum1, msum1);
153 msum1 = _mm_hadd_ps (msum1, msum1);
154 return _mm_cvtss_f32 (msum1);
159 float sqr (
float x) {
164 void fvec_L2sqr_ny_D1 (
float * dis,
const float * x,
165 const float * y,
size_t ny)
168 __m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
171 for (i = 0; i + 3 < ny; i += 4) {
173 tmp = x0 - _mm_loadu_ps (y); y += 4;
175 dis[i] = _mm_cvtss_f32 (accu);
176 tmp = _mm_shuffle_ps (accu, accu, 1);
177 dis[i + 1] = _mm_cvtss_f32 (tmp);
178 tmp = _mm_shuffle_ps (accu, accu, 2);
179 dis[i + 2] = _mm_cvtss_f32 (tmp);
180 tmp = _mm_shuffle_ps (accu, accu, 3);
181 dis[i + 3] = _mm_cvtss_f32 (tmp);
184 dis[i++] = sqr(x0s - *y++);
189 void fvec_L2sqr_ny_D2 (
float * dis,
const float * x,
190 const float * y,
size_t ny)
192 __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
195 for (i = 0; i + 1 < ny; i += 2) {
197 tmp = x0 - _mm_loadu_ps (y); y += 4;
199 accu = _mm_hadd_ps (accu, accu);
200 dis[i] = _mm_cvtss_f32 (accu);
201 accu = _mm_shuffle_ps (accu, accu, 3);
202 dis[i + 1] = _mm_cvtss_f32 (accu);
205 dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]);
211 void fvec_L2sqr_ny_D4 (
float * dis,
const float * x,
212 const float * y,
size_t ny)
214 __m128 x0 = _mm_loadu_ps(x);
216 for (
size_t i = 0; i < ny; i++) {
218 tmp = x0 - _mm_loadu_ps (y); y += 4;
220 accu = _mm_hadd_ps (accu, accu);
221 accu = _mm_hadd_ps (accu, accu);
222 dis[i] = _mm_cvtss_f32 (accu);
227 void fvec_L2sqr_ny_D8 (
float * dis,
const float * x,
228 const float * y,
size_t ny)
230 __m128 x0 = _mm_loadu_ps(x);
231 __m128 x1 = _mm_loadu_ps(x + 4);
233 for (
size_t i = 0; i < ny; i++) {
235 tmp = x0 - _mm_loadu_ps (y); y += 4;
237 tmp = x1 - _mm_loadu_ps (y); y += 4;
239 accu = _mm_hadd_ps (accu, accu);
240 accu = _mm_hadd_ps (accu, accu);
241 dis[i] = _mm_cvtss_f32 (accu);
246 void fvec_L2sqr_ny_D12 (
float * dis,
const float * x,
247 const float * y,
size_t ny)
249 __m128 x0 = _mm_loadu_ps(x);
250 __m128 x1 = _mm_loadu_ps(x + 4);
251 __m128 x2 = _mm_loadu_ps(x + 8);
253 for (
size_t i = 0; i < ny; i++) {
255 tmp = x0 - _mm_loadu_ps (y); y += 4;
257 tmp = x1 - _mm_loadu_ps (y); y += 4;
259 tmp = x2 - _mm_loadu_ps (y); y += 4;
261 accu = _mm_hadd_ps (accu, accu);
262 accu = _mm_hadd_ps (accu, accu);
263 dis[i] = _mm_cvtss_f32 (accu);
270 void fvec_L2sqr_ny (
float * dis,
const float * x,
271 const float * y,
size_t d,
size_t ny) {
275 fvec_L2sqr_ny_D1 (dis, x, y, ny);
278 fvec_L2sqr_ny_D2 (dis, x, y, ny);
281 fvec_L2sqr_ny_D4 (dis, x, y, ny);
284 fvec_L2sqr_ny_D8 (dis, x, y, ny);
287 fvec_L2sqr_ny_D12 (dis, x, y, ny);
290 fvec_L2sqr_ny_ref (dis, x, y, d, ny);
302 static inline __m256 masked_read_8 (
int d,
const float *x)
304 assert (0 <= d && d < 8);
306 __m256 res = _mm256_setzero_ps ();
307 res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
310 __m256 res = _mm256_setzero_ps ();
311 res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
312 res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
317 float fvec_inner_product (
const float * x,
321 __m256 msum1 = _mm256_setzero_ps();
324 __m256 mx = _mm256_loadu_ps (x); x += 8;
325 __m256 my = _mm256_loadu_ps (y); y += 8;
326 msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
330 __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
331 msum2 += _mm256_extractf128_ps(msum1, 0);
334 __m128 mx = _mm_loadu_ps (x); x += 4;
335 __m128 my = _mm_loadu_ps (y); y += 4;
336 msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
341 __m128 mx = masked_read (d, x);
342 __m128 my = masked_read (d, y);
343 msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
346 msum2 = _mm_hadd_ps (msum2, msum2);
347 msum2 = _mm_hadd_ps (msum2, msum2);
348 return _mm_cvtss_f32 (msum2);
355 __m256 msum1 = _mm256_setzero_ps();
358 __m256 mx = _mm256_loadu_ps (x); x += 8;
359 __m256 my = _mm256_loadu_ps (y); y += 8;
360 const __m256 a_m_b1 = mx - my;
361 msum1 += a_m_b1 * a_m_b1;
365 __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
366 msum2 += _mm256_extractf128_ps(msum1, 0);
369 __m128 mx = _mm_loadu_ps (x); x += 4;
370 __m128 my = _mm_loadu_ps (y); y += 4;
371 const __m128 a_m_b1 = mx - my;
372 msum2 += a_m_b1 * a_m_b1;
377 __m128 mx = masked_read (d, x);
378 __m128 my = masked_read (d, y);
379 __m128 a_m_b1 = mx - my;
380 msum2 += a_m_b1 * a_m_b1;
383 msum2 = _mm_hadd_ps (msum2, msum2);
384 msum2 = _mm_hadd_ps (msum2, msum2);
385 return _mm_cvtss_f32 (msum2);
388 #elif defined(__SSE__)
395 __m128 msum1 = _mm_setzero_ps();
398 __m128 mx = _mm_loadu_ps (x); x += 4;
399 __m128 my = _mm_loadu_ps (y); y += 4;
400 const __m128 a_m_b1 = mx - my;
401 msum1 += a_m_b1 * a_m_b1;
407 __m128 mx = masked_read (d, x);
408 __m128 my = masked_read (d, y);
409 __m128 a_m_b1 = mx - my;
410 msum1 += a_m_b1 * a_m_b1;
413 msum1 = _mm_hadd_ps (msum1, msum1);
414 msum1 = _mm_hadd_ps (msum1, msum1);
415 return _mm_cvtss_f32 (msum1);
419 float fvec_inner_product (
const float * x,
424 __m128 msum1 = _mm_setzero_ps();
427 mx = _mm_loadu_ps (x); x += 4;
428 my = _mm_loadu_ps (y); y += 4;
429 msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
434 mx = masked_read (d, x);
435 my = masked_read (d, y);
436 __m128 prod = _mm_mul_ps (mx, my);
438 msum1 = _mm_add_ps (msum1, prod);
440 msum1 = _mm_hadd_ps (msum1, msum1);
441 msum1 = _mm_hadd_ps (msum1, msum1);
442 return _mm_cvtss_f32 (msum1);
445 #elif defined(__aarch64__)
452 if (d & 3)
return fvec_L2sqr_ref (x, y, d);
453 float32x4_t accu = vdupq_n_f32 (0);
454 for (
size_t i = 0; i < d; i += 4) {
455 float32x4_t xi = vld1q_f32 (x + i);
456 float32x4_t yi = vld1q_f32 (y + i);
457 float32x4_t sq = vsubq_f32 (xi, yi);
458 accu = vfmaq_f32 (accu, sq, sq);
460 float32x4_t a2 = vpaddq_f32 (accu, accu);
461 return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
464 float fvec_inner_product (
const float * x,
468 if (d & 3)
return fvec_inner_product_ref (x, y, d);
469 float32x4_t accu = vdupq_n_f32 (0);
470 for (
size_t i = 0; i < d; i += 4) {
471 float32x4_t xi = vld1q_f32 (x + i);
472 float32x4_t yi = vld1q_f32 (y + i);
473 accu = vfmaq_f32 (accu, xi, yi);
475 float32x4_t a2 = vpaddq_f32 (accu, accu);
476 return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
481 if (d & 3)
return fvec_norm_L2sqr_ref (x, d);
482 float32x4_t accu = vdupq_n_f32 (0);
483 for (
size_t i = 0; i < d; i += 4) {
484 float32x4_t xi = vld1q_f32 (x + i);
485 accu = vfmaq_f32 (accu, xi, xi);
487 float32x4_t a2 = vpaddq_f32 (accu, accu);
488 return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
492 void fvec_L2sqr_ny (
float * dis,
const float * x,
493 const float * y,
size_t d,
size_t ny) {
494 fvec_L2sqr_ny_ref (dis, x, y, d, ny);
505 return fvec_L2sqr_ref (x, y, d);
508 float fvec_inner_product (
const float * x,
512 return fvec_inner_product_ref (x, y, d);
517 return fvec_norm_L2sqr_ref (x, d);
520 void fvec_L2sqr_ny (
float * dis,
const float * x,
521 const float * y,
size_t d,
size_t ny) {
522 fvec_L2sqr_ny_ref (dis, x, y, d, ny);
552 static inline void fvec_madd_ref (
size_t n,
const float *a,
553 float bf,
const float *b,
float *c) {
554 for (
size_t i = 0; i < n; i++)
555 c[i] = a[i] + bf * b[i];
560 static inline void fvec_madd_sse (
size_t n,
const float *a,
561 float bf,
const float *b,
float *c) {
563 __m128 bf4 = _mm_set_ps1 (bf);
564 __m128 * a4 = (__m128*)a;
565 __m128 * b4 = (__m128*)b;
566 __m128 * c4 = (__m128*)c;
569 *c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
576 void fvec_madd (
size_t n,
const float *a,
577 float bf,
const float *b,
float *c)
580 ((((
long)a) | ((
long)b) | ((
long)c)) & 15) == 0)
581 fvec_madd_sse (n, a, bf, b, c);
583 fvec_madd_ref (n, a, bf, b, c);
589 float bf,
const float *b,
float *c)
591 fvec_madd_ref (n, a, bf, b, c);
596 static inline int fvec_madd_and_argmin_ref (
size_t n,
const float *a,
597 float bf,
const float *b,
float *c) {
601 for (
size_t i = 0; i < n; i++) {
602 c[i] = a[i] + bf * b[i];
613 static inline int fvec_madd_and_argmin_sse (
614 size_t n,
const float *a,
615 float bf,
const float *b,
float *c) {
617 __m128 bf4 = _mm_set_ps1 (bf);
618 __m128 vmin4 = _mm_set_ps1 (1e20);
619 __m128i imin4 = _mm_set1_epi32 (-1);
620 __m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
621 __m128i inc4 = _mm_set1_epi32 (4);
622 __m128 * a4 = (__m128*)a;
623 __m128 * b4 = (__m128*)b;
624 __m128 * c4 = (__m128*)c;
627 __m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
629 __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
632 imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
633 _mm_andnot_si128 (mask, imin4));
634 vmin4 = _mm_min_ps (vmin4, vc4);
638 idx4 = _mm_add_epi32 (idx4, inc4);
643 idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
644 __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
645 __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
646 imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
647 _mm_andnot_si128 (mask, imin4));
648 vmin4 = _mm_min_ps (vmin4, vc4);
652 idx4 = _mm_shuffle_epi32 (imin4, 1);
653 __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
654 __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
655 imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
656 _mm_andnot_si128 (mask, imin4));
659 return _mm_cvtsi128_si32 (imin4);
664 float bf,
const float *b,
float *c)
667 ((((
long)a) | ((
long)b) | ((
long)c)) & 15) == 0)
668 return fvec_madd_and_argmin_sse (n, a, bf, b, c);
670 return fvec_madd_and_argmin_ref (n, a, bf, b, c);
676 float bf,
const float *b,
float *c)
678 return fvec_madd_and_argmin_ref (n, a, bf, b, c);
float fvec_L2sqr(const float *x, const float *y, size_t d)
Squared L2 distance between two vectors.
void fvec_madd(size_t n, const float *a, float bf, const float *b, float *c)
float fvec_norm_L2sqr(const float *x, size_t d)
int fvec_madd_and_argmin(size_t n, const float *a, float bf, const float *b, float *c)