Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Reductions.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #pragma once
13 
14 #include "DeviceDefs.cuh"
15 #include "PtxUtils.cuh"
16 #include "ReductionOperators.cuh"
17 #include "StaticUtils.h"
18 #include "WarpShuffles.cuh"
19 #include <cuda.h>
20 
21 namespace faiss { namespace gpu {
22 
23 template <typename T, typename Op, int ReduceWidth = kWarpSize>
24 __device__ inline T warpReduceAll(T val, Op op) {
25 #pragma unroll
26  for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
27  val = op(val, shfl_xor(val, mask));
28  }
29 
30  return val;
31 }
32 
33 /// Sums a register value across all warp threads
34 template <typename T, int ReduceWidth = kWarpSize>
35 __device__ inline T warpReduceAllSum(T val) {
36  return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
37 }
38 
39 /// Performs a block-wide reduction
40 template <typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
41 __device__ inline T blockReduceAll(T val, Op op, T* smem) {
42  int laneId = getLaneId();
43  int warpId = threadIdx.x / kWarpSize;
44 
45  val = warpReduceAll<T, Op>(val, op);
46  if (laneId == 0) {
47  smem[warpId] = val;
48  }
49  __syncthreads();
50 
51  if (warpId == 0) {
52  val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] :
53  op.identity();
54  val = warpReduceAll<T, Op>(val, op);
55 
56  if (BroadcastAll) {
57  __threadfence_block();
58 
59  if (laneId == 0) {
60  smem[0] = val;
61  }
62  }
63  }
64 
65  if (BroadcastAll) {
66  __syncthreads();
67  val = smem[0];
68  }
69 
70  if (KillWARDependency) {
71  __syncthreads();
72  }
73 
74  return val;
75 }
76 
77 /// Performs a block-wide reduction of multiple values simultaneously
78 template <int Num, typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
79 __device__ inline void blockReduceAll(T val[Num], Op op, T* smem) {
80  int laneId = getLaneId();
81  int warpId = threadIdx.x / kWarpSize;
82 
83 #pragma unroll
84  for (int i = 0; i < Num; ++i) {
85  val[i] = warpReduceAll<T, Op>(val[i], op);
86  }
87 
88  if (laneId == 0) {
89 #pragma unroll
90  for (int i = 0; i < Num; ++i) {
91  smem[warpId * Num + i] = val[i];
92  }
93  }
94 
95  __syncthreads();
96 
97  if (warpId == 0) {
98 #pragma unroll
99  for (int i = 0; i < Num; ++i) {
100  val[i] =
101  laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] :
102  op.identity();
103  val[i] = warpReduceAll<T, Op>(val[i], op);
104  }
105 
106  if (BroadcastAll) {
107  __threadfence_block();
108 
109  if (laneId == 0) {
110 #pragma unroll
111  for (int i = 0; i < Num; ++i) {
112  smem[i] = val[i];
113  }
114  }
115  }
116  }
117 
118  if (BroadcastAll) {
119  __syncthreads();
120 #pragma unroll
121  for (int i = 0; i < Num; ++i) {
122  val[i] = smem[i];
123  }
124  }
125 
126  if (KillWARDependency) {
127  __syncthreads();
128  }
129 }
130 
131 
132 /// Sums a register value across the entire block
133 template <typename T, bool BroadcastAll, bool KillWARDependency>
134 __device__ inline T blockReduceAllSum(T val, T* smem) {
135  return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
136  val, Sum<T>(), smem);
137 }
138 
139 template <int Num, typename T, bool BroadcastAll, bool KillWARDependency>
140 __device__ inline void blockReduceAllSum(T vals[Num], T* smem) {
141  return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
142  vals, Sum<T>(), smem);
143 }
144 
145 } } // namespace