Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MatrixMult.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "MatrixMult.cuh"
10 #include "DeviceMemory.h"
11 #include "DeviceUtils.h" // CUDA_VERIFY
12 #include "DeviceTensor.cuh"
13 #include "HostTensor.cuh"
14 
15 namespace faiss { namespace gpu {
16 
17 template <typename T>
18 struct CublasGemm {
19 };
20 
21 template <>
22 struct CublasGemm<float> {
23  static cublasStatus_t gemm(cublasHandle_t handle,
24  cublasOperation_t transa,
25  cublasOperation_t transb,
26  int m,
27  int n,
28  int k,
29  float fAlpha,
30  const float *A,
31  int lda,
32  const float *B,
33  int ldb,
34  float fBeta,
35  float *C,
36  int ldc,
37  bool useHgemm) {
38  return cublasSgemm(handle, transa, transb, m, n, k,
39  &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
40  }
41 };
42 
43 #ifdef FAISS_USE_FLOAT16
44 template <>
45 struct CublasGemm<half> {
46  static cublasStatus_t gemm(cublasHandle_t handle,
47  cublasOperation_t transa,
48  cublasOperation_t transb,
49  int m,
50  int n,
51  int k,
52  const float fAlpha,
53  const half *A,
54  int lda,
55  const half *B,
56  int ldb,
57  const float fBeta,
58  half *C,
59  int ldc,
60  bool useHgemm) {
61  if (getDeviceSupportsFloat16Math(getCurrentDevice()) && useHgemm) {
62  half hAlpha = hostFloat2Half(fAlpha);
63  half hBeta = hostFloat2Half(fBeta);
64 
65  return cublasHgemm(handle, transa, transb, m, n, k,
66  &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
67  }
68 
69  // CUDA 8.0 changes the half datatype specifier
70 #if CUDA_VERSION == 7050
71  auto halfType = CUBLAS_DATA_HALF;
72 #else
73  auto halfType = CUDA_R_16F;
74 #endif // CUDA_VERSION
75 
76  return cublasSgemmEx(handle, transa, transb, m, n, k,
77  &fAlpha, A, halfType, lda,
78  B, halfType, ldb,
79  &fBeta,
80  C, halfType, ldc);
81  }
82 };
83 #endif // FAISS_USE_FLOAT16
84 
85 
86 template <typename T>
87 void
88 runMatrixMult(Tensor<T, 2, true>& c, bool transC,
89  Tensor<T, 2, true>& a, bool transA,
90  Tensor<T, 2, true>& b, bool transB,
91  float alpha,
92  float beta,
93  bool useHgemm,
94  cublasHandle_t handle,
95  cudaStream_t stream) {
96  cublasSetStream(handle, stream);
97 
98  // Check that we have (m x k) * (k x n) = (m x n)
99  // using the input row-major layout
100  int aM = transA ? a.getSize(1) : a.getSize(0);
101  int aK = transA ? a.getSize(0) : a.getSize(1);
102 
103  int bK = transB ? b.getSize(1) : b.getSize(0);
104  int bN = transB ? b.getSize(0) : b.getSize(1);
105 
106  int cM = transC ? c.getSize(1) : c.getSize(0);
107  int cN = transC ? c.getSize(0) : c.getSize(1);
108 
109  FAISS_ASSERT(aM == cM);
110  FAISS_ASSERT(aK == bK);
111  FAISS_ASSERT(bN == cN);
112 
113  FAISS_ASSERT(a.getStride(1) == 1);
114  FAISS_ASSERT(b.getStride(1) == 1);
115  FAISS_ASSERT(c.getStride(1) == 1);
116 
117  // Now, we have to represent the matrix multiplication in
118  // column-major layout
119  T* pA = transC ? a.data() : b.data();
120  T* pB = transC ? b.data() : a.data();
121  T* pC = c.data();
122 
123  int m = c.getSize(1); // stride 1 size
124  int n = c.getSize(0); // other size
125  int k = transA ? a.getSize(0) : a.getSize(1);
126 
127  int lda = transC ? a.getStride(0) : b.getStride(0);
128  int ldb = transC ? b.getStride(0) : a.getStride(0);
129  int ldc = c.getStride(0);
130 
131  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
132  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
133 
134  if (transC) {
135  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
136  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
137  }
138 
139  auto err = CublasGemm<T>::gemm(handle,
140  gemmTrA, gemmTrB,
141  m, n, k, alpha,
142  pA, lda, pB, ldb, beta,
143  pC, ldc, useHgemm);
144 
145  FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
146  "cublas failed (%d): %s "
147  "(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
148  (int) err,
149  useHgemm ? "Hgemm" : "Sgemm",
150  a.getSize(0), a.getSize(1), transA ? "'" : "",
151  b.getSize(0), b.getSize(1), transB ? "'" : "",
152  c.getSize(0), c.getSize(1), transC ? "'" : "");
153  CUDA_TEST_ERROR();
154 }
155 
156 void runMatrixMult(Tensor<float, 2, true>& c, bool transC,
157  Tensor<float, 2, true>& a, bool transA,
158  Tensor<float, 2, true>& b, bool transB,
159  float alpha,
160  float beta,
161  bool useHgemm,
162  cublasHandle_t handle,
163  cudaStream_t stream) {
164  return runMatrixMult<float>(c, transC, a, transA, b, transB,
165  alpha, beta, useHgemm, handle, stream);
166 }
167 
168 #ifdef FAISS_USE_FLOAT16
169 void runMatrixMult(Tensor<half, 2, true>& c, bool transC,
170  Tensor<half, 2, true>& a, bool transA,
171  Tensor<half, 2, true>& b, bool transB,
172  float alpha,
173  float beta,
174  bool useHgemm,
175  cublasHandle_t handle,
176  cudaStream_t stream) {
177  return runMatrixMult<half>(c, transC, a, transA, b, transB,
178  alpha, beta, useHgemm, handle, stream);
179 }
180 #endif
181 
182 void
183 runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
184  Tensor<float, 3, true>& a, bool transA,
185  Tensor<float, 3, true>& b, bool transB,
186  float alpha,
187  float beta,
188  cublasHandle_t handle,
189  cudaStream_t stream) {
190  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
191  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
192 
193  for (int i = 0; i < a.getSize(0); ++i) {
194  auto cView = c[i].view();
195  auto aView = a[i].view();
196  auto bView = b[i].view();
197 
198  runMatrixMult(cView, transC,
199  aView, transA,
200  bView, transB,
201  alpha, beta, false, handle, stream);
202  }
203 }
204 
205 void
206 runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
207  Tensor<float, 3, true>& a, bool transA,
208  Tensor<float, 3, true>& b, bool transB,
209  float alpha,
210  float beta,
211  DeviceMemory& mem,
212  cublasHandle_t handle,
213  cudaStream_t stream) {
214  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
215  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
216  cublasSetStream(handle, stream);
217 
218  // Check that we have (m x k) * (k x n) = (m x n)
219  // using the input row-major layout
220  int aM = transA ? a.getSize(2) : a.getSize(1);
221  int aK = transA ? a.getSize(1) : a.getSize(2);
222 
223  int bK = transB ? b.getSize(2) : b.getSize(1);
224  int bN = transB ? b.getSize(1) : b.getSize(2);
225 
226  int cM = transC ? c.getSize(2) : c.getSize(1);
227  int cN = transC ? c.getSize(1) : c.getSize(2);
228 
229  FAISS_ASSERT(aM == cM);
230  FAISS_ASSERT(aK == bK);
231  FAISS_ASSERT(bN == cN);
232 
233  // Now, we have to represent the matrix multiplication in
234  // column-major layout
235  float* pA = transC ? a.data() : b.data();
236  float* pB = transC ? b.data() : a.data();
237  float* pC = c.data();
238 
239  int m = c.getSize(2); // stride 1 size
240  int n = c.getSize(1); // other size
241  int k = transA ? a.getSize(1) : a.getSize(2);
242 
243  int lda = transC ? a.getStride(1) : b.getStride(1);
244  int ldb = transC ? b.getStride(1) : a.getStride(1);
245  int ldc = c.getStride(1);
246 
247  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
248  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
249 
250  if (transC) {
251  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
252  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
253  }
254 
255  HostTensor<float*, 1, true> hostA({a.getSize(0)});
256  HostTensor<float*, 1, true> hostB({b.getSize(0)});
257  HostTensor<float*, 1, true> hostC({c.getSize(0)});
258 
259  size_t aOffset = a.getStride(0);
260  size_t bOffset = b.getStride(0);
261  size_t cOffset = c.getStride(0);
262 
263  for (int i = 0; i < a.getSize(0); ++i) {
264  hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
265  hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
266  hostC[i] = c.data() + i * cOffset;
267  }
268 
269  DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
270  DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
271  DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
272 
273  auto err =
274  cublasSgemmBatched(handle,
275  gemmTrA, gemmTrB,
276  m, n, k, &alpha,
277  (const float**) deviceA.data(), lda,
278  (const float**) deviceB.data(), ldb, &beta,
279  deviceC.data(), ldc, a.getSize(0));
280  FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
281  "cublasSgemmBatched failed (%d)", (int) err);
282  CUDA_TEST_ERROR();
283 }
284 
285 } } // namespace