13 #include "Float16.cuh"
21 namespace faiss {
namespace gpu {
27 static inline __device__ T add(T a, T b) {
31 static inline __device__ T sub(T a, T b) {
35 static inline __device__ T mul(T a, T b) {
39 static inline __device__ T neg(T v) {
48 static inline __device__
bool lt(T a, T b) {
52 static inline __device__
bool gt(T a, T b) {
56 static inline __device__
bool eq(T a, T b) {
60 static inline __device__ T zero() {
67 typedef float ScalarType;
69 static inline __device__ float2 add(float2 a, float2 b) {
76 static inline __device__ float2 sub(float2 a, float2 b) {
83 static inline __device__ float2 add(float2 a,
float b) {
90 static inline __device__ float2 sub(float2 a,
float b) {
97 static inline __device__ float2 mul(float2 a, float2 b) {
104 static inline __device__ float2 mul(float2 a,
float b) {
111 static inline __device__ float2 neg(float2 v) {
127 static inline __device__ float2 zero() {
137 typedef float ScalarType;
139 static inline __device__ float4 add(float4 a, float4 b) {
148 static inline __device__ float4 sub(float4 a, float4 b) {
157 static inline __device__ float4 add(float4 a,
float b) {
166 static inline __device__ float4 sub(float4 a,
float b) {
175 static inline __device__ float4 mul(float4 a, float4 b) {
184 static inline __device__ float4 mul(float4 a,
float b) {
193 static inline __device__ float4 neg(float4 v) {
203 return v.x + v.y + v.z + v.w;
211 static inline __device__ float4 zero() {
221 #ifdef FAISS_USE_FLOAT16
225 typedef half ScalarType;
227 static inline __device__ half add(half a, half b) {
228 #ifdef FAISS_USE_FULL_FLOAT16
231 return __float2half(__half2float(a) + __half2float(b));
235 static inline __device__ half sub(half a, half b) {
236 #ifdef FAISS_USE_FULL_FLOAT16
239 return __float2half(__half2float(a) - __half2float(b));
243 static inline __device__ half mul(half a, half b) {
244 #ifdef FAISS_USE_FULL_FLOAT16
247 return __float2half(__half2float(a) * __half2float(b));
251 static inline __device__ half neg(half v) {
252 #ifdef FAISS_USE_FULL_FLOAT16
255 return __float2half(-__half2float(v));
259 static inline __device__ half
reduceAdd(half v) {
263 static inline __device__
bool lt(half a, half b) {
264 #ifdef FAISS_USE_FULL_FLOAT16
267 return __half2float(a) < __half2float(b);
271 static inline __device__
bool gt(half a, half b) {
272 #ifdef FAISS_USE_FULL_FLOAT16
275 return __half2float(a) > __half2float(b);
279 static inline __device__
bool eq(half a, half b) {
280 #ifdef FAISS_USE_FULL_FLOAT16
283 return __half2float(a) == __half2float(b);
287 static inline __device__ half zero() {
296 typedef half ScalarType;
298 static inline __device__ half2 add(half2 a, half2 b) {
299 #ifdef FAISS_USE_FULL_FLOAT16
300 return __hadd2(a, b);
302 float2 af = __half22float2(a);
303 float2 bf = __half22float2(b);
308 return __float22half2_rn(af);
312 static inline __device__ half2 sub(half2 a, half2 b) {
313 #ifdef FAISS_USE_FULL_FLOAT16
314 return __hsub2(a, b);
316 float2 af = __half22float2(a);
317 float2 bf = __half22float2(b);
322 return __float22half2_rn(af);
326 static inline __device__ half2 add(half2 a, half b) {
327 #ifdef FAISS_USE_FULL_FLOAT16
328 half2 b2 = __half2half2(b);
329 return __hadd2(a, b2);
331 float2 af = __half22float2(a);
332 float bf = __half2float(b);
337 return __float22half2_rn(af);
341 static inline __device__ half2 sub(half2 a, half b) {
342 #ifdef FAISS_USE_FULL_FLOAT16
343 half2 b2 = __half2half2(b);
344 return __hsub2(a, b2);
346 float2 af = __half22float2(a);
347 float bf = __half2float(b);
352 return __float22half2_rn(af);
356 static inline __device__ half2 mul(half2 a, half2 b) {
357 #ifdef FAISS_USE_FULL_FLOAT16
358 return __hmul2(a, b);
360 float2 af = __half22float2(a);
361 float2 bf = __half22float2(b);
366 return __float22half2_rn(af);
370 static inline __device__ half2 mul(half2 a, half b) {
371 #ifdef FAISS_USE_FULL_FLOAT16
372 half2 b2 = __half2half2(b);
373 return __hmul2(a, b2);
375 float2 af = __half22float2(a);
376 float bf = __half2float(b);
381 return __float22half2_rn(af);
385 static inline __device__ half2 neg(half2 v) {
386 #ifdef FAISS_USE_FULL_FLOAT16
389 float2 vf = __half22float2(v);
393 return __float22half2_rn(vf);
397 static inline __device__ half
reduceAdd(half2 v) {
398 #ifdef FAISS_USE_FULL_FLOAT16
399 half hv = __high2half(v);
400 half lv = __low2half(v);
402 return __hadd(hv, lv);
404 float2 vf = __half22float2(v);
407 return __float2half(vf.x);
416 static inline __device__ half2 zero() {
417 return __half2half2(Math<half>::zero());
423 typedef half ScalarType;
425 static inline __device__ Half4 add(Half4 a, Half4 b) {
427 h.a = Math<half2>::add(a.a, b.a);
428 h.b = Math<half2>::add(a.b, b.b);
432 static inline __device__ Half4 sub(Half4 a, Half4 b) {
434 h.a = Math<half2>::sub(a.a, b.a);
435 h.b = Math<half2>::sub(a.b, b.b);
439 static inline __device__ Half4 add(Half4 a, half b) {
441 h.a = Math<half2>::add(a.a, b);
442 h.b = Math<half2>::add(a.b, b);
446 static inline __device__ Half4 sub(Half4 a, half b) {
448 h.a = Math<half2>::sub(a.a, b);
449 h.b = Math<half2>::sub(a.b, b);
453 static inline __device__ Half4 mul(Half4 a, Half4 b) {
455 h.a = Math<half2>::mul(a.a, b.a);
456 h.b = Math<half2>::mul(a.b, b.b);
460 static inline __device__ Half4 mul(Half4 a, half b) {
462 h.a = Math<half2>::mul(a.a, b);
463 h.b = Math<half2>::mul(a.b, b);
467 static inline __device__ Half4 neg(Half4 v) {
469 h.a = Math<half2>::neg(v.a);
470 h.b = Math<half2>::neg(v.b);
474 static inline __device__ half
reduceAdd(Half4 v) {
477 return Math<half>::add(hx, hy);
485 static inline __device__ Half4 zero() {
487 h.a = Math<half2>::zero();
488 h.b = Math<half2>::zero();
495 typedef half ScalarType;
497 static inline __device__ Half8 add(Half8 a, Half8 b) {
499 h.a = Math<Half4>::add(a.a, b.a);
500 h.b = Math<Half4>::add(a.b, b.b);
504 static inline __device__ Half8 sub(Half8 a, Half8 b) {
506 h.a = Math<Half4>::sub(a.a, b.a);
507 h.b = Math<Half4>::sub(a.b, b.b);
511 static inline __device__ Half8 add(Half8 a, half b) {
513 h.a = Math<Half4>::add(a.a, b);
514 h.b = Math<Half4>::add(a.b, b);
518 static inline __device__ Half8 sub(Half8 a, half b) {
520 h.a = Math<Half4>::sub(a.a, b);
521 h.b = Math<Half4>::sub(a.b, b);
525 static inline __device__ Half8 mul(Half8 a, Half8 b) {
527 h.a = Math<Half4>::mul(a.a, b.a);
528 h.b = Math<Half4>::mul(a.b, b.b);
532 static inline __device__ Half8 mul(Half8 a, half b) {
534 h.a = Math<Half4>::mul(a.a, b);
535 h.b = Math<Half4>::mul(a.b, b);
539 static inline __device__ Half8 neg(Half8 v) {
541 h.a = Math<Half4>::neg(v.a);
542 h.b = Math<Half4>::neg(v.b);
546 static inline __device__ half
reduceAdd(Half8 v) {
549 return Math<half>::add(hx, hy);
557 static inline __device__ Half8 zero() {
559 h.a = Math<Half4>::zero();
560 h.b = Math<Half4>::zero();
565 #endif // FAISS_USE_FLOAT16
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)