Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
L2Norm.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "L2Norm.cuh"
10 #include "../../FaissAssert.h"
11 #include "../utils/ConversionOperators.cuh"
12 #include "../utils/DeviceDefs.cuh"
13 #include "../utils/DeviceUtils.h"
14 #include "../utils/Float16.cuh"
15 #include "../utils/MathOperators.cuh"
16 #include "../utils/PtxUtils.cuh"
17 #include "../utils/StaticUtils.h"
18 #include "../utils/Reductions.cuh"
19 
20 namespace faiss { namespace gpu {
21 
22 // Input: (batch x dim)
23 // Output: (batch norm)
24 // Done under the presumption that the dimension size is not too large
25 // (<10k or so), since there wouldn't be enough parallelism applying a
26 // single block to the problem. Also that each vector is large enough
27 // (>64), since a single block works on multiple rows' norms at the
28 // same time.
29 // T: the type we are doing the math in (e.g., float, half)
30 // TVec: the potentially vectorized type we are loading in (e.g.,
31 // float4, half2)
32 template <typename T, typename TVec, typename IndexType,
33  int RowTileSize, bool NormLoop, bool NormSquared>
34 __global__ void
35 l2NormRowMajor(Tensor<TVec, 2, true, IndexType> input,
36  Tensor<T, 1, true, IndexType> output) {
37  extern __shared__ char smemByte[]; // #warps * RowTileSize elements
38  T* smem = (T*) smemByte;
39 
40  IndexType numWarps = utils::divUp(blockDim.x, kWarpSize);
41  IndexType laneId = getLaneId();
42  IndexType warpId = threadIdx.x / kWarpSize;
43 
44  bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
45  IndexType rowStart = RowTileSize * blockIdx.x;
46  T rowNorm[RowTileSize];
47 
48  if (lastRowTile) {
49  // We are handling the very end of the input matrix rows
50  for (IndexType row = 0; row < input.getSize(0) - rowStart; ++row) {
51  if (NormLoop) {
52  rowNorm[0] = Math<T>::zero();
53 
54  for (IndexType col = threadIdx.x;
55  col < input.getSize(1); col += blockDim.x) {
56  TVec val = input[rowStart + row][col];
57  val = Math<TVec>::mul(val, val);
58  rowNorm[0] = Math<T>::add(rowNorm[0], Math<TVec>::reduceAdd(val));
59  }
60  } else {
61  TVec val = input[rowStart + row][threadIdx.x];
62  val = Math<TVec>::mul(val, val);
63  rowNorm[0] = Math<TVec>::reduceAdd(val);
64  }
65 
66  rowNorm[0] = warpReduceAllSum(rowNorm[0]);
67  if (laneId == 0) {
68  smem[row * numWarps + warpId] = rowNorm[0];
69  }
70  }
71  } else {
72  // We are guaranteed that all RowTileSize rows are available in
73  // [rowStart, rowStart + RowTileSize)
74 
75  if (NormLoop) {
76  // A single block of threads is not big enough to span each
77  // vector
78  TVec tmp[RowTileSize];
79 
80 #pragma unroll
81  for (int row = 0; row < RowTileSize; ++row) {
82  rowNorm[row] = Math<T>::zero();
83  }
84 
85  for (IndexType col = threadIdx.x;
86  col < input.getSize(1); col += blockDim.x) {
87 #pragma unroll
88  for (int row = 0; row < RowTileSize; ++row) {
89  tmp[row] = input[rowStart + row][col];
90  }
91 
92 #pragma unroll
93  for (int row = 0; row < RowTileSize; ++row) {
94  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
95  }
96 
97 #pragma unroll
98  for (int row = 0; row < RowTileSize; ++row) {
99  rowNorm[row] = Math<T>::add(rowNorm[row],
100  Math<TVec>::reduceAdd(tmp[row]));
101  }
102  }
103  } else {
104  TVec tmp[RowTileSize];
105 
106  // A block of threads is the exact size of the vector
107 #pragma unroll
108  for (int row = 0; row < RowTileSize; ++row) {
109  tmp[row] = input[rowStart + row][threadIdx.x];
110  }
111 
112 #pragma unroll
113  for (int row = 0; row < RowTileSize; ++row) {
114  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
115  }
116 
117 #pragma unroll
118  for (int row = 0; row < RowTileSize; ++row) {
119  rowNorm[row] = Math<TVec>::reduceAdd(tmp[row]);
120  }
121  }
122 
123  // Sum up all parts in each warp
124 #pragma unroll
125  for (int row = 0; row < RowTileSize; ++row) {
126  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
127  }
128 
129  if (laneId == 0) {
130 #pragma unroll
131  for (int row = 0; row < RowTileSize; ++row) {
132  smem[row * numWarps + warpId] = rowNorm[row];
133  }
134  }
135  }
136 
137  __syncthreads();
138 
139  // Sum across warps
140  if (warpId == 0) {
141 #pragma unroll
142  for (int row = 0; row < RowTileSize; ++row) {
143  rowNorm[row] = laneId < numWarps ?
144  smem[row * numWarps + laneId] : Math<T>::zero();
145  }
146 
147 #pragma unroll
148  for (int row = 0; row < RowTileSize; ++row) {
149  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
150  }
151 
152  // Write out answer
153  if (laneId == 0) {
154 #pragma unroll
155  for (int row = 0; row < RowTileSize; ++row) {
156  int outCol = rowStart + row;
157 
158  if (lastRowTile) {
159  if (outCol < output.getSize(0)) {
160  output[outCol] =
161  NormSquared ? rowNorm[row] :
162  ConvertTo<T>::to(
163  sqrtf(ConvertTo<float>::to(rowNorm[row])));
164  }
165  } else {
166  output[outCol] =
167  NormSquared ? rowNorm[row] :
168  ConvertTo<T>::to(
169  sqrtf(ConvertTo<float>::to(rowNorm[row])));
170  }
171  }
172  }
173  }
174 }
175 
176 // Input: (dim x batch)
177 // Output: (batch norm)
178 // Handles the case where `input` is column major. A single thread calculates
179 // the norm of each vector instead of a block-wide reduction.
180 template <typename T, typename IndexType, bool NormSquared>
181 __global__ void
182 l2NormColMajor(Tensor<T, 2, true, IndexType> input,
183  Tensor<T, 1, true, IndexType> output) {
184  // grid-stride loop to handle all batch elements
185  for (IndexType batch = blockIdx.x * blockDim.x + threadIdx.x;
186  batch < input.getSize(1);
187  batch += gridDim.x * blockDim.x) {
188  float sum = 0;
189 
190  // This is still a coalesced load from the memory
191  for (IndexType dim = 0; dim < input.getSize(0); ++dim) {
192  // Just do the math in float32, even if the input is float16
193  float v = ConvertTo<float>::to(input[dim][batch]);
194  sum += v * v;
195  }
196 
197  if (!NormSquared) {
198  sum = sqrtf(sum);
199  }
200 
201  output[batch] = ConvertTo<T>::to(sum);
202  }
203 }
204 
205 template <typename T, typename TVec, typename IndexType>
206 void runL2Norm(Tensor<T, 2, true, IndexType>& input,
207  bool inputRowMajor,
208  Tensor<T, 1, true, IndexType>& output,
209  bool normSquared,
210  cudaStream_t stream) {
211  IndexType maxThreads = (IndexType) getMaxThreadsCurrentDevice();
212  constexpr int rowTileSize = 8;
213 
214 #define RUN_L2_ROW_MAJOR(TYPE_T, TYPE_TVEC, INPUT) \
215  do { \
216  if (normLoop) { \
217  if (normSquared) { \
218  l2NormRowMajor<TYPE_T, TYPE_TVEC, IndexType, rowTileSize, true, true> \
219  <<<grid, block, smem, stream>>>(INPUT, output); \
220  } else { \
221  l2NormRowMajor<TYPE_T, TYPE_TVEC, IndexType, rowTileSize, true, false> \
222  <<<grid, block, smem, stream>>>(INPUT, output); \
223  } \
224  } else { \
225  if (normSquared) { \
226  l2NormRowMajor<TYPE_T, TYPE_TVEC, IndexType, rowTileSize, false, true> \
227  <<<grid, block, smem, stream>>>(INPUT, output); \
228  } else { \
229  l2NormRowMajor<TYPE_T, TYPE_TVEC, IndexType, rowTileSize, false, false> \
230  <<<grid, block, smem, stream>>>(INPUT, output); \
231  } \
232  } \
233  } while (0)
234 
235  if (inputRowMajor) {
236  //
237  // Row-major kernel
238  ///
239 
240  if (input.template canCastResize<TVec>()) {
241  // Can load using the vectorized type
242  auto inputV = input.template castResize<TVec>();
243 
244  auto dim = inputV.getSize(1);
245  bool normLoop = dim > maxThreads;
246  auto numThreads = min(dim, maxThreads);
247 
248  auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
249  auto block = dim3(numThreads);
250 
251  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
252 
253  RUN_L2_ROW_MAJOR(T, TVec, inputV);
254  } else {
255  // Can't load using the vectorized type
256 
257  auto dim = input.getSize(1);
258  bool normLoop = dim > maxThreads;
259  auto numThreads = min(dim, maxThreads);
260 
261  auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
262  auto block = dim3(numThreads);
263 
264  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
265 
266  RUN_L2_ROW_MAJOR(T, T, input);
267  }
268  } else {
269  //
270  // Column-major kernel
271  //
272 
273  // Just use a fixed-sized block, since the kernel threads are fully
274  // independent
275  auto block = 128;
276 
277  // Cap the grid size at 2^16 since there is a grid-stride loop to handle
278  // processing everything
279  auto grid = (int)
280  std::min(utils::divUp(input.getSize(1), (IndexType) block),
281  (IndexType) 65536);
282 
283  if (normSquared) {
284  l2NormColMajor<T, IndexType, true><<<grid, block, 0, stream>>>(
285  input, output);
286  } else {
287  l2NormColMajor<T, IndexType, false><<<grid, block, 0, stream>>>(
288  input, output);
289  }
290  }
291 
292 #undef RUN_L2
293 
294  CUDA_TEST_ERROR();
295 }
296 
297 void runL2Norm(Tensor<float, 2, true>& input,
298  bool inputRowMajor,
299  Tensor<float, 1, true>& output,
300  bool normSquared,
301  cudaStream_t stream) {
302  if (input.canUseIndexType<int>()) {
303  runL2Norm<float, float4, int>(
304  input, inputRowMajor, output, normSquared, stream);
305  } else {
306  auto inputCast = input.castIndexType<long>();
307  auto outputCast = output.castIndexType<long>();
308 
309  runL2Norm<float, float4, long>(
310  inputCast, inputRowMajor, outputCast, normSquared, stream);
311  }
312 }
313 
314 #ifdef FAISS_USE_FLOAT16
315 void runL2Norm(Tensor<half, 2, true>& input,
316  bool inputRowMajor,
317  Tensor<half, 1, true>& output,
318  bool normSquared,
319  cudaStream_t stream) {
320  if (input.canUseIndexType<int>()) {
321  runL2Norm<half, half2, int>(
322  input, inputRowMajor, output, normSquared, stream);
323  } else {
324  auto inputCast = input.castIndexType<long>();
325  auto outputCast = output.castIndexType<long>();
326 
327  runL2Norm<half, half2, long>(
328  inputCast, inputRowMajor, outputCast, normSquared, stream);
329  }
330 }
331 #endif
332 
333 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)