Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MatrixMult.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #pragma once
13 
14 #include <cublas_v2.h>
15 #include "Float16.cuh"
16 #include "Tensor.cuh"
17 
18 namespace faiss { namespace gpu {
19 
20 class DeviceMemory;
21 
22 /// C = alpha * A * B + beta * C
23 /// Expects row major layout, not fortran/blas column major!
24 void runMatrixMult(Tensor<float, 2, true>& c, bool transC,
25  Tensor<float, 2, true>& a, bool transA,
26  Tensor<float, 2, true>& b, bool transB,
27  float alpha,
28  float beta,
29  cublasHandle_t handle,
30  cudaStream_t stream);
31 
32 #ifdef FAISS_USE_FLOAT16
33 /// C = alpha * A * B + beta * C
34 /// Expects row major layout, not fortran/blas column major!
35 void runMatrixMult(Tensor<half, 2, true>& c, bool transC,
36  Tensor<half, 2, true>& a, bool transA,
37  Tensor<half, 2, true>& b, bool transB,
38  float alpha,
39  float beta,
40  cublasHandle_t handle,
41  cudaStream_t stream);
42 #endif
43 
44 /// C_i = alpha * A_i * B_i + beta * C_i
45 /// where `i` is the outermost dimension, via iterated gemm
46 /// Expects row major layout, not fortran/blas column major!
47 void runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
48  Tensor<float, 3, true>& a, bool transA,
49  Tensor<float, 3, true>& b, bool transB,
50  float alpha,
51  float beta,
52  cublasHandle_t handle,
53  cudaStream_t stream);
54 
55 /// C_i = alpha * A_i * B_i + beta * C_i
56 /// where `i` is the outermost dimension, via batched gemm
57 /// Expects row major layout, not fortran/blas column major!
58 void runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
59  Tensor<float, 3, true>& a, bool transA,
60  Tensor<float, 3, true>& b, bool transB,
61  float alpha,
62  float beta,
63  DeviceMemory& mem,
64  cublasHandle_t handle,
65  cudaStream_t stream);
66 
67 } } // namespace