14 #include "DeviceDefs.cuh"
15 #include "PtxUtils.cuh"
16 #include "ReductionOperators.cuh"
17 #include "StaticUtils.h"
18 #include "WarpShuffles.cuh"
21 namespace faiss {
namespace gpu {
23 template <
typename T,
typename Op,
int ReduceW
idth = kWarpSize>
24 __device__
inline T warpReduceAll(T val, Op op) {
26 for (
int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
27 val = op(val, shfl_xor(val, mask));
34 template <
typename T,
int ReduceW
idth = kWarpSize>
35 __device__
inline T warpReduceAllSum(T val) {
36 return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
40 template <
typename T,
typename Op,
bool BroadcastAll,
bool KillWARDependency>
41 __device__
inline T blockReduceAll(T val, Op op, T* smem) {
42 int laneId = getLaneId();
43 int warpId = threadIdx.x / kWarpSize;
45 val = warpReduceAll<T, Op>(val, op);
52 val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] :
54 val = warpReduceAll<T, Op>(val, op);
57 __threadfence_block();
70 if (KillWARDependency) {
78 template <
int Num,
typename T,
typename Op,
bool BroadcastAll,
bool KillWARDependency>
79 __device__
inline void blockReduceAll(T val[Num], Op op, T* smem) {
80 int laneId = getLaneId();
81 int warpId = threadIdx.x / kWarpSize;
84 for (
int i = 0; i < Num; ++i) {
85 val[i] = warpReduceAll<T, Op>(val[i], op);
90 for (
int i = 0; i < Num; ++i) {
91 smem[warpId * Num + i] = val[i];
99 for (
int i = 0; i < Num; ++i) {
101 laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] :
103 val[i] = warpReduceAll<T, Op>(val[i], op);
107 __threadfence_block();
111 for (
int i = 0; i < Num; ++i) {
121 for (
int i = 0; i < Num; ++i) {
126 if (KillWARDependency) {
133 template <
typename T,
bool BroadcastAll,
bool KillWARDependency>
134 __device__
inline T blockReduceAllSum(T val, T* smem) {
135 return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
136 val, Sum<T>(), smem);
139 template <
int Num,
typename T,
bool BroadcastAll,
bool KillWARDependency>
140 __device__
inline void blockReduceAllSum(T vals[Num], T* smem) {
141 return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
142 vals, Sum<T>(), smem);