Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MathOperators.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #pragma once
13 
14 #include "Float16.cuh"
15 
16 //
17 // Templated wrappers to express math for different scalar and vector
18 // types, so kernels can have the same written form but can operate
19 // over half and float, and on vector types transparently
20 //
21 
22 namespace faiss { namespace gpu {
23 
24 template <typename T>
25 struct Math {
26  typedef T ScalarType;
27 
28  static inline __device__ T add(T a, T b) {
29  return a + b;
30  }
31 
32  static inline __device__ T sub(T a, T b) {
33  return a - b;
34  }
35 
36  static inline __device__ T mul(T a, T b) {
37  return a * b;
38  }
39 
40  static inline __device__ T neg(T v) {
41  return -v;
42  }
43 
44  /// For a vector type, this is a horizontal add, returning sum(v_i)
45  static inline __device__ T reduceAdd(T v) {
46  return v;
47  }
48 
49  static inline __device__ bool lt(T a, T b) {
50  return a < b;
51  }
52 
53  static inline __device__ bool gt(T a, T b) {
54  return a > b;
55  }
56 
57  static inline __device__ bool eq(T a, T b) {
58  return a == b;
59  }
60 
61  static inline __device__ T zero() {
62  return (T) 0;
63  }
64 };
65 
66 template <>
67 struct Math<float2> {
68  typedef float ScalarType;
69 
70  static inline __device__ float2 add(float2 a, float2 b) {
71  float2 v;
72  v.x = a.x + b.x;
73  v.y = a.y + b.y;
74  return v;
75  }
76 
77  static inline __device__ float2 sub(float2 a, float2 b) {
78  float2 v;
79  v.x = a.x - b.x;
80  v.y = a.y - b.y;
81  return v;
82  }
83 
84  static inline __device__ float2 add(float2 a, float b) {
85  float2 v;
86  v.x = a.x + b;
87  v.y = a.y + b;
88  return v;
89  }
90 
91  static inline __device__ float2 sub(float2 a, float b) {
92  float2 v;
93  v.x = a.x - b;
94  v.y = a.y - b;
95  return v;
96  }
97 
98  static inline __device__ float2 mul(float2 a, float2 b) {
99  float2 v;
100  v.x = a.x * b.x;
101  v.y = a.y * b.y;
102  return v;
103  }
104 
105  static inline __device__ float2 mul(float2 a, float b) {
106  float2 v;
107  v.x = a.x * b;
108  v.y = a.y * b;
109  return v;
110  }
111 
112  static inline __device__ float2 neg(float2 v) {
113  v.x = -v.x;
114  v.y = -v.y;
115  return v;
116  }
117 
118  /// For a vector type, this is a horizontal add, returning sum(v_i)
119  static inline __device__ float reduceAdd(float2 v) {
120  return v.x + v.y;
121  }
122 
123  // not implemented for vector types
124  // static inline __device__ bool lt(float2 a, float2 b);
125  // static inline __device__ bool gt(float2 a, float2 b);
126  // static inline __device__ bool eq(float2 a, float2 b);
127 
128  static inline __device__ float2 zero() {
129  float2 v;
130  v.x = 0.0f;
131  v.y = 0.0f;
132  return v;
133  }
134 };
135 
136 template <>
137 struct Math<float4> {
138  typedef float ScalarType;
139 
140  static inline __device__ float4 add(float4 a, float4 b) {
141  float4 v;
142  v.x = a.x + b.x;
143  v.y = a.y + b.y;
144  v.z = a.z + b.z;
145  v.w = a.w + b.w;
146  return v;
147  }
148 
149  static inline __device__ float4 sub(float4 a, float4 b) {
150  float4 v;
151  v.x = a.x - b.x;
152  v.y = a.y - b.y;
153  v.z = a.z - b.z;
154  v.w = a.w - b.w;
155  return v;
156  }
157 
158  static inline __device__ float4 add(float4 a, float b) {
159  float4 v;
160  v.x = a.x + b;
161  v.y = a.y + b;
162  v.z = a.z + b;
163  v.w = a.w + b;
164  return v;
165  }
166 
167  static inline __device__ float4 sub(float4 a, float b) {
168  float4 v;
169  v.x = a.x - b;
170  v.y = a.y - b;
171  v.z = a.z - b;
172  v.w = a.w - b;
173  return v;
174  }
175 
176  static inline __device__ float4 mul(float4 a, float4 b) {
177  float4 v;
178  v.x = a.x * b.x;
179  v.y = a.y * b.y;
180  v.z = a.z * b.z;
181  v.w = a.w * b.w;
182  return v;
183  }
184 
185  static inline __device__ float4 mul(float4 a, float b) {
186  float4 v;
187  v.x = a.x * b;
188  v.y = a.y * b;
189  v.z = a.z * b;
190  v.w = a.w * b;
191  return v;
192  }
193 
194  static inline __device__ float4 neg(float4 v) {
195  v.x = -v.x;
196  v.y = -v.y;
197  v.z = -v.z;
198  v.w = -v.w;
199  return v;
200  }
201 
202  /// For a vector type, this is a horizontal add, returning sum(v_i)
203  static inline __device__ float reduceAdd(float4 v) {
204  return v.x + v.y + v.z + v.w;
205  }
206 
207  // not implemented for vector types
208  // static inline __device__ bool lt(float4 a, float4 b);
209  // static inline __device__ bool gt(float4 a, float4 b);
210  // static inline __device__ bool eq(float4 a, float4 b);
211 
212  static inline __device__ float4 zero() {
213  float4 v;
214  v.x = 0.0f;
215  v.y = 0.0f;
216  v.z = 0.0f;
217  v.w = 0.0f;
218  return v;
219  }
220 };
221 
222 #ifdef FAISS_USE_FLOAT16
223 
224 template <>
225 struct Math<half> {
226  typedef half ScalarType;
227 
228  static inline __device__ half add(half a, half b) {
229 #ifdef FAISS_USE_FULL_FLOAT16
230  return __hadd(a, b);
231 #else
232  return __float2half(__half2float(a) + __half2float(b));
233 #endif
234  }
235 
236  static inline __device__ half sub(half a, half b) {
237 #ifdef FAISS_USE_FULL_FLOAT16
238  return __hsub(a, b);
239 #else
240  return __float2half(__half2float(a) - __half2float(b));
241 #endif
242  }
243 
244  static inline __device__ half mul(half a, half b) {
245 #ifdef FAISS_USE_FULL_FLOAT16
246  return __hmul(a, b);
247 #else
248  return __float2half(__half2float(a) * __half2float(b));
249 #endif
250  }
251 
252  static inline __device__ half neg(half v) {
253 #ifdef FAISS_USE_FULL_FLOAT16
254  return __hneg(v);
255 #else
256  return __float2half(-__half2float(v));
257 #endif
258  }
259 
260  static inline __device__ half reduceAdd(half v) {
261  return v;
262  }
263 
264  static inline __device__ bool lt(half a, half b) {
265 #ifdef FAISS_USE_FULL_FLOAT16
266  return __hlt(a, b);
267 #else
268  return __half2float(a) < __half2float(b);
269 #endif
270  }
271 
272  static inline __device__ bool gt(half a, half b) {
273 #ifdef FAISS_USE_FULL_FLOAT16
274  return __hgt(a, b);
275 #else
276  return __half2float(a) > __half2float(b);
277 #endif
278  }
279 
280  static inline __device__ bool eq(half a, half b) {
281 #ifdef FAISS_USE_FULL_FLOAT16
282  return __heq(a, b);
283 #else
284  return __half2float(a) == __half2float(b);
285 #endif
286  }
287 
288  static inline __device__ half zero() {
289  half h;
290  h.x = 0;
291  return h;
292  }
293 };
294 
295 template <>
296 struct Math<half2> {
297  typedef half ScalarType;
298 
299  static inline __device__ half2 add(half2 a, half2 b) {
300 #ifdef FAISS_USE_FULL_FLOAT16
301  return __hadd2(a, b);
302 #else
303  float2 af = __half22float2(a);
304  float2 bf = __half22float2(b);
305 
306  af.x += bf.x;
307  af.y += bf.y;
308 
309  return __float22half2_rn(af);
310 #endif
311  }
312 
313  static inline __device__ half2 sub(half2 a, half2 b) {
314 #ifdef FAISS_USE_FULL_FLOAT16
315  return __hsub2(a, b);
316 #else
317  float2 af = __half22float2(a);
318  float2 bf = __half22float2(b);
319 
320  af.x -= bf.x;
321  af.y -= bf.y;
322 
323  return __float22half2_rn(af);
324 #endif
325  }
326 
327  static inline __device__ half2 add(half2 a, half b) {
328 #ifdef FAISS_USE_FULL_FLOAT16
329  half2 b2 = __half2half2(b);
330  return __hadd2(a, b2);
331 #else
332  float2 af = __half22float2(a);
333  float bf = __half2float(b);
334 
335  af.x += bf;
336  af.y += bf;
337 
338  return __float22half2_rn(af);
339 #endif
340  }
341 
342  static inline __device__ half2 sub(half2 a, half b) {
343 #ifdef FAISS_USE_FULL_FLOAT16
344  half2 b2 = __half2half2(b);
345  return __hsub2(a, b2);
346 #else
347  float2 af = __half22float2(a);
348  float bf = __half2float(b);
349 
350  af.x -= bf;
351  af.y -= bf;
352 
353  return __float22half2_rn(af);
354 #endif
355  }
356 
357  static inline __device__ half2 mul(half2 a, half2 b) {
358 #ifdef FAISS_USE_FULL_FLOAT16
359  return __hmul2(a, b);
360 #else
361  float2 af = __half22float2(a);
362  float2 bf = __half22float2(b);
363 
364  af.x *= bf.x;
365  af.y *= bf.y;
366 
367  return __float22half2_rn(af);
368 #endif
369  }
370 
371  static inline __device__ half2 mul(half2 a, half b) {
372 #ifdef FAISS_USE_FULL_FLOAT16
373  half2 b2 = __half2half2(b);
374  return __hmul2(a, b2);
375 #else
376  float2 af = __half22float2(a);
377  float bf = __half2float(b);
378 
379  af.x *= bf;
380  af.y *= bf;
381 
382  return __float22half2_rn(af);
383 #endif
384  }
385 
386  static inline __device__ half2 neg(half2 v) {
387 #ifdef FAISS_USE_FULL_FLOAT16
388  return __hneg2(v);
389 #else
390  float2 vf = __half22float2(v);
391  vf.x = -vf.x;
392  vf.y = -vf.y;
393 
394  return __float22half2_rn(vf);
395 #endif
396  }
397 
398  static inline __device__ half reduceAdd(half2 v) {
399 #ifdef FAISS_USE_FULL_FLOAT16
400  half hv = __high2half(v);
401  half lv = __low2half(v);
402 
403  return __hadd(hv, lv);
404 #else
405  float2 vf = __half22float2(v);
406  vf.x += vf.y;
407 
408  return __float2half(vf.x);
409 #endif
410  }
411 
412  // not implemented for vector types
413  // static inline __device__ bool lt(half2 a, half2 b);
414  // static inline __device__ bool gt(half2 a, half2 b);
415  // static inline __device__ bool eq(half2 a, half2 b);
416 
417  static inline __device__ half2 zero() {
418  return __half2half2(Math<half>::zero());
419  }
420 };
421 
422 template <>
423 struct Math<Half4> {
424  typedef half ScalarType;
425 
426  static inline __device__ Half4 add(Half4 a, Half4 b) {
427  Half4 h;
428  h.a = Math<half2>::add(a.a, b.a);
429  h.b = Math<half2>::add(a.b, b.b);
430  return h;
431  }
432 
433  static inline __device__ Half4 sub(Half4 a, Half4 b) {
434  Half4 h;
435  h.a = Math<half2>::sub(a.a, b.a);
436  h.b = Math<half2>::sub(a.b, b.b);
437  return h;
438  }
439 
440  static inline __device__ Half4 add(Half4 a, half b) {
441  Half4 h;
442  h.a = Math<half2>::add(a.a, b);
443  h.b = Math<half2>::add(a.b, b);
444  return h;
445  }
446 
447  static inline __device__ Half4 sub(Half4 a, half b) {
448  Half4 h;
449  h.a = Math<half2>::sub(a.a, b);
450  h.b = Math<half2>::sub(a.b, b);
451  return h;
452  }
453 
454  static inline __device__ Half4 mul(Half4 a, Half4 b) {
455  Half4 h;
456  h.a = Math<half2>::mul(a.a, b.a);
457  h.b = Math<half2>::mul(a.b, b.b);
458  return h;
459  }
460 
461  static inline __device__ Half4 mul(Half4 a, half b) {
462  Half4 h;
463  h.a = Math<half2>::mul(a.a, b);
464  h.b = Math<half2>::mul(a.b, b);
465  return h;
466  }
467 
468  static inline __device__ Half4 neg(Half4 v) {
469  Half4 h;
470  h.a = Math<half2>::neg(v.a);
471  h.b = Math<half2>::neg(v.b);
472  return h;
473  }
474 
475  static inline __device__ half reduceAdd(Half4 v) {
476  half hx = Math<half2>::reduceAdd(v.a);
477  half hy = Math<half2>::reduceAdd(v.b);
478  return Math<half>::add(hx, hy);
479  }
480 
481  // not implemented for vector types
482  // static inline __device__ bool lt(Half4 a, Half4 b);
483  // static inline __device__ bool gt(Half4 a, Half4 b);
484  // static inline __device__ bool eq(Half4 a, Half4 b);
485 
486  static inline __device__ Half4 zero() {
487  Half4 h;
488  h.a = Math<half2>::zero();
489  h.b = Math<half2>::zero();
490  return h;
491  }
492 };
493 
494 template <>
495 struct Math<Half8> {
496  typedef half ScalarType;
497 
498  static inline __device__ Half8 add(Half8 a, Half8 b) {
499  Half8 h;
500  h.a = Math<Half4>::add(a.a, b.a);
501  h.b = Math<Half4>::add(a.b, b.b);
502  return h;
503  }
504 
505  static inline __device__ Half8 sub(Half8 a, Half8 b) {
506  Half8 h;
507  h.a = Math<Half4>::sub(a.a, b.a);
508  h.b = Math<Half4>::sub(a.b, b.b);
509  return h;
510  }
511 
512  static inline __device__ Half8 add(Half8 a, half b) {
513  Half8 h;
514  h.a = Math<Half4>::add(a.a, b);
515  h.b = Math<Half4>::add(a.b, b);
516  return h;
517  }
518 
519  static inline __device__ Half8 sub(Half8 a, half b) {
520  Half8 h;
521  h.a = Math<Half4>::sub(a.a, b);
522  h.b = Math<Half4>::sub(a.b, b);
523  return h;
524  }
525 
526  static inline __device__ Half8 mul(Half8 a, Half8 b) {
527  Half8 h;
528  h.a = Math<Half4>::mul(a.a, b.a);
529  h.b = Math<Half4>::mul(a.b, b.b);
530  return h;
531  }
532 
533  static inline __device__ Half8 mul(Half8 a, half b) {
534  Half8 h;
535  h.a = Math<Half4>::mul(a.a, b);
536  h.b = Math<Half4>::mul(a.b, b);
537  return h;
538  }
539 
540  static inline __device__ Half8 neg(Half8 v) {
541  Half8 h;
542  h.a = Math<Half4>::neg(v.a);
543  h.b = Math<Half4>::neg(v.b);
544  return h;
545  }
546 
547  static inline __device__ half reduceAdd(Half8 v) {
548  half hx = Math<Half4>::reduceAdd(v.a);
549  half hy = Math<Half4>::reduceAdd(v.b);
550  return Math<half>::add(hx, hy);
551  }
552 
553  // not implemented for vector types
554  // static inline __device__ bool lt(Half8 a, Half8 b);
555  // static inline __device__ bool gt(Half8 a, Half8 b);
556  // static inline __device__ bool eq(Half8 a, Half8 b);
557 
558  static inline __device__ Half8 zero() {
559  Half8 h;
560  h.a = Math<Half4>::zero();
561  h.b = Math<Half4>::zero();
562  return h;
563  }
564 };
565 
566 #endif // FAISS_USE_FLOAT16
567 
568 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)