Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MatrixMult.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include <cublas_v2.h>
14 #include "Float16.cuh"
15 #include "Tensor.cuh"
16 
17 namespace faiss { namespace gpu {
18 
19 class DeviceMemory;
20 
21 /// C = alpha * A * B + beta * C
22 /// Expects row major layout, not fortran/blas column major!
23 void runMatrixMult(Tensor<float, 2, true>& c, bool transC,
24  Tensor<float, 2, true>& a, bool transA,
25  Tensor<float, 2, true>& b, bool transB,
26  float alpha,
27  float beta,
28  bool useHgemm, // ignored for float32
29  cublasHandle_t handle,
30  cudaStream_t stream);
31 
32 #ifdef FAISS_USE_FLOAT16
33 /// C = alpha * A * B + beta * C
34 /// Expects row major layout, not fortran/blas column major!
35 void runMatrixMult(Tensor<half, 2, true>& c, bool transC,
36  Tensor<half, 2, true>& a, bool transA,
37  Tensor<half, 2, true>& b, bool transB,
38  float alpha,
39  float beta,
40  bool useHgemm,
41  cublasHandle_t handle,
42  cudaStream_t stream);
43 #endif
44 
45 /// C_i = alpha * A_i * B_i + beta * C_i
46 /// where `i` is the outermost dimension, via iterated gemm
47 /// Expects row major layout, not fortran/blas column major!
48 void runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
49  Tensor<float, 3, true>& a, bool transA,
50  Tensor<float, 3, true>& b, bool transB,
51  float alpha,
52  float beta,
53  cublasHandle_t handle,
54  cudaStream_t stream);
55 
56 /// C_i = alpha * A_i * B_i + beta * C_i
57 /// where `i` is the outermost dimension, via batched gemm
58 /// Expects row major layout, not fortran/blas column major!
59 void runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
60  Tensor<float, 3, true>& a, bool transA,
61  Tensor<float, 3, true>& b, bool transB,
62  float alpha,
63  float beta,
64  DeviceMemory& mem,
65  cublasHandle_t handle,
66  cudaStream_t stream);
67 
68 } } // namespace