13 #include "DeviceDefs.cuh"
14 #include "PtxUtils.cuh"
15 #include "ReductionOperators.cuh"
16 #include "StaticUtils.h"
17 #include "WarpShuffles.cuh"
20 namespace faiss {
namespace gpu {
22 template <
typename T,
typename Op,
int ReduceW
idth = kWarpSize>
23 __device__
inline T warpReduceAll(T val, Op op) {
25 for (
int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
26 val = op(val, shfl_xor(val, mask));
33 template <
typename T,
int ReduceW
idth = kWarpSize>
34 __device__
inline T warpReduceAllSum(T val) {
35 return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
39 template <
typename T,
typename Op,
bool BroadcastAll,
bool KillWARDependency>
40 __device__
inline T blockReduceAll(T val, Op op, T* smem) {
41 int laneId = getLaneId();
42 int warpId = threadIdx.x / kWarpSize;
44 val = warpReduceAll<T, Op>(val, op);
51 val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] :
53 val = warpReduceAll<T, Op>(val, op);
56 __threadfence_block();
69 if (KillWARDependency) {
77 template <
int Num,
typename T,
typename Op,
bool BroadcastAll,
bool KillWARDependency>
78 __device__
inline void blockReduceAll(T val[Num], Op op, T* smem) {
79 int laneId = getLaneId();
80 int warpId = threadIdx.x / kWarpSize;
83 for (
int i = 0; i < Num; ++i) {
84 val[i] = warpReduceAll<T, Op>(val[i], op);
89 for (
int i = 0; i < Num; ++i) {
90 smem[warpId * Num + i] = val[i];
98 for (
int i = 0; i < Num; ++i) {
100 laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] :
102 val[i] = warpReduceAll<T, Op>(val[i], op);
106 __threadfence_block();
110 for (
int i = 0; i < Num; ++i) {
120 for (
int i = 0; i < Num; ++i) {
125 if (KillWARDependency) {
132 template <
typename T,
bool BroadcastAll,
bool KillWARDependency>
133 __device__
inline T blockReduceAllSum(T val, T* smem) {
134 return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
135 val, Sum<T>(), smem);
138 template <
int Num,
typename T,
bool BroadcastAll,
bool KillWARDependency>
139 __device__
inline void blockReduceAllSum(T vals[Num], T* smem) {
140 return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
141 vals, Sum<T>(), smem);