12 #include "MatrixMult.cuh"
13 #include "DeviceMemory.h"
14 #include "DeviceUtils.h"
15 #include "DeviceTensor.cuh"
16 #include "HostTensor.cuh"
18 namespace faiss {
namespace gpu {
26 static cublasStatus_t gemm(cublasHandle_t handle,
27 cublasOperation_t transa,
28 cublasOperation_t transb,
40 return cublasSgemm(handle, transa, transb, m, n, k,
41 &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
45 #ifdef FAISS_USE_FLOAT16
48 static cublasStatus_t gemm(cublasHandle_t handle,
49 cublasOperation_t transa,
50 cublasOperation_t transb,
62 if (getDeviceSupportsFloat16Math(getCurrentDevice())) {
63 half hAlpha = hostFloat2Half(fAlpha);
64 half hBeta = hostFloat2Half(fBeta);
66 return cublasHgemm(handle, transa, transb, m, n, k,
67 &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
71 #if CUDA_VERSION == 7050
72 auto halfType = CUBLAS_DATA_HALF;
74 auto halfType = CUDA_R_16F;
75 #endif // CUDA_VERSION
77 return cublasSgemmEx(handle, transa, transb, m, n, k,
78 &fAlpha, A, halfType, lda,
84 #endif // FAISS_USE_FLOAT16
89 runMatrixMult(Tensor<T, 2, true>& c,
bool transC,
90 Tensor<T, 2, true>& a,
bool transA,
91 Tensor<T, 2, true>& b,
bool transB,
94 cublasHandle_t handle,
95 cudaStream_t stream) {
96 cublasSetStream(handle, stream);
100 int aM = transA ? a.getSize(1) : a.getSize(0);
101 int aK = transA ? a.getSize(0) : a.getSize(1);
103 int bK = transB ? b.getSize(1) : b.getSize(0);
104 int bN = transB ? b.getSize(0) : b.getSize(1);
106 int cM = transC ? c.getSize(1) : c.getSize(0);
107 int cN = transC ? c.getSize(0) : c.getSize(1);
109 FAISS_ASSERT(aM == cM);
110 FAISS_ASSERT(aK == bK);
111 FAISS_ASSERT(bN == cN);
115 T* pA = transC ? a.data() : b.data();
116 T* pB = transC ? b.data() : a.data();
119 int m = c.getSize(1);
120 int n = c.getSize(0);
121 int k = transA ? a.getSize(0) : a.getSize(1);
123 int lda = transC ? a.getSize(1) : b.getSize(1);
124 int ldb = transC ? b.getSize(1) : a.getSize(1);
125 int ldc = c.getSize(1);
127 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
128 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
131 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
132 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
135 auto err = CublasGemm<T>::gemm(handle,
138 pA, lda, pB, ldb, beta,
141 FAISS_ASSERT(err == CUBLAS_STATUS_SUCCESS);
142 CUDA_VERIFY(cudaGetLastError());
145 void runMatrixMult(Tensor<float, 2, true>& c,
bool transC,
146 Tensor<float, 2, true>& a,
bool transA,
147 Tensor<float, 2, true>& b,
bool transB,
150 cublasHandle_t handle,
151 cudaStream_t stream) {
152 return runMatrixMult<float>(c, transC, a, transA, b, transB,
153 alpha, beta, handle, stream);
156 #ifdef FAISS_USE_FLOAT16
157 void runMatrixMult(Tensor<half, 2, true>& c,
bool transC,
158 Tensor<half, 2, true>& a,
bool transA,
159 Tensor<half, 2, true>& b,
bool transB,
162 cublasHandle_t handle,
163 cudaStream_t stream) {
164 return runMatrixMult<half>(c, transC, a, transA, b, transB,
165 alpha, beta, handle, stream);
170 runIteratedMatrixMult(Tensor<float, 3, true>& c,
bool transC,
171 Tensor<float, 3, true>& a,
bool transA,
172 Tensor<float, 3, true>& b,
bool transB,
175 cublasHandle_t handle,
176 cudaStream_t stream) {
177 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
178 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
180 for (
int i = 0; i < a.getSize(0); ++i) {
181 auto cView = c[i].view();
182 auto aView = a[i].view();
183 auto bView = b[i].view();
185 runMatrixMult(cView, transC,
188 alpha, beta, handle, stream);
193 runBatchMatrixMult(Tensor<float, 3, true>& c,
bool transC,
194 Tensor<float, 3, true>& a,
bool transA,
195 Tensor<float, 3, true>& b,
bool transB,
199 cublasHandle_t handle,
200 cudaStream_t stream) {
201 FAISS_ASSERT(c.getSize(0) == a.getSize(0));
202 FAISS_ASSERT(a.getSize(0) == b.getSize(0));
203 cublasSetStream(handle, stream);
207 int aM = transA ? a.getSize(2) : a.getSize(1);
208 int aK = transA ? a.getSize(1) : a.getSize(2);
210 int bK = transB ? b.getSize(2) : b.getSize(1);
211 int bN = transB ? b.getSize(1) : b.getSize(2);
213 int cM = transC ? c.getSize(2) : c.getSize(1);
214 int cN = transC ? c.getSize(1) : c.getSize(2);
216 FAISS_ASSERT(aM == cM);
217 FAISS_ASSERT(aK == bK);
218 FAISS_ASSERT(bN == cN);
222 float* pA = transC ? a.data() : b.data();
223 float* pB = transC ? b.data() : a.data();
224 float* pC = c.data();
226 int m = c.getSize(2);
227 int n = c.getSize(1);
228 int k = transA ? a.getSize(1) : a.getSize(2);
230 int lda = transC ? a.getSize(2) : b.getSize(2);
231 int ldb = transC ? b.getSize(2) : a.getSize(2);
232 int ldc = c.getSize(2);
234 auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
235 auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
238 gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
239 gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
242 HostTensor<float*, 1, true> hostA({a.getSize(0)});
243 HostTensor<float*, 1, true> hostB({b.getSize(0)});
244 HostTensor<float*, 1, true> hostC({c.getSize(0)});
246 size_t aOffset = a.getSize(1) * a.getSize(2);
247 size_t bOffset = b.getSize(1) * b.getSize(2);
248 size_t cOffset = c.getSize(1) * c.getSize(2);
250 for (
int i = 0; i < a.getSize(0); ++i) {
251 hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
252 hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
253 hostC[i] = c.data() + i * cOffset;
256 DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
257 DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
258 DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
261 cublasSgemmBatched(handle,
264 (
const float**) deviceA.data(), lda,
265 (
const float**) deviceB.data(), ldb, &beta,
266 deviceC.data(), ldc, a.getSize(0));
267 FAISS_ASSERT(err == CUBLAS_STATUS_SUCCESS);
268 CUDA_VERIFY(cudaGetLastError());