11 #include "DeviceDefs.cuh"
12 #include "PtxUtils.cuh"
13 #include "ReductionOperators.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
18 namespace faiss {
namespace gpu {
20 template <
typename T,
typename Op,
int ReduceW
idth = kWarpSize>
21 __device__
inline T warpReduceAll(T val, Op op) {
23 for (
int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
24 val = op(val, shfl_xor(val, mask));
31 template <
typename T,
int ReduceW
idth = kWarpSize>
32 __device__
inline T warpReduceAllSum(T val) {
33 return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
37 template <
typename T,
typename Op,
bool BroadcastAll,
bool KillWARDependency>
38 __device__
inline T blockReduceAll(T val, Op op, T* smem) {
39 int laneId = getLaneId();
40 int warpId = threadIdx.x / kWarpSize;
42 val = warpReduceAll<T, Op>(val, op);
49 val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] :
51 val = warpReduceAll<T, Op>(val, op);
54 __threadfence_block();
67 if (KillWARDependency) {
75 template <
int Num,
typename T,
typename Op,
bool BroadcastAll,
bool KillWARDependency>
76 __device__
inline void blockReduceAll(T val[Num], Op op, T* smem) {
77 int laneId = getLaneId();
78 int warpId = threadIdx.x / kWarpSize;
81 for (
int i = 0; i < Num; ++i) {
82 val[i] = warpReduceAll<T, Op>(val[i], op);
87 for (
int i = 0; i < Num; ++i) {
88 smem[warpId * Num + i] = val[i];
96 for (
int i = 0; i < Num; ++i) {
98 laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] :
100 val[i] = warpReduceAll<T, Op>(val[i], op);
104 __threadfence_block();
108 for (
int i = 0; i < Num; ++i) {
118 for (
int i = 0; i < Num; ++i) {
123 if (KillWARDependency) {
130 template <
typename T,
bool BroadcastAll,
bool KillWARDependency>
131 __device__
inline T blockReduceAllSum(T val, T* smem) {
132 return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
133 val, Sum<T>(), smem);
136 template <
int Num,
typename T,
bool BroadcastAll,
bool KillWARDependency>
137 __device__
inline void blockReduceAllSum(T vals[Num], T* smem) {
138 return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
139 vals, Sum<T>(), smem);