Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MatrixMult.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "MatrixMult.cuh"
11 #include "DeviceMemory.h"
12 #include "DeviceUtils.h" // CUDA_VERIFY
13 #include "DeviceTensor.cuh"
14 #include "HostTensor.cuh"
15 
16 namespace faiss { namespace gpu {
17 
18 template <typename T>
19 struct CublasGemm {
20 };
21 
22 template <>
23 struct CublasGemm<float> {
24  static cublasStatus_t gemm(cublasHandle_t handle,
25  cublasOperation_t transa,
26  cublasOperation_t transb,
27  int m,
28  int n,
29  int k,
30  float fAlpha,
31  const float *A,
32  int lda,
33  const float *B,
34  int ldb,
35  float fBeta,
36  float *C,
37  int ldc,
38  bool useHgemm) {
39  return cublasSgemm(handle, transa, transb, m, n, k,
40  &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
41  }
42 };
43 
44 #ifdef FAISS_USE_FLOAT16
45 template <>
46 struct CublasGemm<half> {
47  static cublasStatus_t gemm(cublasHandle_t handle,
48  cublasOperation_t transa,
49  cublasOperation_t transb,
50  int m,
51  int n,
52  int k,
53  const float fAlpha,
54  const half *A,
55  int lda,
56  const half *B,
57  int ldb,
58  const float fBeta,
59  half *C,
60  int ldc,
61  bool useHgemm) {
62  if (getDeviceSupportsFloat16Math(getCurrentDevice()) && useHgemm) {
63  half hAlpha = hostFloat2Half(fAlpha);
64  half hBeta = hostFloat2Half(fBeta);
65 
66  return cublasHgemm(handle, transa, transb, m, n, k,
67  &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
68  }
69 
70  // CUDA 8.0 changes the half datatype specifier
71 #if CUDA_VERSION == 7050
72  auto halfType = CUBLAS_DATA_HALF;
73 #else
74  auto halfType = CUDA_R_16F;
75 #endif // CUDA_VERSION
76 
77  return cublasSgemmEx(handle, transa, transb, m, n, k,
78  &fAlpha, A, halfType, lda,
79  B, halfType, ldb,
80  &fBeta,
81  C, halfType, ldc);
82  }
83 };
84 #endif // FAISS_USE_FLOAT16
85 
86 
87 template <typename T>
88 void
89 runMatrixMult(Tensor<T, 2, true>& c, bool transC,
90  Tensor<T, 2, true>& a, bool transA,
91  Tensor<T, 2, true>& b, bool transB,
92  float alpha,
93  float beta,
94  bool useHgemm,
95  cublasHandle_t handle,
96  cudaStream_t stream) {
97  cublasSetStream(handle, stream);
98 
99  // Check that we have (m x k) * (k x n) = (m x n)
100  // using the input row-major layout
101  int aM = transA ? a.getSize(1) : a.getSize(0);
102  int aK = transA ? a.getSize(0) : a.getSize(1);
103 
104  int bK = transB ? b.getSize(1) : b.getSize(0);
105  int bN = transB ? b.getSize(0) : b.getSize(1);
106 
107  int cM = transC ? c.getSize(1) : c.getSize(0);
108  int cN = transC ? c.getSize(0) : c.getSize(1);
109 
110  FAISS_ASSERT(aM == cM);
111  FAISS_ASSERT(aK == bK);
112  FAISS_ASSERT(bN == cN);
113 
114  FAISS_ASSERT(a.getStride(1) == 1);
115  FAISS_ASSERT(b.getStride(1) == 1);
116  FAISS_ASSERT(c.getStride(1) == 1);
117 
118  // Now, we have to represent the matrix multiplication in
119  // column-major layout
120  T* pA = transC ? a.data() : b.data();
121  T* pB = transC ? b.data() : a.data();
122  T* pC = c.data();
123 
124  int m = c.getSize(1); // stride 1 size
125  int n = c.getSize(0); // other size
126  int k = transA ? a.getSize(0) : a.getSize(1);
127 
128  int lda = transC ? a.getStride(0) : b.getStride(0);
129  int ldb = transC ? b.getStride(0) : a.getStride(0);
130  int ldc = c.getStride(0);
131 
132  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
133  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
134 
135  if (transC) {
136  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
137  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
138  }
139 
140  auto err = CublasGemm<T>::gemm(handle,
141  gemmTrA, gemmTrB,
142  m, n, k, alpha,
143  pA, lda, pB, ldb, beta,
144  pC, ldc, useHgemm);
145 
146  FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
147  "cublas failed (%d): %s "
148  "(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
149  (int) err,
150  useHgemm ? "Hgemm" : "Sgemm",
151  a.getSize(0), a.getSize(1), transA ? "'" : "",
152  b.getSize(0), b.getSize(1), transB ? "'" : "",
153  c.getSize(0), c.getSize(1), transC ? "'" : "");
154  CUDA_TEST_ERROR();
155 }
156 
157 void runMatrixMult(Tensor<float, 2, true>& c, bool transC,
158  Tensor<float, 2, true>& a, bool transA,
159  Tensor<float, 2, true>& b, bool transB,
160  float alpha,
161  float beta,
162  bool useHgemm,
163  cublasHandle_t handle,
164  cudaStream_t stream) {
165  return runMatrixMult<float>(c, transC, a, transA, b, transB,
166  alpha, beta, useHgemm, handle, stream);
167 }
168 
169 #ifdef FAISS_USE_FLOAT16
170 void runMatrixMult(Tensor<half, 2, true>& c, bool transC,
171  Tensor<half, 2, true>& a, bool transA,
172  Tensor<half, 2, true>& b, bool transB,
173  float alpha,
174  float beta,
175  bool useHgemm,
176  cublasHandle_t handle,
177  cudaStream_t stream) {
178  return runMatrixMult<half>(c, transC, a, transA, b, transB,
179  alpha, beta, useHgemm, handle, stream);
180 }
181 #endif
182 
183 void
184 runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
185  Tensor<float, 3, true>& a, bool transA,
186  Tensor<float, 3, true>& b, bool transB,
187  float alpha,
188  float beta,
189  cublasHandle_t handle,
190  cudaStream_t stream) {
191  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
192  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
193 
194  for (int i = 0; i < a.getSize(0); ++i) {
195  auto cView = c[i].view();
196  auto aView = a[i].view();
197  auto bView = b[i].view();
198 
199  runMatrixMult(cView, transC,
200  aView, transA,
201  bView, transB,
202  alpha, beta, false, handle, stream);
203  }
204 }
205 
206 void
207 runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
208  Tensor<float, 3, true>& a, bool transA,
209  Tensor<float, 3, true>& b, bool transB,
210  float alpha,
211  float beta,
212  DeviceMemory& mem,
213  cublasHandle_t handle,
214  cudaStream_t stream) {
215  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
216  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
217  cublasSetStream(handle, stream);
218 
219  // Check that we have (m x k) * (k x n) = (m x n)
220  // using the input row-major layout
221  int aM = transA ? a.getSize(2) : a.getSize(1);
222  int aK = transA ? a.getSize(1) : a.getSize(2);
223 
224  int bK = transB ? b.getSize(2) : b.getSize(1);
225  int bN = transB ? b.getSize(1) : b.getSize(2);
226 
227  int cM = transC ? c.getSize(2) : c.getSize(1);
228  int cN = transC ? c.getSize(1) : c.getSize(2);
229 
230  FAISS_ASSERT(aM == cM);
231  FAISS_ASSERT(aK == bK);
232  FAISS_ASSERT(bN == cN);
233 
234  // Now, we have to represent the matrix multiplication in
235  // column-major layout
236  float* pA = transC ? a.data() : b.data();
237  float* pB = transC ? b.data() : a.data();
238  float* pC = c.data();
239 
240  int m = c.getSize(2); // stride 1 size
241  int n = c.getSize(1); // other size
242  int k = transA ? a.getSize(1) : a.getSize(2);
243 
244  int lda = transC ? a.getStride(1) : b.getStride(1);
245  int ldb = transC ? b.getStride(1) : a.getStride(1);
246  int ldc = c.getStride(1);
247 
248  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
249  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
250 
251  if (transC) {
252  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
253  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
254  }
255 
256  HostTensor<float*, 1, true> hostA({a.getSize(0)});
257  HostTensor<float*, 1, true> hostB({b.getSize(0)});
258  HostTensor<float*, 1, true> hostC({c.getSize(0)});
259 
260  size_t aOffset = a.getStride(0);
261  size_t bOffset = b.getStride(0);
262  size_t cOffset = c.getStride(0);
263 
264  for (int i = 0; i < a.getSize(0); ++i) {
265  hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
266  hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
267  hostC[i] = c.data() + i * cOffset;
268  }
269 
270  DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
271  DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
272  DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
273 
274  auto err =
275  cublasSgemmBatched(handle,
276  gemmTrA, gemmTrB,
277  m, n, k, &alpha,
278  (const float**) deviceA.data(), lda,
279  (const float**) deviceB.data(), ldb, &beta,
280  deviceC.data(), ldc, a.getSize(0));
281  FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
282  "cublasSgemmBatched failed (%d)", (int) err);
283  CUDA_TEST_ERROR();
284 }
285 
286 } } // namespace