/** * Copyright (c) 2015-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the CC-by-NC license found in the * LICENSE file in the root directory of this source tree. */ // Copyright 2004-present Facebook. All Rights Reserved. #pragma once #include "DeviceDefs.cuh" #include "PtxUtils.cuh" #include "ReductionOperators.cuh" #include "StaticUtils.h" #include "WarpShuffles.cuh" #include namespace faiss { namespace gpu { template __device__ inline T warpReduceAll(T val, Op op) { #pragma unroll for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { val = op(val, shfl_xor(val, mask)); } return val; } /// Sums a register value across all warp threads template __device__ inline T warpReduceAllSum(T val) { return warpReduceAll, ReduceWidth>(val, Sum()); } /// Performs a block-wide reduction template __device__ inline T blockReduceAll(T val, Op op, T* smem) { int laneId = getLaneId(); int warpId = threadIdx.x / kWarpSize; val = warpReduceAll(val, op); if (laneId == 0) { smem[warpId] = val; } __syncthreads(); if (warpId == 0) { val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId] : op.identity(); val = warpReduceAll(val, op); if (BroadcastAll) { __threadfence_block(); if (laneId == 0) { smem[0] = val; } } } if (BroadcastAll) { __syncthreads(); val = smem[0]; } if (KillWARDependency) { __syncthreads(); } return val; } /// Performs a block-wide reduction of multiple values simultaneously template __device__ inline void blockReduceAll(T val[Num], Op op, T* smem) { int laneId = getLaneId(); int warpId = threadIdx.x / kWarpSize; #pragma unroll for (int i = 0; i < Num; ++i) { val[i] = warpReduceAll(val[i], op); } if (laneId == 0) { #pragma unroll for (int i = 0; i < Num; ++i) { smem[warpId * Num + i] = val[i]; } } __syncthreads(); if (warpId == 0) { #pragma unroll for (int i = 0; i < Num; ++i) { val[i] = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId * Num + i] : op.identity(); val[i] = warpReduceAll(val[i], op); } if (BroadcastAll) { __threadfence_block(); if (laneId == 0) { #pragma unroll for (int i = 0; i < Num; ++i) { smem[i] = val[i]; } } } } if (BroadcastAll) { __syncthreads(); #pragma unroll for (int i = 0; i < Num; ++i) { val[i] = smem[i]; } } if (KillWARDependency) { __syncthreads(); } } /// Sums a register value across the entire block template __device__ inline T blockReduceAllSum(T val, T* smem) { return blockReduceAll, BroadcastAll, KillWARDependency>( val, Sum(), smem); } template __device__ inline void blockReduceAllSum(T vals[Num], T* smem) { return blockReduceAll, BroadcastAll, KillWARDependency>( vals, Sum(), smem); } } } // namespace