Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MathOperators.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #pragma once
11 
12 #include "Float16.cuh"
13 
14 //
15 // Templated wrappers to express math for different scalar and vector
16 // types, so kernels can have the same written form but can operate
17 // over half and float, and on vector types transparently
18 //
19 
20 namespace faiss { namespace gpu {
21 
22 template <typename T>
23 struct Math {
24  typedef T ScalarType;
25 
26  static inline __device__ T add(T a, T b) {
27  return a + b;
28  }
29 
30  static inline __device__ T sub(T a, T b) {
31  return a - b;
32  }
33 
34  static inline __device__ T mul(T a, T b) {
35  return a * b;
36  }
37 
38  static inline __device__ T neg(T v) {
39  return -v;
40  }
41 
42  /// For a vector type, this is a horizontal add, returning sum(v_i)
43  static inline __device__ T reduceAdd(T v) {
44  return v;
45  }
46 
47  static inline __device__ bool lt(T a, T b) {
48  return a < b;
49  }
50 
51  static inline __device__ bool gt(T a, T b) {
52  return a > b;
53  }
54 
55  static inline __device__ bool eq(T a, T b) {
56  return a == b;
57  }
58 
59  static inline __device__ T zero() {
60  return (T) 0;
61  }
62 };
63 
64 template <>
65 struct Math<float2> {
66  typedef float ScalarType;
67 
68  static inline __device__ float2 add(float2 a, float2 b) {
69  float2 v;
70  v.x = a.x + b.x;
71  v.y = a.y + b.y;
72  return v;
73  }
74 
75  static inline __device__ float2 sub(float2 a, float2 b) {
76  float2 v;
77  v.x = a.x - b.x;
78  v.y = a.y - b.y;
79  return v;
80  }
81 
82  static inline __device__ float2 add(float2 a, float b) {
83  float2 v;
84  v.x = a.x + b;
85  v.y = a.y + b;
86  return v;
87  }
88 
89  static inline __device__ float2 sub(float2 a, float b) {
90  float2 v;
91  v.x = a.x - b;
92  v.y = a.y - b;
93  return v;
94  }
95 
96  static inline __device__ float2 mul(float2 a, float2 b) {
97  float2 v;
98  v.x = a.x * b.x;
99  v.y = a.y * b.y;
100  return v;
101  }
102 
103  static inline __device__ float2 mul(float2 a, float b) {
104  float2 v;
105  v.x = a.x * b;
106  v.y = a.y * b;
107  return v;
108  }
109 
110  static inline __device__ float2 neg(float2 v) {
111  v.x = -v.x;
112  v.y = -v.y;
113  return v;
114  }
115 
116  /// For a vector type, this is a horizontal add, returning sum(v_i)
117  static inline __device__ float reduceAdd(float2 v) {
118  return v.x + v.y;
119  }
120 
121  // not implemented for vector types
122  // static inline __device__ bool lt(float2 a, float2 b);
123  // static inline __device__ bool gt(float2 a, float2 b);
124  // static inline __device__ bool eq(float2 a, float2 b);
125 
126  static inline __device__ float2 zero() {
127  float2 v;
128  v.x = 0.0f;
129  v.y = 0.0f;
130  return v;
131  }
132 };
133 
134 template <>
135 struct Math<float4> {
136  typedef float ScalarType;
137 
138  static inline __device__ float4 add(float4 a, float4 b) {
139  float4 v;
140  v.x = a.x + b.x;
141  v.y = a.y + b.y;
142  v.z = a.z + b.z;
143  v.w = a.w + b.w;
144  return v;
145  }
146 
147  static inline __device__ float4 sub(float4 a, float4 b) {
148  float4 v;
149  v.x = a.x - b.x;
150  v.y = a.y - b.y;
151  v.z = a.z - b.z;
152  v.w = a.w - b.w;
153  return v;
154  }
155 
156  static inline __device__ float4 add(float4 a, float b) {
157  float4 v;
158  v.x = a.x + b;
159  v.y = a.y + b;
160  v.z = a.z + b;
161  v.w = a.w + b;
162  return v;
163  }
164 
165  static inline __device__ float4 sub(float4 a, float b) {
166  float4 v;
167  v.x = a.x - b;
168  v.y = a.y - b;
169  v.z = a.z - b;
170  v.w = a.w - b;
171  return v;
172  }
173 
174  static inline __device__ float4 mul(float4 a, float4 b) {
175  float4 v;
176  v.x = a.x * b.x;
177  v.y = a.y * b.y;
178  v.z = a.z * b.z;
179  v.w = a.w * b.w;
180  return v;
181  }
182 
183  static inline __device__ float4 mul(float4 a, float b) {
184  float4 v;
185  v.x = a.x * b;
186  v.y = a.y * b;
187  v.z = a.z * b;
188  v.w = a.w * b;
189  return v;
190  }
191 
192  static inline __device__ float4 neg(float4 v) {
193  v.x = -v.x;
194  v.y = -v.y;
195  v.z = -v.z;
196  v.w = -v.w;
197  return v;
198  }
199 
200  /// For a vector type, this is a horizontal add, returning sum(v_i)
201  static inline __device__ float reduceAdd(float4 v) {
202  return v.x + v.y + v.z + v.w;
203  }
204 
205  // not implemented for vector types
206  // static inline __device__ bool lt(float4 a, float4 b);
207  // static inline __device__ bool gt(float4 a, float4 b);
208  // static inline __device__ bool eq(float4 a, float4 b);
209 
210  static inline __device__ float4 zero() {
211  float4 v;
212  v.x = 0.0f;
213  v.y = 0.0f;
214  v.z = 0.0f;
215  v.w = 0.0f;
216  return v;
217  }
218 };
219 
220 #ifdef FAISS_USE_FLOAT16
221 
222 template <>
223 struct Math<half> {
224  typedef half ScalarType;
225 
226  static inline __device__ half add(half a, half b) {
227 #ifdef FAISS_USE_FULL_FLOAT16
228  return __hadd(a, b);
229 #else
230  return __float2half(__half2float(a) + __half2float(b));
231 #endif
232  }
233 
234  static inline __device__ half sub(half a, half b) {
235 #ifdef FAISS_USE_FULL_FLOAT16
236  return __hsub(a, b);
237 #else
238  return __float2half(__half2float(a) - __half2float(b));
239 #endif
240  }
241 
242  static inline __device__ half mul(half a, half b) {
243 #ifdef FAISS_USE_FULL_FLOAT16
244  return __hmul(a, b);
245 #else
246  return __float2half(__half2float(a) * __half2float(b));
247 #endif
248  }
249 
250  static inline __device__ half neg(half v) {
251 #ifdef FAISS_USE_FULL_FLOAT16
252  return __hneg(v);
253 #else
254  return __float2half(-__half2float(v));
255 #endif
256  }
257 
258  static inline __device__ half reduceAdd(half v) {
259  return v;
260  }
261 
262  static inline __device__ bool lt(half a, half b) {
263 #ifdef FAISS_USE_FULL_FLOAT16
264  return __hlt(a, b);
265 #else
266  return __half2float(a) < __half2float(b);
267 #endif
268  }
269 
270  static inline __device__ bool gt(half a, half b) {
271 #ifdef FAISS_USE_FULL_FLOAT16
272  return __hgt(a, b);
273 #else
274  return __half2float(a) > __half2float(b);
275 #endif
276  }
277 
278  static inline __device__ bool eq(half a, half b) {
279 #ifdef FAISS_USE_FULL_FLOAT16
280  return __heq(a, b);
281 #else
282  return __half2float(a) == __half2float(b);
283 #endif
284  }
285 
286  static inline __device__ half zero() {
287 #if CUDA_VERSION >= 9000
288  return 0;
289 #else
290  half h;
291  h.x = 0;
292  return h;
293 #endif
294  }
295 };
296 
297 template <>
298 struct Math<half2> {
299  typedef half ScalarType;
300 
301  static inline __device__ half2 add(half2 a, half2 b) {
302 #ifdef FAISS_USE_FULL_FLOAT16
303  return __hadd2(a, b);
304 #else
305  float2 af = __half22float2(a);
306  float2 bf = __half22float2(b);
307 
308  af.x += bf.x;
309  af.y += bf.y;
310 
311  return __float22half2_rn(af);
312 #endif
313  }
314 
315  static inline __device__ half2 sub(half2 a, half2 b) {
316 #ifdef FAISS_USE_FULL_FLOAT16
317  return __hsub2(a, b);
318 #else
319  float2 af = __half22float2(a);
320  float2 bf = __half22float2(b);
321 
322  af.x -= bf.x;
323  af.y -= bf.y;
324 
325  return __float22half2_rn(af);
326 #endif
327  }
328 
329  static inline __device__ half2 add(half2 a, half b) {
330 #ifdef FAISS_USE_FULL_FLOAT16
331  half2 b2 = __half2half2(b);
332  return __hadd2(a, b2);
333 #else
334  float2 af = __half22float2(a);
335  float bf = __half2float(b);
336 
337  af.x += bf;
338  af.y += bf;
339 
340  return __float22half2_rn(af);
341 #endif
342  }
343 
344  static inline __device__ half2 sub(half2 a, half b) {
345 #ifdef FAISS_USE_FULL_FLOAT16
346  half2 b2 = __half2half2(b);
347  return __hsub2(a, b2);
348 #else
349  float2 af = __half22float2(a);
350  float bf = __half2float(b);
351 
352  af.x -= bf;
353  af.y -= bf;
354 
355  return __float22half2_rn(af);
356 #endif
357  }
358 
359  static inline __device__ half2 mul(half2 a, half2 b) {
360 #ifdef FAISS_USE_FULL_FLOAT16
361  return __hmul2(a, b);
362 #else
363  float2 af = __half22float2(a);
364  float2 bf = __half22float2(b);
365 
366  af.x *= bf.x;
367  af.y *= bf.y;
368 
369  return __float22half2_rn(af);
370 #endif
371  }
372 
373  static inline __device__ half2 mul(half2 a, half b) {
374 #ifdef FAISS_USE_FULL_FLOAT16
375  half2 b2 = __half2half2(b);
376  return __hmul2(a, b2);
377 #else
378  float2 af = __half22float2(a);
379  float bf = __half2float(b);
380 
381  af.x *= bf;
382  af.y *= bf;
383 
384  return __float22half2_rn(af);
385 #endif
386  }
387 
388  static inline __device__ half2 neg(half2 v) {
389 #ifdef FAISS_USE_FULL_FLOAT16
390  return __hneg2(v);
391 #else
392  float2 vf = __half22float2(v);
393  vf.x = -vf.x;
394  vf.y = -vf.y;
395 
396  return __float22half2_rn(vf);
397 #endif
398  }
399 
400  static inline __device__ half reduceAdd(half2 v) {
401 #ifdef FAISS_USE_FULL_FLOAT16
402  half hv = __high2half(v);
403  half lv = __low2half(v);
404 
405  return __hadd(hv, lv);
406 #else
407  float2 vf = __half22float2(v);
408  vf.x += vf.y;
409 
410  return __float2half(vf.x);
411 #endif
412  }
413 
414  // not implemented for vector types
415  // static inline __device__ bool lt(half2 a, half2 b);
416  // static inline __device__ bool gt(half2 a, half2 b);
417  // static inline __device__ bool eq(half2 a, half2 b);
418 
419  static inline __device__ half2 zero() {
420  return __half2half2(Math<half>::zero());
421  }
422 };
423 
424 template <>
425 struct Math<Half4> {
426  typedef half ScalarType;
427 
428  static inline __device__ Half4 add(Half4 a, Half4 b) {
429  Half4 h;
430  h.a = Math<half2>::add(a.a, b.a);
431  h.b = Math<half2>::add(a.b, b.b);
432  return h;
433  }
434 
435  static inline __device__ Half4 sub(Half4 a, Half4 b) {
436  Half4 h;
437  h.a = Math<half2>::sub(a.a, b.a);
438  h.b = Math<half2>::sub(a.b, b.b);
439  return h;
440  }
441 
442  static inline __device__ Half4 add(Half4 a, half b) {
443  Half4 h;
444  h.a = Math<half2>::add(a.a, b);
445  h.b = Math<half2>::add(a.b, b);
446  return h;
447  }
448 
449  static inline __device__ Half4 sub(Half4 a, half b) {
450  Half4 h;
451  h.a = Math<half2>::sub(a.a, b);
452  h.b = Math<half2>::sub(a.b, b);
453  return h;
454  }
455 
456  static inline __device__ Half4 mul(Half4 a, Half4 b) {
457  Half4 h;
458  h.a = Math<half2>::mul(a.a, b.a);
459  h.b = Math<half2>::mul(a.b, b.b);
460  return h;
461  }
462 
463  static inline __device__ Half4 mul(Half4 a, half b) {
464  Half4 h;
465  h.a = Math<half2>::mul(a.a, b);
466  h.b = Math<half2>::mul(a.b, b);
467  return h;
468  }
469 
470  static inline __device__ Half4 neg(Half4 v) {
471  Half4 h;
472  h.a = Math<half2>::neg(v.a);
473  h.b = Math<half2>::neg(v.b);
474  return h;
475  }
476 
477  static inline __device__ half reduceAdd(Half4 v) {
478  half hx = Math<half2>::reduceAdd(v.a);
479  half hy = Math<half2>::reduceAdd(v.b);
480  return Math<half>::add(hx, hy);
481  }
482 
483  // not implemented for vector types
484  // static inline __device__ bool lt(Half4 a, Half4 b);
485  // static inline __device__ bool gt(Half4 a, Half4 b);
486  // static inline __device__ bool eq(Half4 a, Half4 b);
487 
488  static inline __device__ Half4 zero() {
489  Half4 h;
490  h.a = Math<half2>::zero();
491  h.b = Math<half2>::zero();
492  return h;
493  }
494 };
495 
496 template <>
497 struct Math<Half8> {
498  typedef half ScalarType;
499 
500  static inline __device__ Half8 add(Half8 a, Half8 b) {
501  Half8 h;
502  h.a = Math<Half4>::add(a.a, b.a);
503  h.b = Math<Half4>::add(a.b, b.b);
504  return h;
505  }
506 
507  static inline __device__ Half8 sub(Half8 a, Half8 b) {
508  Half8 h;
509  h.a = Math<Half4>::sub(a.a, b.a);
510  h.b = Math<Half4>::sub(a.b, b.b);
511  return h;
512  }
513 
514  static inline __device__ Half8 add(Half8 a, half b) {
515  Half8 h;
516  h.a = Math<Half4>::add(a.a, b);
517  h.b = Math<Half4>::add(a.b, b);
518  return h;
519  }
520 
521  static inline __device__ Half8 sub(Half8 a, half b) {
522  Half8 h;
523  h.a = Math<Half4>::sub(a.a, b);
524  h.b = Math<Half4>::sub(a.b, b);
525  return h;
526  }
527 
528  static inline __device__ Half8 mul(Half8 a, Half8 b) {
529  Half8 h;
530  h.a = Math<Half4>::mul(a.a, b.a);
531  h.b = Math<Half4>::mul(a.b, b.b);
532  return h;
533  }
534 
535  static inline __device__ Half8 mul(Half8 a, half b) {
536  Half8 h;
537  h.a = Math<Half4>::mul(a.a, b);
538  h.b = Math<Half4>::mul(a.b, b);
539  return h;
540  }
541 
542  static inline __device__ Half8 neg(Half8 v) {
543  Half8 h;
544  h.a = Math<Half4>::neg(v.a);
545  h.b = Math<Half4>::neg(v.b);
546  return h;
547  }
548 
549  static inline __device__ half reduceAdd(Half8 v) {
550  half hx = Math<Half4>::reduceAdd(v.a);
551  half hy = Math<Half4>::reduceAdd(v.b);
552  return Math<half>::add(hx, hy);
553  }
554 
555  // not implemented for vector types
556  // static inline __device__ bool lt(Half8 a, Half8 b);
557  // static inline __device__ bool gt(Half8 a, Half8 b);
558  // static inline __device__ bool eq(Half8 a, Half8 b);
559 
560  static inline __device__ Half8 zero() {
561  Half8 h;
562  h.a = Math<Half4>::zero();
563  h.b = Math<Half4>::zero();
564  return h;
565  }
566 };
567 
568 #endif // FAISS_USE_FLOAT16
569 
570 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)