Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MatrixMult.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #pragma once
11 
12 #include <cublas_v2.h>
13 #include "Float16.cuh"
14 #include "Tensor.cuh"
15 
16 namespace faiss { namespace gpu {
17 
18 class DeviceMemory;
19 
20 /// C = alpha * A * B + beta * C
21 /// Expects row major layout, not fortran/blas column major!
22 void runMatrixMult(Tensor<float, 2, true>& c, bool transC,
23  Tensor<float, 2, true>& a, bool transA,
24  Tensor<float, 2, true>& b, bool transB,
25  float alpha,
26  float beta,
27  bool useHgemm, // ignored for float32
28  cublasHandle_t handle,
29  cudaStream_t stream);
30 
31 #ifdef FAISS_USE_FLOAT16
32 /// C = alpha * A * B + beta * C
33 /// Expects row major layout, not fortran/blas column major!
34 void runMatrixMult(Tensor<half, 2, true>& c, bool transC,
35  Tensor<half, 2, true>& a, bool transA,
36  Tensor<half, 2, true>& b, bool transB,
37  float alpha,
38  float beta,
39  bool useHgemm,
40  cublasHandle_t handle,
41  cudaStream_t stream);
42 #endif
43 
44 /// C_i = alpha * A_i * B_i + beta * C_i
45 /// where `i` is the outermost dimension, via iterated gemm
46 /// Expects row major layout, not fortran/blas column major!
47 void runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
48  Tensor<float, 3, true>& a, bool transA,
49  Tensor<float, 3, true>& b, bool transB,
50  float alpha,
51  float beta,
52  cublasHandle_t handle,
53  cudaStream_t stream);
54 
55 /// C_i = alpha * A_i * B_i + beta * C_i
56 /// where `i` is the outermost dimension, via batched gemm
57 /// Expects row major layout, not fortran/blas column major!
58 void runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
59  Tensor<float, 3, true>& a, bool transA,
60  Tensor<float, 3, true>& b, bool transB,
61  float alpha,
62  float beta,
63  DeviceMemory& mem,
64  cublasHandle_t handle,
65  cudaStream_t stream);
66 
67 } } // namespace