Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Reductions.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #pragma once
12 
13 #include "DeviceDefs.cuh"
14 #include "PtxUtils.cuh"
15 #include "ReductionOperators.cuh"
16 #include "StaticUtils.h"
17 #include "WarpShuffles.cuh"
18 #include <cuda.h>
19 
20 namespace faiss { namespace gpu {
21 
22 template <typename T, typename Op, int ReduceWidth = kWarpSize>
23 __device__ inline T warpReduceAll(T val, Op op) {
24 #pragma unroll
25  for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
26  val = op(val, shfl_xor(val, mask));
27  }
28 
29  return val;
30 }
31 
32 /// Sums a register value across all warp threads
33 template <typename T, int ReduceWidth = kWarpSize>
34 __device__ inline T warpReduceAllSum(T val) {
35  return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
36 }
37 
38 /// Performs a block-wide reduction
39 template <typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
40 __device__ inline T blockReduceAll(T val, Op op, T* smem) {
41  int laneId = getLaneId();
42  int warpId = threadIdx.x / kWarpSize;
43 
44  val = warpReduceAll<T, Op>(val, op);
45  if (laneId == 0) {
46  smem[warpId] = val;
47  }
48  __syncthreads();
49 
50  if (warpId == 0) {
51  val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] :
52  op.identity();
53  val = warpReduceAll<T, Op>(val, op);
54 
55  if (BroadcastAll) {
56  __threadfence_block();
57 
58  if (laneId == 0) {
59  smem[0] = val;
60  }
61  }
62  }
63 
64  if (BroadcastAll) {
65  __syncthreads();
66  val = smem[0];
67  }
68 
69  if (KillWARDependency) {
70  __syncthreads();
71  }
72 
73  return val;
74 }
75 
76 /// Performs a block-wide reduction of multiple values simultaneously
77 template <int Num, typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
78 __device__ inline void blockReduceAll(T val[Num], Op op, T* smem) {
79  int laneId = getLaneId();
80  int warpId = threadIdx.x / kWarpSize;
81 
82 #pragma unroll
83  for (int i = 0; i < Num; ++i) {
84  val[i] = warpReduceAll<T, Op>(val[i], op);
85  }
86 
87  if (laneId == 0) {
88 #pragma unroll
89  for (int i = 0; i < Num; ++i) {
90  smem[warpId * Num + i] = val[i];
91  }
92  }
93 
94  __syncthreads();
95 
96  if (warpId == 0) {
97 #pragma unroll
98  for (int i = 0; i < Num; ++i) {
99  val[i] =
100  laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] :
101  op.identity();
102  val[i] = warpReduceAll<T, Op>(val[i], op);
103  }
104 
105  if (BroadcastAll) {
106  __threadfence_block();
107 
108  if (laneId == 0) {
109 #pragma unroll
110  for (int i = 0; i < Num; ++i) {
111  smem[i] = val[i];
112  }
113  }
114  }
115  }
116 
117  if (BroadcastAll) {
118  __syncthreads();
119 #pragma unroll
120  for (int i = 0; i < Num; ++i) {
121  val[i] = smem[i];
122  }
123  }
124 
125  if (KillWARDependency) {
126  __syncthreads();
127  }
128 }
129 
130 
131 /// Sums a register value across the entire block
132 template <typename T, bool BroadcastAll, bool KillWARDependency>
133 __device__ inline T blockReduceAllSum(T val, T* smem) {
134  return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
135  val, Sum<T>(), smem);
136 }
137 
138 template <int Num, typename T, bool BroadcastAll, bool KillWARDependency>
139 __device__ inline void blockReduceAllSum(T vals[Num], T* smem) {
140  return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
141  vals, Sum<T>(), smem);
142 }
143 
144 } } // namespace