Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Reductions.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #pragma once
10 
11 #include "DeviceDefs.cuh"
12 #include "PtxUtils.cuh"
13 #include "ReductionOperators.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
16 #include <cuda.h>
17 
18 namespace faiss { namespace gpu {
19 
20 template <typename T, typename Op, int ReduceWidth = kWarpSize>
21 __device__ inline T warpReduceAll(T val, Op op) {
22 #pragma unroll
23  for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
24  val = op(val, shfl_xor(val, mask));
25  }
26 
27  return val;
28 }
29 
30 /// Sums a register value across all warp threads
31 template <typename T, int ReduceWidth = kWarpSize>
32 __device__ inline T warpReduceAllSum(T val) {
33  return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
34 }
35 
36 /// Performs a block-wide reduction
37 template <typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
38 __device__ inline T blockReduceAll(T val, Op op, T* smem) {
39  int laneId = getLaneId();
40  int warpId = threadIdx.x / kWarpSize;
41 
42  val = warpReduceAll<T, Op>(val, op);
43  if (laneId == 0) {
44  smem[warpId] = val;
45  }
46  __syncthreads();
47 
48  if (warpId == 0) {
49  val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] :
50  op.identity();
51  val = warpReduceAll<T, Op>(val, op);
52 
53  if (BroadcastAll) {
54  __threadfence_block();
55 
56  if (laneId == 0) {
57  smem[0] = val;
58  }
59  }
60  }
61 
62  if (BroadcastAll) {
63  __syncthreads();
64  val = smem[0];
65  }
66 
67  if (KillWARDependency) {
68  __syncthreads();
69  }
70 
71  return val;
72 }
73 
74 /// Performs a block-wide reduction of multiple values simultaneously
75 template <int Num, typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
76 __device__ inline void blockReduceAll(T val[Num], Op op, T* smem) {
77  int laneId = getLaneId();
78  int warpId = threadIdx.x / kWarpSize;
79 
80 #pragma unroll
81  for (int i = 0; i < Num; ++i) {
82  val[i] = warpReduceAll<T, Op>(val[i], op);
83  }
84 
85  if (laneId == 0) {
86 #pragma unroll
87  for (int i = 0; i < Num; ++i) {
88  smem[warpId * Num + i] = val[i];
89  }
90  }
91 
92  __syncthreads();
93 
94  if (warpId == 0) {
95 #pragma unroll
96  for (int i = 0; i < Num; ++i) {
97  val[i] =
98  laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] :
99  op.identity();
100  val[i] = warpReduceAll<T, Op>(val[i], op);
101  }
102 
103  if (BroadcastAll) {
104  __threadfence_block();
105 
106  if (laneId == 0) {
107 #pragma unroll
108  for (int i = 0; i < Num; ++i) {
109  smem[i] = val[i];
110  }
111  }
112  }
113  }
114 
115  if (BroadcastAll) {
116  __syncthreads();
117 #pragma unroll
118  for (int i = 0; i < Num; ++i) {
119  val[i] = smem[i];
120  }
121  }
122 
123  if (KillWARDependency) {
124  __syncthreads();
125  }
126 }
127 
128 
129 /// Sums a register value across the entire block
130 template <typename T, bool BroadcastAll, bool KillWARDependency>
131 __device__ inline T blockReduceAllSum(T val, T* smem) {
132  return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
133  val, Sum<T>(), smem);
134 }
135 
136 template <int Num, typename T, bool BroadcastAll, bool KillWARDependency>
137 __device__ inline void blockReduceAllSum(T vals[Num], T* smem) {
138  return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
139  vals, Sum<T>(), smem);
140 }
141 
142 } } // namespace