Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MatrixMult.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #pragma once
10 
11 #include <cublas_v2.h>
12 #include "Float16.cuh"
13 #include "Tensor.cuh"
14 
15 namespace faiss { namespace gpu {
16 
17 class DeviceMemory;
18 
19 /// C = alpha * A * B + beta * C
20 /// Expects row major layout, not fortran/blas column major!
21 void runMatrixMult(Tensor<float, 2, true>& c, bool transC,
22  Tensor<float, 2, true>& a, bool transA,
23  Tensor<float, 2, true>& b, bool transB,
24  float alpha,
25  float beta,
26  bool useHgemm, // ignored for float32
27  cublasHandle_t handle,
28  cudaStream_t stream);
29 
30 #ifdef FAISS_USE_FLOAT16
31 /// C = alpha * A * B + beta * C
32 /// Expects row major layout, not fortran/blas column major!
33 void runMatrixMult(Tensor<half, 2, true>& c, bool transC,
34  Tensor<half, 2, true>& a, bool transA,
35  Tensor<half, 2, true>& b, bool transB,
36  float alpha,
37  float beta,
38  bool useHgemm,
39  cublasHandle_t handle,
40  cudaStream_t stream);
41 #endif
42 
43 /// C_i = alpha * A_i * B_i + beta * C_i
44 /// where `i` is the outermost dimension, via iterated gemm
45 /// Expects row major layout, not fortran/blas column major!
46 void runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
47  Tensor<float, 3, true>& a, bool transA,
48  Tensor<float, 3, true>& b, bool transB,
49  float alpha,
50  float beta,
51  cublasHandle_t handle,
52  cudaStream_t stream);
53 
54 /// C_i = alpha * A_i * B_i + beta * C_i
55 /// where `i` is the outermost dimension, via batched gemm
56 /// Expects row major layout, not fortran/blas column major!
57 void runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
58  Tensor<float, 3, true>& a, bool transA,
59  Tensor<float, 3, true>& b, bool transB,
60  float alpha,
61  float beta,
62  DeviceMemory& mem,
63  cublasHandle_t handle,
64  cudaStream_t stream);
65 
66 } } // namespace