9 #include "MatrixMult.cuh"
10 #include "DeviceMemory.h"
11 #include "DeviceUtils.h"
12 #include "DeviceTensor.cuh"
13 #include "HostTensor.cuh"
15 namespace faiss {
namespace gpu {
23 static cublasStatus_t gemm(cublasHandle_t handle,
24 cublasOperation_t transa,
25 cublasOperation_t transb,
38 return cublasSgemm(handle, transa, transb, m, n, k,
39 &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
43 #ifdef FAISS_USE_FLOAT16
46 static cublasStatus_t gemm(cublasHandle_t handle,
47 cublasOperation_t transa,
48 cublasOperation_t transb,
61 if (getDeviceSupportsFloat16Math(getCurrentDevice()) && useHgemm) {
62 half hAlpha = hostFloat2Half(fAlpha);
63 half hBeta = hostFloat2Half(fBeta);
65 return cublasHgemm(handle, transa, transb, m, n, k,
66 &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
70 #if CUDA_VERSION == 7050
71 auto halfType = CUBLAS_DATA_HALF;
73 auto halfType = CUDA_R_16F;
74 #endif // CUDA_VERSION
76 return cublasSgemmEx(handle, transa, transb, m, n, k,
77 &fAlpha, A, halfType, lda,
83 #endif // FAISS_USE_FLOAT16
88 runMatrixMult(Tensor<T, 2, true>& c,
bool transC,
89 Tensor<T, 2, true>& a,
bool transA,
90 Tensor<T, 2, true>& b,
bool transB,
94 cublasHandle_t handle,
95 cudaStream_t stream) {
96 cublasSetStream(handle, stream);
100 int aM = transA ? a.getSize(1) : a.getSize(0);
101 int aK = transA ? a.getSize(0) : a.getSize(1);
103 int bK = transB ? b.getSize(1) : b.getSize(0);
104 int bN = transB ? b.getSize(0) : b.getSize(1);
106 int cM = transC ? c.getSize(1) : c.getSize(0);
107 int cN = transC ? c.getSize(0) : c.getSize(1);
109 FAISS_ASSERT(aM == cM);
110 FAISS_ASSERT(aK == bK);
111 FAISS_ASSERT(bN == cN);
113 FAISS_ASSERT(a.getStride(1) == 1);
114 FAISS_ASSERT(b.getStride(1) == 1);
115 FAISS_ASSERT(c.getStride(1) == 1);
119 T* pA = transC ? a.data() : b.data();
120 T* pB = transC ? b.data() : a.data();
123 int m = c.getSize(1);
124 int n = c.getSize(0);
125 int k = transA ? a.getSize(0) : a.getSize(1);
127 int lda = transC ? a.getStride(0) : b.getStride(0);
128 int ldb = transC ? b.getStride(0) : a.getStride(0);
129 int ldc = c.getStride(0);
131 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
132 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
135 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
136 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
139 auto err = CublasGemm<T>::gemm(handle,
142 pA, lda, pB, ldb, beta,
145 FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
146 "cublas failed (%d): %s "
147 "(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
149 useHgemm ?
"Hgemm" :
"Sgemm",
150 a.getSize(0), a.getSize(1), transA ?
"'" :
"",
151 b.getSize(0), b.getSize(1), transB ?
"'" :
"",
152 c.getSize(0), c.getSize(1), transC ?
"'" :
"");
156 void runMatrixMult(Tensor<float, 2, true>& c,
bool transC,
157 Tensor<float, 2, true>& a,
bool transA,
158 Tensor<float, 2, true>& b,
bool transB,
162 cublasHandle_t handle,
163 cudaStream_t stream) {
164 return runMatrixMult<float>(c, transC, a, transA, b, transB,
165 alpha, beta, useHgemm, handle, stream);
168 #ifdef FAISS_USE_FLOAT16
169 void runMatrixMult(Tensor<half, 2, true>& c,
bool transC,
170 Tensor<half, 2, true>& a,
bool transA,
171 Tensor<half, 2, true>& b,
bool transB,
175 cublasHandle_t handle,
176 cudaStream_t stream) {
177 return runMatrixMult<half>(c, transC, a, transA, b, transB,
178 alpha, beta, useHgemm, handle, stream);
183 runIteratedMatrixMult(Tensor<float, 3, true>& c,
bool transC,
184 Tensor<float, 3, true>& a,
bool transA,
185 Tensor<float, 3, true>& b,
bool transB,
188 cublasHandle_t handle,
189 cudaStream_t stream) {
190 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
191 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
193 for (
int i = 0; i < a.getSize(0); ++i) {
194 auto cView = c[i].view();
195 auto aView = a[i].view();
196 auto bView = b[i].view();
198 runMatrixMult(cView, transC,
201 alpha, beta,
false, handle, stream);
206 runBatchMatrixMult(Tensor<float, 3, true>& c,
bool transC,
207 Tensor<float, 3, true>& a,
bool transA,
208 Tensor<float, 3, true>& b,
bool transB,
212 cublasHandle_t handle,
213 cudaStream_t stream) {
214 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
215 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
216 cublasSetStream(handle, stream);
220 int aM = transA ? a.getSize(2) : a.getSize(1);
221 int aK = transA ? a.getSize(1) : a.getSize(2);
223 int bK = transB ? b.getSize(2) : b.getSize(1);
224 int bN = transB ? b.getSize(1) : b.getSize(2);
226 int cM = transC ? c.getSize(2) : c.getSize(1);
227 int cN = transC ? c.getSize(1) : c.getSize(2);
229 FAISS_ASSERT(aM == cM);
230 FAISS_ASSERT(aK == bK);
231 FAISS_ASSERT(bN == cN);
235 float* pA = transC ? a.data() : b.data();
236 float* pB = transC ? b.data() : a.data();
237 float* pC = c.data();
239 int m = c.getSize(2);
240 int n = c.getSize(1);
241 int k = transA ? a.getSize(1) : a.getSize(2);
243 int lda = transC ? a.getStride(1) : b.getStride(1);
244 int ldb = transC ? b.getStride(1) : a.getStride(1);
245 int ldc = c.getStride(1);
247 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
248 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
251 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
252 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
255 HostTensor<float*, 1, true> hostA({a.getSize(0)});
256 HostTensor<float*, 1, true> hostB({b.getSize(0)});
257 HostTensor<float*, 1, true> hostC({c.getSize(0)});
259 size_t aOffset = a.getStride(0);
260 size_t bOffset = b.getStride(0);
261 size_t cOffset = c.getStride(0);
263 for (
int i = 0; i < a.getSize(0); ++i) {
264 hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
265 hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
266 hostC[i] = c.data() + i * cOffset;
269 DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
270 DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
271 DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
274 cublasSgemmBatched(handle,
277 (
const float**) deviceA.data(), lda,
278 (
const float**) deviceB.data(), ldb, &beta,
279 deviceC.data(), ldc, a.getSize(0));
280 FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
281 "cublasSgemmBatched failed (%d)", (
int) err);