Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MathOperators.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #pragma once
10 
11 #include "Float16.cuh"
12 
13 //
14 // Templated wrappers to express math for different scalar and vector
15 // types, so kernels can have the same written form but can operate
16 // over half and float, and on vector types transparently
17 //
18 
19 namespace faiss { namespace gpu {
20 
21 template <typename T>
22 struct Math {
23  typedef T ScalarType;
24 
25  static inline __device__ T add(T a, T b) {
26  return a + b;
27  }
28 
29  static inline __device__ T sub(T a, T b) {
30  return a - b;
31  }
32 
33  static inline __device__ T mul(T a, T b) {
34  return a * b;
35  }
36 
37  static inline __device__ T neg(T v) {
38  return -v;
39  }
40 
41  /// For a vector type, this is a horizontal add, returning sum(v_i)
42  static inline __device__ T reduceAdd(T v) {
43  return v;
44  }
45 
46  static inline __device__ bool lt(T a, T b) {
47  return a < b;
48  }
49 
50  static inline __device__ bool gt(T a, T b) {
51  return a > b;
52  }
53 
54  static inline __device__ bool eq(T a, T b) {
55  return a == b;
56  }
57 
58  static inline __device__ T zero() {
59  return (T) 0;
60  }
61 };
62 
63 template <>
64 struct Math<float2> {
65  typedef float ScalarType;
66 
67  static inline __device__ float2 add(float2 a, float2 b) {
68  float2 v;
69  v.x = a.x + b.x;
70  v.y = a.y + b.y;
71  return v;
72  }
73 
74  static inline __device__ float2 sub(float2 a, float2 b) {
75  float2 v;
76  v.x = a.x - b.x;
77  v.y = a.y - b.y;
78  return v;
79  }
80 
81  static inline __device__ float2 add(float2 a, float b) {
82  float2 v;
83  v.x = a.x + b;
84  v.y = a.y + b;
85  return v;
86  }
87 
88  static inline __device__ float2 sub(float2 a, float b) {
89  float2 v;
90  v.x = a.x - b;
91  v.y = a.y - b;
92  return v;
93  }
94 
95  static inline __device__ float2 mul(float2 a, float2 b) {
96  float2 v;
97  v.x = a.x * b.x;
98  v.y = a.y * b.y;
99  return v;
100  }
101 
102  static inline __device__ float2 mul(float2 a, float b) {
103  float2 v;
104  v.x = a.x * b;
105  v.y = a.y * b;
106  return v;
107  }
108 
109  static inline __device__ float2 neg(float2 v) {
110  v.x = -v.x;
111  v.y = -v.y;
112  return v;
113  }
114 
115  /// For a vector type, this is a horizontal add, returning sum(v_i)
116  static inline __device__ float reduceAdd(float2 v) {
117  return v.x + v.y;
118  }
119 
120  // not implemented for vector types
121  // static inline __device__ bool lt(float2 a, float2 b);
122  // static inline __device__ bool gt(float2 a, float2 b);
123  // static inline __device__ bool eq(float2 a, float2 b);
124 
125  static inline __device__ float2 zero() {
126  float2 v;
127  v.x = 0.0f;
128  v.y = 0.0f;
129  return v;
130  }
131 };
132 
133 template <>
134 struct Math<float4> {
135  typedef float ScalarType;
136 
137  static inline __device__ float4 add(float4 a, float4 b) {
138  float4 v;
139  v.x = a.x + b.x;
140  v.y = a.y + b.y;
141  v.z = a.z + b.z;
142  v.w = a.w + b.w;
143  return v;
144  }
145 
146  static inline __device__ float4 sub(float4 a, float4 b) {
147  float4 v;
148  v.x = a.x - b.x;
149  v.y = a.y - b.y;
150  v.z = a.z - b.z;
151  v.w = a.w - b.w;
152  return v;
153  }
154 
155  static inline __device__ float4 add(float4 a, float b) {
156  float4 v;
157  v.x = a.x + b;
158  v.y = a.y + b;
159  v.z = a.z + b;
160  v.w = a.w + b;
161  return v;
162  }
163 
164  static inline __device__ float4 sub(float4 a, float b) {
165  float4 v;
166  v.x = a.x - b;
167  v.y = a.y - b;
168  v.z = a.z - b;
169  v.w = a.w - b;
170  return v;
171  }
172 
173  static inline __device__ float4 mul(float4 a, float4 b) {
174  float4 v;
175  v.x = a.x * b.x;
176  v.y = a.y * b.y;
177  v.z = a.z * b.z;
178  v.w = a.w * b.w;
179  return v;
180  }
181 
182  static inline __device__ float4 mul(float4 a, float b) {
183  float4 v;
184  v.x = a.x * b;
185  v.y = a.y * b;
186  v.z = a.z * b;
187  v.w = a.w * b;
188  return v;
189  }
190 
191  static inline __device__ float4 neg(float4 v) {
192  v.x = -v.x;
193  v.y = -v.y;
194  v.z = -v.z;
195  v.w = -v.w;
196  return v;
197  }
198 
199  /// For a vector type, this is a horizontal add, returning sum(v_i)
200  static inline __device__ float reduceAdd(float4 v) {
201  return v.x + v.y + v.z + v.w;
202  }
203 
204  // not implemented for vector types
205  // static inline __device__ bool lt(float4 a, float4 b);
206  // static inline __device__ bool gt(float4 a, float4 b);
207  // static inline __device__ bool eq(float4 a, float4 b);
208 
209  static inline __device__ float4 zero() {
210  float4 v;
211  v.x = 0.0f;
212  v.y = 0.0f;
213  v.z = 0.0f;
214  v.w = 0.0f;
215  return v;
216  }
217 };
218 
219 #ifdef FAISS_USE_FLOAT16
220 
221 template <>
222 struct Math<half> {
223  typedef half ScalarType;
224 
225  static inline __device__ half add(half a, half b) {
226 #ifdef FAISS_USE_FULL_FLOAT16
227  return __hadd(a, b);
228 #else
229  return __float2half(__half2float(a) + __half2float(b));
230 #endif
231  }
232 
233  static inline __device__ half sub(half a, half b) {
234 #ifdef FAISS_USE_FULL_FLOAT16
235  return __hsub(a, b);
236 #else
237  return __float2half(__half2float(a) - __half2float(b));
238 #endif
239  }
240 
241  static inline __device__ half mul(half a, half b) {
242 #ifdef FAISS_USE_FULL_FLOAT16
243  return __hmul(a, b);
244 #else
245  return __float2half(__half2float(a) * __half2float(b));
246 #endif
247  }
248 
249  static inline __device__ half neg(half v) {
250 #ifdef FAISS_USE_FULL_FLOAT16
251  return __hneg(v);
252 #else
253  return __float2half(-__half2float(v));
254 #endif
255  }
256 
257  static inline __device__ half reduceAdd(half v) {
258  return v;
259  }
260 
261  static inline __device__ bool lt(half a, half b) {
262 #ifdef FAISS_USE_FULL_FLOAT16
263  return __hlt(a, b);
264 #else
265  return __half2float(a) < __half2float(b);
266 #endif
267  }
268 
269  static inline __device__ bool gt(half a, half b) {
270 #ifdef FAISS_USE_FULL_FLOAT16
271  return __hgt(a, b);
272 #else
273  return __half2float(a) > __half2float(b);
274 #endif
275  }
276 
277  static inline __device__ bool eq(half a, half b) {
278 #ifdef FAISS_USE_FULL_FLOAT16
279  return __heq(a, b);
280 #else
281  return __half2float(a) == __half2float(b);
282 #endif
283  }
284 
285  static inline __device__ half zero() {
286 #if CUDA_VERSION >= 9000
287  return 0;
288 #else
289  half h;
290  h.x = 0;
291  return h;
292 #endif
293  }
294 };
295 
296 template <>
297 struct Math<half2> {
298  typedef half ScalarType;
299 
300  static inline __device__ half2 add(half2 a, half2 b) {
301 #ifdef FAISS_USE_FULL_FLOAT16
302  return __hadd2(a, b);
303 #else
304  float2 af = __half22float2(a);
305  float2 bf = __half22float2(b);
306 
307  af.x += bf.x;
308  af.y += bf.y;
309 
310  return __float22half2_rn(af);
311 #endif
312  }
313 
314  static inline __device__ half2 sub(half2 a, half2 b) {
315 #ifdef FAISS_USE_FULL_FLOAT16
316  return __hsub2(a, b);
317 #else
318  float2 af = __half22float2(a);
319  float2 bf = __half22float2(b);
320 
321  af.x -= bf.x;
322  af.y -= bf.y;
323 
324  return __float22half2_rn(af);
325 #endif
326  }
327 
328  static inline __device__ half2 add(half2 a, half b) {
329 #ifdef FAISS_USE_FULL_FLOAT16
330  half2 b2 = __half2half2(b);
331  return __hadd2(a, b2);
332 #else
333  float2 af = __half22float2(a);
334  float bf = __half2float(b);
335 
336  af.x += bf;
337  af.y += bf;
338 
339  return __float22half2_rn(af);
340 #endif
341  }
342 
343  static inline __device__ half2 sub(half2 a, half b) {
344 #ifdef FAISS_USE_FULL_FLOAT16
345  half2 b2 = __half2half2(b);
346  return __hsub2(a, b2);
347 #else
348  float2 af = __half22float2(a);
349  float bf = __half2float(b);
350 
351  af.x -= bf;
352  af.y -= bf;
353 
354  return __float22half2_rn(af);
355 #endif
356  }
357 
358  static inline __device__ half2 mul(half2 a, half2 b) {
359 #ifdef FAISS_USE_FULL_FLOAT16
360  return __hmul2(a, b);
361 #else
362  float2 af = __half22float2(a);
363  float2 bf = __half22float2(b);
364 
365  af.x *= bf.x;
366  af.y *= bf.y;
367 
368  return __float22half2_rn(af);
369 #endif
370  }
371 
372  static inline __device__ half2 mul(half2 a, half b) {
373 #ifdef FAISS_USE_FULL_FLOAT16
374  half2 b2 = __half2half2(b);
375  return __hmul2(a, b2);
376 #else
377  float2 af = __half22float2(a);
378  float bf = __half2float(b);
379 
380  af.x *= bf;
381  af.y *= bf;
382 
383  return __float22half2_rn(af);
384 #endif
385  }
386 
387  static inline __device__ half2 neg(half2 v) {
388 #ifdef FAISS_USE_FULL_FLOAT16
389  return __hneg2(v);
390 #else
391  float2 vf = __half22float2(v);
392  vf.x = -vf.x;
393  vf.y = -vf.y;
394 
395  return __float22half2_rn(vf);
396 #endif
397  }
398 
399  static inline __device__ half reduceAdd(half2 v) {
400 #ifdef FAISS_USE_FULL_FLOAT16
401  half hv = __high2half(v);
402  half lv = __low2half(v);
403 
404  return __hadd(hv, lv);
405 #else
406  float2 vf = __half22float2(v);
407  vf.x += vf.y;
408 
409  return __float2half(vf.x);
410 #endif
411  }
412 
413  // not implemented for vector types
414  // static inline __device__ bool lt(half2 a, half2 b);
415  // static inline __device__ bool gt(half2 a, half2 b);
416  // static inline __device__ bool eq(half2 a, half2 b);
417 
418  static inline __device__ half2 zero() {
419  return __half2half2(Math<half>::zero());
420  }
421 };
422 
423 template <>
424 struct Math<Half4> {
425  typedef half ScalarType;
426 
427  static inline __device__ Half4 add(Half4 a, Half4 b) {
428  Half4 h;
429  h.a = Math<half2>::add(a.a, b.a);
430  h.b = Math<half2>::add(a.b, b.b);
431  return h;
432  }
433 
434  static inline __device__ Half4 sub(Half4 a, Half4 b) {
435  Half4 h;
436  h.a = Math<half2>::sub(a.a, b.a);
437  h.b = Math<half2>::sub(a.b, b.b);
438  return h;
439  }
440 
441  static inline __device__ Half4 add(Half4 a, half b) {
442  Half4 h;
443  h.a = Math<half2>::add(a.a, b);
444  h.b = Math<half2>::add(a.b, b);
445  return h;
446  }
447 
448  static inline __device__ Half4 sub(Half4 a, half b) {
449  Half4 h;
450  h.a = Math<half2>::sub(a.a, b);
451  h.b = Math<half2>::sub(a.b, b);
452  return h;
453  }
454 
455  static inline __device__ Half4 mul(Half4 a, Half4 b) {
456  Half4 h;
457  h.a = Math<half2>::mul(a.a, b.a);
458  h.b = Math<half2>::mul(a.b, b.b);
459  return h;
460  }
461 
462  static inline __device__ Half4 mul(Half4 a, half b) {
463  Half4 h;
464  h.a = Math<half2>::mul(a.a, b);
465  h.b = Math<half2>::mul(a.b, b);
466  return h;
467  }
468 
469  static inline __device__ Half4 neg(Half4 v) {
470  Half4 h;
471  h.a = Math<half2>::neg(v.a);
472  h.b = Math<half2>::neg(v.b);
473  return h;
474  }
475 
476  static inline __device__ half reduceAdd(Half4 v) {
477  half hx = Math<half2>::reduceAdd(v.a);
478  half hy = Math<half2>::reduceAdd(v.b);
479  return Math<half>::add(hx, hy);
480  }
481 
482  // not implemented for vector types
483  // static inline __device__ bool lt(Half4 a, Half4 b);
484  // static inline __device__ bool gt(Half4 a, Half4 b);
485  // static inline __device__ bool eq(Half4 a, Half4 b);
486 
487  static inline __device__ Half4 zero() {
488  Half4 h;
489  h.a = Math<half2>::zero();
490  h.b = Math<half2>::zero();
491  return h;
492  }
493 };
494 
495 template <>
496 struct Math<Half8> {
497  typedef half ScalarType;
498 
499  static inline __device__ Half8 add(Half8 a, Half8 b) {
500  Half8 h;
501  h.a = Math<Half4>::add(a.a, b.a);
502  h.b = Math<Half4>::add(a.b, b.b);
503  return h;
504  }
505 
506  static inline __device__ Half8 sub(Half8 a, Half8 b) {
507  Half8 h;
508  h.a = Math<Half4>::sub(a.a, b.a);
509  h.b = Math<Half4>::sub(a.b, b.b);
510  return h;
511  }
512 
513  static inline __device__ Half8 add(Half8 a, half b) {
514  Half8 h;
515  h.a = Math<Half4>::add(a.a, b);
516  h.b = Math<Half4>::add(a.b, b);
517  return h;
518  }
519 
520  static inline __device__ Half8 sub(Half8 a, half b) {
521  Half8 h;
522  h.a = Math<Half4>::sub(a.a, b);
523  h.b = Math<Half4>::sub(a.b, b);
524  return h;
525  }
526 
527  static inline __device__ Half8 mul(Half8 a, Half8 b) {
528  Half8 h;
529  h.a = Math<Half4>::mul(a.a, b.a);
530  h.b = Math<Half4>::mul(a.b, b.b);
531  return h;
532  }
533 
534  static inline __device__ Half8 mul(Half8 a, half b) {
535  Half8 h;
536  h.a = Math<Half4>::mul(a.a, b);
537  h.b = Math<Half4>::mul(a.b, b);
538  return h;
539  }
540 
541  static inline __device__ Half8 neg(Half8 v) {
542  Half8 h;
543  h.a = Math<Half4>::neg(v.a);
544  h.b = Math<Half4>::neg(v.b);
545  return h;
546  }
547 
548  static inline __device__ half reduceAdd(Half8 v) {
549  half hx = Math<Half4>::reduceAdd(v.a);
550  half hy = Math<Half4>::reduceAdd(v.b);
551  return Math<half>::add(hx, hy);
552  }
553 
554  // not implemented for vector types
555  // static inline __device__ bool lt(Half8 a, Half8 b);
556  // static inline __device__ bool gt(Half8 a, Half8 b);
557  // static inline __device__ bool eq(Half8 a, Half8 b);
558 
559  static inline __device__ Half8 zero() {
560  Half8 h;
561  h.a = Math<Half4>::zero();
562  h.b = Math<Half4>::zero();
563  return h;
564  }
565 };
566 
567 #endif // FAISS_USE_FLOAT16
568 
569 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)