12 #include "../../FaissAssert.h"
13 #include "../utils/ConversionOperators.cuh"
14 #include "../utils/DeviceDefs.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/Float16.cuh"
17 #include "../utils/MathOperators.cuh"
18 #include "../utils/PtxUtils.cuh"
19 #include "../utils/StaticUtils.h"
20 #include "../utils/Reductions.cuh"
22 namespace faiss {
namespace gpu {
34 template <
typename T,
typename TVec,
typename TIndex,
35 int RowTileSize,
bool NormLoop,
bool NormSquared>
36 __global__
void l2Norm(Tensor<TVec, 2, true, TIndex> input,
37 Tensor<T, 1, true, TIndex> output) {
38 extern __shared__
char smemByte[];
39 T* smem = (T*) smemByte;
41 TIndex numWarps = utils::divUp(blockDim.x, kWarpSize);
42 TIndex laneId = getLaneId();
43 TIndex warpId = threadIdx.x / kWarpSize;
45 bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
46 TIndex rowStart = RowTileSize * blockIdx.x;
47 T rowNorm[RowTileSize];
51 for (TIndex row = 0; row < input.getSize(0) - rowStart; ++row) {
53 rowNorm[0] = Math<T>::zero();
55 for (TIndex col = threadIdx.x;
56 col < input.getSize(1); col += blockDim.x) {
57 TVec val = input[rowStart + row][col];
58 val = Math<TVec>::mul(val, val);
59 rowNorm[0] = Math<T>::add(rowNorm[0], Math<TVec>::reduceAdd(val));
62 TVec val = input[rowStart + row][threadIdx.x];
63 val = Math<TVec>::mul(val, val);
67 rowNorm[0] = warpReduceAllSum(rowNorm[0]);
69 smem[row * numWarps + warpId] = rowNorm[0];
79 TVec tmp[RowTileSize];
82 for (
int row = 0; row < RowTileSize; ++row) {
83 rowNorm[row] = Math<T>::zero();
86 for (TIndex col = threadIdx.x;
87 col < input.getSize(1); col += blockDim.x) {
89 for (
int row = 0; row < RowTileSize; ++row) {
90 tmp[row] = input[rowStart + row][col];
94 for (
int row = 0; row < RowTileSize; ++row) {
95 tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
99 for (
int row = 0; row < RowTileSize; ++row) {
100 rowNorm[row] = Math<T>::add(rowNorm[row],
101 Math<TVec>::reduceAdd(tmp[row]));
105 TVec tmp[RowTileSize];
109 for (
int row = 0; row < RowTileSize; ++row) {
110 tmp[row] = input[rowStart + row][threadIdx.x];
114 for (
int row = 0; row < RowTileSize; ++row) {
115 tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
119 for (
int row = 0; row < RowTileSize; ++row) {
126 for (
int row = 0; row < RowTileSize; ++row) {
127 rowNorm[row] = warpReduceAllSum(rowNorm[row]);
132 for (
int row = 0; row < RowTileSize; ++row) {
133 smem[row * numWarps + warpId] = rowNorm[row];
143 for (
int row = 0; row < RowTileSize; ++row) {
144 rowNorm[row] = laneId < numWarps ?
145 smem[row * numWarps + laneId] : Math<T>::zero();
149 for (
int row = 0; row < RowTileSize; ++row) {
150 rowNorm[row] = warpReduceAllSum(rowNorm[row]);
156 for (
int row = 0; row < RowTileSize; ++row) {
157 int outCol = rowStart + row;
160 if (outCol < output.getSize(0)) {
162 NormSquared ? rowNorm[row] :
164 sqrtf(ConvertTo<float>::to(rowNorm[row])));
168 NormSquared ? rowNorm[row] :
170 sqrtf(ConvertTo<float>::to(rowNorm[row])));
177 template <
typename T,
typename TVec,
typename TIndex>
178 void runL2Norm(Tensor<T, 2, true, TIndex>& input,
179 Tensor<T, 1, true, TIndex>& output,
181 cudaStream_t stream) {
182 FAISS_ASSERT(input.getSize(0) == output.getSize(0));
184 TIndex maxThreads = (TIndex) getMaxThreadsCurrentDevice();
185 constexpr
int rowTileSize = 8;
187 #define RUN_L2(TYPE_T, TYPE_TVEC, INPUT) \
191 l2Norm<TYPE_T, TYPE_TVEC, TIndex, rowTileSize, true, true> \
192 <<<grid, block, smem, stream>>>(INPUT, output); \
194 l2Norm<TYPE_T, TYPE_TVEC, TIndex, rowTileSize, true, false> \
195 <<<grid, block, smem, stream>>>(INPUT, output); \
199 l2Norm<TYPE_T, TYPE_TVEC, TIndex, rowTileSize, false, true> \
200 <<<grid, block, smem, stream>>>(INPUT, output); \
202 l2Norm<TYPE_T, TYPE_TVEC, TIndex, rowTileSize, false, false> \
203 <<<grid, block, smem, stream>>>(INPUT, output); \
208 if (input.template canCastResize<TVec>()) {
210 auto inputV = input.template castResize<TVec>();
212 auto dim = inputV.getSize(1);
213 bool normLoop = dim > maxThreads;
214 auto numThreads = min(dim, maxThreads);
216 auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
217 auto block = dim3(numThreads);
219 auto smem =
sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
221 RUN_L2(T, TVec, inputV);
225 auto dim = input.getSize(1);
226 bool normLoop = dim > maxThreads;
227 auto numThreads = min(dim, maxThreads);
229 auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
230 auto block = dim3(numThreads);
232 auto smem =
sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
242 void runL2Norm(Tensor<float, 2, true>& input,
243 Tensor<float, 1, true>& output,
245 cudaStream_t stream) {
246 if (input.canUseIndexType<
int>()) {
247 runL2Norm<float, float4, int>(input, output, normSquared, stream);
249 auto inputCast = input.castIndexType<
long>();
250 auto outputCast = output.castIndexType<
long>();
251 runL2Norm<float, float4, long>(inputCast, outputCast, normSquared, stream);
255 #ifdef FAISS_USE_FLOAT16
256 void runL2Norm(Tensor<half, 2, true>& input,
257 Tensor<half, 1, true>& output,
259 cudaStream_t stream) {
260 if (input.canUseIndexType<
int>()) {
261 runL2Norm<half, half2, int>(input, output, normSquared, stream);
263 auto inputCast = input.castIndexType<
long>();
264 auto outputCast = output.castIndexType<
long>();
265 runL2Norm<half, half2, long>(inputCast, outputCast, normSquared, stream);
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)