10 #include "MatrixMult.cuh"
11 #include "DeviceMemory.h"
12 #include "DeviceUtils.h"
13 #include "DeviceTensor.cuh"
14 #include "HostTensor.cuh"
16 namespace faiss {
namespace gpu {
24 static cublasStatus_t gemm(cublasHandle_t handle,
25 cublasOperation_t transa,
26 cublasOperation_t transb,
39 return cublasSgemm(handle, transa, transb, m, n, k,
40 &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
44 #ifdef FAISS_USE_FLOAT16
47 static cublasStatus_t gemm(cublasHandle_t handle,
48 cublasOperation_t transa,
49 cublasOperation_t transb,
62 if (getDeviceSupportsFloat16Math(getCurrentDevice()) && useHgemm) {
63 half hAlpha = hostFloat2Half(fAlpha);
64 half hBeta = hostFloat2Half(fBeta);
66 return cublasHgemm(handle, transa, transb, m, n, k,
67 &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
71 #if CUDA_VERSION == 7050
72 auto halfType = CUBLAS_DATA_HALF;
74 auto halfType = CUDA_R_16F;
75 #endif // CUDA_VERSION
77 return cublasSgemmEx(handle, transa, transb, m, n, k,
78 &fAlpha, A, halfType, lda,
84 #endif // FAISS_USE_FLOAT16
89 runMatrixMult(Tensor<T, 2, true>& c,
bool transC,
90 Tensor<T, 2, true>& a,
bool transA,
91 Tensor<T, 2, true>& b,
bool transB,
95 cublasHandle_t handle,
96 cudaStream_t stream) {
97 cublasSetStream(handle, stream);
101 int aM = transA ? a.getSize(1) : a.getSize(0);
102 int aK = transA ? a.getSize(0) : a.getSize(1);
104 int bK = transB ? b.getSize(1) : b.getSize(0);
105 int bN = transB ? b.getSize(0) : b.getSize(1);
107 int cM = transC ? c.getSize(1) : c.getSize(0);
108 int cN = transC ? c.getSize(0) : c.getSize(1);
110 FAISS_ASSERT(aM == cM);
111 FAISS_ASSERT(aK == bK);
112 FAISS_ASSERT(bN == cN);
114 FAISS_ASSERT(a.getStride(1) == 1);
115 FAISS_ASSERT(b.getStride(1) == 1);
116 FAISS_ASSERT(c.getStride(1) == 1);
120 T* pA = transC ? a.data() : b.data();
121 T* pB = transC ? b.data() : a.data();
124 int m = c.getSize(1);
125 int n = c.getSize(0);
126 int k = transA ? a.getSize(0) : a.getSize(1);
128 int lda = transC ? a.getStride(0) : b.getStride(0);
129 int ldb = transC ? b.getStride(0) : a.getStride(0);
130 int ldc = c.getStride(0);
132 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
133 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
136 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
137 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
140 auto err = CublasGemm<T>::gemm(handle,
143 pA, lda, pB, ldb, beta,
146 FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
147 "cublas failed (%d): %s "
148 "(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
150 useHgemm ?
"Hgemm" :
"Sgemm",
151 a.getSize(0), a.getSize(1), transA ?
"'" :
"",
152 b.getSize(0), b.getSize(1), transB ?
"'" :
"",
153 c.getSize(0), c.getSize(1), transC ?
"'" :
"");
157 void runMatrixMult(Tensor<float, 2, true>& c,
bool transC,
158 Tensor<float, 2, true>& a,
bool transA,
159 Tensor<float, 2, true>& b,
bool transB,
163 cublasHandle_t handle,
164 cudaStream_t stream) {
165 return runMatrixMult<float>(c, transC, a, transA, b, transB,
166 alpha, beta, useHgemm, handle, stream);
169 #ifdef FAISS_USE_FLOAT16
170 void runMatrixMult(Tensor<half, 2, true>& c,
bool transC,
171 Tensor<half, 2, true>& a,
bool transA,
172 Tensor<half, 2, true>& b,
bool transB,
176 cublasHandle_t handle,
177 cudaStream_t stream) {
178 return runMatrixMult<half>(c, transC, a, transA, b, transB,
179 alpha, beta, useHgemm, handle, stream);
184 runIteratedMatrixMult(Tensor<float, 3, true>& c,
bool transC,
185 Tensor<float, 3, true>& a,
bool transA,
186 Tensor<float, 3, true>& b,
bool transB,
189 cublasHandle_t handle,
190 cudaStream_t stream) {
191 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
192 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
194 for (
int i = 0; i < a.getSize(0); ++i) {
195 auto cView = c[i].view();
196 auto aView = a[i].view();
197 auto bView = b[i].view();
199 runMatrixMult(cView, transC,
202 alpha, beta,
false, handle, stream);
207 runBatchMatrixMult(Tensor<float, 3, true>& c,
bool transC,
208 Tensor<float, 3, true>& a,
bool transA,
209 Tensor<float, 3, true>& b,
bool transB,
213 cublasHandle_t handle,
214 cudaStream_t stream) {
215 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
216 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
217 cublasSetStream(handle, stream);
221 int aM = transA ? a.getSize(2) : a.getSize(1);
222 int aK = transA ? a.getSize(1) : a.getSize(2);
224 int bK = transB ? b.getSize(2) : b.getSize(1);
225 int bN = transB ? b.getSize(1) : b.getSize(2);
227 int cM = transC ? c.getSize(2) : c.getSize(1);
228 int cN = transC ? c.getSize(1) : c.getSize(2);
230 FAISS_ASSERT(aM == cM);
231 FAISS_ASSERT(aK == bK);
232 FAISS_ASSERT(bN == cN);
236 float* pA = transC ? a.data() : b.data();
237 float* pB = transC ? b.data() : a.data();
238 float* pC = c.data();
240 int m = c.getSize(2);
241 int n = c.getSize(1);
242 int k = transA ? a.getSize(1) : a.getSize(2);
244 int lda = transC ? a.getStride(1) : b.getStride(1);
245 int ldb = transC ? b.getStride(1) : a.getStride(1);
246 int ldc = c.getStride(1);
248 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
249 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
252 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
253 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
256 HostTensor<float*, 1, true> hostA({a.getSize(0)});
257 HostTensor<float*, 1, true> hostB({b.getSize(0)});
258 HostTensor<float*, 1, true> hostC({c.getSize(0)});
260 size_t aOffset = a.getStride(0);
261 size_t bOffset = b.getStride(0);
262 size_t cOffset = c.getStride(0);
264 for (
int i = 0; i < a.getSize(0); ++i) {
265 hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
266 hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
267 hostC[i] = c.data() + i * cOffset;
270 DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
271 DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
272 DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
275 cublasSgemmBatched(handle,
278 (
const float**) deviceA.data(), lda,
279 (
const float**) deviceB.data(), ldb, &beta,
280 deviceC.data(), ldc, a.getSize(0));
281 FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
282 "cublasSgemmBatched failed (%d)", (
int) err);