11 #include "MatrixMult.cuh"
12 #include "DeviceMemory.h"
13 #include "DeviceUtils.h"
14 #include "DeviceTensor.cuh"
15 #include "HostTensor.cuh"
17 namespace faiss {
namespace gpu {
25 static cublasStatus_t gemm(cublasHandle_t handle,
26 cublasOperation_t transa,
27 cublasOperation_t transb,
40 return cublasSgemm(handle, transa, transb, m, n, k,
41 &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
45 #ifdef FAISS_USE_FLOAT16
48 static cublasStatus_t gemm(cublasHandle_t handle,
49 cublasOperation_t transa,
50 cublasOperation_t transb,
63 if (getDeviceSupportsFloat16Math(getCurrentDevice()) && useHgemm) {
64 half hAlpha = hostFloat2Half(fAlpha);
65 half hBeta = hostFloat2Half(fBeta);
67 return cublasHgemm(handle, transa, transb, m, n, k,
68 &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
72 #if CUDA_VERSION == 7050
73 auto halfType = CUBLAS_DATA_HALF;
75 auto halfType = CUDA_R_16F;
76 #endif // CUDA_VERSION
78 return cublasSgemmEx(handle, transa, transb, m, n, k,
79 &fAlpha, A, halfType, lda,
85 #endif // FAISS_USE_FLOAT16
90 runMatrixMult(Tensor<T, 2, true>& c,
bool transC,
91 Tensor<T, 2, true>& a,
bool transA,
92 Tensor<T, 2, true>& b,
bool transB,
96 cublasHandle_t handle,
97 cudaStream_t stream) {
98 cublasSetStream(handle, stream);
102 int aM = transA ? a.getSize(1) : a.getSize(0);
103 int aK = transA ? a.getSize(0) : a.getSize(1);
105 int bK = transB ? b.getSize(1) : b.getSize(0);
106 int bN = transB ? b.getSize(0) : b.getSize(1);
108 int cM = transC ? c.getSize(1) : c.getSize(0);
109 int cN = transC ? c.getSize(0) : c.getSize(1);
111 FAISS_ASSERT(aM == cM);
112 FAISS_ASSERT(aK == bK);
113 FAISS_ASSERT(bN == cN);
115 FAISS_ASSERT(a.getStride(1) == 1);
116 FAISS_ASSERT(b.getStride(1) == 1);
117 FAISS_ASSERT(c.getStride(1) == 1);
121 T* pA = transC ? a.data() : b.data();
122 T* pB = transC ? b.data() : a.data();
125 int m = c.getSize(1);
126 int n = c.getSize(0);
127 int k = transA ? a.getSize(0) : a.getSize(1);
129 int lda = transC ? a.getStride(0) : b.getStride(0);
130 int ldb = transC ? b.getStride(0) : a.getStride(0);
131 int ldc = c.getStride(0);
133 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
134 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
137 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
138 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
141 auto err = CublasGemm<T>::gemm(handle,
144 pA, lda, pB, ldb, beta,
147 FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
148 "cublas failed (%d): %s "
149 "(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
151 useHgemm ?
"Hgemm" :
"Sgemm",
152 a.getSize(0), a.getSize(1), transA ?
"'" :
"",
153 b.getSize(0), b.getSize(1), transB ?
"'" :
"",
154 c.getSize(0), c.getSize(1), transC ?
"'" :
"");
158 void runMatrixMult(Tensor<float, 2, true>& c,
bool transC,
159 Tensor<float, 2, true>& a,
bool transA,
160 Tensor<float, 2, true>& b,
bool transB,
164 cublasHandle_t handle,
165 cudaStream_t stream) {
166 return runMatrixMult<float>(c, transC, a, transA, b, transB,
167 alpha, beta, useHgemm, handle, stream);
170 #ifdef FAISS_USE_FLOAT16
171 void runMatrixMult(Tensor<half, 2, true>& c,
bool transC,
172 Tensor<half, 2, true>& a,
bool transA,
173 Tensor<half, 2, true>& b,
bool transB,
177 cublasHandle_t handle,
178 cudaStream_t stream) {
179 return runMatrixMult<half>(c, transC, a, transA, b, transB,
180 alpha, beta, useHgemm, handle, stream);
185 runIteratedMatrixMult(Tensor<float, 3, true>& c,
bool transC,
186 Tensor<float, 3, true>& a,
bool transA,
187 Tensor<float, 3, true>& b,
bool transB,
190 cublasHandle_t handle,
191 cudaStream_t stream) {
192 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
193 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
195 for (
int i = 0; i < a.getSize(0); ++i) {
196 auto cView = c[i].view();
197 auto aView = a[i].view();
198 auto bView = b[i].view();
200 runMatrixMult(cView, transC,
203 alpha, beta,
false, handle, stream);
208 runBatchMatrixMult(Tensor<float, 3, true>& c,
bool transC,
209 Tensor<float, 3, true>& a,
bool transA,
210 Tensor<float, 3, true>& b,
bool transB,
214 cublasHandle_t handle,
215 cudaStream_t stream) {
216 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
217 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
218 cublasSetStream(handle, stream);
222 int aM = transA ? a.getSize(2) : a.getSize(1);
223 int aK = transA ? a.getSize(1) : a.getSize(2);
225 int bK = transB ? b.getSize(2) : b.getSize(1);
226 int bN = transB ? b.getSize(1) : b.getSize(2);
228 int cM = transC ? c.getSize(2) : c.getSize(1);
229 int cN = transC ? c.getSize(1) : c.getSize(2);
231 FAISS_ASSERT(aM == cM);
232 FAISS_ASSERT(aK == bK);
233 FAISS_ASSERT(bN == cN);
237 float* pA = transC ? a.data() : b.data();
238 float* pB = transC ? b.data() : a.data();
239 float* pC = c.data();
241 int m = c.getSize(2);
242 int n = c.getSize(1);
243 int k = transA ? a.getSize(1) : a.getSize(2);
245 int lda = transC ? a.getStride(1) : b.getStride(1);
246 int ldb = transC ? b.getStride(1) : a.getStride(1);
247 int ldc = c.getStride(1);
249 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
250 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
253 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
254 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
257 HostTensor<float*, 1, true> hostA({a.getSize(0)});
258 HostTensor<float*, 1, true> hostB({b.getSize(0)});
259 HostTensor<float*, 1, true> hostC({c.getSize(0)});
261 size_t aOffset = a.getStride(0);
262 size_t bOffset = b.getStride(0);
263 size_t cOffset = c.getStride(0);
265 for (
int i = 0; i < a.getSize(0); ++i) {
266 hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
267 hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
268 hostC[i] = c.data() + i * cOffset;
271 DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
272 DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
273 DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
276 cublasSgemmBatched(handle,
279 (
const float**) deviceA.data(), lda,
280 (
const float**) deviceB.data(), ldb, &beta,
281 deviceC.data(), ldc, a.getSize(0));
282 FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
283 "cublasSgemmBatched failed (%d)", (
int) err);