Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MathOperators.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include "Float16.cuh"
14 
15 //
16 // Templated wrappers to express math for different scalar and vector
17 // types, so kernels can have the same written form but can operate
18 // over half and float, and on vector types transparently
19 //
20 
21 namespace faiss { namespace gpu {
22 
23 template <typename T>
24 struct Math {
25  typedef T ScalarType;
26 
27  static inline __device__ T add(T a, T b) {
28  return a + b;
29  }
30 
31  static inline __device__ T sub(T a, T b) {
32  return a - b;
33  }
34 
35  static inline __device__ T mul(T a, T b) {
36  return a * b;
37  }
38 
39  static inline __device__ T neg(T v) {
40  return -v;
41  }
42 
43  /// For a vector type, this is a horizontal add, returning sum(v_i)
44  static inline __device__ T reduceAdd(T v) {
45  return v;
46  }
47 
48  static inline __device__ bool lt(T a, T b) {
49  return a < b;
50  }
51 
52  static inline __device__ bool gt(T a, T b) {
53  return a > b;
54  }
55 
56  static inline __device__ bool eq(T a, T b) {
57  return a == b;
58  }
59 
60  static inline __device__ T zero() {
61  return (T) 0;
62  }
63 };
64 
65 template <>
66 struct Math<float2> {
67  typedef float ScalarType;
68 
69  static inline __device__ float2 add(float2 a, float2 b) {
70  float2 v;
71  v.x = a.x + b.x;
72  v.y = a.y + b.y;
73  return v;
74  }
75 
76  static inline __device__ float2 sub(float2 a, float2 b) {
77  float2 v;
78  v.x = a.x - b.x;
79  v.y = a.y - b.y;
80  return v;
81  }
82 
83  static inline __device__ float2 add(float2 a, float b) {
84  float2 v;
85  v.x = a.x + b;
86  v.y = a.y + b;
87  return v;
88  }
89 
90  static inline __device__ float2 sub(float2 a, float b) {
91  float2 v;
92  v.x = a.x - b;
93  v.y = a.y - b;
94  return v;
95  }
96 
97  static inline __device__ float2 mul(float2 a, float2 b) {
98  float2 v;
99  v.x = a.x * b.x;
100  v.y = a.y * b.y;
101  return v;
102  }
103 
104  static inline __device__ float2 mul(float2 a, float b) {
105  float2 v;
106  v.x = a.x * b;
107  v.y = a.y * b;
108  return v;
109  }
110 
111  static inline __device__ float2 neg(float2 v) {
112  v.x = -v.x;
113  v.y = -v.y;
114  return v;
115  }
116 
117  /// For a vector type, this is a horizontal add, returning sum(v_i)
118  static inline __device__ float reduceAdd(float2 v) {
119  return v.x + v.y;
120  }
121 
122  // not implemented for vector types
123  // static inline __device__ bool lt(float2 a, float2 b);
124  // static inline __device__ bool gt(float2 a, float2 b);
125  // static inline __device__ bool eq(float2 a, float2 b);
126 
127  static inline __device__ float2 zero() {
128  float2 v;
129  v.x = 0.0f;
130  v.y = 0.0f;
131  return v;
132  }
133 };
134 
135 template <>
136 struct Math<float4> {
137  typedef float ScalarType;
138 
139  static inline __device__ float4 add(float4 a, float4 b) {
140  float4 v;
141  v.x = a.x + b.x;
142  v.y = a.y + b.y;
143  v.z = a.z + b.z;
144  v.w = a.w + b.w;
145  return v;
146  }
147 
148  static inline __device__ float4 sub(float4 a, float4 b) {
149  float4 v;
150  v.x = a.x - b.x;
151  v.y = a.y - b.y;
152  v.z = a.z - b.z;
153  v.w = a.w - b.w;
154  return v;
155  }
156 
157  static inline __device__ float4 add(float4 a, float b) {
158  float4 v;
159  v.x = a.x + b;
160  v.y = a.y + b;
161  v.z = a.z + b;
162  v.w = a.w + b;
163  return v;
164  }
165 
166  static inline __device__ float4 sub(float4 a, float b) {
167  float4 v;
168  v.x = a.x - b;
169  v.y = a.y - b;
170  v.z = a.z - b;
171  v.w = a.w - b;
172  return v;
173  }
174 
175  static inline __device__ float4 mul(float4 a, float4 b) {
176  float4 v;
177  v.x = a.x * b.x;
178  v.y = a.y * b.y;
179  v.z = a.z * b.z;
180  v.w = a.w * b.w;
181  return v;
182  }
183 
184  static inline __device__ float4 mul(float4 a, float b) {
185  float4 v;
186  v.x = a.x * b;
187  v.y = a.y * b;
188  v.z = a.z * b;
189  v.w = a.w * b;
190  return v;
191  }
192 
193  static inline __device__ float4 neg(float4 v) {
194  v.x = -v.x;
195  v.y = -v.y;
196  v.z = -v.z;
197  v.w = -v.w;
198  return v;
199  }
200 
201  /// For a vector type, this is a horizontal add, returning sum(v_i)
202  static inline __device__ float reduceAdd(float4 v) {
203  return v.x + v.y + v.z + v.w;
204  }
205 
206  // not implemented for vector types
207  // static inline __device__ bool lt(float4 a, float4 b);
208  // static inline __device__ bool gt(float4 a, float4 b);
209  // static inline __device__ bool eq(float4 a, float4 b);
210 
211  static inline __device__ float4 zero() {
212  float4 v;
213  v.x = 0.0f;
214  v.y = 0.0f;
215  v.z = 0.0f;
216  v.w = 0.0f;
217  return v;
218  }
219 };
220 
221 #ifdef FAISS_USE_FLOAT16
222 
223 template <>
224 struct Math<half> {
225  typedef half ScalarType;
226 
227  static inline __device__ half add(half a, half b) {
228 #ifdef FAISS_USE_FULL_FLOAT16
229  return __hadd(a, b);
230 #else
231  return __float2half(__half2float(a) + __half2float(b));
232 #endif
233  }
234 
235  static inline __device__ half sub(half a, half b) {
236 #ifdef FAISS_USE_FULL_FLOAT16
237  return __hsub(a, b);
238 #else
239  return __float2half(__half2float(a) - __half2float(b));
240 #endif
241  }
242 
243  static inline __device__ half mul(half a, half b) {
244 #ifdef FAISS_USE_FULL_FLOAT16
245  return __hmul(a, b);
246 #else
247  return __float2half(__half2float(a) * __half2float(b));
248 #endif
249  }
250 
251  static inline __device__ half neg(half v) {
252 #ifdef FAISS_USE_FULL_FLOAT16
253  return __hneg(v);
254 #else
255  return __float2half(-__half2float(v));
256 #endif
257  }
258 
259  static inline __device__ half reduceAdd(half v) {
260  return v;
261  }
262 
263  static inline __device__ bool lt(half a, half b) {
264 #ifdef FAISS_USE_FULL_FLOAT16
265  return __hlt(a, b);
266 #else
267  return __half2float(a) < __half2float(b);
268 #endif
269  }
270 
271  static inline __device__ bool gt(half a, half b) {
272 #ifdef FAISS_USE_FULL_FLOAT16
273  return __hgt(a, b);
274 #else
275  return __half2float(a) > __half2float(b);
276 #endif
277  }
278 
279  static inline __device__ bool eq(half a, half b) {
280 #ifdef FAISS_USE_FULL_FLOAT16
281  return __heq(a, b);
282 #else
283  return __half2float(a) == __half2float(b);
284 #endif
285  }
286 
287  static inline __device__ half zero() {
288 #if CUDA_VERSION >= 9000
289  return 0;
290 #else
291  half h;
292  h.x = 0;
293  return h;
294 #endif
295  }
296 };
297 
298 template <>
299 struct Math<half2> {
300  typedef half ScalarType;
301 
302  static inline __device__ half2 add(half2 a, half2 b) {
303 #ifdef FAISS_USE_FULL_FLOAT16
304  return __hadd2(a, b);
305 #else
306  float2 af = __half22float2(a);
307  float2 bf = __half22float2(b);
308 
309  af.x += bf.x;
310  af.y += bf.y;
311 
312  return __float22half2_rn(af);
313 #endif
314  }
315 
316  static inline __device__ half2 sub(half2 a, half2 b) {
317 #ifdef FAISS_USE_FULL_FLOAT16
318  return __hsub2(a, b);
319 #else
320  float2 af = __half22float2(a);
321  float2 bf = __half22float2(b);
322 
323  af.x -= bf.x;
324  af.y -= bf.y;
325 
326  return __float22half2_rn(af);
327 #endif
328  }
329 
330  static inline __device__ half2 add(half2 a, half b) {
331 #ifdef FAISS_USE_FULL_FLOAT16
332  half2 b2 = __half2half2(b);
333  return __hadd2(a, b2);
334 #else
335  float2 af = __half22float2(a);
336  float bf = __half2float(b);
337 
338  af.x += bf;
339  af.y += bf;
340 
341  return __float22half2_rn(af);
342 #endif
343  }
344 
345  static inline __device__ half2 sub(half2 a, half b) {
346 #ifdef FAISS_USE_FULL_FLOAT16
347  half2 b2 = __half2half2(b);
348  return __hsub2(a, b2);
349 #else
350  float2 af = __half22float2(a);
351  float bf = __half2float(b);
352 
353  af.x -= bf;
354  af.y -= bf;
355 
356  return __float22half2_rn(af);
357 #endif
358  }
359 
360  static inline __device__ half2 mul(half2 a, half2 b) {
361 #ifdef FAISS_USE_FULL_FLOAT16
362  return __hmul2(a, b);
363 #else
364  float2 af = __half22float2(a);
365  float2 bf = __half22float2(b);
366 
367  af.x *= bf.x;
368  af.y *= bf.y;
369 
370  return __float22half2_rn(af);
371 #endif
372  }
373 
374  static inline __device__ half2 mul(half2 a, half b) {
375 #ifdef FAISS_USE_FULL_FLOAT16
376  half2 b2 = __half2half2(b);
377  return __hmul2(a, b2);
378 #else
379  float2 af = __half22float2(a);
380  float bf = __half2float(b);
381 
382  af.x *= bf;
383  af.y *= bf;
384 
385  return __float22half2_rn(af);
386 #endif
387  }
388 
389  static inline __device__ half2 neg(half2 v) {
390 #ifdef FAISS_USE_FULL_FLOAT16
391  return __hneg2(v);
392 #else
393  float2 vf = __half22float2(v);
394  vf.x = -vf.x;
395  vf.y = -vf.y;
396 
397  return __float22half2_rn(vf);
398 #endif
399  }
400 
401  static inline __device__ half reduceAdd(half2 v) {
402 #ifdef FAISS_USE_FULL_FLOAT16
403  half hv = __high2half(v);
404  half lv = __low2half(v);
405 
406  return __hadd(hv, lv);
407 #else
408  float2 vf = __half22float2(v);
409  vf.x += vf.y;
410 
411  return __float2half(vf.x);
412 #endif
413  }
414 
415  // not implemented for vector types
416  // static inline __device__ bool lt(half2 a, half2 b);
417  // static inline __device__ bool gt(half2 a, half2 b);
418  // static inline __device__ bool eq(half2 a, half2 b);
419 
420  static inline __device__ half2 zero() {
421  return __half2half2(Math<half>::zero());
422  }
423 };
424 
425 template <>
426 struct Math<Half4> {
427  typedef half ScalarType;
428 
429  static inline __device__ Half4 add(Half4 a, Half4 b) {
430  Half4 h;
431  h.a = Math<half2>::add(a.a, b.a);
432  h.b = Math<half2>::add(a.b, b.b);
433  return h;
434  }
435 
436  static inline __device__ Half4 sub(Half4 a, Half4 b) {
437  Half4 h;
438  h.a = Math<half2>::sub(a.a, b.a);
439  h.b = Math<half2>::sub(a.b, b.b);
440  return h;
441  }
442 
443  static inline __device__ Half4 add(Half4 a, half b) {
444  Half4 h;
445  h.a = Math<half2>::add(a.a, b);
446  h.b = Math<half2>::add(a.b, b);
447  return h;
448  }
449 
450  static inline __device__ Half4 sub(Half4 a, half b) {
451  Half4 h;
452  h.a = Math<half2>::sub(a.a, b);
453  h.b = Math<half2>::sub(a.b, b);
454  return h;
455  }
456 
457  static inline __device__ Half4 mul(Half4 a, Half4 b) {
458  Half4 h;
459  h.a = Math<half2>::mul(a.a, b.a);
460  h.b = Math<half2>::mul(a.b, b.b);
461  return h;
462  }
463 
464  static inline __device__ Half4 mul(Half4 a, half b) {
465  Half4 h;
466  h.a = Math<half2>::mul(a.a, b);
467  h.b = Math<half2>::mul(a.b, b);
468  return h;
469  }
470 
471  static inline __device__ Half4 neg(Half4 v) {
472  Half4 h;
473  h.a = Math<half2>::neg(v.a);
474  h.b = Math<half2>::neg(v.b);
475  return h;
476  }
477 
478  static inline __device__ half reduceAdd(Half4 v) {
479  half hx = Math<half2>::reduceAdd(v.a);
480  half hy = Math<half2>::reduceAdd(v.b);
481  return Math<half>::add(hx, hy);
482  }
483 
484  // not implemented for vector types
485  // static inline __device__ bool lt(Half4 a, Half4 b);
486  // static inline __device__ bool gt(Half4 a, Half4 b);
487  // static inline __device__ bool eq(Half4 a, Half4 b);
488 
489  static inline __device__ Half4 zero() {
490  Half4 h;
491  h.a = Math<half2>::zero();
492  h.b = Math<half2>::zero();
493  return h;
494  }
495 };
496 
497 template <>
498 struct Math<Half8> {
499  typedef half ScalarType;
500 
501  static inline __device__ Half8 add(Half8 a, Half8 b) {
502  Half8 h;
503  h.a = Math<Half4>::add(a.a, b.a);
504  h.b = Math<Half4>::add(a.b, b.b);
505  return h;
506  }
507 
508  static inline __device__ Half8 sub(Half8 a, Half8 b) {
509  Half8 h;
510  h.a = Math<Half4>::sub(a.a, b.a);
511  h.b = Math<Half4>::sub(a.b, b.b);
512  return h;
513  }
514 
515  static inline __device__ Half8 add(Half8 a, half b) {
516  Half8 h;
517  h.a = Math<Half4>::add(a.a, b);
518  h.b = Math<Half4>::add(a.b, b);
519  return h;
520  }
521 
522  static inline __device__ Half8 sub(Half8 a, half b) {
523  Half8 h;
524  h.a = Math<Half4>::sub(a.a, b);
525  h.b = Math<Half4>::sub(a.b, b);
526  return h;
527  }
528 
529  static inline __device__ Half8 mul(Half8 a, Half8 b) {
530  Half8 h;
531  h.a = Math<Half4>::mul(a.a, b.a);
532  h.b = Math<Half4>::mul(a.b, b.b);
533  return h;
534  }
535 
536  static inline __device__ Half8 mul(Half8 a, half b) {
537  Half8 h;
538  h.a = Math<Half4>::mul(a.a, b);
539  h.b = Math<Half4>::mul(a.b, b);
540  return h;
541  }
542 
543  static inline __device__ Half8 neg(Half8 v) {
544  Half8 h;
545  h.a = Math<Half4>::neg(v.a);
546  h.b = Math<Half4>::neg(v.b);
547  return h;
548  }
549 
550  static inline __device__ half reduceAdd(Half8 v) {
551  half hx = Math<Half4>::reduceAdd(v.a);
552  half hy = Math<Half4>::reduceAdd(v.b);
553  return Math<half>::add(hx, hy);
554  }
555 
556  // not implemented for vector types
557  // static inline __device__ bool lt(Half8 a, Half8 b);
558  // static inline __device__ bool gt(Half8 a, Half8 b);
559  // static inline __device__ bool eq(Half8 a, Half8 b);
560 
561  static inline __device__ Half8 zero() {
562  Half8 h;
563  h.a = Math<Half4>::zero();
564  h.b = Math<Half4>::zero();
565  return h;
566  }
567 };
568 
569 #endif // FAISS_USE_FLOAT16
570 
571 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)