11 #include "MatrixMult.cuh"
12 #include "DeviceMemory.h"
13 #include "DeviceUtils.h"
14 #include "DeviceTensor.cuh"
15 #include "HostTensor.cuh"
17 namespace faiss {
namespace gpu {
25 static cublasStatus_t gemm(cublasHandle_t handle,
26 cublasOperation_t transa,
27 cublasOperation_t transb,
40 return cublasSgemm(handle, transa, transb, m, n, k,
41 &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
45 #ifdef FAISS_USE_FLOAT16
48 static cublasStatus_t gemm(cublasHandle_t handle,
49 cublasOperation_t transa,
50 cublasOperation_t transb,
63 if (getDeviceSupportsFloat16Math(getCurrentDevice()) && useHgemm) {
64 half hAlpha = hostFloat2Half(fAlpha);
65 half hBeta = hostFloat2Half(fBeta);
67 return cublasHgemm(handle, transa, transb, m, n, k,
68 &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
72 #if CUDA_VERSION == 7050
73 auto halfType = CUBLAS_DATA_HALF;
75 auto halfType = CUDA_R_16F;
76 #endif // CUDA_VERSION
78 return cublasSgemmEx(handle, transa, transb, m, n, k,
79 &fAlpha, A, halfType, lda,
85 #endif // FAISS_USE_FLOAT16
90 runMatrixMult(Tensor<T, 2, true>& c,
bool transC,
91 Tensor<T, 2, true>& a,
bool transA,
92 Tensor<T, 2, true>& b,
bool transB,
96 cublasHandle_t handle,
97 cudaStream_t stream) {
98 cublasSetStream(handle, stream);
102 int aM = transA ? a.getSize(1) : a.getSize(0);
103 int aK = transA ? a.getSize(0) : a.getSize(1);
105 int bK = transB ? b.getSize(1) : b.getSize(0);
106 int bN = transB ? b.getSize(0) : b.getSize(1);
108 int cM = transC ? c.getSize(1) : c.getSize(0);
109 int cN = transC ? c.getSize(0) : c.getSize(1);
111 FAISS_ASSERT(aM == cM);
112 FAISS_ASSERT(aK == bK);
113 FAISS_ASSERT(bN == cN);
117 T* pA = transC ? a.data() : b.data();
118 T* pB = transC ? b.data() : a.data();
121 int m = c.getSize(1);
122 int n = c.getSize(0);
123 int k = transA ? a.getSize(0) : a.getSize(1);
125 int lda = transC ? a.getSize(1) : b.getSize(1);
126 int ldb = transC ? b.getSize(1) : a.getSize(1);
127 int ldc = c.getSize(1);
129 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
130 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
133 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
134 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
137 auto err = CublasGemm<T>::gemm(handle,
140 pA, lda, pB, ldb, beta,
143 FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
144 "cublas failed (%d): %s "
145 "(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
147 useHgemm ?
"Hgemm" :
"Sgemm",
148 a.getSize(0), a.getSize(1), transA ?
"'" :
"",
149 b.getSize(0), b.getSize(1), transB ?
"'" :
"",
150 c.getSize(0), c.getSize(1), transC ?
"'" :
"");
154 void runMatrixMult(Tensor<float, 2, true>& c,
bool transC,
155 Tensor<float, 2, true>& a,
bool transA,
156 Tensor<float, 2, true>& b,
bool transB,
160 cublasHandle_t handle,
161 cudaStream_t stream) {
162 return runMatrixMult<float>(c, transC, a, transA, b, transB,
163 alpha, beta, useHgemm, handle, stream);
166 #ifdef FAISS_USE_FLOAT16
167 void runMatrixMult(Tensor<half, 2, true>& c,
bool transC,
168 Tensor<half, 2, true>& a,
bool transA,
169 Tensor<half, 2, true>& b,
bool transB,
173 cublasHandle_t handle,
174 cudaStream_t stream) {
175 return runMatrixMult<half>(c, transC, a, transA, b, transB,
176 alpha, beta, useHgemm, handle, stream);
181 runIteratedMatrixMult(Tensor<float, 3, true>& c,
bool transC,
182 Tensor<float, 3, true>& a,
bool transA,
183 Tensor<float, 3, true>& b,
bool transB,
186 cublasHandle_t handle,
187 cudaStream_t stream) {
188 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
189 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
191 for (
int i = 0; i < a.getSize(0); ++i) {
192 auto cView = c[i].view();
193 auto aView = a[i].view();
194 auto bView = b[i].view();
196 runMatrixMult(cView, transC,
199 alpha, beta,
false, handle, stream);
204 runBatchMatrixMult(Tensor<float, 3, true>& c,
bool transC,
205 Tensor<float, 3, true>& a,
bool transA,
206 Tensor<float, 3, true>& b,
bool transB,
210 cublasHandle_t handle,
211 cudaStream_t stream) {
212 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
213 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
214 cublasSetStream(handle, stream);
218 int aM = transA ? a.getSize(2) : a.getSize(1);
219 int aK = transA ? a.getSize(1) : a.getSize(2);
221 int bK = transB ? b.getSize(2) : b.getSize(1);
222 int bN = transB ? b.getSize(1) : b.getSize(2);
224 int cM = transC ? c.getSize(2) : c.getSize(1);
225 int cN = transC ? c.getSize(1) : c.getSize(2);
227 FAISS_ASSERT(aM == cM);
228 FAISS_ASSERT(aK == bK);
229 FAISS_ASSERT(bN == cN);
233 float* pA = transC ? a.data() : b.data();
234 float* pB = transC ? b.data() : a.data();
235 float* pC = c.data();
237 int m = c.getSize(2);
238 int n = c.getSize(1);
239 int k = transA ? a.getSize(1) : a.getSize(2);
241 int lda = transC ? a.getSize(2) : b.getSize(2);
242 int ldb = transC ? b.getSize(2) : a.getSize(2);
243 int ldc = c.getSize(2);
245 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
246 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
249 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
250 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
253 HostTensor<float*, 1, true> hostA({a.getSize(0)});
254 HostTensor<float*, 1, true> hostB({b.getSize(0)});
255 HostTensor<float*, 1, true> hostC({c.getSize(0)});
257 size_t aOffset = a.getSize(1) * a.getSize(2);
258 size_t bOffset = b.getSize(1) * b.getSize(2);
259 size_t cOffset = c.getSize(1) * c.getSize(2);
261 for (
int i = 0; i < a.getSize(0); ++i) {
262 hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
263 hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
264 hostC[i] = c.data() + i * cOffset;
267 DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
268 DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
269 DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
272 cublasSgemmBatched(handle,
275 (
const float**) deviceA.data(), lda,
276 (
const float**) deviceB.data(), ldb, &beta,
277 deviceC.data(), ldc, a.getSize(0));
278 FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
279 "cublasSgemmBatched failed (%d)", (
int) err);