12 #include "DeviceDefs.cuh"
13 #include "PtxUtils.cuh"
14 #include "ReductionOperators.cuh"
15 #include "StaticUtils.h"
16 #include "WarpShuffles.cuh"
19 namespace faiss {
namespace gpu {
21 template <
typename T,
typename Op,
int ReduceW
idth = kWarpSize>
22 __device__
inline T warpReduceAll(T val, Op op) {
24 for (
int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
25 val = op(val, shfl_xor(val, mask));
32 template <
typename T,
int ReduceW
idth = kWarpSize>
33 __device__
inline T warpReduceAllSum(T val) {
34 return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
38 template <
typename T,
typename Op,
bool BroadcastAll,
bool KillWARDependency>
39 __device__
inline T blockReduceAll(T val, Op op, T* smem) {
40 int laneId = getLaneId();
41 int warpId = threadIdx.x / kWarpSize;
43 val = warpReduceAll<T, Op>(val, op);
50 val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] :
52 val = warpReduceAll<T, Op>(val, op);
55 __threadfence_block();
68 if (KillWARDependency) {
76 template <
int Num,
typename T,
typename Op,
bool BroadcastAll,
bool KillWARDependency>
77 __device__
inline void blockReduceAll(T val[Num], Op op, T* smem) {
78 int laneId = getLaneId();
79 int warpId = threadIdx.x / kWarpSize;
82 for (
int i = 0; i < Num; ++i) {
83 val[i] = warpReduceAll<T, Op>(val[i], op);
88 for (
int i = 0; i < Num; ++i) {
89 smem[warpId * Num + i] = val[i];
97 for (
int i = 0; i < Num; ++i) {
99 laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] :
101 val[i] = warpReduceAll<T, Op>(val[i], op);
105 __threadfence_block();
109 for (
int i = 0; i < Num; ++i) {
119 for (
int i = 0; i < Num; ++i) {
124 if (KillWARDependency) {
131 template <
typename T,
bool BroadcastAll,
bool KillWARDependency>
132 __device__
inline T blockReduceAllSum(T val, T* smem) {
133 return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
134 val, Sum<T>(), smem);
137 template <
int Num,
typename T,
bool BroadcastAll,
bool KillWARDependency>
138 __device__
inline void blockReduceAllSum(T vals[Num], T* smem) {
139 return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
140 vals, Sum<T>(), smem);