Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MathOperators.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include "Float16.cuh"
14 
15 //
16 // Templated wrappers to express math for different scalar and vector
17 // types, so kernels can have the same written form but can operate
18 // over half and float, and on vector types transparently
19 //
20 
21 namespace faiss { namespace gpu {
22 
23 template <typename T>
24 struct Math {
25  typedef T ScalarType;
26 
27  static inline __device__ T add(T a, T b) {
28  return a + b;
29  }
30 
31  static inline __device__ T sub(T a, T b) {
32  return a - b;
33  }
34 
35  static inline __device__ T mul(T a, T b) {
36  return a * b;
37  }
38 
39  static inline __device__ T neg(T v) {
40  return -v;
41  }
42 
43  /// For a vector type, this is a horizontal add, returning sum(v_i)
44  static inline __device__ T reduceAdd(T v) {
45  return v;
46  }
47 
48  static inline __device__ bool lt(T a, T b) {
49  return a < b;
50  }
51 
52  static inline __device__ bool gt(T a, T b) {
53  return a > b;
54  }
55 
56  static inline __device__ bool eq(T a, T b) {
57  return a == b;
58  }
59 
60  static inline __device__ T zero() {
61  return (T) 0;
62  }
63 };
64 
65 template <>
66 struct Math<float2> {
67  typedef float ScalarType;
68 
69  static inline __device__ float2 add(float2 a, float2 b) {
70  float2 v;
71  v.x = a.x + b.x;
72  v.y = a.y + b.y;
73  return v;
74  }
75 
76  static inline __device__ float2 sub(float2 a, float2 b) {
77  float2 v;
78  v.x = a.x - b.x;
79  v.y = a.y - b.y;
80  return v;
81  }
82 
83  static inline __device__ float2 add(float2 a, float b) {
84  float2 v;
85  v.x = a.x + b;
86  v.y = a.y + b;
87  return v;
88  }
89 
90  static inline __device__ float2 sub(float2 a, float b) {
91  float2 v;
92  v.x = a.x - b;
93  v.y = a.y - b;
94  return v;
95  }
96 
97  static inline __device__ float2 mul(float2 a, float2 b) {
98  float2 v;
99  v.x = a.x * b.x;
100  v.y = a.y * b.y;
101  return v;
102  }
103 
104  static inline __device__ float2 mul(float2 a, float b) {
105  float2 v;
106  v.x = a.x * b;
107  v.y = a.y * b;
108  return v;
109  }
110 
111  static inline __device__ float2 neg(float2 v) {
112  v.x = -v.x;
113  v.y = -v.y;
114  return v;
115  }
116 
117  /// For a vector type, this is a horizontal add, returning sum(v_i)
118  static inline __device__ float reduceAdd(float2 v) {
119  return v.x + v.y;
120  }
121 
122  // not implemented for vector types
123  // static inline __device__ bool lt(float2 a, float2 b);
124  // static inline __device__ bool gt(float2 a, float2 b);
125  // static inline __device__ bool eq(float2 a, float2 b);
126 
127  static inline __device__ float2 zero() {
128  float2 v;
129  v.x = 0.0f;
130  v.y = 0.0f;
131  return v;
132  }
133 };
134 
135 template <>
136 struct Math<float4> {
137  typedef float ScalarType;
138 
139  static inline __device__ float4 add(float4 a, float4 b) {
140  float4 v;
141  v.x = a.x + b.x;
142  v.y = a.y + b.y;
143  v.z = a.z + b.z;
144  v.w = a.w + b.w;
145  return v;
146  }
147 
148  static inline __device__ float4 sub(float4 a, float4 b) {
149  float4 v;
150  v.x = a.x - b.x;
151  v.y = a.y - b.y;
152  v.z = a.z - b.z;
153  v.w = a.w - b.w;
154  return v;
155  }
156 
157  static inline __device__ float4 add(float4 a, float b) {
158  float4 v;
159  v.x = a.x + b;
160  v.y = a.y + b;
161  v.z = a.z + b;
162  v.w = a.w + b;
163  return v;
164  }
165 
166  static inline __device__ float4 sub(float4 a, float b) {
167  float4 v;
168  v.x = a.x - b;
169  v.y = a.y - b;
170  v.z = a.z - b;
171  v.w = a.w - b;
172  return v;
173  }
174 
175  static inline __device__ float4 mul(float4 a, float4 b) {
176  float4 v;
177  v.x = a.x * b.x;
178  v.y = a.y * b.y;
179  v.z = a.z * b.z;
180  v.w = a.w * b.w;
181  return v;
182  }
183 
184  static inline __device__ float4 mul(float4 a, float b) {
185  float4 v;
186  v.x = a.x * b;
187  v.y = a.y * b;
188  v.z = a.z * b;
189  v.w = a.w * b;
190  return v;
191  }
192 
193  static inline __device__ float4 neg(float4 v) {
194  v.x = -v.x;
195  v.y = -v.y;
196  v.z = -v.z;
197  v.w = -v.w;
198  return v;
199  }
200 
201  /// For a vector type, this is a horizontal add, returning sum(v_i)
202  static inline __device__ float reduceAdd(float4 v) {
203  return v.x + v.y + v.z + v.w;
204  }
205 
206  // not implemented for vector types
207  // static inline __device__ bool lt(float4 a, float4 b);
208  // static inline __device__ bool gt(float4 a, float4 b);
209  // static inline __device__ bool eq(float4 a, float4 b);
210 
211  static inline __device__ float4 zero() {
212  float4 v;
213  v.x = 0.0f;
214  v.y = 0.0f;
215  v.z = 0.0f;
216  v.w = 0.0f;
217  return v;
218  }
219 };
220 
221 #ifdef FAISS_USE_FLOAT16
222 
223 template <>
224 struct Math<half> {
225  typedef half ScalarType;
226 
227  static inline __device__ half add(half a, half b) {
228 #ifdef FAISS_USE_FULL_FLOAT16
229  return __hadd(a, b);
230 #else
231  return __float2half(__half2float(a) + __half2float(b));
232 #endif
233  }
234 
235  static inline __device__ half sub(half a, half b) {
236 #ifdef FAISS_USE_FULL_FLOAT16
237  return __hsub(a, b);
238 #else
239  return __float2half(__half2float(a) - __half2float(b));
240 #endif
241  }
242 
243  static inline __device__ half mul(half a, half b) {
244 #ifdef FAISS_USE_FULL_FLOAT16
245  return __hmul(a, b);
246 #else
247  return __float2half(__half2float(a) * __half2float(b));
248 #endif
249  }
250 
251  static inline __device__ half neg(half v) {
252 #ifdef FAISS_USE_FULL_FLOAT16
253  return __hneg(v);
254 #else
255  return __float2half(-__half2float(v));
256 #endif
257  }
258 
259  static inline __device__ half reduceAdd(half v) {
260  return v;
261  }
262 
263  static inline __device__ bool lt(half a, half b) {
264 #ifdef FAISS_USE_FULL_FLOAT16
265  return __hlt(a, b);
266 #else
267  return __half2float(a) < __half2float(b);
268 #endif
269  }
270 
271  static inline __device__ bool gt(half a, half b) {
272 #ifdef FAISS_USE_FULL_FLOAT16
273  return __hgt(a, b);
274 #else
275  return __half2float(a) > __half2float(b);
276 #endif
277  }
278 
279  static inline __device__ bool eq(half a, half b) {
280 #ifdef FAISS_USE_FULL_FLOAT16
281  return __heq(a, b);
282 #else
283  return __half2float(a) == __half2float(b);
284 #endif
285  }
286 
287  static inline __device__ half zero() {
288  half h;
289  h.x = 0;
290  return h;
291  }
292 };
293 
294 template <>
295 struct Math<half2> {
296  typedef half ScalarType;
297 
298  static inline __device__ half2 add(half2 a, half2 b) {
299 #ifdef FAISS_USE_FULL_FLOAT16
300  return __hadd2(a, b);
301 #else
302  float2 af = __half22float2(a);
303  float2 bf = __half22float2(b);
304 
305  af.x += bf.x;
306  af.y += bf.y;
307 
308  return __float22half2_rn(af);
309 #endif
310  }
311 
312  static inline __device__ half2 sub(half2 a, half2 b) {
313 #ifdef FAISS_USE_FULL_FLOAT16
314  return __hsub2(a, b);
315 #else
316  float2 af = __half22float2(a);
317  float2 bf = __half22float2(b);
318 
319  af.x -= bf.x;
320  af.y -= bf.y;
321 
322  return __float22half2_rn(af);
323 #endif
324  }
325 
326  static inline __device__ half2 add(half2 a, half b) {
327 #ifdef FAISS_USE_FULL_FLOAT16
328  half2 b2 = __half2half2(b);
329  return __hadd2(a, b2);
330 #else
331  float2 af = __half22float2(a);
332  float bf = __half2float(b);
333 
334  af.x += bf;
335  af.y += bf;
336 
337  return __float22half2_rn(af);
338 #endif
339  }
340 
341  static inline __device__ half2 sub(half2 a, half b) {
342 #ifdef FAISS_USE_FULL_FLOAT16
343  half2 b2 = __half2half2(b);
344  return __hsub2(a, b2);
345 #else
346  float2 af = __half22float2(a);
347  float bf = __half2float(b);
348 
349  af.x -= bf;
350  af.y -= bf;
351 
352  return __float22half2_rn(af);
353 #endif
354  }
355 
356  static inline __device__ half2 mul(half2 a, half2 b) {
357 #ifdef FAISS_USE_FULL_FLOAT16
358  return __hmul2(a, b);
359 #else
360  float2 af = __half22float2(a);
361  float2 bf = __half22float2(b);
362 
363  af.x *= bf.x;
364  af.y *= bf.y;
365 
366  return __float22half2_rn(af);
367 #endif
368  }
369 
370  static inline __device__ half2 mul(half2 a, half b) {
371 #ifdef FAISS_USE_FULL_FLOAT16
372  half2 b2 = __half2half2(b);
373  return __hmul2(a, b2);
374 #else
375  float2 af = __half22float2(a);
376  float bf = __half2float(b);
377 
378  af.x *= bf;
379  af.y *= bf;
380 
381  return __float22half2_rn(af);
382 #endif
383  }
384 
385  static inline __device__ half2 neg(half2 v) {
386 #ifdef FAISS_USE_FULL_FLOAT16
387  return __hneg2(v);
388 #else
389  float2 vf = __half22float2(v);
390  vf.x = -vf.x;
391  vf.y = -vf.y;
392 
393  return __float22half2_rn(vf);
394 #endif
395  }
396 
397  static inline __device__ half reduceAdd(half2 v) {
398 #ifdef FAISS_USE_FULL_FLOAT16
399  half hv = __high2half(v);
400  half lv = __low2half(v);
401 
402  return __hadd(hv, lv);
403 #else
404  float2 vf = __half22float2(v);
405  vf.x += vf.y;
406 
407  return __float2half(vf.x);
408 #endif
409  }
410 
411  // not implemented for vector types
412  // static inline __device__ bool lt(half2 a, half2 b);
413  // static inline __device__ bool gt(half2 a, half2 b);
414  // static inline __device__ bool eq(half2 a, half2 b);
415 
416  static inline __device__ half2 zero() {
417  return __half2half2(Math<half>::zero());
418  }
419 };
420 
421 template <>
422 struct Math<Half4> {
423  typedef half ScalarType;
424 
425  static inline __device__ Half4 add(Half4 a, Half4 b) {
426  Half4 h;
427  h.a = Math<half2>::add(a.a, b.a);
428  h.b = Math<half2>::add(a.b, b.b);
429  return h;
430  }
431 
432  static inline __device__ Half4 sub(Half4 a, Half4 b) {
433  Half4 h;
434  h.a = Math<half2>::sub(a.a, b.a);
435  h.b = Math<half2>::sub(a.b, b.b);
436  return h;
437  }
438 
439  static inline __device__ Half4 add(Half4 a, half b) {
440  Half4 h;
441  h.a = Math<half2>::add(a.a, b);
442  h.b = Math<half2>::add(a.b, b);
443  return h;
444  }
445 
446  static inline __device__ Half4 sub(Half4 a, half b) {
447  Half4 h;
448  h.a = Math<half2>::sub(a.a, b);
449  h.b = Math<half2>::sub(a.b, b);
450  return h;
451  }
452 
453  static inline __device__ Half4 mul(Half4 a, Half4 b) {
454  Half4 h;
455  h.a = Math<half2>::mul(a.a, b.a);
456  h.b = Math<half2>::mul(a.b, b.b);
457  return h;
458  }
459 
460  static inline __device__ Half4 mul(Half4 a, half b) {
461  Half4 h;
462  h.a = Math<half2>::mul(a.a, b);
463  h.b = Math<half2>::mul(a.b, b);
464  return h;
465  }
466 
467  static inline __device__ Half4 neg(Half4 v) {
468  Half4 h;
469  h.a = Math<half2>::neg(v.a);
470  h.b = Math<half2>::neg(v.b);
471  return h;
472  }
473 
474  static inline __device__ half reduceAdd(Half4 v) {
475  half hx = Math<half2>::reduceAdd(v.a);
476  half hy = Math<half2>::reduceAdd(v.b);
477  return Math<half>::add(hx, hy);
478  }
479 
480  // not implemented for vector types
481  // static inline __device__ bool lt(Half4 a, Half4 b);
482  // static inline __device__ bool gt(Half4 a, Half4 b);
483  // static inline __device__ bool eq(Half4 a, Half4 b);
484 
485  static inline __device__ Half4 zero() {
486  Half4 h;
487  h.a = Math<half2>::zero();
488  h.b = Math<half2>::zero();
489  return h;
490  }
491 };
492 
493 template <>
494 struct Math<Half8> {
495  typedef half ScalarType;
496 
497  static inline __device__ Half8 add(Half8 a, Half8 b) {
498  Half8 h;
499  h.a = Math<Half4>::add(a.a, b.a);
500  h.b = Math<Half4>::add(a.b, b.b);
501  return h;
502  }
503 
504  static inline __device__ Half8 sub(Half8 a, Half8 b) {
505  Half8 h;
506  h.a = Math<Half4>::sub(a.a, b.a);
507  h.b = Math<Half4>::sub(a.b, b.b);
508  return h;
509  }
510 
511  static inline __device__ Half8 add(Half8 a, half b) {
512  Half8 h;
513  h.a = Math<Half4>::add(a.a, b);
514  h.b = Math<Half4>::add(a.b, b);
515  return h;
516  }
517 
518  static inline __device__ Half8 sub(Half8 a, half b) {
519  Half8 h;
520  h.a = Math<Half4>::sub(a.a, b);
521  h.b = Math<Half4>::sub(a.b, b);
522  return h;
523  }
524 
525  static inline __device__ Half8 mul(Half8 a, Half8 b) {
526  Half8 h;
527  h.a = Math<Half4>::mul(a.a, b.a);
528  h.b = Math<Half4>::mul(a.b, b.b);
529  return h;
530  }
531 
532  static inline __device__ Half8 mul(Half8 a, half b) {
533  Half8 h;
534  h.a = Math<Half4>::mul(a.a, b);
535  h.b = Math<Half4>::mul(a.b, b);
536  return h;
537  }
538 
539  static inline __device__ Half8 neg(Half8 v) {
540  Half8 h;
541  h.a = Math<Half4>::neg(v.a);
542  h.b = Math<Half4>::neg(v.b);
543  return h;
544  }
545 
546  static inline __device__ half reduceAdd(Half8 v) {
547  half hx = Math<Half4>::reduceAdd(v.a);
548  half hy = Math<Half4>::reduceAdd(v.b);
549  return Math<half>::add(hx, hy);
550  }
551 
552  // not implemented for vector types
553  // static inline __device__ bool lt(Half8 a, Half8 b);
554  // static inline __device__ bool gt(Half8 a, Half8 b);
555  // static inline __device__ bool eq(Half8 a, Half8 b);
556 
557  static inline __device__ Half8 zero() {
558  Half8 h;
559  h.a = Math<Half4>::zero();
560  h.b = Math<Half4>::zero();
561  return h;
562  }
563 };
564 
565 #endif // FAISS_USE_FLOAT16
566 
567 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)