Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BroadcastSum.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #pragma once
11 
12 #include "../utils/Float16.cuh"
13 #include "../utils/Tensor.cuh"
14 
15 namespace faiss { namespace gpu {
16 
17 // output[x][i] += input[i] for all x
18 void runSumAlongColumns(Tensor<float, 1, true>& input,
19  Tensor<float, 2, true>& output,
20  cudaStream_t stream);
21 
22 #ifdef FAISS_USE_FLOAT16
23 void runSumAlongColumns(Tensor<half, 1, true>& input,
24  Tensor<half, 2, true>& output,
25  cudaStream_t stream);
26 #endif
27 
28 // output[x][i] = input[i] for all x
29 void runAssignAlongColumns(Tensor<float, 1, true>& input,
30  Tensor<float, 2, true>& output,
31  cudaStream_t stream);
32 
33 #ifdef FAISS_USE_FLOAT16
34 void runAssignAlongColumns(Tensor<half, 1, true>& input,
35  Tensor<half, 2, true>& output,
36  cudaStream_t stream);
37 #endif
38 
39 // output[i][x] += input[i] for all x
40 // If zeroClamp, output[i][x] = max(output[i][x] + input[i], 0) for all x
41 void runSumAlongRows(Tensor<float, 1, true>& input,
42  Tensor<float, 2, true>& output,
43  bool zeroClamp,
44  cudaStream_t stream);
45 
46 #ifdef FAISS_USE_FLOAT16
47 void runSumAlongRows(Tensor<half, 1, true>& input,
48  Tensor<half, 2, true>& output,
49  bool zeroClamp,
50  cudaStream_t stream);
51 #endif
52 
53 } } // namespace