Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MatrixMult.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "MatrixMult.cuh"
12 #include "DeviceMemory.h"
13 #include "DeviceUtils.h" // CUDA_VERIFY
14 #include "DeviceTensor.cuh"
15 #include "HostTensor.cuh"
16 
17 namespace faiss { namespace gpu {
18 
19 template <typename T>
20 struct CublasGemm {
21 };
22 
23 template <>
24 struct CublasGemm<float> {
25  static cublasStatus_t gemm(cublasHandle_t handle,
26  cublasOperation_t transa,
27  cublasOperation_t transb,
28  int m,
29  int n,
30  int k,
31  float fAlpha,
32  const float *A,
33  int lda,
34  const float *B,
35  int ldb,
36  float fBeta,
37  float *C,
38  int ldc,
39  bool useHgemm) {
40  return cublasSgemm(handle, transa, transb, m, n, k,
41  &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
42  }
43 };
44 
45 #ifdef FAISS_USE_FLOAT16
46 template <>
47 struct CublasGemm<half> {
48  static cublasStatus_t gemm(cublasHandle_t handle,
49  cublasOperation_t transa,
50  cublasOperation_t transb,
51  int m,
52  int n,
53  int k,
54  const float fAlpha,
55  const half *A,
56  int lda,
57  const half *B,
58  int ldb,
59  const float fBeta,
60  half *C,
61  int ldc,
62  bool useHgemm) {
63  if (getDeviceSupportsFloat16Math(getCurrentDevice()) && useHgemm) {
64  half hAlpha = hostFloat2Half(fAlpha);
65  half hBeta = hostFloat2Half(fBeta);
66 
67  return cublasHgemm(handle, transa, transb, m, n, k,
68  &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
69  }
70 
71  // CUDA 8.0 changes the half datatype specifier
72 #if CUDA_VERSION == 7050
73  auto halfType = CUBLAS_DATA_HALF;
74 #else
75  auto halfType = CUDA_R_16F;
76 #endif // CUDA_VERSION
77 
78  return cublasSgemmEx(handle, transa, transb, m, n, k,
79  &fAlpha, A, halfType, lda,
80  B, halfType, ldb,
81  &fBeta,
82  C, halfType, ldc);
83  }
84 };
85 #endif // FAISS_USE_FLOAT16
86 
87 
88 template <typename T>
89 void
90 runMatrixMult(Tensor<T, 2, true>& c, bool transC,
91  Tensor<T, 2, true>& a, bool transA,
92  Tensor<T, 2, true>& b, bool transB,
93  float alpha,
94  float beta,
95  bool useHgemm,
96  cublasHandle_t handle,
97  cudaStream_t stream) {
98  cublasSetStream(handle, stream);
99 
100  // Check that we have (m x k) * (k x n) = (m x n)
101  // using the input row-major layout
102  int aM = transA ? a.getSize(1) : a.getSize(0);
103  int aK = transA ? a.getSize(0) : a.getSize(1);
104 
105  int bK = transB ? b.getSize(1) : b.getSize(0);
106  int bN = transB ? b.getSize(0) : b.getSize(1);
107 
108  int cM = transC ? c.getSize(1) : c.getSize(0);
109  int cN = transC ? c.getSize(0) : c.getSize(1);
110 
111  FAISS_ASSERT(aM == cM);
112  FAISS_ASSERT(aK == bK);
113  FAISS_ASSERT(bN == cN);
114 
115  FAISS_ASSERT(a.getStride(1) == 1);
116  FAISS_ASSERT(b.getStride(1) == 1);
117  FAISS_ASSERT(c.getStride(1) == 1);
118 
119  // Now, we have to represent the matrix multiplication in
120  // column-major layout
121  T* pA = transC ? a.data() : b.data();
122  T* pB = transC ? b.data() : a.data();
123  T* pC = c.data();
124 
125  int m = c.getSize(1); // stride 1 size
126  int n = c.getSize(0); // other size
127  int k = transA ? a.getSize(0) : a.getSize(1);
128 
129  int lda = transC ? a.getStride(0) : b.getStride(0);
130  int ldb = transC ? b.getStride(0) : a.getStride(0);
131  int ldc = c.getStride(0);
132 
133  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
134  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
135 
136  if (transC) {
137  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
138  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
139  }
140 
141  auto err = CublasGemm<T>::gemm(handle,
142  gemmTrA, gemmTrB,
143  m, n, k, alpha,
144  pA, lda, pB, ldb, beta,
145  pC, ldc, useHgemm);
146 
147  FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
148  "cublas failed (%d): %s "
149  "(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
150  (int) err,
151  useHgemm ? "Hgemm" : "Sgemm",
152  a.getSize(0), a.getSize(1), transA ? "'" : "",
153  b.getSize(0), b.getSize(1), transB ? "'" : "",
154  c.getSize(0), c.getSize(1), transC ? "'" : "");
155  CUDA_TEST_ERROR();
156 }
157 
158 void runMatrixMult(Tensor<float, 2, true>& c, bool transC,
159  Tensor<float, 2, true>& a, bool transA,
160  Tensor<float, 2, true>& b, bool transB,
161  float alpha,
162  float beta,
163  bool useHgemm,
164  cublasHandle_t handle,
165  cudaStream_t stream) {
166  return runMatrixMult<float>(c, transC, a, transA, b, transB,
167  alpha, beta, useHgemm, handle, stream);
168 }
169 
170 #ifdef FAISS_USE_FLOAT16
171 void runMatrixMult(Tensor<half, 2, true>& c, bool transC,
172  Tensor<half, 2, true>& a, bool transA,
173  Tensor<half, 2, true>& b, bool transB,
174  float alpha,
175  float beta,
176  bool useHgemm,
177  cublasHandle_t handle,
178  cudaStream_t stream) {
179  return runMatrixMult<half>(c, transC, a, transA, b, transB,
180  alpha, beta, useHgemm, handle, stream);
181 }
182 #endif
183 
184 void
185 runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
186  Tensor<float, 3, true>& a, bool transA,
187  Tensor<float, 3, true>& b, bool transB,
188  float alpha,
189  float beta,
190  cublasHandle_t handle,
191  cudaStream_t stream) {
192  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
193  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
194 
195  for (int i = 0; i < a.getSize(0); ++i) {
196  auto cView = c[i].view();
197  auto aView = a[i].view();
198  auto bView = b[i].view();
199 
200  runMatrixMult(cView, transC,
201  aView, transA,
202  bView, transB,
203  alpha, beta, false, handle, stream);
204  }
205 }
206 
207 void
208 runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
209  Tensor<float, 3, true>& a, bool transA,
210  Tensor<float, 3, true>& b, bool transB,
211  float alpha,
212  float beta,
213  DeviceMemory& mem,
214  cublasHandle_t handle,
215  cudaStream_t stream) {
216  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
217  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
218  cublasSetStream(handle, stream);
219 
220  // Check that we have (m x k) * (k x n) = (m x n)
221  // using the input row-major layout
222  int aM = transA ? a.getSize(2) : a.getSize(1);
223  int aK = transA ? a.getSize(1) : a.getSize(2);
224 
225  int bK = transB ? b.getSize(2) : b.getSize(1);
226  int bN = transB ? b.getSize(1) : b.getSize(2);
227 
228  int cM = transC ? c.getSize(2) : c.getSize(1);
229  int cN = transC ? c.getSize(1) : c.getSize(2);
230 
231  FAISS_ASSERT(aM == cM);
232  FAISS_ASSERT(aK == bK);
233  FAISS_ASSERT(bN == cN);
234 
235  // Now, we have to represent the matrix multiplication in
236  // column-major layout
237  float* pA = transC ? a.data() : b.data();
238  float* pB = transC ? b.data() : a.data();
239  float* pC = c.data();
240 
241  int m = c.getSize(2); // stride 1 size
242  int n = c.getSize(1); // other size
243  int k = transA ? a.getSize(1) : a.getSize(2);
244 
245  int lda = transC ? a.getStride(1) : b.getStride(1);
246  int ldb = transC ? b.getStride(1) : a.getStride(1);
247  int ldc = c.getStride(1);
248 
249  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
250  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
251 
252  if (transC) {
253  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
254  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
255  }
256 
257  HostTensor<float*, 1, true> hostA({a.getSize(0)});
258  HostTensor<float*, 1, true> hostB({b.getSize(0)});
259  HostTensor<float*, 1, true> hostC({c.getSize(0)});
260 
261  size_t aOffset = a.getStride(0);
262  size_t bOffset = b.getStride(0);
263  size_t cOffset = c.getStride(0);
264 
265  for (int i = 0; i < a.getSize(0); ++i) {
266  hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
267  hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
268  hostC[i] = c.data() + i * cOffset;
269  }
270 
271  DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
272  DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
273  DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
274 
275  auto err =
276  cublasSgemmBatched(handle,
277  gemmTrA, gemmTrB,
278  m, n, k, &alpha,
279  (const float**) deviceA.data(), lda,
280  (const float**) deviceB.data(), ldb, &beta,
281  deviceC.data(), ldc, a.getSize(0));
282  FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
283  "cublasSgemmBatched failed (%d)", (int) err);
284  CUDA_TEST_ERROR();
285 }
286 
287 } } // namespace