Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
/tmp/faiss/utils_simd.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // -*- c++ -*-
10 
11 #include "utils.h"
12 
13 #include <cstdio>
14 #include <cassert>
15 #include <cstring>
16 #include <cmath>
17 
18 #ifdef __SSE__
19 #include <immintrin.h>
20 #endif
21 
22 #ifdef __aarch64__
23 #include <arm_neon.h>
24 #endif
25 
26 #include <omp.h>
27 
28 
29 
30 /**************************************************
31  * Get some stats about the system
32  **************************************************/
33 
34 namespace faiss {
35 
36 #ifdef __AVX__
37 #define USE_AVX
38 #endif
39 
40 
41 /*********************************************************
42  * Optimized distance computations
43  *********************************************************/
44 
45 
46 /* Functions to compute:
47  - L2 distance between 2 vectors
48  - inner product between 2 vectors
49  - L2 norm of a vector
50 
51  The functions should probably not be invoked when a large number of
52  vectors are be processed in batch (in which case Matrix multiply
53  is faster), but may be useful for comparing vectors isolated in
54  memory.
55 
56  Works with any vectors of any dimension, even unaligned (in which
57  case they are slower).
58 
59 */
60 
61 
62 /*********************************************************
63  * Reference implementations
64  */
65 
66 /* same without SSE */
67 float fvec_L2sqr_ref (const float * x,
68  const float * y,
69  size_t d)
70 {
71  size_t i;
72  float res = 0;
73  for (i = 0; i < d; i++) {
74  const float tmp = x[i] - y[i];
75  res += tmp * tmp;
76  }
77  return res;
78 }
79 
80 float fvec_inner_product_ref (const float * x,
81  const float * y,
82  size_t d)
83 {
84  size_t i;
85  float res = 0;
86  for (i = 0; i < d; i++)
87  res += x[i] * y[i];
88  return res;
89 }
90 
91 float fvec_norm_L2sqr_ref (const float *x, size_t d)
92 {
93  size_t i;
94  double res = 0;
95  for (i = 0; i < d; i++)
96  res += x[i] * x[i];
97  return res;
98 }
99 
100 
101 void fvec_L2sqr_ny_ref (float * dis,
102  const float * x,
103  const float * y,
104  size_t d, size_t ny)
105 {
106  for (size_t i = 0; i < ny; i++) {
107  dis[i] = fvec_L2sqr (x, y, d);
108  y += d;
109  }
110 }
111 
112 
113 
114 
115 /*********************************************************
116  * SSE and AVX implementations
117  */
118 
119 #ifdef __SSE__
120 
121 // reads 0 <= d < 4 floats as __m128
122 static inline __m128 masked_read (int d, const float *x)
123 {
124  assert (0 <= d && d < 4);
125  __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
126  switch (d) {
127  case 3:
128  buf[2] = x[2];
129  case 2:
130  buf[1] = x[1];
131  case 1:
132  buf[0] = x[0];
133  }
134  return _mm_load_ps (buf);
135  // cannot use AVX2 _mm_mask_set1_epi32
136 }
137 
138 float fvec_norm_L2sqr (const float * x,
139  size_t d)
140 {
141  __m128 mx;
142  __m128 msum1 = _mm_setzero_ps();
143 
144  while (d >= 4) {
145  mx = _mm_loadu_ps (x); x += 4;
146  msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
147  d -= 4;
148  }
149 
150  mx = masked_read (d, x);
151  msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
152 
153  msum1 = _mm_hadd_ps (msum1, msum1);
154  msum1 = _mm_hadd_ps (msum1, msum1);
155  return _mm_cvtss_f32 (msum1);
156 }
157 
158 namespace {
159 
160 float sqr (float x) {
161  return x * x;
162 }
163 
164 
165 void fvec_L2sqr_ny_D1 (float * dis, const float * x,
166  const float * y, size_t ny)
167 {
168  float x0s = x[0];
169  __m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
170 
171  size_t i;
172  for (i = 0; i + 3 < ny; i += 4) {
173  __m128 tmp, accu;
174  tmp = x0 - _mm_loadu_ps (y); y += 4;
175  accu = tmp * tmp;
176  dis[i] = _mm_cvtss_f32 (accu);
177  tmp = _mm_shuffle_ps (accu, accu, 1);
178  dis[i + 1] = _mm_cvtss_f32 (tmp);
179  tmp = _mm_shuffle_ps (accu, accu, 2);
180  dis[i + 2] = _mm_cvtss_f32 (tmp);
181  tmp = _mm_shuffle_ps (accu, accu, 3);
182  dis[i + 3] = _mm_cvtss_f32 (tmp);
183  }
184  while (i < ny) { // handle non-multiple-of-4 case
185  dis[i++] = sqr(x0s - *y++);
186  }
187 }
188 
189 
190 void fvec_L2sqr_ny_D2 (float * dis, const float * x,
191  const float * y, size_t ny)
192 {
193  __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
194 
195  size_t i;
196  for (i = 0; i + 1 < ny; i += 2) {
197  __m128 tmp, accu;
198  tmp = x0 - _mm_loadu_ps (y); y += 4;
199  accu = tmp * tmp;
200  accu = _mm_hadd_ps (accu, accu);
201  dis[i] = _mm_cvtss_f32 (accu);
202  accu = _mm_shuffle_ps (accu, accu, 3);
203  dis[i + 1] = _mm_cvtss_f32 (accu);
204  }
205  if (i < ny) { // handle odd case
206  dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]);
207  }
208 }
209 
210 
211 
212 void fvec_L2sqr_ny_D4 (float * dis, const float * x,
213  const float * y, size_t ny)
214 {
215  __m128 x0 = _mm_loadu_ps(x);
216 
217  for (size_t i = 0; i < ny; i++) {
218  __m128 tmp, accu;
219  tmp = x0 - _mm_loadu_ps (y); y += 4;
220  accu = tmp * tmp;
221  accu = _mm_hadd_ps (accu, accu);
222  accu = _mm_hadd_ps (accu, accu);
223  dis[i] = _mm_cvtss_f32 (accu);
224  }
225 }
226 
227 
228 void fvec_L2sqr_ny_D8 (float * dis, const float * x,
229  const float * y, size_t ny)
230 {
231  __m128 x0 = _mm_loadu_ps(x);
232  __m128 x1 = _mm_loadu_ps(x + 4);
233 
234  for (size_t i = 0; i < ny; i++) {
235  __m128 tmp, accu;
236  tmp = x0 - _mm_loadu_ps (y); y += 4;
237  accu = tmp * tmp;
238  tmp = x1 - _mm_loadu_ps (y); y += 4;
239  accu += tmp * tmp;
240  accu = _mm_hadd_ps (accu, accu);
241  accu = _mm_hadd_ps (accu, accu);
242  dis[i] = _mm_cvtss_f32 (accu);
243  }
244 }
245 
246 
247 void fvec_L2sqr_ny_D12 (float * dis, const float * x,
248  const float * y, size_t ny)
249 {
250  __m128 x0 = _mm_loadu_ps(x);
251  __m128 x1 = _mm_loadu_ps(x + 4);
252  __m128 x2 = _mm_loadu_ps(x + 8);
253 
254  for (size_t i = 0; i < ny; i++) {
255  __m128 tmp, accu;
256  tmp = x0 - _mm_loadu_ps (y); y += 4;
257  accu = tmp * tmp;
258  tmp = x1 - _mm_loadu_ps (y); y += 4;
259  accu += tmp * tmp;
260  tmp = x2 - _mm_loadu_ps (y); y += 4;
261  accu += tmp * tmp;
262  accu = _mm_hadd_ps (accu, accu);
263  accu = _mm_hadd_ps (accu, accu);
264  dis[i] = _mm_cvtss_f32 (accu);
265  }
266 }
267 
268 
269 } // anonymous namespace
270 
271 void fvec_L2sqr_ny (float * dis, const float * x,
272  const float * y, size_t d, size_t ny) {
273  // optimized for a few special cases
274  switch(d) {
275  case 1:
276  fvec_L2sqr_ny_D1 (dis, x, y, ny);
277  return;
278  case 2:
279  fvec_L2sqr_ny_D2 (dis, x, y, ny);
280  return;
281  case 4:
282  fvec_L2sqr_ny_D4 (dis, x, y, ny);
283  return;
284  case 8:
285  fvec_L2sqr_ny_D8 (dis, x, y, ny);
286  return;
287  case 12:
288  fvec_L2sqr_ny_D12 (dis, x, y, ny);
289  return;
290  default:
291  fvec_L2sqr_ny_ref (dis, x, y, d, ny);
292  return;
293  }
294 }
295 
296 
297 
298 #endif
299 
300 #ifdef USE_AVX
301 
302 // reads 0 <= d < 8 floats as __m256
303 static inline __m256 masked_read_8 (int d, const float *x)
304 {
305  assert (0 <= d && d < 8);
306  if (d < 4) {
307  __m256 res = _mm256_setzero_ps ();
308  res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
309  return res;
310  } else {
311  __m256 res = _mm256_setzero_ps ();
312  res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
313  res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
314  return res;
315  }
316 }
317 
318 float fvec_inner_product (const float * x,
319  const float * y,
320  size_t d)
321 {
322  __m256 msum1 = _mm256_setzero_ps();
323 
324  while (d >= 8) {
325  __m256 mx = _mm256_loadu_ps (x); x += 8;
326  __m256 my = _mm256_loadu_ps (y); y += 8;
327  msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
328  d -= 8;
329  }
330 
331  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
332  msum2 += _mm256_extractf128_ps(msum1, 0);
333 
334  if (d >= 4) {
335  __m128 mx = _mm_loadu_ps (x); x += 4;
336  __m128 my = _mm_loadu_ps (y); y += 4;
337  msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
338  d -= 4;
339  }
340 
341  if (d > 0) {
342  __m128 mx = masked_read (d, x);
343  __m128 my = masked_read (d, y);
344  msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
345  }
346 
347  msum2 = _mm_hadd_ps (msum2, msum2);
348  msum2 = _mm_hadd_ps (msum2, msum2);
349  return _mm_cvtss_f32 (msum2);
350 }
351 
352 float fvec_L2sqr (const float * x,
353  const float * y,
354  size_t d)
355 {
356  __m256 msum1 = _mm256_setzero_ps();
357 
358  while (d >= 8) {
359  __m256 mx = _mm256_loadu_ps (x); x += 8;
360  __m256 my = _mm256_loadu_ps (y); y += 8;
361  const __m256 a_m_b1 = mx - my;
362  msum1 += a_m_b1 * a_m_b1;
363  d -= 8;
364  }
365 
366  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
367  msum2 += _mm256_extractf128_ps(msum1, 0);
368 
369  if (d >= 4) {
370  __m128 mx = _mm_loadu_ps (x); x += 4;
371  __m128 my = _mm_loadu_ps (y); y += 4;
372  const __m128 a_m_b1 = mx - my;
373  msum2 += a_m_b1 * a_m_b1;
374  d -= 4;
375  }
376 
377  if (d > 0) {
378  __m128 mx = masked_read (d, x);
379  __m128 my = masked_read (d, y);
380  __m128 a_m_b1 = mx - my;
381  msum2 += a_m_b1 * a_m_b1;
382  }
383 
384  msum2 = _mm_hadd_ps (msum2, msum2);
385  msum2 = _mm_hadd_ps (msum2, msum2);
386  return _mm_cvtss_f32 (msum2);
387 }
388 
389 #elif defined(__SSE__)
390 
391 /* SSE-implementation of L2 distance */
392 float fvec_L2sqr (const float * x,
393  const float * y,
394  size_t d)
395 {
396  __m128 msum1 = _mm_setzero_ps();
397 
398  while (d >= 4) {
399  __m128 mx = _mm_loadu_ps (x); x += 4;
400  __m128 my = _mm_loadu_ps (y); y += 4;
401  const __m128 a_m_b1 = mx - my;
402  msum1 += a_m_b1 * a_m_b1;
403  d -= 4;
404  }
405 
406  if (d > 0) {
407  // add the last 1, 2 or 3 values
408  __m128 mx = masked_read (d, x);
409  __m128 my = masked_read (d, y);
410  __m128 a_m_b1 = mx - my;
411  msum1 += a_m_b1 * a_m_b1;
412  }
413 
414  msum1 = _mm_hadd_ps (msum1, msum1);
415  msum1 = _mm_hadd_ps (msum1, msum1);
416  return _mm_cvtss_f32 (msum1);
417 }
418 
419 
420 float fvec_inner_product (const float * x,
421  const float * y,
422  size_t d)
423 {
424  __m128 mx, my;
425  __m128 msum1 = _mm_setzero_ps();
426 
427  while (d >= 4) {
428  mx = _mm_loadu_ps (x); x += 4;
429  my = _mm_loadu_ps (y); y += 4;
430  msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
431  d -= 4;
432  }
433 
434  // add the last 1, 2, or 3 values
435  mx = masked_read (d, x);
436  my = masked_read (d, y);
437  __m128 prod = _mm_mul_ps (mx, my);
438 
439  msum1 = _mm_add_ps (msum1, prod);
440 
441  msum1 = _mm_hadd_ps (msum1, msum1);
442  msum1 = _mm_hadd_ps (msum1, msum1);
443  return _mm_cvtss_f32 (msum1);
444 }
445 
446 #elif defined(__aarch64__)
447 
448 
449 float fvec_L2sqr (const float * x,
450  const float * y,
451  size_t d)
452 {
453  if (d & 3) return fvec_L2sqr_ref (x, y, d);
454  float32x4_t accu = vdupq_n_f32 (0);
455  for (size_t i = 0; i < d; i += 4) {
456  float32x4_t xi = vld1q_f32 (x + i);
457  float32x4_t yi = vld1q_f32 (y + i);
458  float32x4_t sq = vsubq_f32 (xi, yi);
459  accu = vfmaq_f32 (accu, sq, sq);
460  }
461  float32x4_t a2 = vpaddq_f32 (accu, accu);
462  return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
463 }
464 
465 float fvec_inner_product (const float * x,
466  const float * y,
467  size_t d)
468 {
469  if (d & 3) return fvec_inner_product_ref (x, y, d);
470  float32x4_t accu = vdupq_n_f32 (0);
471  for (size_t i = 0; i < d; i += 4) {
472  float32x4_t xi = vld1q_f32 (x + i);
473  float32x4_t yi = vld1q_f32 (y + i);
474  accu = vfmaq_f32 (accu, xi, yi);
475  }
476  float32x4_t a2 = vpaddq_f32 (accu, accu);
477  return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
478 }
479 
480 float fvec_norm_L2sqr (const float *x, size_t d)
481 {
482  if (d & 3) return fvec_norm_L2sqr_ref (x, d);
483  float32x4_t accu = vdupq_n_f32 (0);
484  for (size_t i = 0; i < d; i += 4) {
485  float32x4_t xi = vld1q_f32 (x + i);
486  accu = vfmaq_f32 (accu, xi, xi);
487  }
488  float32x4_t a2 = vpaddq_f32 (accu, accu);
489  return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
490 }
491 
492 // not optimized for ARM
493 void fvec_L2sqr_ny (float * dis, const float * x,
494  const float * y, size_t d, size_t ny) {
495  fvec_L2sqr_ny_ref (dis, x, y, d, ny);
496 }
497 
498 
499 #else
500 // scalar implementation
501 
502 float fvec_L2sqr (const float * x,
503  const float * y,
504  size_t d)
505 {
506  return fvec_L2sqr_ref (x, y, d);
507 }
508 
509 float fvec_inner_product (const float * x,
510  const float * y,
511  size_t d)
512 {
513  return fvec_inner_product_ref (x, y, d);
514 }
515 
516 float fvec_norm_L2sqr (const float *x, size_t d)
517 {
518  return fvec_norm_L2sqr_ref (x, d);
519 }
520 
521 void fvec_L2sqr_ny (float * dis, const float * x,
522  const float * y, size_t d, size_t ny) {
523  fvec_L2sqr_ny_ref (dis, x, y, d, ny);
524 }
525 
526 
527 #endif
528 
529 
530 
531 
532 
533 
534 
535 
536 
537 
538 
539 
540 
541 
542 
543 
544 
545 
546 
547 
548 /***************************************************************************
549  * heavily optimized table computations
550  ***************************************************************************/
551 
552 
553 static inline void fvec_madd_ref (size_t n, const float *a,
554  float bf, const float *b, float *c) {
555  for (size_t i = 0; i < n; i++)
556  c[i] = a[i] + bf * b[i];
557 }
558 
559 #ifdef __SSE__
560 
561 static inline void fvec_madd_sse (size_t n, const float *a,
562  float bf, const float *b, float *c) {
563  n >>= 2;
564  __m128 bf4 = _mm_set_ps1 (bf);
565  __m128 * a4 = (__m128*)a;
566  __m128 * b4 = (__m128*)b;
567  __m128 * c4 = (__m128*)c;
568 
569  while (n--) {
570  *c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
571  b4++;
572  a4++;
573  c4++;
574  }
575 }
576 
577 void fvec_madd (size_t n, const float *a,
578  float bf, const float *b, float *c)
579 {
580  if ((n & 3) == 0 &&
581  ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
582  fvec_madd_sse (n, a, bf, b, c);
583  else
584  fvec_madd_ref (n, a, bf, b, c);
585 }
586 
587 #else
588 
589 void fvec_madd (size_t n, const float *a,
590  float bf, const float *b, float *c)
591 {
592  fvec_madd_ref (n, a, bf, b, c);
593 }
594 
595 #endif
596 
597 static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
598  float bf, const float *b, float *c) {
599  float vmin = 1e20;
600  int imin = -1;
601 
602  for (size_t i = 0; i < n; i++) {
603  c[i] = a[i] + bf * b[i];
604  if (c[i] < vmin) {
605  vmin = c[i];
606  imin = i;
607  }
608  }
609  return imin;
610 }
611 
612 #ifdef __SSE__
613 
614 static inline int fvec_madd_and_argmin_sse (
615  size_t n, const float *a,
616  float bf, const float *b, float *c) {
617  n >>= 2;
618  __m128 bf4 = _mm_set_ps1 (bf);
619  __m128 vmin4 = _mm_set_ps1 (1e20);
620  __m128i imin4 = _mm_set1_epi32 (-1);
621  __m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
622  __m128i inc4 = _mm_set1_epi32 (4);
623  __m128 * a4 = (__m128*)a;
624  __m128 * b4 = (__m128*)b;
625  __m128 * c4 = (__m128*)c;
626 
627  while (n--) {
628  __m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
629  *c4 = vc4;
630  __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
631  // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
632 
633  imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
634  _mm_andnot_si128 (mask, imin4));
635  vmin4 = _mm_min_ps (vmin4, vc4);
636  b4++;
637  a4++;
638  c4++;
639  idx4 = _mm_add_epi32 (idx4, inc4);
640  }
641 
642  // 4 values -> 2
643  {
644  idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
645  __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
646  __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
647  imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
648  _mm_andnot_si128 (mask, imin4));
649  vmin4 = _mm_min_ps (vmin4, vc4);
650  }
651  // 2 values -> 1
652  {
653  idx4 = _mm_shuffle_epi32 (imin4, 1);
654  __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
655  __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
656  imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
657  _mm_andnot_si128 (mask, imin4));
658  // vmin4 = _mm_min_ps (vmin4, vc4);
659  }
660  return _mm_cvtsi128_si32 (imin4);
661 }
662 
663 
664 int fvec_madd_and_argmin (size_t n, const float *a,
665  float bf, const float *b, float *c)
666 {
667  if ((n & 3) == 0 &&
668  ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
669  return fvec_madd_and_argmin_sse (n, a, bf, b, c);
670  else
671  return fvec_madd_and_argmin_ref (n, a, bf, b, c);
672 }
673 
674 #else
675 
676 int fvec_madd_and_argmin (size_t n, const float *a,
677  float bf, const float *b, float *c)
678 {
679  return fvec_madd_and_argmin_ref (n, a, bf, b, c);
680 }
681 
682 #endif
683 
684 
685 
686 
687 } // namespace faiss
float fvec_L2sqr(const float *x, const float *y, size_t d)
Squared L2 distance between two vectors.
Definition: utils_simd.cpp:502
void fvec_madd(size_t n, const float *a, float bf, const float *b, float *c)
Definition: utils_simd.cpp:589
float fvec_norm_L2sqr(const float *x, size_t d)
Definition: utils_simd.cpp:516
int fvec_madd_and_argmin(size_t n, const float *a, float bf, const float *b, float *c)
Definition: utils_simd.cpp:676