11 #include "Float16.cuh"
19 namespace faiss {
namespace gpu {
25 static inline __device__ T add(T a, T b) {
29 static inline __device__ T sub(T a, T b) {
33 static inline __device__ T mul(T a, T b) {
37 static inline __device__ T neg(T v) {
46 static inline __device__
bool lt(T a, T b) {
50 static inline __device__
bool gt(T a, T b) {
54 static inline __device__
bool eq(T a, T b) {
58 static inline __device__ T zero() {
65 typedef float ScalarType;
67 static inline __device__ float2 add(float2 a, float2 b) {
74 static inline __device__ float2 sub(float2 a, float2 b) {
81 static inline __device__ float2 add(float2 a,
float b) {
88 static inline __device__ float2 sub(float2 a,
float b) {
95 static inline __device__ float2 mul(float2 a, float2 b) {
102 static inline __device__ float2 mul(float2 a,
float b) {
109 static inline __device__ float2 neg(float2 v) {
125 static inline __device__ float2 zero() {
135 typedef float ScalarType;
137 static inline __device__ float4 add(float4 a, float4 b) {
146 static inline __device__ float4 sub(float4 a, float4 b) {
155 static inline __device__ float4 add(float4 a,
float b) {
164 static inline __device__ float4 sub(float4 a,
float b) {
173 static inline __device__ float4 mul(float4 a, float4 b) {
182 static inline __device__ float4 mul(float4 a,
float b) {
191 static inline __device__ float4 neg(float4 v) {
201 return v.x + v.y + v.z + v.w;
209 static inline __device__ float4 zero() {
219 #ifdef FAISS_USE_FLOAT16
223 typedef half ScalarType;
225 static inline __device__ half add(half a, half b) {
226 #ifdef FAISS_USE_FULL_FLOAT16
229 return __float2half(__half2float(a) + __half2float(b));
233 static inline __device__ half sub(half a, half b) {
234 #ifdef FAISS_USE_FULL_FLOAT16
237 return __float2half(__half2float(a) - __half2float(b));
241 static inline __device__ half mul(half a, half b) {
242 #ifdef FAISS_USE_FULL_FLOAT16
245 return __float2half(__half2float(a) * __half2float(b));
249 static inline __device__ half neg(half v) {
250 #ifdef FAISS_USE_FULL_FLOAT16
253 return __float2half(-__half2float(v));
257 static inline __device__ half
reduceAdd(half v) {
261 static inline __device__
bool lt(half a, half b) {
262 #ifdef FAISS_USE_FULL_FLOAT16
265 return __half2float(a) < __half2float(b);
269 static inline __device__
bool gt(half a, half b) {
270 #ifdef FAISS_USE_FULL_FLOAT16
273 return __half2float(a) > __half2float(b);
277 static inline __device__
bool eq(half a, half b) {
278 #ifdef FAISS_USE_FULL_FLOAT16
281 return __half2float(a) == __half2float(b);
285 static inline __device__ half zero() {
286 #if CUDA_VERSION >= 9000
298 typedef half ScalarType;
300 static inline __device__ half2 add(half2 a, half2 b) {
301 #ifdef FAISS_USE_FULL_FLOAT16
302 return __hadd2(a, b);
304 float2 af = __half22float2(a);
305 float2 bf = __half22float2(b);
310 return __float22half2_rn(af);
314 static inline __device__ half2 sub(half2 a, half2 b) {
315 #ifdef FAISS_USE_FULL_FLOAT16
316 return __hsub2(a, b);
318 float2 af = __half22float2(a);
319 float2 bf = __half22float2(b);
324 return __float22half2_rn(af);
328 static inline __device__ half2 add(half2 a, half b) {
329 #ifdef FAISS_USE_FULL_FLOAT16
330 half2 b2 = __half2half2(b);
331 return __hadd2(a, b2);
333 float2 af = __half22float2(a);
334 float bf = __half2float(b);
339 return __float22half2_rn(af);
343 static inline __device__ half2 sub(half2 a, half b) {
344 #ifdef FAISS_USE_FULL_FLOAT16
345 half2 b2 = __half2half2(b);
346 return __hsub2(a, b2);
348 float2 af = __half22float2(a);
349 float bf = __half2float(b);
354 return __float22half2_rn(af);
358 static inline __device__ half2 mul(half2 a, half2 b) {
359 #ifdef FAISS_USE_FULL_FLOAT16
360 return __hmul2(a, b);
362 float2 af = __half22float2(a);
363 float2 bf = __half22float2(b);
368 return __float22half2_rn(af);
372 static inline __device__ half2 mul(half2 a, half b) {
373 #ifdef FAISS_USE_FULL_FLOAT16
374 half2 b2 = __half2half2(b);
375 return __hmul2(a, b2);
377 float2 af = __half22float2(a);
378 float bf = __half2float(b);
383 return __float22half2_rn(af);
387 static inline __device__ half2 neg(half2 v) {
388 #ifdef FAISS_USE_FULL_FLOAT16
391 float2 vf = __half22float2(v);
395 return __float22half2_rn(vf);
399 static inline __device__ half
reduceAdd(half2 v) {
400 #ifdef FAISS_USE_FULL_FLOAT16
401 half hv = __high2half(v);
402 half lv = __low2half(v);
404 return __hadd(hv, lv);
406 float2 vf = __half22float2(v);
409 return __float2half(vf.x);
418 static inline __device__ half2 zero() {
419 return __half2half2(Math<half>::zero());
425 typedef half ScalarType;
427 static inline __device__ Half4 add(Half4 a, Half4 b) {
429 h.a = Math<half2>::add(a.a, b.a);
430 h.b = Math<half2>::add(a.b, b.b);
434 static inline __device__ Half4 sub(Half4 a, Half4 b) {
436 h.a = Math<half2>::sub(a.a, b.a);
437 h.b = Math<half2>::sub(a.b, b.b);
441 static inline __device__ Half4 add(Half4 a, half b) {
443 h.a = Math<half2>::add(a.a, b);
444 h.b = Math<half2>::add(a.b, b);
448 static inline __device__ Half4 sub(Half4 a, half b) {
450 h.a = Math<half2>::sub(a.a, b);
451 h.b = Math<half2>::sub(a.b, b);
455 static inline __device__ Half4 mul(Half4 a, Half4 b) {
457 h.a = Math<half2>::mul(a.a, b.a);
458 h.b = Math<half2>::mul(a.b, b.b);
462 static inline __device__ Half4 mul(Half4 a, half b) {
464 h.a = Math<half2>::mul(a.a, b);
465 h.b = Math<half2>::mul(a.b, b);
469 static inline __device__ Half4 neg(Half4 v) {
471 h.a = Math<half2>::neg(v.a);
472 h.b = Math<half2>::neg(v.b);
476 static inline __device__ half
reduceAdd(Half4 v) {
479 return Math<half>::add(hx, hy);
487 static inline __device__ Half4 zero() {
489 h.a = Math<half2>::zero();
490 h.b = Math<half2>::zero();
497 typedef half ScalarType;
499 static inline __device__ Half8 add(Half8 a, Half8 b) {
501 h.a = Math<Half4>::add(a.a, b.a);
502 h.b = Math<Half4>::add(a.b, b.b);
506 static inline __device__ Half8 sub(Half8 a, Half8 b) {
508 h.a = Math<Half4>::sub(a.a, b.a);
509 h.b = Math<Half4>::sub(a.b, b.b);
513 static inline __device__ Half8 add(Half8 a, half b) {
515 h.a = Math<Half4>::add(a.a, b);
516 h.b = Math<Half4>::add(a.b, b);
520 static inline __device__ Half8 sub(Half8 a, half b) {
522 h.a = Math<Half4>::sub(a.a, b);
523 h.b = Math<Half4>::sub(a.b, b);
527 static inline __device__ Half8 mul(Half8 a, Half8 b) {
529 h.a = Math<Half4>::mul(a.a, b.a);
530 h.b = Math<Half4>::mul(a.b, b.b);
534 static inline __device__ Half8 mul(Half8 a, half b) {
536 h.a = Math<Half4>::mul(a.a, b);
537 h.b = Math<Half4>::mul(a.b, b);
541 static inline __device__ Half8 neg(Half8 v) {
543 h.a = Math<Half4>::neg(v.a);
544 h.b = Math<Half4>::neg(v.b);
548 static inline __device__ half
reduceAdd(Half8 v) {
551 return Math<half>::add(hx, hy);
559 static inline __device__ Half8 zero() {
561 h.a = Math<Half4>::zero();
562 h.b = Math<Half4>::zero();
567 #endif // FAISS_USE_FLOAT16
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float2 v)
For a vector type, this is a horizontal add, returning sum(v_i)
static __device__ float reduceAdd(float4 v)
For a vector type, this is a horizontal add, returning sum(v_i)