Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Reductions.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #pragma once
11 
12 #include "DeviceDefs.cuh"
13 #include "PtxUtils.cuh"
14 #include "ReductionOperators.cuh"
15 #include "StaticUtils.h"
16 #include "WarpShuffles.cuh"
17 #include <cuda.h>
18 
19 namespace faiss { namespace gpu {
20 
21 template <typename T, typename Op, int ReduceWidth = kWarpSize>
22 __device__ inline T warpReduceAll(T val, Op op) {
23 #pragma unroll
24  for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
25  val = op(val, shfl_xor(val, mask));
26  }
27 
28  return val;
29 }
30 
31 /// Sums a register value across all warp threads
32 template <typename T, int ReduceWidth = kWarpSize>
33 __device__ inline T warpReduceAllSum(T val) {
34  return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
35 }
36 
37 /// Performs a block-wide reduction
38 template <typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
39 __device__ inline T blockReduceAll(T val, Op op, T* smem) {
40  int laneId = getLaneId();
41  int warpId = threadIdx.x / kWarpSize;
42 
43  val = warpReduceAll<T, Op>(val, op);
44  if (laneId == 0) {
45  smem[warpId] = val;
46  }
47  __syncthreads();
48 
49  if (warpId == 0) {
50  val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] :
51  op.identity();
52  val = warpReduceAll<T, Op>(val, op);
53 
54  if (BroadcastAll) {
55  __threadfence_block();
56 
57  if (laneId == 0) {
58  smem[0] = val;
59  }
60  }
61  }
62 
63  if (BroadcastAll) {
64  __syncthreads();
65  val = smem[0];
66  }
67 
68  if (KillWARDependency) {
69  __syncthreads();
70  }
71 
72  return val;
73 }
74 
75 /// Performs a block-wide reduction of multiple values simultaneously
76 template <int Num, typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
77 __device__ inline void blockReduceAll(T val[Num], Op op, T* smem) {
78  int laneId = getLaneId();
79  int warpId = threadIdx.x / kWarpSize;
80 
81 #pragma unroll
82  for (int i = 0; i < Num; ++i) {
83  val[i] = warpReduceAll<T, Op>(val[i], op);
84  }
85 
86  if (laneId == 0) {
87 #pragma unroll
88  for (int i = 0; i < Num; ++i) {
89  smem[warpId * Num + i] = val[i];
90  }
91  }
92 
93  __syncthreads();
94 
95  if (warpId == 0) {
96 #pragma unroll
97  for (int i = 0; i < Num; ++i) {
98  val[i] =
99  laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] :
100  op.identity();
101  val[i] = warpReduceAll<T, Op>(val[i], op);
102  }
103 
104  if (BroadcastAll) {
105  __threadfence_block();
106 
107  if (laneId == 0) {
108 #pragma unroll
109  for (int i = 0; i < Num; ++i) {
110  smem[i] = val[i];
111  }
112  }
113  }
114  }
115 
116  if (BroadcastAll) {
117  __syncthreads();
118 #pragma unroll
119  for (int i = 0; i < Num; ++i) {
120  val[i] = smem[i];
121  }
122  }
123 
124  if (KillWARDependency) {
125  __syncthreads();
126  }
127 }
128 
129 
130 /// Sums a register value across the entire block
131 template <typename T, bool BroadcastAll, bool KillWARDependency>
132 __device__ inline T blockReduceAllSum(T val, T* smem) {
133  return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
134  val, Sum<T>(), smem);
135 }
136 
137 template <int Num, typename T, bool BroadcastAll, bool KillWARDependency>
138 __device__ inline void blockReduceAllSum(T vals[Num], T* smem) {
139  return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
140  vals, Sum<T>(), smem);
141 }
142 
143 } } // namespace