Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/data/users/hoss/faiss/utils_simd.cpp
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 #include "utils.h"
11 
12 #include <cstdio>
13 #include <cassert>
14 #include <cstring>
15 #include <cmath>
16 
17 #ifdef __SSE__
18 #include <immintrin.h>
19 #endif
20 
21 #ifdef __aarch64__
22 #include <arm_neon.h>
23 #endif
24 
25 #include <omp.h>
26 
27 
28 
29 /**************************************************
30  * Get some stats about the system
31  **************************************************/
32 
33 namespace faiss {
34 
35 #ifdef __AVX__
36 #define USE_AVX
37 #endif
38 
39 
40 /*********************************************************
41  * Optimized distance computations
42  *********************************************************/
43 
44 
45 /* Functions to compute:
46  - L2 distance between 2 vectors
47  - inner product between 2 vectors
48  - L2 norm of a vector
49 
50  The functions should probably not be invoked when a large number of
51  vectors are be processed in batch (in which case Matrix multiply
52  is faster), but may be useful for comparing vectors isolated in
53  memory.
54 
55  Works with any vectors of any dimension, even unaligned (in which
56  case they are slower).
57 
58 */
59 
60 
61 /*********************************************************
62  * Reference implementations
63  */
64 
65 /* same without SSE */
66 float fvec_L2sqr_ref (const float * x,
67  const float * y,
68  size_t d)
69 {
70  size_t i;
71  float res = 0;
72  for (i = 0; i < d; i++) {
73  const float tmp = x[i] - y[i];
74  res += tmp * tmp;
75  }
76  return res;
77 }
78 
79 float fvec_inner_product_ref (const float * x,
80  const float * y,
81  size_t d)
82 {
83  size_t i;
84  float res = 0;
85  for (i = 0; i < d; i++)
86  res += x[i] * y[i];
87  return res;
88 }
89 
90 float fvec_norm_L2sqr_ref (const float *x, size_t d)
91 {
92  size_t i;
93  double res = 0;
94  for (i = 0; i < d; i++)
95  res += x[i] * x[i];
96  return res;
97 }
98 
99 
100 void fvec_L2sqr_ny_ref (float * dis,
101  const float * x,
102  const float * y,
103  size_t d, size_t ny)
104 {
105  for (size_t i = 0; i < ny; i++) {
106  dis[i] = fvec_L2sqr (x, y, d);
107  y += d;
108  }
109 }
110 
111 
112 
113 
114 /*********************************************************
115  * SSE and AVX implementations
116  */
117 
118 #ifdef __SSE__
119 
120 // reads 0 <= d < 4 floats as __m128
121 static inline __m128 masked_read (int d, const float *x)
122 {
123  assert (0 <= d && d < 4);
124  __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
125  switch (d) {
126  case 3:
127  buf[2] = x[2];
128  case 2:
129  buf[1] = x[1];
130  case 1:
131  buf[0] = x[0];
132  }
133  return _mm_load_ps (buf);
134  // cannot use AVX2 _mm_mask_set1_epi32
135 }
136 
137 float fvec_norm_L2sqr (const float * x,
138  size_t d)
139 {
140  __m128 mx;
141  __m128 msum1 = _mm_setzero_ps();
142 
143  while (d >= 4) {
144  mx = _mm_loadu_ps (x); x += 4;
145  msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
146  d -= 4;
147  }
148 
149  mx = masked_read (d, x);
150  msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
151 
152  msum1 = _mm_hadd_ps (msum1, msum1);
153  msum1 = _mm_hadd_ps (msum1, msum1);
154  return _mm_cvtss_f32 (msum1);
155 }
156 
157 namespace {
158 
159 float sqr (float x) {
160  return x * x;
161 }
162 
163 
164 void fvec_L2sqr_ny_D1 (float * dis, const float * x,
165  const float * y, size_t ny)
166 {
167  float x0s = x[0];
168  __m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
169 
170  size_t i;
171  for (i = 0; i + 3 < ny; i += 4) {
172  __m128 tmp, accu;
173  tmp = x0 - _mm_loadu_ps (y); y += 4;
174  accu = tmp * tmp;
175  dis[i] = _mm_cvtss_f32 (accu);
176  tmp = _mm_shuffle_ps (accu, accu, 1);
177  dis[i + 1] = _mm_cvtss_f32 (tmp);
178  tmp = _mm_shuffle_ps (accu, accu, 2);
179  dis[i + 2] = _mm_cvtss_f32 (tmp);
180  tmp = _mm_shuffle_ps (accu, accu, 3);
181  dis[i + 3] = _mm_cvtss_f32 (tmp);
182  }
183  while (i < ny) { // handle non-multiple-of-4 case
184  dis[i++] = sqr(x0s - *y++);
185  }
186 }
187 
188 
189 void fvec_L2sqr_ny_D2 (float * dis, const float * x,
190  const float * y, size_t ny)
191 {
192  __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
193 
194  size_t i;
195  for (i = 0; i + 1 < ny; i += 2) {
196  __m128 tmp, accu;
197  tmp = x0 - _mm_loadu_ps (y); y += 4;
198  accu = tmp * tmp;
199  accu = _mm_hadd_ps (accu, accu);
200  dis[i] = _mm_cvtss_f32 (accu);
201  accu = _mm_shuffle_ps (accu, accu, 3);
202  dis[i + 1] = _mm_cvtss_f32 (accu);
203  }
204  if (i < ny) { // handle odd case
205  dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]);
206  }
207 }
208 
209 
210 
211 void fvec_L2sqr_ny_D4 (float * dis, const float * x,
212  const float * y, size_t ny)
213 {
214  __m128 x0 = _mm_loadu_ps(x);
215 
216  for (size_t i = 0; i < ny; i++) {
217  __m128 tmp, accu;
218  tmp = x0 - _mm_loadu_ps (y); y += 4;
219  accu = tmp * tmp;
220  accu = _mm_hadd_ps (accu, accu);
221  accu = _mm_hadd_ps (accu, accu);
222  dis[i] = _mm_cvtss_f32 (accu);
223  }
224 }
225 
226 
227 void fvec_L2sqr_ny_D8 (float * dis, const float * x,
228  const float * y, size_t ny)
229 {
230  __m128 x0 = _mm_loadu_ps(x);
231  __m128 x1 = _mm_loadu_ps(x + 4);
232 
233  for (size_t i = 0; i < ny; i++) {
234  __m128 tmp, accu;
235  tmp = x0 - _mm_loadu_ps (y); y += 4;
236  accu = tmp * tmp;
237  tmp = x1 - _mm_loadu_ps (y); y += 4;
238  accu += tmp * tmp;
239  accu = _mm_hadd_ps (accu, accu);
240  accu = _mm_hadd_ps (accu, accu);
241  dis[i] = _mm_cvtss_f32 (accu);
242  }
243 }
244 
245 
246 void fvec_L2sqr_ny_D12 (float * dis, const float * x,
247  const float * y, size_t ny)
248 {
249  __m128 x0 = _mm_loadu_ps(x);
250  __m128 x1 = _mm_loadu_ps(x + 4);
251  __m128 x2 = _mm_loadu_ps(x + 8);
252 
253  for (size_t i = 0; i < ny; i++) {
254  __m128 tmp, accu;
255  tmp = x0 - _mm_loadu_ps (y); y += 4;
256  accu = tmp * tmp;
257  tmp = x1 - _mm_loadu_ps (y); y += 4;
258  accu += tmp * tmp;
259  tmp = x2 - _mm_loadu_ps (y); y += 4;
260  accu += tmp * tmp;
261  accu = _mm_hadd_ps (accu, accu);
262  accu = _mm_hadd_ps (accu, accu);
263  dis[i] = _mm_cvtss_f32 (accu);
264  }
265 }
266 
267 
268 } // anonymous namespace
269 
270 void fvec_L2sqr_ny (float * dis, const float * x,
271  const float * y, size_t d, size_t ny) {
272  // optimized for a few special cases
273  switch(d) {
274  case 1:
275  fvec_L2sqr_ny_D1 (dis, x, y, ny);
276  return;
277  case 2:
278  fvec_L2sqr_ny_D2 (dis, x, y, ny);
279  return;
280  case 4:
281  fvec_L2sqr_ny_D4 (dis, x, y, ny);
282  return;
283  case 8:
284  fvec_L2sqr_ny_D8 (dis, x, y, ny);
285  return;
286  case 12:
287  fvec_L2sqr_ny_D12 (dis, x, y, ny);
288  return;
289  default:
290  fvec_L2sqr_ny_ref (dis, x, y, d, ny);
291  return;
292  }
293 }
294 
295 
296 
297 #endif
298 
299 #ifdef USE_AVX
300 
301 // reads 0 <= d < 8 floats as __m256
302 static inline __m256 masked_read_8 (int d, const float *x)
303 {
304  assert (0 <= d && d < 8);
305  if (d < 4) {
306  __m256 res = _mm256_setzero_ps ();
307  res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
308  return res;
309  } else {
310  __m256 res = _mm256_setzero_ps ();
311  res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
312  res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
313  return res;
314  }
315 }
316 
317 float fvec_inner_product (const float * x,
318  const float * y,
319  size_t d)
320 {
321  __m256 msum1 = _mm256_setzero_ps();
322 
323  while (d >= 8) {
324  __m256 mx = _mm256_loadu_ps (x); x += 8;
325  __m256 my = _mm256_loadu_ps (y); y += 8;
326  msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
327  d -= 8;
328  }
329 
330  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
331  msum2 += _mm256_extractf128_ps(msum1, 0);
332 
333  if (d >= 4) {
334  __m128 mx = _mm_loadu_ps (x); x += 4;
335  __m128 my = _mm_loadu_ps (y); y += 4;
336  msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
337  d -= 4;
338  }
339 
340  if (d > 0) {
341  __m128 mx = masked_read (d, x);
342  __m128 my = masked_read (d, y);
343  msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
344  }
345 
346  msum2 = _mm_hadd_ps (msum2, msum2);
347  msum2 = _mm_hadd_ps (msum2, msum2);
348  return _mm_cvtss_f32 (msum2);
349 }
350 
351 float fvec_L2sqr (const float * x,
352  const float * y,
353  size_t d)
354 {
355  __m256 msum1 = _mm256_setzero_ps();
356 
357  while (d >= 8) {
358  __m256 mx = _mm256_loadu_ps (x); x += 8;
359  __m256 my = _mm256_loadu_ps (y); y += 8;
360  const __m256 a_m_b1 = mx - my;
361  msum1 += a_m_b1 * a_m_b1;
362  d -= 8;
363  }
364 
365  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
366  msum2 += _mm256_extractf128_ps(msum1, 0);
367 
368  if (d >= 4) {
369  __m128 mx = _mm_loadu_ps (x); x += 4;
370  __m128 my = _mm_loadu_ps (y); y += 4;
371  const __m128 a_m_b1 = mx - my;
372  msum2 += a_m_b1 * a_m_b1;
373  d -= 4;
374  }
375 
376  if (d > 0) {
377  __m128 mx = masked_read (d, x);
378  __m128 my = masked_read (d, y);
379  __m128 a_m_b1 = mx - my;
380  msum2 += a_m_b1 * a_m_b1;
381  }
382 
383  msum2 = _mm_hadd_ps (msum2, msum2);
384  msum2 = _mm_hadd_ps (msum2, msum2);
385  return _mm_cvtss_f32 (msum2);
386 }
387 
388 #elif defined(__SSE__)
389 
390 /* SSE-implementation of L2 distance */
391 float fvec_L2sqr (const float * x,
392  const float * y,
393  size_t d)
394 {
395  __m128 msum1 = _mm_setzero_ps();
396 
397  while (d >= 4) {
398  __m128 mx = _mm_loadu_ps (x); x += 4;
399  __m128 my = _mm_loadu_ps (y); y += 4;
400  const __m128 a_m_b1 = mx - my;
401  msum1 += a_m_b1 * a_m_b1;
402  d -= 4;
403  }
404 
405  if (d > 0) {
406  // add the last 1, 2 or 3 values
407  __m128 mx = masked_read (d, x);
408  __m128 my = masked_read (d, y);
409  __m128 a_m_b1 = mx - my;
410  msum1 += a_m_b1 * a_m_b1;
411  }
412 
413  msum1 = _mm_hadd_ps (msum1, msum1);
414  msum1 = _mm_hadd_ps (msum1, msum1);
415  return _mm_cvtss_f32 (msum1);
416 }
417 
418 
419 float fvec_inner_product (const float * x,
420  const float * y,
421  size_t d)
422 {
423  __m128 mx, my;
424  __m128 msum1 = _mm_setzero_ps();
425 
426  while (d >= 4) {
427  mx = _mm_loadu_ps (x); x += 4;
428  my = _mm_loadu_ps (y); y += 4;
429  msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
430  d -= 4;
431  }
432 
433  // add the last 1, 2, or 3 values
434  mx = masked_read (d, x);
435  my = masked_read (d, y);
436  __m128 prod = _mm_mul_ps (mx, my);
437 
438  msum1 = _mm_add_ps (msum1, prod);
439 
440  msum1 = _mm_hadd_ps (msum1, msum1);
441  msum1 = _mm_hadd_ps (msum1, msum1);
442  return _mm_cvtss_f32 (msum1);
443 }
444 
445 #elif defined(__aarch64__)
446 
447 
448 float fvec_L2sqr (const float * x,
449  const float * y,
450  size_t d)
451 {
452  if (d & 3) return fvec_L2sqr_ref (x, y, d);
453  float32x4_t accu = vdupq_n_f32 (0);
454  for (size_t i = 0; i < d; i += 4) {
455  float32x4_t xi = vld1q_f32 (x + i);
456  float32x4_t yi = vld1q_f32 (y + i);
457  float32x4_t sq = vsubq_f32 (xi, yi);
458  accu = vfmaq_f32 (accu, sq, sq);
459  }
460  float32x4_t a2 = vpaddq_f32 (accu, accu);
461  return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
462 }
463 
464 float fvec_inner_product (const float * x,
465  const float * y,
466  size_t d)
467 {
468  if (d & 3) return fvec_inner_product_ref (x, y, d);
469  float32x4_t accu = vdupq_n_f32 (0);
470  for (size_t i = 0; i < d; i += 4) {
471  float32x4_t xi = vld1q_f32 (x + i);
472  float32x4_t yi = vld1q_f32 (y + i);
473  accu = vfmaq_f32 (accu, xi, yi);
474  }
475  float32x4_t a2 = vpaddq_f32 (accu, accu);
476  return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
477 }
478 
479 float fvec_norm_L2sqr (const float *x, size_t d)
480 {
481  if (d & 3) return fvec_norm_L2sqr_ref (x, d);
482  float32x4_t accu = vdupq_n_f32 (0);
483  for (size_t i = 0; i < d; i += 4) {
484  float32x4_t xi = vld1q_f32 (x + i);
485  accu = vfmaq_f32 (accu, xi, xi);
486  }
487  float32x4_t a2 = vpaddq_f32 (accu, accu);
488  return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
489 }
490 
491 // not optimized for ARM
492 void fvec_L2sqr_ny (float * dis, const float * x,
493  const float * y, size_t d, size_t ny) {
494  fvec_L2sqr_ny_ref (dis, x, y, d, ny);
495 }
496 
497 
498 #else
499 // scalar implementation
500 
501 float fvec_L2sqr (const float * x,
502  const float * y,
503  size_t d)
504 {
505  return fvec_L2sqr_ref (x, y, d);
506 }
507 
508 float fvec_inner_product (const float * x,
509  const float * y,
510  size_t d)
511 {
512  return fvec_inner_product_ref (x, y, d);
513 }
514 
515 float fvec_norm_L2sqr (const float *x, size_t d)
516 {
517  return fvec_norm_L2sqr_ref (x, d);
518 }
519 
520 void fvec_L2sqr_ny (float * dis, const float * x,
521  const float * y, size_t d, size_t ny) {
522  fvec_L2sqr_ny_ref (dis, x, y, d, ny);
523 }
524 
525 
526 #endif
527 
528 
529 
530 
531 
532 
533 
534 
535 
536 
537 
538 
539 
540 
541 
542 
543 
544 
545 
546 
547 /***************************************************************************
548  * heavily optimized table computations
549  ***************************************************************************/
550 
551 
552 static inline void fvec_madd_ref (size_t n, const float *a,
553  float bf, const float *b, float *c) {
554  for (size_t i = 0; i < n; i++)
555  c[i] = a[i] + bf * b[i];
556 }
557 
558 #ifdef __SSE__
559 
560 static inline void fvec_madd_sse (size_t n, const float *a,
561  float bf, const float *b, float *c) {
562  n >>= 2;
563  __m128 bf4 = _mm_set_ps1 (bf);
564  __m128 * a4 = (__m128*)a;
565  __m128 * b4 = (__m128*)b;
566  __m128 * c4 = (__m128*)c;
567 
568  while (n--) {
569  *c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
570  b4++;
571  a4++;
572  c4++;
573  }
574 }
575 
576 void fvec_madd (size_t n, const float *a,
577  float bf, const float *b, float *c)
578 {
579  if ((n & 3) == 0 &&
580  ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
581  fvec_madd_sse (n, a, bf, b, c);
582  else
583  fvec_madd_ref (n, a, bf, b, c);
584 }
585 
586 #else
587 
588 void fvec_madd (size_t n, const float *a,
589  float bf, const float *b, float *c)
590 {
591  fvec_madd_ref (n, a, bf, b, c);
592 }
593 
594 #endif
595 
596 static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
597  float bf, const float *b, float *c) {
598  float vmin = 1e20;
599  int imin = -1;
600 
601  for (size_t i = 0; i < n; i++) {
602  c[i] = a[i] + bf * b[i];
603  if (c[i] < vmin) {
604  vmin = c[i];
605  imin = i;
606  }
607  }
608  return imin;
609 }
610 
611 #ifdef __SSE__
612 
613 static inline int fvec_madd_and_argmin_sse (
614  size_t n, const float *a,
615  float bf, const float *b, float *c) {
616  n >>= 2;
617  __m128 bf4 = _mm_set_ps1 (bf);
618  __m128 vmin4 = _mm_set_ps1 (1e20);
619  __m128i imin4 = _mm_set1_epi32 (-1);
620  __m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
621  __m128i inc4 = _mm_set1_epi32 (4);
622  __m128 * a4 = (__m128*)a;
623  __m128 * b4 = (__m128*)b;
624  __m128 * c4 = (__m128*)c;
625 
626  while (n--) {
627  __m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
628  *c4 = vc4;
629  __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
630  // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
631 
632  imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
633  _mm_andnot_si128 (mask, imin4));
634  vmin4 = _mm_min_ps (vmin4, vc4);
635  b4++;
636  a4++;
637  c4++;
638  idx4 = _mm_add_epi32 (idx4, inc4);
639  }
640 
641  // 4 values -> 2
642  {
643  idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
644  __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
645  __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
646  imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
647  _mm_andnot_si128 (mask, imin4));
648  vmin4 = _mm_min_ps (vmin4, vc4);
649  }
650  // 2 values -> 1
651  {
652  idx4 = _mm_shuffle_epi32 (imin4, 1);
653  __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
654  __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
655  imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
656  _mm_andnot_si128 (mask, imin4));
657  // vmin4 = _mm_min_ps (vmin4, vc4);
658  }
659  return _mm_cvtsi128_si32 (imin4);
660 }
661 
662 
663 int fvec_madd_and_argmin (size_t n, const float *a,
664  float bf, const float *b, float *c)
665 {
666  if ((n & 3) == 0 &&
667  ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
668  return fvec_madd_and_argmin_sse (n, a, bf, b, c);
669  else
670  return fvec_madd_and_argmin_ref (n, a, bf, b, c);
671 }
672 
673 #else
674 
675 int fvec_madd_and_argmin (size_t n, const float *a,
676  float bf, const float *b, float *c)
677 {
678  return fvec_madd_and_argmin_ref (n, a, bf, b, c);
679 }
680 
681 #endif
682 
683 
684 
685 
686 } // namespace faiss
float fvec_L2sqr(const float *x, const float *y, size_t d)
Squared L2 distance between two vectors.
Definition: utils_simd.cpp:501
void fvec_madd(size_t n, const float *a, float bf, const float *b, float *c)
Definition: utils_simd.cpp:588
float fvec_norm_L2sqr(const float *x, size_t d)
Definition: utils_simd.cpp:515
int fvec_madd_and_argmin(size_t n, const float *a, float bf, const float *b, float *c)
Definition: utils_simd.cpp:675