Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MatrixMult.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "MatrixMult.cuh"
13 #include "DeviceMemory.h"
14 #include "DeviceUtils.h" // CUDA_VERIFY
15 #include "DeviceTensor.cuh"
16 #include "HostTensor.cuh"
17 
18 namespace faiss { namespace gpu {
19 
20 template <typename T>
21 struct CublasGemm {
22 };
23 
24 template <>
25 struct CublasGemm<float> {
26  static cublasStatus_t gemm(cublasHandle_t handle,
27  cublasOperation_t transa,
28  cublasOperation_t transb,
29  int m,
30  int n,
31  int k,
32  float fAlpha,
33  const float *A,
34  int lda,
35  const float *B,
36  int ldb,
37  float fBeta,
38  float *C,
39  int ldc) {
40  return cublasSgemm(handle, transa, transb, m, n, k,
41  &fAlpha, A, lda, B, ldb, &fBeta, C, ldc);
42  }
43 };
44 
45 #ifdef FAISS_USE_FLOAT16
46 template <>
47 struct CublasGemm<half> {
48  static cublasStatus_t gemm(cublasHandle_t handle,
49  cublasOperation_t transa,
50  cublasOperation_t transb,
51  int m,
52  int n,
53  int k,
54  const float fAlpha,
55  const half *A,
56  int lda,
57  const half *B,
58  int ldb,
59  const float fBeta,
60  half *C,
61  int ldc) {
62  if (getDeviceSupportsFloat16Math(getCurrentDevice())) {
63  half hAlpha = hostFloat2Half(fAlpha);
64  half hBeta = hostFloat2Half(fBeta);
65 
66  return cublasHgemm(handle, transa, transb, m, n, k,
67  &hAlpha, A, lda, B, ldb, &hBeta, C, ldc);
68  }
69 
70  // CUDA 8.0 changes the half datatype specifier
71 #if CUDA_VERSION == 7050
72  auto halfType = CUBLAS_DATA_HALF;
73 #else
74  auto halfType = CUDA_R_16F;
75 #endif // CUDA_VERSION
76 
77  return cublasSgemmEx(handle, transa, transb, m, n, k,
78  &fAlpha, A, halfType, lda,
79  B, halfType, ldb,
80  &fBeta,
81  C, halfType, ldc);
82  }
83 };
84 #endif // FAISS_USE_FLOAT16
85 
86 
87 template <typename T>
88 void
89 runMatrixMult(Tensor<T, 2, true>& c, bool transC,
90  Tensor<T, 2, true>& a, bool transA,
91  Tensor<T, 2, true>& b, bool transB,
92  float alpha,
93  float beta,
94  cublasHandle_t handle,
95  cudaStream_t stream) {
96  cublasSetStream(handle, stream);
97 
98  // Check that we have (m x k) * (k x n) = (m x n)
99  // using the input row-major layout
100  int aM = transA ? a.getSize(1) : a.getSize(0);
101  int aK = transA ? a.getSize(0) : a.getSize(1);
102 
103  int bK = transB ? b.getSize(1) : b.getSize(0);
104  int bN = transB ? b.getSize(0) : b.getSize(1);
105 
106  int cM = transC ? c.getSize(1) : c.getSize(0);
107  int cN = transC ? c.getSize(0) : c.getSize(1);
108 
109  FAISS_ASSERT(aM == cM);
110  FAISS_ASSERT(aK == bK);
111  FAISS_ASSERT(bN == cN);
112 
113  // Now, we have to represent the matrix multiplication in
114  // column-major layout
115  T* pA = transC ? a.data() : b.data();
116  T* pB = transC ? b.data() : a.data();
117  T* pC = c.data();
118 
119  int m = c.getSize(1); // stride 1 size
120  int n = c.getSize(0); // other size
121  int k = transA ? a.getSize(0) : a.getSize(1);
122 
123  int lda = transC ? a.getSize(1) : b.getSize(1);
124  int ldb = transC ? b.getSize(1) : a.getSize(1);
125  int ldc = c.getSize(1);
126 
127  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
128  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
129 
130  if (transC) {
131  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
132  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
133  }
134 
135  auto err = CublasGemm<T>::gemm(handle,
136  gemmTrA, gemmTrB,
137  m, n, k, alpha,
138  pA, lda, pB, ldb, beta,
139  pC, ldc);
140 
141  FAISS_ASSERT(err == CUBLAS_STATUS_SUCCESS);
142  CUDA_VERIFY(cudaGetLastError());
143 }
144 
145 void runMatrixMult(Tensor<float, 2, true>& c, bool transC,
146  Tensor<float, 2, true>& a, bool transA,
147  Tensor<float, 2, true>& b, bool transB,
148  float alpha,
149  float beta,
150  cublasHandle_t handle,
151  cudaStream_t stream) {
152  return runMatrixMult<float>(c, transC, a, transA, b, transB,
153  alpha, beta, handle, stream);
154 }
155 
156 #ifdef FAISS_USE_FLOAT16
157 void runMatrixMult(Tensor<half, 2, true>& c, bool transC,
158  Tensor<half, 2, true>& a, bool transA,
159  Tensor<half, 2, true>& b, bool transB,
160  float alpha,
161  float beta,
162  cublasHandle_t handle,
163  cudaStream_t stream) {
164  return runMatrixMult<half>(c, transC, a, transA, b, transB,
165  alpha, beta, handle, stream);
166 }
167 #endif
168 
169 void
170 runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
171  Tensor<float, 3, true>& a, bool transA,
172  Tensor<float, 3, true>& b, bool transB,
173  float alpha,
174  float beta,
175  cublasHandle_t handle,
176  cudaStream_t stream) {
177  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
178  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
179 
180  for (int i = 0; i < a.getSize(0); ++i) {
181  auto cView = c[i].view();
182  auto aView = a[i].view();
183  auto bView = b[i].view();
184 
185  runMatrixMult(cView, transC,
186  aView, transA,
187  bView, transB,
188  alpha, beta, handle, stream);
189  }
190 }
191 
192 void
193 runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
194  Tensor<float, 3, true>& a, bool transA,
195  Tensor<float, 3, true>& b, bool transB,
196  float alpha,
197  float beta,
198  DeviceMemory& mem,
199  cublasHandle_t handle,
200  cudaStream_t stream) {
201  FAISS_ASSERT(c.getSize(0) == a.getSize(0));
202  FAISS_ASSERT(a.getSize(0) == b.getSize(0));
203  cublasSetStream(handle, stream);
204 
205  // Check that we have (m x k) * (k x n) = (m x n)
206  // using the input row-major layout
207  int aM = transA ? a.getSize(2) : a.getSize(1);
208  int aK = transA ? a.getSize(1) : a.getSize(2);
209 
210  int bK = transB ? b.getSize(2) : b.getSize(1);
211  int bN = transB ? b.getSize(1) : b.getSize(2);
212 
213  int cM = transC ? c.getSize(2) : c.getSize(1);
214  int cN = transC ? c.getSize(1) : c.getSize(2);
215 
216  FAISS_ASSERT(aM == cM);
217  FAISS_ASSERT(aK == bK);
218  FAISS_ASSERT(bN == cN);
219 
220  // Now, we have to represent the matrix multiplication in
221  // column-major layout
222  float* pA = transC ? a.data() : b.data();
223  float* pB = transC ? b.data() : a.data();
224  float* pC = c.data();
225 
226  int m = c.getSize(2); // stride 1 size
227  int n = c.getSize(1); // other size
228  int k = transA ? a.getSize(1) : a.getSize(2);
229 
230  int lda = transC ? a.getSize(2) : b.getSize(2);
231  int ldb = transC ? b.getSize(2) : a.getSize(2);
232  int ldc = c.getSize(2);
233 
234  auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
235  auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
236 
237  if (transC) {
238  gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
239  gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
240  }
241 
242  HostTensor<float*, 1, true> hostA({a.getSize(0)});
243  HostTensor<float*, 1, true> hostB({b.getSize(0)});
244  HostTensor<float*, 1, true> hostC({c.getSize(0)});
245 
246  size_t aOffset = a.getSize(1) * a.getSize(2);
247  size_t bOffset = b.getSize(1) * b.getSize(2);
248  size_t cOffset = c.getSize(1) * c.getSize(2);
249 
250  for (int i = 0; i < a.getSize(0); ++i) {
251  hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
252  hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
253  hostC[i] = c.data() + i * cOffset;
254  }
255 
256  DeviceTensor<float*, 1, true> deviceA(mem, hostA, stream);
257  DeviceTensor<float*, 1, true> deviceB(mem, hostB, stream);
258  DeviceTensor<float*, 1, true> deviceC(mem, hostC, stream);
259 
260  auto err =
261  cublasSgemmBatched(handle,
262  gemmTrA, gemmTrB,
263  m, n, k, &alpha,
264  (const float**) deviceA.data(), lda,
265  (const float**) deviceB.data(), ldb, &beta,
266  deviceC.data(), ldc, a.getSize(0));
267  FAISS_ASSERT(err == CUBLAS_STATUS_SUCCESS);
268  CUDA_VERIFY(cudaGetLastError());
269 }
270 
271 } } // namespace