Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MatrixMult.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "MatrixMult.cuh"
12 #include "DeviceMemory.h"
13 #include "DeviceUtils.h" // CUDA_VERIFY
14 #include "DeviceTensor.cuh"
15 #include "HostTensor.cuh"
16 
17 namespace faiss { namespace gpu {
18 
19 template <typename T>
20 struct CublasGemm {
21 };
22 
23 template <>
24 struct CublasGemm<float> {
25  static cublasStatus_t gemm(cublasHandle_t handle,
26  cublasOperation_t transa,
27  cublasOperation_t transb,
28  int m,
29  int n,
30  int k,
31  float fAlpha,
32  const float *A,
33  int lda,
34  const float *B,
35  int ldb,
36  float fBeta,
37  float *C,
38  int ldc,
39  bool useHgemm) {
40  return cublasSgemm(handle, transa, transb, m, n, k,
41  &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
42  }
43 };
44 
45 #ifdef FAISS_USE_FLOAT16
46 template <>
47 struct CublasGemm<half> {
48  static cublasStatus_t gemm(cublasHandle_t handle,
49  cublasOperation_t transa,
50  cublasOperation_t transb,
51  int m,
52  int n,
53  int k,
54  const float fAlpha,
55  const half *A,
56  int lda,
57  const half *B,
58  int ldb,
59  const float fBeta,
60  half *C,
61  int ldc,
62  bool useHgemm) {
63  if (getDeviceSupportsFloat16Math(getCurrentDevice()) && useHgemm) {
64  half hAlpha = hostFloat2Half(fAlpha);
65  half hBeta = hostFloat2Half(fBeta);
66 
67  return cublasHgemm(handle, transa, transb, m, n, k,
68  &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
69  }
70 
71  // CUDA 8.0 changes the half datatype specifier
72 #if CUDA_VERSION == 7050
73  auto halfType = CUBLAS_DATA_HALF;
74 #else
75  auto halfType = CUDA_R_16F;
76 #endif // CUDA_VERSION
77 
78  return cublasSgemmEx(handle, transa, transb, m, n, k,
79  &fAlpha, A, halfType, lda,
80  B, halfType, ldb,
81  &fBeta,
82  C, halfType, ldc);
83  }
84 };
85 #endif // FAISS_USE_FLOAT16
86 
87 
88 template <typename T>
89 void
90 runMatrixMult(Tensor<T, 2, true>& c, bool transC,
91  Tensor<T, 2, true>& a, bool transA,
92  Tensor<T, 2, true>& b, bool transB,
93  float alpha,
94  float beta,
95  bool useHgemm,
96  cublasHandle_t handle,
97  cudaStream_t stream) {
98  cublasSetStream(handle, stream);
99 
100  // Check that we have (m x k) * (k x n) = (m x n)
101  // using the input row-major layout
102  int aM = transA ? a.getSize(1) : a.getSize(0);
103  int aK = transA ? a.getSize(0) : a.getSize(1);
104 
105  int bK = transB ? b.getSize(1) : b.getSize(0);
106  int bN = transB ? b.getSize(0) : b.getSize(1);
107 
108  int cM = transC ? c.getSize(1) : c.getSize(0);
109  int cN = transC ? c.getSize(0) : c.getSize(1);
110 
111  FAISS_ASSERT(aM == cM);
112  FAISS_ASSERT(aK == bK);
113  FAISS_ASSERT(bN == cN);
114 
115  // Now, we have to represent the matrix multiplication in
116  // column-major layout
117  T* pA = transC ? a.data() : b.data();
118  T* pB = transC ? b.data() : a.data();
119  T* pC = c.data();
120 
121  int m = c.getSize(1); // stride 1 size
122  int n = c.getSize(0); // other size
123  int k = transA ? a.getSize(0) : a.getSize(1);
124 
125  int lda = transC ? a.getSize(1) : b.getSize(1);
126  int ldb = transC ? b.getSize(1) : a.getSize(1);
127  int ldc = c.getSize(1);
128 
129  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
130  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
131 
132  if (transC) {
133  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
134  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
135  }
136 
137  auto err = CublasGemm<T>::gemm(handle,
138  gemmTrA, gemmTrB,
139  m, n, k, alpha,
140  pA, lda, pB, ldb, beta,
141  pC, ldc, useHgemm);
142 
143  FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
144  "cublas failed (%d): %s "
145  "(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
146  (int) err,
147  useHgemm ? "Hgemm" : "Sgemm",
148  a.getSize(0), a.getSize(1), transA ? "'" : "",
149  b.getSize(0), b.getSize(1), transB ? "'" : "",
150  c.getSize(0), c.getSize(1), transC ? "'" : "");
151  CUDA_TEST_ERROR();
152 }
153 
154 void runMatrixMult(Tensor<float, 2, true>& c, bool transC,
155  Tensor<float, 2, true>& a, bool transA,
156  Tensor<float, 2, true>& b, bool transB,
157  float alpha,
158  float beta,
159  bool useHgemm,
160  cublasHandle_t handle,
161  cudaStream_t stream) {
162  return runMatrixMult<float>(c, transC, a, transA, b, transB,
163  alpha, beta, useHgemm, handle, stream);
164 }
165 
166 #ifdef FAISS_USE_FLOAT16
167 void runMatrixMult(Tensor<half, 2, true>& c, bool transC,
168  Tensor<half, 2, true>& a, bool transA,
169  Tensor<half, 2, true>& b, bool transB,
170  float alpha,
171  float beta,
172  bool useHgemm,
173  cublasHandle_t handle,
174  cudaStream_t stream) {
175  return runMatrixMult<half>(c, transC, a, transA, b, transB,
176  alpha, beta, useHgemm, handle, stream);
177 }
178 #endif
179 
180 void
181 runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
182  Tensor<float, 3, true>& a, bool transA,
183  Tensor<float, 3, true>& b, bool transB,
184  float alpha,
185  float beta,
186  cublasHandle_t handle,
187  cudaStream_t stream) {
188  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
189  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
190 
191  for (int i = 0; i < a.getSize(0); ++i) {
192  auto cView = c[i].view();
193  auto aView = a[i].view();
194  auto bView = b[i].view();
195 
196  runMatrixMult(cView, transC,
197  aView, transA,
198  bView, transB,
199  alpha, beta, false, handle, stream);
200  }
201 }
202 
203 void
204 runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
205  Tensor<float, 3, true>& a, bool transA,
206  Tensor<float, 3, true>& b, bool transB,
207  float alpha,
208  float beta,
209  DeviceMemory& mem,
210  cublasHandle_t handle,
211  cudaStream_t stream) {
212  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
213  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
214  cublasSetStream(handle, stream);
215 
216  // Check that we have (m x k) * (k x n) = (m x n)
217  // using the input row-major layout
218  int aM = transA ? a.getSize(2) : a.getSize(1);
219  int aK = transA ? a.getSize(1) : a.getSize(2);
220 
221  int bK = transB ? b.getSize(2) : b.getSize(1);
222  int bN = transB ? b.getSize(1) : b.getSize(2);
223 
224  int cM = transC ? c.getSize(2) : c.getSize(1);
225  int cN = transC ? c.getSize(1) : c.getSize(2);
226 
227  FAISS_ASSERT(aM == cM);
228  FAISS_ASSERT(aK == bK);
229  FAISS_ASSERT(bN == cN);
230 
231  // Now, we have to represent the matrix multiplication in
232  // column-major layout
233  float* pA = transC ? a.data() : b.data();
234  float* pB = transC ? b.data() : a.data();
235  float* pC = c.data();
236 
237  int m = c.getSize(2); // stride 1 size
238  int n = c.getSize(1); // other size
239  int k = transA ? a.getSize(1) : a.getSize(2);
240 
241  int lda = transC ? a.getSize(2) : b.getSize(2);
242  int ldb = transC ? b.getSize(2) : a.getSize(2);
243  int ldc = c.getSize(2);
244 
245  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
246  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
247 
248  if (transC) {
249  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
250  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
251  }
252 
253  HostTensor<float*, 1, true> hostA({a.getSize(0)});
254  HostTensor<float*, 1, true> hostB({b.getSize(0)});
255  HostTensor<float*, 1, true> hostC({c.getSize(0)});
256 
257  size_t aOffset = a.getSize(1) * a.getSize(2);
258  size_t bOffset = b.getSize(1) * b.getSize(2);
259  size_t cOffset = c.getSize(1) * c.getSize(2);
260 
261  for (int i = 0; i < a.getSize(0); ++i) {
262  hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
263  hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
264  hostC[i] = c.data() + i * cOffset;
265  }
266 
267  DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
268  DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
269  DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
270 
271  auto err =
272  cublasSgemmBatched(handle,
273  gemmTrA, gemmTrB,
274  m, n, k, &alpha,
275  (const float**) deviceA.data(), lda,
276  (const float**) deviceB.data(), ldb, &beta,
277  deviceC.data(), ldc, a.getSize(0));
278  FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
279  "cublasSgemmBatched failed (%d)", (int) err);
280  CUDA_TEST_ERROR();
281 }
282 
283 } } // namespace