Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BroadcastSum.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #pragma once
10 
11 #include "../utils/Float16.cuh"
12 #include "../utils/Tensor.cuh"
13 
14 namespace faiss { namespace gpu {
15 
16 // output[x][i] += input[i] for all x
17 void runSumAlongColumns(Tensor<float, 1, true>& input,
18  Tensor<float, 2, true>& output,
19  cudaStream_t stream);
20 
21 #ifdef FAISS_USE_FLOAT16
22 void runSumAlongColumns(Tensor<half, 1, true>& input,
23  Tensor<half, 2, true>& output,
24  cudaStream_t stream);
25 #endif
26 
27 // output[x][i] = input[i] for all x
28 void runAssignAlongColumns(Tensor<float, 1, true>& input,
29  Tensor<float, 2, true>& output,
30  cudaStream_t stream);
31 
32 #ifdef FAISS_USE_FLOAT16
33 void runAssignAlongColumns(Tensor<half, 1, true>& input,
34  Tensor<half, 2, true>& output,
35  cudaStream_t stream);
36 #endif
37 
38 // output[i][x] += input[i] for all x
39 // If zeroClamp, output[i][x] = max(output[i][x] + input[i], 0) for all x
40 void runSumAlongRows(Tensor<float, 1, true>& input,
41  Tensor<float, 2, true>& output,
42  bool zeroClamp,
43  cudaStream_t stream);
44 
45 #ifdef FAISS_USE_FLOAT16
46 void runSumAlongRows(Tensor<half, 1, true>& input,
47  Tensor<half, 2, true>& output,
48  bool zeroClamp,
49  cudaStream_t stream);
50 #endif
51 
52 } } // namespace