Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
L2Norm.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "L2Norm.cuh"
12 #include "../../FaissAssert.h"
13 #include "../utils/ConversionOperators.cuh"
14 #include "../utils/DeviceDefs.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/Float16.cuh"
17 #include "../utils/MathOperators.cuh"
18 #include "../utils/PtxUtils.cuh"
19 #include "../utils/StaticUtils.h"
20 #include "../utils/Reductions.cuh"
21 
22 namespace faiss { namespace gpu {
23 
24 // Input: (batch x dim), # repeats
25 // Output: (# repeats, norm of batch vector)
26 // Done under the presumption that the dimension size is not too large
27 // (<10k or so), since there wouldn't be enough parallelism applying a
28 // single block to the problem. Also that each vector is large enough
29 // (>64), since a single block works on multiple rows' norms at the
30 // same time.
31 // T: the type we are doing the math in (e.g., float, half)
32 // TVec: the potentially vectorized type we are loading in (e.g.,
33 // float4, half2)
34 template <typename T, typename TVec,
35  int RowTileSize, bool NormLoop, bool NormSquared>
36 __global__ void l2Norm(Tensor<TVec, 2, true> input,
37  Tensor<T, 1, true> output) {
38  extern __shared__ char smemByte[]; // #warps * RowTileSize elements
39  T* smem = (T*) smemByte;
40 
41  int numWarps = utils::divUp(blockDim.x, kWarpSize);
42  int laneId = getLaneId();
43  int warpId = threadIdx.x / kWarpSize;
44 
45  bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
46  int rowStart = RowTileSize * blockIdx.x;
47  T rowNorm[RowTileSize];
48 
49  if (lastRowTile) {
50  // We are handling the very end of the input matrix rows
51  for (int row = 0; row < input.getSize(0) - rowStart; ++row) {
52  if (NormLoop) {
53  rowNorm[0] = Math<T>::zero();
54 
55  for (int col = threadIdx.x; col < input.getSize(1); col += blockDim.x) {
56  TVec val = input[rowStart + row][col];
57  val = Math<TVec>::mul(val, val);
58  rowNorm[0] = Math<T>::add(rowNorm[0], Math<TVec>::reduceAdd(val));
59  }
60  } else {
61  TVec val = input[rowStart + row][threadIdx.x];
62  val = Math<TVec>::mul(val, val);
63  rowNorm[0] = Math<TVec>::reduceAdd(val);
64  }
65 
66  rowNorm[0] = warpReduceAllSum(rowNorm[0]);
67  if (laneId == 0) {
68  smem[row * numWarps + warpId] = rowNorm[0];
69  }
70  }
71  } else {
72  // We are guaranteed that all RowTileSize rows are available in
73  // [rowStart, rowStart + RowTileSize)
74 
75  if (NormLoop) {
76  // A single block of threads is not big enough to span each
77  // vector
78  TVec tmp[RowTileSize];
79 
80 #pragma unroll
81  for (int row = 0; row < RowTileSize; ++row) {
82  rowNorm[row] = Math<T>::zero();
83  }
84 
85  for (int col = threadIdx.x; col < input.getSize(1); col += blockDim.x) {
86 #pragma unroll
87  for (int row = 0; row < RowTileSize; ++row) {
88  tmp[row] = input[rowStart + row][col];
89  }
90 
91 #pragma unroll
92  for (int row = 0; row < RowTileSize; ++row) {
93  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
94  }
95 
96 #pragma unroll
97  for (int row = 0; row < RowTileSize; ++row) {
98  rowNorm[row] = Math<T>::add(rowNorm[row],
99  Math<TVec>::reduceAdd(tmp[row]));
100  }
101  }
102  } else {
103  TVec tmp[RowTileSize];
104 
105  // A block of threads is the exact size of the vector
106 #pragma unroll
107  for (int row = 0; row < RowTileSize; ++row) {
108  tmp[row] = input[rowStart + row][threadIdx.x];
109  }
110 
111 #pragma unroll
112  for (int row = 0; row < RowTileSize; ++row) {
113  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
114  }
115 
116 #pragma unroll
117  for (int row = 0; row < RowTileSize; ++row) {
118  rowNorm[row] = Math<TVec>::reduceAdd(tmp[row]);
119  }
120  }
121 
122  // Sum up all parts in each warp
123 #pragma unroll
124  for (int row = 0; row < RowTileSize; ++row) {
125  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
126  }
127 
128  if (laneId == 0) {
129 #pragma unroll
130  for (int row = 0; row < RowTileSize; ++row) {
131  smem[row * numWarps + warpId] = rowNorm[row];
132  }
133  }
134  }
135 
136  __syncthreads();
137 
138  // Sum across warps
139  if (warpId == 0) {
140 #pragma unroll
141  for (int row = 0; row < RowTileSize; ++row) {
142  rowNorm[row] = laneId < numWarps ?
143  smem[row * numWarps + laneId] : Math<T>::zero();
144  }
145 
146 #pragma unroll
147  for (int row = 0; row < RowTileSize; ++row) {
148  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
149  }
150 
151  // Write out answer
152  if (laneId == 0) {
153 #pragma unroll
154  for (int row = 0; row < RowTileSize; ++row) {
155  int outCol = rowStart + row;
156 
157  if (lastRowTile) {
158  if (outCol < output.getSize(0)) {
159  output[outCol] =
160  NormSquared ? rowNorm[row] :
161  ConvertTo<T>::to(
162  sqrtf(ConvertTo<float>::to(rowNorm[row])));
163  }
164  } else {
165  output[outCol] =
166  NormSquared ? rowNorm[row] :
167  ConvertTo<T>::to(
168  sqrtf(ConvertTo<float>::to(rowNorm[row])));
169  }
170  }
171  }
172  }
173 }
174 
175 template <typename T, typename TVec>
176 void runL2Norm(Tensor<T, 2, true>& input,
177  Tensor<T, 1, true>& output,
178  bool normSquared,
179  cudaStream_t stream) {
180  FAISS_ASSERT(input.getSize(0) == output.getSize(0));
181 
182  int maxThreads = getMaxThreadsCurrentDevice();
183  constexpr int rowTileSize = 8;
184 
185 #define RUN_L2(TYPE_T, TYPE_TVEC, INPUT) \
186  do { \
187  if (normLoop) { \
188  if (normSquared) { \
189  l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, true, true> \
190  <<<grid, block, smem, stream>>>(INPUT, output); \
191  } else { \
192  l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, true, false> \
193  <<<grid, block, smem, stream>>>(INPUT, output); \
194  } \
195  } else { \
196  if (normSquared) { \
197  l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, false, true> \
198  <<<grid, block, smem, stream>>>(INPUT, output); \
199  } else { \
200  l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, false, false> \
201  <<<grid, block, smem, stream>>>(INPUT, output); \
202  } \
203  } \
204  } while (0)
205 
206  if (input.template canCastResize<TVec>()) {
207  // Can load using the vectorized type
208  auto inputV = input.template castResize<TVec>();
209 
210  int dim = inputV.getSize(1);
211  bool normLoop = dim > maxThreads;
212  int numThreads = min(dim, maxThreads);
213 
214  auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
215  auto block = dim3(numThreads);
216 
217  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
218 
219  RUN_L2(T, TVec, inputV);
220  } else {
221  // Can't load using the vectorized type
222 
223  int dim = input.getSize(1);
224  bool normLoop = dim > maxThreads;
225  int numThreads = min(dim, maxThreads);
226 
227  auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
228  auto block = dim3(numThreads);
229 
230  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
231 
232  RUN_L2(T, T, input);
233  }
234 
235 #undef RUN_L2
236 
237  CUDA_TEST_ERROR();
238 }
239 
240 void runL2Norm(Tensor<float, 2, true>& input,
241  Tensor<float, 1, true>& output,
242  bool normSquared,
243  cudaStream_t stream) {
244  runL2Norm<float, float4>(input, output, normSquared, stream);
245 }
246 
247 #ifdef FAISS_USE_FLOAT16
248 void runL2Norm(Tensor<half, 2, true>& input,
249  Tensor<half, 1, true>& output,
250  bool normSquared,
251  cudaStream_t stream) {
252  runL2Norm<half, half2>(input, output, normSquared, stream);
253 }
254 #endif
255 
256 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)