14 #include "Float16.cuh"
22 namespace faiss {
namespace gpu {
28 static inline __device__ T add(T a, T b) {
32 static inline __device__ T sub(T a, T b) {
36 static inline __device__ T mul(T a, T b) {
40 static inline __device__ T neg(T v) {
49 static inline __device__
bool lt(T a, T b) {
53 static inline __device__
bool gt(T a, T b) {
57 static inline __device__
bool eq(T a, T b) {
61 static inline __device__ T zero() {
68 typedef float ScalarType;
70 static inline __device__ float2 add(float2 a, float2 b) {
77 static inline __device__ float2 sub(float2 a, float2 b) {
84 static inline __device__ float2 add(float2 a,
float b) {
91 static inline __device__ float2 sub(float2 a,
float b) {
98 static inline __device__ float2 mul(float2 a, float2 b) {
105 static inline __device__ float2 mul(float2 a,
float b) {
112 static inline __device__ float2 neg(float2 v) {
128 static inline __device__ float2 zero() {
138 typedef float ScalarType;
140 static inline __device__ float4 add(float4 a, float4 b) {
149 static inline __device__ float4 sub(float4 a, float4 b) {
158 static inline __device__ float4 add(float4 a,
float b) {
167 static inline __device__ float4 sub(float4 a,
float b) {
176 static inline __device__ float4 mul(float4 a, float4 b) {
185 static inline __device__ float4 mul(float4 a,
float b) {
194 static inline __device__ float4 neg(float4 v) {
204 return v.x + v.y + v.z + v.w;
212 static inline __device__ float4 zero() {
222 #ifdef FAISS_USE_FLOAT16
226 typedef half ScalarType;
228 static inline __device__ half add(half a, half b) {
229 #ifdef FAISS_USE_FULL_FLOAT16
232 return __float2half(__half2float(a) + __half2float(b));
236 static inline __device__ half sub(half a, half b) {
237 #ifdef FAISS_USE_FULL_FLOAT16
240 return __float2half(__half2float(a) - __half2float(b));
244 static inline __device__ half mul(half a, half b) {
245 #ifdef FAISS_USE_FULL_FLOAT16
248 return __float2half(__half2float(a) * __half2float(b));
252 static inline __device__ half neg(half v) {
253 #ifdef FAISS_USE_FULL_FLOAT16
256 return __float2half(-__half2float(v));
260 static inline __device__ half
reduceAdd(half v) {
264 static inline __device__
bool lt(half a, half b) {
265 #ifdef FAISS_USE_FULL_FLOAT16
268 return __half2float(a) < __half2float(b);
272 static inline __device__
bool gt(half a, half b) {
273 #ifdef FAISS_USE_FULL_FLOAT16
276 return __half2float(a) > __half2float(b);
280 static inline __device__
bool eq(half a, half b) {
281 #ifdef FAISS_USE_FULL_FLOAT16
284 return __half2float(a) == __half2float(b);
288 static inline __device__ half zero() {
297 typedef half ScalarType;
299 static inline __device__ half2 add(half2 a, half2 b) {
300 #ifdef FAISS_USE_FULL_FLOAT16
301 return __hadd2(a, b);
303 float2 af = __half22float2(a);
304 float2 bf = __half22float2(b);
309 return __float22half2_rn(af);
313 static inline __device__ half2 sub(half2 a, half2 b) {
314 #ifdef FAISS_USE_FULL_FLOAT16
315 return __hsub2(a, b);
317 float2 af = __half22float2(a);
318 float2 bf = __half22float2(b);
323 return __float22half2_rn(af);
327 static inline __device__ half2 add(half2 a, half b) {
328 #ifdef FAISS_USE_FULL_FLOAT16
329 half2 b2 = __half2half2(b);
330 return __hadd2(a, b2);
332 float2 af = __half22float2(a);
333 float bf = __half2float(b);
338 return __float22half2_rn(af);
342 static inline __device__ half2 sub(half2 a, half b) {
343 #ifdef FAISS_USE_FULL_FLOAT16
344 half2 b2 = __half2half2(b);
345 return __hsub2(a, b2);
347 float2 af = __half22float2(a);
348 float bf = __half2float(b);
353 return __float22half2_rn(af);
357 static inline __device__ half2 mul(half2 a, half2 b) {
358 #ifdef FAISS_USE_FULL_FLOAT16
359 return __hmul2(a, b);
361 float2 af = __half22float2(a);
362 float2 bf = __half22float2(b);
367 return __float22half2_rn(af);
371 static inline __device__ half2 mul(half2 a, half b) {
372 #ifdef FAISS_USE_FULL_FLOAT16
373 half2 b2 = __half2half2(b);
374 return __hmul2(a, b2);
376 float2 af = __half22float2(a);
377 float bf = __half2float(b);
382 return __float22half2_rn(af);
386 static inline __device__ half2 neg(half2 v) {
387 #ifdef FAISS_USE_FULL_FLOAT16
390 float2 vf = __half22float2(v);
394 return __float22half2_rn(vf);
398 static inline __device__ half
reduceAdd(half2 v) {
399 #ifdef FAISS_USE_FULL_FLOAT16
400 half hv = __high2half(v);
401 half lv = __low2half(v);
403 return __hadd(hv, lv);
405 float2 vf = __half22float2(v);
408 return __float2half(vf.x);
417 static inline __device__ half2 zero() {
418 return __half2half2(Math<half>::zero());
424 typedef half ScalarType;
426 static inline __device__ Half4 add(Half4 a, Half4 b) {
428 h.a = Math<half2>::add(a.a, b.a);
429 h.b = Math<half2>::add(a.b, b.b);
433 static inline __device__ Half4 sub(Half4 a, Half4 b) {
435 h.a = Math<half2>::sub(a.a, b.a);
436 h.b = Math<half2>::sub(a.b, b.b);
440 static inline __device__ Half4 add(Half4 a, half b) {
442 h.a = Math<half2>::add(a.a, b);
443 h.b = Math<half2>::add(a.b, b);
447 static inline __device__ Half4 sub(Half4 a, half b) {
449 h.a = Math<half2>::sub(a.a, b);
450 h.b = Math<half2>::sub(a.b, b);
454 static inline __device__ Half4 mul(Half4 a, Half4 b) {
456 h.a = Math<half2>::mul(a.a, b.a);
457 h.b = Math<half2>::mul(a.b, b.b);
461 static inline __device__ Half4 mul(Half4 a, half b) {
463 h.a = Math<half2>::mul(a.a, b);
464 h.b = Math<half2>::mul(a.b, b);
468 static inline __device__ Half4 neg(Half4 v) {
470 h.a = Math<half2>::neg(v.a);
471 h.b = Math<half2>::neg(v.b);
475 static inline __device__ half
reduceAdd(Half4 v) {
478 return Math<half>::add(hx, hy);
486 static inline __device__ Half4 zero() {
488 h.a = Math<half2>::zero();
489 h.b = Math<half2>::zero();
496 typedef half ScalarType;
498 static inline __device__ Half8 add(Half8 a, Half8 b) {
500 h.a = Math<Half4>::add(a.a, b.a);
501 h.b = Math<Half4>::add(a.b, b.b);
505 static inline __device__ Half8 sub(Half8 a, Half8 b) {
507 h.a = Math<Half4>::sub(a.a, b.a);
508 h.b = Math<Half4>::sub(a.b, b.b);
512 static inline __device__ Half8 add(Half8 a, half b) {
514 h.a = Math<Half4>::add(a.a, b);
515 h.b = Math<Half4>::add(a.b, b);
519 static inline __device__ Half8 sub(Half8 a, half b) {
521 h.a = Math<Half4>::sub(a.a, b);
522 h.b = Math<Half4>::sub(a.b, b);
526 static inline __device__ Half8 mul(Half8 a, Half8 b) {
528 h.a = Math<Half4>::mul(a.a, b.a);
529 h.b = Math<Half4>::mul(a.b, b.b);
533 static inline __device__ Half8 mul(Half8 a, half b) {
535 h.a = Math<Half4>::mul(a.a, b);
536 h.b = Math<Half4>::mul(a.b, b);
540 static inline __device__ Half8 neg(Half8 v) {
542 h.a = Math<Half4>::neg(v.a);
543 h.b = Math<Half4>::neg(v.b);
547 static inline __device__ half
reduceAdd(Half8 v) {
550 return Math<half>::add(hx, hy);
558 static inline __device__ Half8 zero() {
560 h.a = Math<Half4>::zero();
561 h.b = Math<Half4>::zero();
566 #endif // FAISS_USE_FLOAT16
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)