Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
L2Norm.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "L2Norm.cuh"
13 #include "../../FaissAssert.h"
14 #include "../utils/ConversionOperators.cuh"
15 #include "../utils/DeviceDefs.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Float16.cuh"
18 #include "../utils/MathOperators.cuh"
19 #include "../utils/PtxUtils.cuh"
20 #include "../utils/StaticUtils.h"
21 #include "../utils/Reductions.cuh"
22 
23 namespace faiss { namespace gpu {
24 
25 // Input: (batch x dim), # repeats
26 // Output: (# repeats, norm of batch vector)
27 // Done under the presumption that the dimension size is not too large
28 // (<10k or so), since there wouldn't be enough parallelism applying a
29 // single block to the problem. Also that each vector is large enough
30 // (>64), since a single block works on multiple rows' norms at the
31 // same time.
32 // T: the type we are doing the math in (e.g., float, half)
33 // TVec: the potentially vectorized type we are loading in (e.g.,
34 // float4, half2)
35 template <typename T, typename TVec,
36  int RowTileSize, bool NormLoop, bool NormSquared>
37 __global__ void l2Norm(Tensor<TVec, 2, true> input,
38  Tensor<T, 1, true> output) {
39  extern __shared__ char smemByte[]; // #warps * RowTileSize elements
40  T* smem = (T*) smemByte;
41 
42  int numWarps = utils::divUp(blockDim.x, kWarpSize);
43  int laneId = getLaneId();
44  int warpId = threadIdx.x / kWarpSize;
45 
46  bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
47  int rowStart = RowTileSize * blockIdx.x;
48  T rowNorm[RowTileSize];
49 
50  if (lastRowTile) {
51  // We are handling the very end of the input matrix rows
52  for (int row = 0; row < input.getSize(0) - rowStart; ++row) {
53  if (NormLoop) {
54  rowNorm[0] = Math<T>::zero();
55 
56  for (int col = threadIdx.x; col < input.getSize(1); col += blockDim.x) {
57  TVec val = input[rowStart + row][col];
58  val = Math<TVec>::mul(val, val);
59  rowNorm[0] = Math<T>::add(rowNorm[0], Math<TVec>::reduceAdd(val));
60  }
61  } else {
62  TVec val = input[rowStart + row][threadIdx.x];
63  val = Math<TVec>::mul(val, val);
64  rowNorm[0] = Math<TVec>::reduceAdd(val);
65  }
66 
67  rowNorm[0] = warpReduceAllSum(rowNorm[0]);
68  if (laneId == 0) {
69  smem[row * numWarps + warpId] = rowNorm[0];
70  }
71  }
72  } else {
73  // We are guaranteed that all RowTileSize rows are available in
74  // [rowStart, rowStart + RowTileSize)
75 
76  if (NormLoop) {
77  // A single block of threads is not big enough to span each
78  // vector
79  TVec tmp[RowTileSize];
80 
81 #pragma unroll
82  for (int row = 0; row < RowTileSize; ++row) {
83  rowNorm[row] = Math<T>::zero();
84  }
85 
86  for (int col = threadIdx.x; col < input.getSize(1); col += blockDim.x) {
87 #pragma unroll
88  for (int row = 0; row < RowTileSize; ++row) {
89  tmp[row] = input[rowStart + row][col];
90  }
91 
92 #pragma unroll
93  for (int row = 0; row < RowTileSize; ++row) {
94  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
95  }
96 
97 #pragma unroll
98  for (int row = 0; row < RowTileSize; ++row) {
99  rowNorm[row] = Math<T>::add(rowNorm[row],
100  Math<TVec>::reduceAdd(tmp[row]));
101  }
102  }
103  } else {
104  TVec tmp[RowTileSize];
105 
106  // A block of threads is the exact size of the vector
107 #pragma unroll
108  for (int row = 0; row < RowTileSize; ++row) {
109  tmp[row] = input[rowStart + row][threadIdx.x];
110  }
111 
112 #pragma unroll
113  for (int row = 0; row < RowTileSize; ++row) {
114  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
115  }
116 
117 #pragma unroll
118  for (int row = 0; row < RowTileSize; ++row) {
119  rowNorm[row] = Math<TVec>::reduceAdd(tmp[row]);
120  }
121  }
122 
123  // Sum up all parts in each warp
124 #pragma unroll
125  for (int row = 0; row < RowTileSize; ++row) {
126  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
127  }
128 
129  if (laneId == 0) {
130 #pragma unroll
131  for (int row = 0; row < RowTileSize; ++row) {
132  smem[row * numWarps + warpId] = rowNorm[row];
133  }
134  }
135  }
136 
137  __syncthreads();
138 
139  // Sum across warps
140  if (warpId == 0) {
141 #pragma unroll
142  for (int row = 0; row < RowTileSize; ++row) {
143  rowNorm[row] = laneId < numWarps ?
144  smem[row * numWarps + laneId] : Math<T>::zero();
145  }
146 
147 #pragma unroll
148  for (int row = 0; row < RowTileSize; ++row) {
149  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
150  }
151 
152  // Write out answer
153  if (laneId == 0) {
154 #pragma unroll
155  for (int row = 0; row < RowTileSize; ++row) {
156  int outCol = rowStart + row;
157 
158  if (lastRowTile) {
159  if (outCol < output.getSize(0)) {
160  output[outCol] =
161  NormSquared ? rowNorm[row] :
162  ConvertTo<T>::to(
163  sqrtf(ConvertTo<float>::to(rowNorm[row])));
164  }
165  } else {
166  output[outCol] =
167  NormSquared ? rowNorm[row] :
168  ConvertTo<T>::to(
169  sqrtf(ConvertTo<float>::to(rowNorm[row])));
170  }
171  }
172  }
173  }
174 }
175 
176 template <typename T, typename TVec>
177 void runL2Norm(Tensor<T, 2, true>& input,
178  Tensor<T, 1, true>& output,
179  bool normSquared,
180  cudaStream_t stream) {
181  FAISS_ASSERT(input.getSize(0) == output.getSize(0));
182 
183  int maxThreads = getMaxThreadsCurrentDevice();
184  constexpr int rowTileSize = 8;
185 
186 #define RUN_L2(TYPE_T, TYPE_TVEC, INPUT) \
187  do { \
188  if (normLoop) { \
189  if (normSquared) { \
190  l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, true, true> \
191  <<<grid, block, smem, stream>>>(INPUT, output); \
192  } else { \
193  l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, true, false> \
194  <<<grid, block, smem, stream>>>(INPUT, output); \
195  } \
196  } else { \
197  if (normSquared) { \
198  l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, false, true> \
199  <<<grid, block, smem, stream>>>(INPUT, output); \
200  } else { \
201  l2Norm<TYPE_T, TYPE_TVEC, rowTileSize, false, false> \
202  <<<grid, block, smem, stream>>>(INPUT, output); \
203  } \
204  } \
205  } while (0)
206 
207  if (input.template canCastResize<TVec>()) {
208  // Can load using the vectorized type
209  auto inputV = input.template castResize<TVec>();
210 
211  int dim = inputV.getSize(1);
212  bool normLoop = dim > maxThreads;
213  int numThreads = min(dim, maxThreads);
214 
215  auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
216  auto block = dim3(numThreads);
217 
218  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
219 
220  RUN_L2(T, TVec, inputV);
221  } else {
222  // Can't load using the vectorized type
223 
224  int dim = input.getSize(1);
225  bool normLoop = dim > maxThreads;
226  int numThreads = min(dim, maxThreads);
227 
228  auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
229  auto block = dim3(numThreads);
230 
231  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
232 
233  RUN_L2(T, T, input);
234  }
235 
236 #undef RUN_L2
237 
238  CUDA_VERIFY(cudaGetLastError());
239 }
240 
241 void runL2Norm(Tensor<float, 2, true>& input,
242  Tensor<float, 1, true>& output,
243  bool normSquared,
244  cudaStream_t stream) {
245  runL2Norm<float, float4>(input, output, normSquared, stream);
246 }
247 
248 #ifdef FAISS_USE_FLOAT16
249 void runL2Norm(Tensor<half, 2, true>& input,
250  Tensor<half, 1, true>& output,
251  bool normSquared,
252  cudaStream_t stream) {
253  runL2Norm<half, half2>(input, output, normSquared, stream);
254 }
255 #endif
256 
257 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)