13 #include "Float16.cuh"
21 namespace faiss {
namespace gpu {
27 static inline __device__ T add(T a, T b) {
31 static inline __device__ T sub(T a, T b) {
35 static inline __device__ T mul(T a, T b) {
39 static inline __device__ T neg(T v) {
48 static inline __device__
bool lt(T a, T b) {
52 static inline __device__
bool gt(T a, T b) {
56 static inline __device__
bool eq(T a, T b) {
60 static inline __device__ T zero() {
67 typedef float ScalarType;
69 static inline __device__ float2 add(float2 a, float2 b) {
76 static inline __device__ float2 sub(float2 a, float2 b) {
83 static inline __device__ float2 add(float2 a,
float b) {
90 static inline __device__ float2 sub(float2 a,
float b) {
97 static inline __device__ float2 mul(float2 a, float2 b) {
104 static inline __device__ float2 mul(float2 a,
float b) {
111 static inline __device__ float2 neg(float2 v) {
127 static inline __device__ float2 zero() {
137 typedef float ScalarType;
139 static inline __device__ float4 add(float4 a, float4 b) {
148 static inline __device__ float4 sub(float4 a, float4 b) {
157 static inline __device__ float4 add(float4 a,
float b) {
166 static inline __device__ float4 sub(float4 a,
float b) {
175 static inline __device__ float4 mul(float4 a, float4 b) {
184 static inline __device__ float4 mul(float4 a,
float b) {
193 static inline __device__ float4 neg(float4 v) {
203 return v.x + v.y + v.z + v.w;
211 static inline __device__ float4 zero() {
221 #ifdef FAISS_USE_FLOAT16
225 typedef half ScalarType;
227 static inline __device__ half add(half a, half b) {
228 #ifdef FAISS_USE_FULL_FLOAT16
231 return __float2half(__half2float(a) + __half2float(b));
235 static inline __device__ half sub(half a, half b) {
236 #ifdef FAISS_USE_FULL_FLOAT16
239 return __float2half(__half2float(a) - __half2float(b));
243 static inline __device__ half mul(half a, half b) {
244 #ifdef FAISS_USE_FULL_FLOAT16
247 return __float2half(__half2float(a) * __half2float(b));
251 static inline __device__ half neg(half v) {
252 #ifdef FAISS_USE_FULL_FLOAT16
255 return __float2half(-__half2float(v));
259 static inline __device__ half
reduceAdd(half v) {
263 static inline __device__
bool lt(half a, half b) {
264 #ifdef FAISS_USE_FULL_FLOAT16
267 return __half2float(a) < __half2float(b);
271 static inline __device__
bool gt(half a, half b) {
272 #ifdef FAISS_USE_FULL_FLOAT16
275 return __half2float(a) > __half2float(b);
279 static inline __device__
bool eq(half a, half b) {
280 #ifdef FAISS_USE_FULL_FLOAT16
283 return __half2float(a) == __half2float(b);
287 static inline __device__ half zero() {
288 #if CUDA_VERSION >= 9000
300 typedef half ScalarType;
302 static inline __device__ half2 add(half2 a, half2 b) {
303 #ifdef FAISS_USE_FULL_FLOAT16
304 return __hadd2(a, b);
306 float2 af = __half22float2(a);
307 float2 bf = __half22float2(b);
312 return __float22half2_rn(af);
316 static inline __device__ half2 sub(half2 a, half2 b) {
317 #ifdef FAISS_USE_FULL_FLOAT16
318 return __hsub2(a, b);
320 float2 af = __half22float2(a);
321 float2 bf = __half22float2(b);
326 return __float22half2_rn(af);
330 static inline __device__ half2 add(half2 a, half b) {
331 #ifdef FAISS_USE_FULL_FLOAT16
332 half2 b2 = __half2half2(b);
333 return __hadd2(a, b2);
335 float2 af = __half22float2(a);
336 float bf = __half2float(b);
341 return __float22half2_rn(af);
345 static inline __device__ half2 sub(half2 a, half b) {
346 #ifdef FAISS_USE_FULL_FLOAT16
347 half2 b2 = __half2half2(b);
348 return __hsub2(a, b2);
350 float2 af = __half22float2(a);
351 float bf = __half2float(b);
356 return __float22half2_rn(af);
360 static inline __device__ half2 mul(half2 a, half2 b) {
361 #ifdef FAISS_USE_FULL_FLOAT16
362 return __hmul2(a, b);
364 float2 af = __half22float2(a);
365 float2 bf = __half22float2(b);
370 return __float22half2_rn(af);
374 static inline __device__ half2 mul(half2 a, half b) {
375 #ifdef FAISS_USE_FULL_FLOAT16
376 half2 b2 = __half2half2(b);
377 return __hmul2(a, b2);
379 float2 af = __half22float2(a);
380 float bf = __half2float(b);
385 return __float22half2_rn(af);
389 static inline __device__ half2 neg(half2 v) {
390 #ifdef FAISS_USE_FULL_FLOAT16
393 float2 vf = __half22float2(v);
397 return __float22half2_rn(vf);
401 static inline __device__ half
reduceAdd(half2 v) {
402 #ifdef FAISS_USE_FULL_FLOAT16
403 half hv = __high2half(v);
404 half lv = __low2half(v);
406 return __hadd(hv, lv);
408 float2 vf = __half22float2(v);
411 return __float2half(vf.x);
420 static inline __device__ half2 zero() {
421 return __half2half2(Math<half>::zero());
427 typedef half ScalarType;
429 static inline __device__ Half4 add(Half4 a, Half4 b) {
431 h.a = Math<half2>::add(a.a, b.a);
432 h.b = Math<half2>::add(a.b, b.b);
436 static inline __device__ Half4 sub(Half4 a, Half4 b) {
438 h.a = Math<half2>::sub(a.a, b.a);
439 h.b = Math<half2>::sub(a.b, b.b);
443 static inline __device__ Half4 add(Half4 a, half b) {
445 h.a = Math<half2>::add(a.a, b);
446 h.b = Math<half2>::add(a.b, b);
450 static inline __device__ Half4 sub(Half4 a, half b) {
452 h.a = Math<half2>::sub(a.a, b);
453 h.b = Math<half2>::sub(a.b, b);
457 static inline __device__ Half4 mul(Half4 a, Half4 b) {
459 h.a = Math<half2>::mul(a.a, b.a);
460 h.b = Math<half2>::mul(a.b, b.b);
464 static inline __device__ Half4 mul(Half4 a, half b) {
466 h.a = Math<half2>::mul(a.a, b);
467 h.b = Math<half2>::mul(a.b, b);
471 static inline __device__ Half4 neg(Half4 v) {
473 h.a = Math<half2>::neg(v.a);
474 h.b = Math<half2>::neg(v.b);
478 static inline __device__ half
reduceAdd(Half4 v) {
481 return Math<half>::add(hx, hy);
489 static inline __device__ Half4 zero() {
491 h.a = Math<half2>::zero();
492 h.b = Math<half2>::zero();
499 typedef half ScalarType;
501 static inline __device__ Half8 add(Half8 a, Half8 b) {
503 h.a = Math<Half4>::add(a.a, b.a);
504 h.b = Math<Half4>::add(a.b, b.b);
508 static inline __device__ Half8 sub(Half8 a, Half8 b) {
510 h.a = Math<Half4>::sub(a.a, b.a);
511 h.b = Math<Half4>::sub(a.b, b.b);
515 static inline __device__ Half8 add(Half8 a, half b) {
517 h.a = Math<Half4>::add(a.a, b);
518 h.b = Math<Half4>::add(a.b, b);
522 static inline __device__ Half8 sub(Half8 a, half b) {
524 h.a = Math<Half4>::sub(a.a, b);
525 h.b = Math<Half4>::sub(a.b, b);
529 static inline __device__ Half8 mul(Half8 a, Half8 b) {
531 h.a = Math<Half4>::mul(a.a, b.a);
532 h.b = Math<Half4>::mul(a.b, b.b);
536 static inline __device__ Half8 mul(Half8 a, half b) {
538 h.a = Math<Half4>::mul(a.a, b);
539 h.b = Math<Half4>::mul(a.b, b);
543 static inline __device__ Half8 neg(Half8 v) {
545 h.a = Math<Half4>::neg(v.a);
546 h.b = Math<Half4>::neg(v.b);
550 static inline __device__ half
reduceAdd(Half8 v) {
553 return Math<half>::add(hx, hy);
561 static inline __device__ Half8 zero() {
563 h.a = Math<Half4>::zero();
564 h.b = Math<Half4>::zero();
569 #endif // FAISS_USE_FLOAT16
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)