faiss/gpu/utils/Reductions.cuh

146 lines
3.3 KiB
Plaintext

/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the CC-by-NC license found in the
* LICENSE file in the root directory of this source tree.
*/
// Copyright 2004-present Facebook. All Rights Reserved.
#pragma once
#include "DeviceDefs.cuh"
#include "PtxUtils.cuh"
#include "ReductionOperators.cuh"
#include "StaticUtils.h"
#include "WarpShuffles.cuh"
#include <cuda.h>
namespace faiss { namespace gpu {
template <typename T, typename Op, int ReduceWidth = kWarpSize>
__device__ inline T warpReduceAll(T val, Op op) {
#pragma unroll
for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
val = op(val, shfl_xor(val, mask));
}
return val;
}
/// Sums a register value across all warp threads
template <typename T, int ReduceWidth = kWarpSize>
__device__ inline T warpReduceAllSum(T val) {
return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
}
/// Performs a block-wide reduction
template <typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
__device__ inline T blockReduceAll(T val, Op op, T* smem) {
int laneId = getLaneId();
int warpId = threadIdx.x / kWarpSize;
val = warpReduceAll<T, Op>(val, op);
if (laneId == 0) {
smem[warpId] = val;
}
__syncthreads();
if (warpId == 0) {
val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] :
op.identity();
val = warpReduceAll<T, Op>(val, op);
if (BroadcastAll) {
__threadfence_block();
if (laneId == 0) {
smem[0] = val;
}
}
}
if (BroadcastAll) {
__syncthreads();
val = smem[0];
}
if (KillWARDependency) {
__syncthreads();
}
return val;
}
/// Performs a block-wide reduction of multiple values simultaneously
template <int Num, typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
__device__ inline void blockReduceAll(T val[Num], Op op, T* smem) {
int laneId = getLaneId();
int warpId = threadIdx.x / kWarpSize;
#pragma unroll
for (int i = 0; i < Num; ++i) {
val[i] = warpReduceAll<T, Op>(val[i], op);
}
if (laneId == 0) {
#pragma unroll
for (int i = 0; i < Num; ++i) {
smem[warpId * Num + i] = val[i];
}
}
__syncthreads();
if (warpId == 0) {
#pragma unroll
for (int i = 0; i < Num; ++i) {
val[i] =
laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] :
op.identity();
val[i] = warpReduceAll<T, Op>(val[i], op);
}
if (BroadcastAll) {
__threadfence_block();
if (laneId == 0) {
#pragma unroll
for (int i = 0; i < Num; ++i) {
smem[i] = val[i];
}
}
}
}
if (BroadcastAll) {
__syncthreads();
#pragma unroll
for (int i = 0; i < Num; ++i) {
val[i] = smem[i];
}
}
if (KillWARDependency) {
__syncthreads();
}
}
/// Sums a register value across the entire block
template <typename T, bool BroadcastAll, bool KillWARDependency>
__device__ inline T blockReduceAllSum(T val, T* smem) {
return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
val, Sum<T>(), smem);
}
template <int Num, typename T, bool BroadcastAll, bool KillWARDependency>
__device__ inline void blockReduceAllSum(T vals[Num], T* smem) {
return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
vals, Sum<T>(), smem);
}
} } // namespace