12 #include "Float16.cuh"
20 namespace faiss {
namespace gpu {
26 static inline __device__ T add(T a, T b) {
30 static inline __device__ T sub(T a, T b) {
34 static inline __device__ T mul(T a, T b) {
38 static inline __device__ T neg(T v) {
47 static inline __device__
bool lt(T a, T b) {
51 static inline __device__
bool gt(T a, T b) {
55 static inline __device__
bool eq(T a, T b) {
59 static inline __device__ T zero() {
66 typedef float ScalarType;
68 static inline __device__ float2 add(float2 a, float2 b) {
75 static inline __device__ float2 sub(float2 a, float2 b) {
82 static inline __device__ float2 add(float2 a,
float b) {
89 static inline __device__ float2 sub(float2 a,
float b) {
96 static inline __device__ float2 mul(float2 a, float2 b) {
103 static inline __device__ float2 mul(float2 a,
float b) {
110 static inline __device__ float2 neg(float2 v) {
126 static inline __device__ float2 zero() {
136 typedef float ScalarType;
138 static inline __device__ float4 add(float4 a, float4 b) {
147 static inline __device__ float4 sub(float4 a, float4 b) {
156 static inline __device__ float4 add(float4 a,
float b) {
165 static inline __device__ float4 sub(float4 a,
float b) {
174 static inline __device__ float4 mul(float4 a, float4 b) {
183 static inline __device__ float4 mul(float4 a,
float b) {
192 static inline __device__ float4 neg(float4 v) {
202 return v.x + v.y + v.z + v.w;
210 static inline __device__ float4 zero() {
220 #ifdef FAISS_USE_FLOAT16
224 typedef half ScalarType;
226 static inline __device__ half add(half a, half b) {
227 #ifdef FAISS_USE_FULL_FLOAT16
230 return __float2half(__half2float(a) + __half2float(b));
234 static inline __device__ half sub(half a, half b) {
235 #ifdef FAISS_USE_FULL_FLOAT16
238 return __float2half(__half2float(a) - __half2float(b));
242 static inline __device__ half mul(half a, half b) {
243 #ifdef FAISS_USE_FULL_FLOAT16
246 return __float2half(__half2float(a) * __half2float(b));
250 static inline __device__ half neg(half v) {
251 #ifdef FAISS_USE_FULL_FLOAT16
254 return __float2half(-__half2float(v));
258 static inline __device__ half
reduceAdd(half v) {
262 static inline __device__
bool lt(half a, half b) {
263 #ifdef FAISS_USE_FULL_FLOAT16
266 return __half2float(a) < __half2float(b);
270 static inline __device__
bool gt(half a, half b) {
271 #ifdef FAISS_USE_FULL_FLOAT16
274 return __half2float(a) > __half2float(b);
278 static inline __device__
bool eq(half a, half b) {
279 #ifdef FAISS_USE_FULL_FLOAT16
282 return __half2float(a) == __half2float(b);
286 static inline __device__ half zero() {
287 #if CUDA_VERSION >= 9000
299 typedef half ScalarType;
301 static inline __device__ half2 add(half2 a, half2 b) {
302 #ifdef FAISS_USE_FULL_FLOAT16
303 return __hadd2(a, b);
305 float2 af = __half22float2(a);
306 float2 bf = __half22float2(b);
311 return __float22half2_rn(af);
315 static inline __device__ half2 sub(half2 a, half2 b) {
316 #ifdef FAISS_USE_FULL_FLOAT16
317 return __hsub2(a, b);
319 float2 af = __half22float2(a);
320 float2 bf = __half22float2(b);
325 return __float22half2_rn(af);
329 static inline __device__ half2 add(half2 a, half b) {
330 #ifdef FAISS_USE_FULL_FLOAT16
331 half2 b2 = __half2half2(b);
332 return __hadd2(a, b2);
334 float2 af = __half22float2(a);
335 float bf = __half2float(b);
340 return __float22half2_rn(af);
344 static inline __device__ half2 sub(half2 a, half b) {
345 #ifdef FAISS_USE_FULL_FLOAT16
346 half2 b2 = __half2half2(b);
347 return __hsub2(a, b2);
349 float2 af = __half22float2(a);
350 float bf = __half2float(b);
355 return __float22half2_rn(af);
359 static inline __device__ half2 mul(half2 a, half2 b) {
360 #ifdef FAISS_USE_FULL_FLOAT16
361 return __hmul2(a, b);
363 float2 af = __half22float2(a);
364 float2 bf = __half22float2(b);
369 return __float22half2_rn(af);
373 static inline __device__ half2 mul(half2 a, half b) {
374 #ifdef FAISS_USE_FULL_FLOAT16
375 half2 b2 = __half2half2(b);
376 return __hmul2(a, b2);
378 float2 af = __half22float2(a);
379 float bf = __half2float(b);
384 return __float22half2_rn(af);
388 static inline __device__ half2 neg(half2 v) {
389 #ifdef FAISS_USE_FULL_FLOAT16
392 float2 vf = __half22float2(v);
396 return __float22half2_rn(vf);
400 static inline __device__ half
reduceAdd(half2 v) {
401 #ifdef FAISS_USE_FULL_FLOAT16
402 half hv = __high2half(v);
403 half lv = __low2half(v);
405 return __hadd(hv, lv);
407 float2 vf = __half22float2(v);
410 return __float2half(vf.x);
419 static inline __device__ half2 zero() {
420 return __half2half2(Math<half>::zero());
426 typedef half ScalarType;
428 static inline __device__ Half4 add(Half4 a, Half4 b) {
430 h.a = Math<half2>::add(a.a, b.a);
431 h.b = Math<half2>::add(a.b, b.b);
435 static inline __device__ Half4 sub(Half4 a, Half4 b) {
437 h.a = Math<half2>::sub(a.a, b.a);
438 h.b = Math<half2>::sub(a.b, b.b);
442 static inline __device__ Half4 add(Half4 a, half b) {
444 h.a = Math<half2>::add(a.a, b);
445 h.b = Math<half2>::add(a.b, b);
449 static inline __device__ Half4 sub(Half4 a, half b) {
451 h.a = Math<half2>::sub(a.a, b);
452 h.b = Math<half2>::sub(a.b, b);
456 static inline __device__ Half4 mul(Half4 a, Half4 b) {
458 h.a = Math<half2>::mul(a.a, b.a);
459 h.b = Math<half2>::mul(a.b, b.b);
463 static inline __device__ Half4 mul(Half4 a, half b) {
465 h.a = Math<half2>::mul(a.a, b);
466 h.b = Math<half2>::mul(a.b, b);
470 static inline __device__ Half4 neg(Half4 v) {
472 h.a = Math<half2>::neg(v.a);
473 h.b = Math<half2>::neg(v.b);
477 static inline __device__ half
reduceAdd(Half4 v) {
480 return Math<half>::add(hx, hy);
488 static inline __device__ Half4 zero() {
490 h.a = Math<half2>::zero();
491 h.b = Math<half2>::zero();
498 typedef half ScalarType;
500 static inline __device__ Half8 add(Half8 a, Half8 b) {
502 h.a = Math<Half4>::add(a.a, b.a);
503 h.b = Math<Half4>::add(a.b, b.b);
507 static inline __device__ Half8 sub(Half8 a, Half8 b) {
509 h.a = Math<Half4>::sub(a.a, b.a);
510 h.b = Math<Half4>::sub(a.b, b.b);
514 static inline __device__ Half8 add(Half8 a, half b) {
516 h.a = Math<Half4>::add(a.a, b);
517 h.b = Math<Half4>::add(a.b, b);
521 static inline __device__ Half8 sub(Half8 a, half b) {
523 h.a = Math<Half4>::sub(a.a, b);
524 h.b = Math<Half4>::sub(a.b, b);
528 static inline __device__ Half8 mul(Half8 a, Half8 b) {
530 h.a = Math<Half4>::mul(a.a, b.a);
531 h.b = Math<Half4>::mul(a.b, b.b);
535 static inline __device__ Half8 mul(Half8 a, half b) {
537 h.a = Math<Half4>::mul(a.a, b);
538 h.b = Math<Half4>::mul(a.b, b);
542 static inline __device__ Half8 neg(Half8 v) {
544 h.a = Math<Half4>::neg(v.a);
545 h.b = Math<Half4>::neg(v.b);
549 static inline __device__ half
reduceAdd(Half8 v) {
552 return Math<half>::add(hx, hy);
560 static inline __device__ Half8 zero() {
562 h.a = Math<Half4>::zero();
563 h.b = Math<Half4>::zero();
568 #endif // FAISS_USE_FLOAT16
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)