Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
L2Norm.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "L2Norm.cuh"
12 #include "../../FaissAssert.h"
13 #include "../utils/ConversionOperators.cuh"
14 #include "../utils/DeviceDefs.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/Float16.cuh"
17 #include "../utils/MathOperators.cuh"
18 #include "../utils/PtxUtils.cuh"
19 #include "../utils/StaticUtils.h"
20 #include "../utils/Reductions.cuh"
21 
22 namespace faiss { namespace gpu {
23 
24 // Input: (batch x dim), # repeats
25 // Output: (# repeats, norm of batch vector)
26 // Done under the presumption that the dimension size is not too large
27 // (<10k or so), since there wouldn't be enough parallelism applying a
28 // single block to the problem. Also that each vector is large enough
29 // (>64), since a single block works on multiple rows' norms at the
30 // same time.
31 // T: the type we are doing the math in (e.g., float, half)
32 // TVec: the potentially vectorized type we are loading in (e.g.,
33 // float4, half2)
34 template <typename T, typename TVec, typename TIndex,
35  int RowTileSize, bool NormLoop, bool NormSquared>
36 __global__ void l2Norm(Tensor<TVec, 2, true, TIndex> input,
37  Tensor<T, 1, true, TIndex> output) {
38  extern __shared__ char smemByte[]; // #warps * RowTileSize elements
39  T* smem = (T*) smemByte;
40 
41  TIndex numWarps = utils::divUp(blockDim.x, kWarpSize);
42  TIndex laneId = getLaneId();
43  TIndex warpId = threadIdx.x / kWarpSize;
44 
45  bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
46  TIndex rowStart = RowTileSize * blockIdx.x;
47  T rowNorm[RowTileSize];
48 
49  if (lastRowTile) {
50  // We are handling the very end of the input matrix rows
51  for (TIndex row = 0; row < input.getSize(0) - rowStart; ++row) {
52  if (NormLoop) {
53  rowNorm[0] = Math<T>::zero();
54 
55  for (TIndex col = threadIdx.x;
56  col < input.getSize(1); col += blockDim.x) {
57  TVec val = input[rowStart + row][col];
58  val = Math<TVec>::mul(val, val);
59  rowNorm[0] = Math<T>::add(rowNorm[0], Math<TVec>::reduceAdd(val));
60  }
61  } else {
62  TVec val = input[rowStart + row][threadIdx.x];
63  val = Math<TVec>::mul(val, val);
64  rowNorm[0] = Math<TVec>::reduceAdd(val);
65  }
66 
67  rowNorm[0] = warpReduceAllSum(rowNorm[0]);
68  if (laneId == 0) {
69  smem[row * numWarps + warpId] = rowNorm[0];
70  }
71  }
72  } else {
73  // We are guaranteed that all RowTileSize rows are available in
74  // [rowStart, rowStart + RowTileSize)
75 
76  if (NormLoop) {
77  // A single block of threads is not big enough to span each
78  // vector
79  TVec tmp[RowTileSize];
80 
81 #pragma unroll
82  for (int row = 0; row < RowTileSize; ++row) {
83  rowNorm[row] = Math<T>::zero();
84  }
85 
86  for (TIndex col = threadIdx.x;
87  col < input.getSize(1); col += blockDim.x) {
88 #pragma unroll
89  for (int row = 0; row < RowTileSize; ++row) {
90  tmp[row] = input[rowStart + row][col];
91  }
92 
93 #pragma unroll
94  for (int row = 0; row < RowTileSize; ++row) {
95  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
96  }
97 
98 #pragma unroll
99  for (int row = 0; row < RowTileSize; ++row) {
100  rowNorm[row] = Math<T>::add(rowNorm[row],
101  Math<TVec>::reduceAdd(tmp[row]));
102  }
103  }
104  } else {
105  TVec tmp[RowTileSize];
106 
107  // A block of threads is the exact size of the vector
108 #pragma unroll
109  for (int row = 0; row < RowTileSize; ++row) {
110  tmp[row] = input[rowStart + row][threadIdx.x];
111  }
112 
113 #pragma unroll
114  for (int row = 0; row < RowTileSize; ++row) {
115  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
116  }
117 
118 #pragma unroll
119  for (int row = 0; row < RowTileSize; ++row) {
120  rowNorm[row] = Math<TVec>::reduceAdd(tmp[row]);
121  }
122  }
123 
124  // Sum up all parts in each warp
125 #pragma unroll
126  for (int row = 0; row < RowTileSize; ++row) {
127  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
128  }
129 
130  if (laneId == 0) {
131 #pragma unroll
132  for (int row = 0; row < RowTileSize; ++row) {
133  smem[row * numWarps + warpId] = rowNorm[row];
134  }
135  }
136  }
137 
138  __syncthreads();
139 
140  // Sum across warps
141  if (warpId == 0) {
142 #pragma unroll
143  for (int row = 0; row < RowTileSize; ++row) {
144  rowNorm[row] = laneId < numWarps ?
145  smem[row * numWarps + laneId] : Math<T>::zero();
146  }
147 
148 #pragma unroll
149  for (int row = 0; row < RowTileSize; ++row) {
150  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
151  }
152 
153  // Write out answer
154  if (laneId == 0) {
155 #pragma unroll
156  for (int row = 0; row < RowTileSize; ++row) {
157  int outCol = rowStart + row;
158 
159  if (lastRowTile) {
160  if (outCol < output.getSize(0)) {
161  output[outCol] =
162  NormSquared ? rowNorm[row] :
163  ConvertTo<T>::to(
164  sqrtf(ConvertTo<float>::to(rowNorm[row])));
165  }
166  } else {
167  output[outCol] =
168  NormSquared ? rowNorm[row] :
169  ConvertTo<T>::to(
170  sqrtf(ConvertTo<float>::to(rowNorm[row])));
171  }
172  }
173  }
174  }
175 }
176 
177 template <typename T, typename TVec, typename TIndex>
178 void runL2Norm(Tensor<T, 2, true, TIndex>& input,
179  Tensor<T, 1, true, TIndex>& output,
180  bool normSquared,
181  cudaStream_t stream) {
182  FAISS_ASSERT(input.getSize(0) == output.getSize(0));
183 
184  TIndex maxThreads = (TIndex) getMaxThreadsCurrentDevice();
185  constexpr int rowTileSize = 8;
186 
187 #define RUN_L2(TYPE_T, TYPE_TVEC, INPUT) \
188  do { \
189  if (normLoop) { \
190  if (normSquared) { \
191  l2Norm<TYPE_T, TYPE_TVEC, TIndex, rowTileSize, true, true> \
192  <<<grid, block, smem, stream>>>(INPUT, output); \
193  } else { \
194  l2Norm<TYPE_T, TYPE_TVEC, TIndex, rowTileSize, true, false> \
195  <<<grid, block, smem, stream>>>(INPUT, output); \
196  } \
197  } else { \
198  if (normSquared) { \
199  l2Norm<TYPE_T, TYPE_TVEC, TIndex, rowTileSize, false, true> \
200  <<<grid, block, smem, stream>>>(INPUT, output); \
201  } else { \
202  l2Norm<TYPE_T, TYPE_TVEC, TIndex, rowTileSize, false, false> \
203  <<<grid, block, smem, stream>>>(INPUT, output); \
204  } \
205  } \
206  } while (0)
207 
208  if (input.template canCastResize<TVec>()) {
209  // Can load using the vectorized type
210  auto inputV = input.template castResize<TVec>();
211 
212  auto dim = inputV.getSize(1);
213  bool normLoop = dim > maxThreads;
214  auto numThreads = min(dim, maxThreads);
215 
216  auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
217  auto block = dim3(numThreads);
218 
219  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
220 
221  RUN_L2(T, TVec, inputV);
222  } else {
223  // Can't load using the vectorized type
224 
225  auto dim = input.getSize(1);
226  bool normLoop = dim > maxThreads;
227  auto numThreads = min(dim, maxThreads);
228 
229  auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
230  auto block = dim3(numThreads);
231 
232  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
233 
234  RUN_L2(T, T, input);
235  }
236 
237 #undef RUN_L2
238 
239  CUDA_TEST_ERROR();
240 }
241 
242 void runL2Norm(Tensor<float, 2, true>& input,
243  Tensor<float, 1, true>& output,
244  bool normSquared,
245  cudaStream_t stream) {
246  if (input.canUseIndexType<int>()) {
247  runL2Norm<float, float4, int>(input, output, normSquared, stream);
248  } else {
249  auto inputCast = input.castIndexType<long>();
250  auto outputCast = output.castIndexType<long>();
251  runL2Norm<float, float4, long>(inputCast, outputCast, normSquared, stream);
252  }
253 }
254 
255 #ifdef FAISS_USE_FLOAT16
256 void runL2Norm(Tensor<half, 2, true>& input,
257  Tensor<half, 1, true>& output,
258  bool normSquared,
259  cudaStream_t stream) {
260  if (input.canUseIndexType<int>()) {
261  runL2Norm<half, half2, int>(input, output, normSquared, stream);
262  } else {
263  auto inputCast = input.castIndexType<long>();
264  auto outputCast = output.castIndexType<long>();
265  runL2Norm<half, half2, long>(inputCast, outputCast, normSquared, stream);
266  }
267 }
268 #endif
269 
270 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)