/** * Copyright (c) 2015-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the CC-by-NC license found in the * LICENSE file in the root directory of this source tree. */ // Copyright 2004-present Facebook. All Rights Reserved. #pragma once #include "Float16.cuh" // // Templated wrappers to express math for different scalar and vector // types, so kernels can have the same written form but can operate // over half and float, and on vector types transparently // namespace faiss { namespace gpu { template struct Math { typedef T ScalarType; static inline __device__ T add(T a, T b) { return a + b; } static inline __device__ T sub(T a, T b) { return a - b; } static inline __device__ T mul(T a, T b) { return a * b; } static inline __device__ T neg(T v) { return -v; } /// For a vector type, this is a horizontal add, returning sum(v_i) static inline __device__ T reduceAdd(T v) { return v; } static inline __device__ bool lt(T a, T b) { return a < b; } static inline __device__ bool gt(T a, T b) { return a > b; } static inline __device__ bool eq(T a, T b) { return a == b; } static inline __device__ T zero() { return (T) 0; } }; template <> struct Math { typedef float ScalarType; static inline __device__ float2 add(float2 a, float2 b) { float2 v; v.x = a.x + b.x; v.y = a.y + b.y; return v; } static inline __device__ float2 sub(float2 a, float2 b) { float2 v; v.x = a.x - b.x; v.y = a.y - b.y; return v; } static inline __device__ float2 add(float2 a, float b) { float2 v; v.x = a.x + b; v.y = a.y + b; return v; } static inline __device__ float2 sub(float2 a, float b) { float2 v; v.x = a.x - b; v.y = a.y - b; return v; } static inline __device__ float2 mul(float2 a, float2 b) { float2 v; v.x = a.x * b.x; v.y = a.y * b.y; return v; } static inline __device__ float2 mul(float2 a, float b) { float2 v; v.x = a.x * b; v.y = a.y * b; return v; } static inline __device__ float2 neg(float2 v) { v.x = -v.x; v.y = -v.y; return v; } /// For a vector type, this is a horizontal add, returning sum(v_i) static inline __device__ float reduceAdd(float2 v) { return v.x + v.y; } // not implemented for vector types // static inline __device__ bool lt(float2 a, float2 b); // static inline __device__ bool gt(float2 a, float2 b); // static inline __device__ bool eq(float2 a, float2 b); static inline __device__ float2 zero() { float2 v; v.x = 0.0f; v.y = 0.0f; return v; } }; template <> struct Math { typedef float ScalarType; static inline __device__ float4 add(float4 a, float4 b) { float4 v; v.x = a.x + b.x; v.y = a.y + b.y; v.z = a.z + b.z; v.w = a.w + b.w; return v; } static inline __device__ float4 sub(float4 a, float4 b) { float4 v; v.x = a.x - b.x; v.y = a.y - b.y; v.z = a.z - b.z; v.w = a.w - b.w; return v; } static inline __device__ float4 add(float4 a, float b) { float4 v; v.x = a.x + b; v.y = a.y + b; v.z = a.z + b; v.w = a.w + b; return v; } static inline __device__ float4 sub(float4 a, float b) { float4 v; v.x = a.x - b; v.y = a.y - b; v.z = a.z - b; v.w = a.w - b; return v; } static inline __device__ float4 mul(float4 a, float4 b) { float4 v; v.x = a.x * b.x; v.y = a.y * b.y; v.z = a.z * b.z; v.w = a.w * b.w; return v; } static inline __device__ float4 mul(float4 a, float b) { float4 v; v.x = a.x * b; v.y = a.y * b; v.z = a.z * b; v.w = a.w * b; return v; } static inline __device__ float4 neg(float4 v) { v.x = -v.x; v.y = -v.y; v.z = -v.z; v.w = -v.w; return v; } /// For a vector type, this is a horizontal add, returning sum(v_i) static inline __device__ float reduceAdd(float4 v) { return v.x + v.y + v.z + v.w; } // not implemented for vector types // static inline __device__ bool lt(float4 a, float4 b); // static inline __device__ bool gt(float4 a, float4 b); // static inline __device__ bool eq(float4 a, float4 b); static inline __device__ float4 zero() { float4 v; v.x = 0.0f; v.y = 0.0f; v.z = 0.0f; v.w = 0.0f; return v; } }; #ifdef FAISS_USE_FLOAT16 template <> struct Math { typedef half ScalarType; static inline __device__ half add(half a, half b) { #ifdef FAISS_USE_FULL_FLOAT16 return __hadd(a, b); #else return __float2half(__half2float(a) + __half2float(b)); #endif } static inline __device__ half sub(half a, half b) { #ifdef FAISS_USE_FULL_FLOAT16 return __hsub(a, b); #else return __float2half(__half2float(a) - __half2float(b)); #endif } static inline __device__ half mul(half a, half b) { #ifdef FAISS_USE_FULL_FLOAT16 return __hmul(a, b); #else return __float2half(__half2float(a) * __half2float(b)); #endif } static inline __device__ half neg(half v) { #ifdef FAISS_USE_FULL_FLOAT16 return __hneg(v); #else return __float2half(-__half2float(v)); #endif } static inline __device__ half reduceAdd(half v) { return v; } static inline __device__ bool lt(half a, half b) { #ifdef FAISS_USE_FULL_FLOAT16 return __hlt(a, b); #else return __half2float(a) < __half2float(b); #endif } static inline __device__ bool gt(half a, half b) { #ifdef FAISS_USE_FULL_FLOAT16 return __hgt(a, b); #else return __half2float(a) > __half2float(b); #endif } static inline __device__ bool eq(half a, half b) { #ifdef FAISS_USE_FULL_FLOAT16 return __heq(a, b); #else return __half2float(a) == __half2float(b); #endif } static inline __device__ half zero() { half h; h.x = 0; return h; } }; template <> struct Math { typedef half ScalarType; static inline __device__ half2 add(half2 a, half2 b) { #ifdef FAISS_USE_FULL_FLOAT16 return __hadd2(a, b); #else float2 af = __half22float2(a); float2 bf = __half22float2(b); af.x += bf.x; af.y += bf.y; return __float22half2_rn(af); #endif } static inline __device__ half2 sub(half2 a, half2 b) { #ifdef FAISS_USE_FULL_FLOAT16 return __hsub2(a, b); #else float2 af = __half22float2(a); float2 bf = __half22float2(b); af.x -= bf.x; af.y -= bf.y; return __float22half2_rn(af); #endif } static inline __device__ half2 add(half2 a, half b) { #ifdef FAISS_USE_FULL_FLOAT16 half2 b2 = __half2half2(b); return __hadd2(a, b2); #else float2 af = __half22float2(a); float bf = __half2float(b); af.x += bf; af.y += bf; return __float22half2_rn(af); #endif } static inline __device__ half2 sub(half2 a, half b) { #ifdef FAISS_USE_FULL_FLOAT16 half2 b2 = __half2half2(b); return __hsub2(a, b2); #else float2 af = __half22float2(a); float bf = __half2float(b); af.x -= bf; af.y -= bf; return __float22half2_rn(af); #endif } static inline __device__ half2 mul(half2 a, half2 b) { #ifdef FAISS_USE_FULL_FLOAT16 return __hmul2(a, b); #else float2 af = __half22float2(a); float2 bf = __half22float2(b); af.x *= bf.x; af.y *= bf.y; return __float22half2_rn(af); #endif } static inline __device__ half2 mul(half2 a, half b) { #ifdef FAISS_USE_FULL_FLOAT16 half2 b2 = __half2half2(b); return __hmul2(a, b2); #else float2 af = __half22float2(a); float bf = __half2float(b); af.x *= bf; af.y *= bf; return __float22half2_rn(af); #endif } static inline __device__ half2 neg(half2 v) { #ifdef FAISS_USE_FULL_FLOAT16 return __hneg2(v); #else float2 vf = __half22float2(v); vf.x = -vf.x; vf.y = -vf.y; return __float22half2_rn(vf); #endif } static inline __device__ half reduceAdd(half2 v) { #ifdef FAISS_USE_FULL_FLOAT16 half hv = __high2half(v); half lv = __low2half(v); return __hadd(hv, lv); #else float2 vf = __half22float2(v); vf.x += vf.y; return __float2half(vf.x); #endif } // not implemented for vector types // static inline __device__ bool lt(half2 a, half2 b); // static inline __device__ bool gt(half2 a, half2 b); // static inline __device__ bool eq(half2 a, half2 b); static inline __device__ half2 zero() { return __half2half2(Math::zero()); } }; template <> struct Math { typedef half ScalarType; static inline __device__ Half4 add(Half4 a, Half4 b) { Half4 h; h.a = Math::add(a.a, b.a); h.b = Math::add(a.b, b.b); return h; } static inline __device__ Half4 sub(Half4 a, Half4 b) { Half4 h; h.a = Math::sub(a.a, b.a); h.b = Math::sub(a.b, b.b); return h; } static inline __device__ Half4 add(Half4 a, half b) { Half4 h; h.a = Math::add(a.a, b); h.b = Math::add(a.b, b); return h; } static inline __device__ Half4 sub(Half4 a, half b) { Half4 h; h.a = Math::sub(a.a, b); h.b = Math::sub(a.b, b); return h; } static inline __device__ Half4 mul(Half4 a, Half4 b) { Half4 h; h.a = Math::mul(a.a, b.a); h.b = Math::mul(a.b, b.b); return h; } static inline __device__ Half4 mul(Half4 a, half b) { Half4 h; h.a = Math::mul(a.a, b); h.b = Math::mul(a.b, b); return h; } static inline __device__ Half4 neg(Half4 v) { Half4 h; h.a = Math::neg(v.a); h.b = Math::neg(v.b); return h; } static inline __device__ half reduceAdd(Half4 v) { half hx = Math::reduceAdd(v.a); half hy = Math::reduceAdd(v.b); return Math::add(hx, hy); } // not implemented for vector types // static inline __device__ bool lt(Half4 a, Half4 b); // static inline __device__ bool gt(Half4 a, Half4 b); // static inline __device__ bool eq(Half4 a, Half4 b); static inline __device__ Half4 zero() { Half4 h; h.a = Math::zero(); h.b = Math::zero(); return h; } }; template <> struct Math { typedef half ScalarType; static inline __device__ Half8 add(Half8 a, Half8 b) { Half8 h; h.a = Math::add(a.a, b.a); h.b = Math::add(a.b, b.b); return h; } static inline __device__ Half8 sub(Half8 a, Half8 b) { Half8 h; h.a = Math::sub(a.a, b.a); h.b = Math::sub(a.b, b.b); return h; } static inline __device__ Half8 add(Half8 a, half b) { Half8 h; h.a = Math::add(a.a, b); h.b = Math::add(a.b, b); return h; } static inline __device__ Half8 sub(Half8 a, half b) { Half8 h; h.a = Math::sub(a.a, b); h.b = Math::sub(a.b, b); return h; } static inline __device__ Half8 mul(Half8 a, Half8 b) { Half8 h; h.a = Math::mul(a.a, b.a); h.b = Math::mul(a.b, b.b); return h; } static inline __device__ Half8 mul(Half8 a, half b) { Half8 h; h.a = Math::mul(a.a, b); h.b = Math::mul(a.b, b); return h; } static inline __device__ Half8 neg(Half8 v) { Half8 h; h.a = Math::neg(v.a); h.b = Math::neg(v.b); return h; } static inline __device__ half reduceAdd(Half8 v) { half hx = Math::reduceAdd(v.a); half hy = Math::reduceAdd(v.b); return Math::add(hx, hy); } // not implemented for vector types // static inline __device__ bool lt(Half8 a, Half8 b); // static inline __device__ bool gt(Half8 a, Half8 b); // static inline __device__ bool eq(Half8 a, Half8 b); static inline __device__ Half8 zero() { Half8 h; h.a = Math::zero(); h.b = Math::zero(); return h; } }; #endif // FAISS_USE_FLOAT16 } } // namespace