Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
L2Norm.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "L2Norm.cuh"
11 #include "../../FaissAssert.h"
12 #include "../utils/ConversionOperators.cuh"
13 #include "../utils/DeviceDefs.cuh"
14 #include "../utils/DeviceUtils.h"
15 #include "../utils/Float16.cuh"
16 #include "../utils/MathOperators.cuh"
17 #include "../utils/PtxUtils.cuh"
18 #include "../utils/StaticUtils.h"
19 #include "../utils/Reductions.cuh"
20 
21 namespace faiss { namespace gpu {
22 
23 // Input: (batch x dim), # repeats
24 // Output: (# repeats, norm of batch vector)
25 // Done under the presumption that the dimension size is not too large
26 // (<10k or so), since there wouldn't be enough parallelism applying a
27 // single block to the problem. Also that each vector is large enough
28 // (>64), since a single block works on multiple rows' norms at the
29 // same time.
30 // T: the type we are doing the math in (e.g., float, half)
31 // TVec: the potentially vectorized type we are loading in (e.g.,
32 // float4, half2)
33 template <typename T, typename TVec, typename int64_t,
34  int RowTileSize, bool NormLoop, bool NormSquared>
35 __global__ void l2Norm(Tensor<TVec, 2, true, int64_t> input,
36  Tensor<T, 1, true, int64_t> output) {
37  extern __shared__ char smemByte[]; // #warps * RowTileSize elements
38  T* smem = (T*) smemByte;
39 
40  int64_t numWarps = utils::divUp(blockDim.x, kWarpSize);
41  int64_t laneId = getLaneId();
42  int64_t warpId = threadIdx.x / kWarpSize;
43 
44  bool lastRowTile = (blockIdx.x == (gridDim.x - 1));
45  int64_t rowStart = RowTileSize * blockIdx.x;
46  T rowNorm[RowTileSize];
47 
48  if (lastRowTile) {
49  // We are handling the very end of the input matrix rows
50  for (int64_t row = 0; row < input.getSize(0) - rowStart; ++row) {
51  if (NormLoop) {
52  rowNorm[0] = Math<T>::zero();
53 
54  for (int64_t col = threadIdx.x;
55  col < input.getSize(1); col += blockDim.x) {
56  TVec val = input[rowStart + row][col];
57  val = Math<TVec>::mul(val, val);
58  rowNorm[0] = Math<T>::add(rowNorm[0], Math<TVec>::reduceAdd(val));
59  }
60  } else {
61  TVec val = input[rowStart + row][threadIdx.x];
62  val = Math<TVec>::mul(val, val);
63  rowNorm[0] = Math<TVec>::reduceAdd(val);
64  }
65 
66  rowNorm[0] = warpReduceAllSum(rowNorm[0]);
67  if (laneId == 0) {
68  smem[row * numWarps + warpId] = rowNorm[0];
69  }
70  }
71  } else {
72  // We are guaranteed that all RowTileSize rows are available in
73  // [rowStart, rowStart + RowTileSize)
74 
75  if (NormLoop) {
76  // A single block of threads is not big enough to span each
77  // vector
78  TVec tmp[RowTileSize];
79 
80 #pragma unroll
81  for (int row = 0; row < RowTileSize; ++row) {
82  rowNorm[row] = Math<T>::zero();
83  }
84 
85  for (int64_t col = threadIdx.x;
86  col < input.getSize(1); col += blockDim.x) {
87 #pragma unroll
88  for (int row = 0; row < RowTileSize; ++row) {
89  tmp[row] = input[rowStart + row][col];
90  }
91 
92 #pragma unroll
93  for (int row = 0; row < RowTileSize; ++row) {
94  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
95  }
96 
97 #pragma unroll
98  for (int row = 0; row < RowTileSize; ++row) {
99  rowNorm[row] = Math<T>::add(rowNorm[row],
100  Math<TVec>::reduceAdd(tmp[row]));
101  }
102  }
103  } else {
104  TVec tmp[RowTileSize];
105 
106  // A block of threads is the exact size of the vector
107 #pragma unroll
108  for (int row = 0; row < RowTileSize; ++row) {
109  tmp[row] = input[rowStart + row][threadIdx.x];
110  }
111 
112 #pragma unroll
113  for (int row = 0; row < RowTileSize; ++row) {
114  tmp[row] = Math<TVec>::mul(tmp[row], tmp[row]);
115  }
116 
117 #pragma unroll
118  for (int row = 0; row < RowTileSize; ++row) {
119  rowNorm[row] = Math<TVec>::reduceAdd(tmp[row]);
120  }
121  }
122 
123  // Sum up all parts in each warp
124 #pragma unroll
125  for (int row = 0; row < RowTileSize; ++row) {
126  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
127  }
128 
129  if (laneId == 0) {
130 #pragma unroll
131  for (int row = 0; row < RowTileSize; ++row) {
132  smem[row * numWarps + warpId] = rowNorm[row];
133  }
134  }
135  }
136 
137  __syncthreads();
138 
139  // Sum across warps
140  if (warpId == 0) {
141 #pragma unroll
142  for (int row = 0; row < RowTileSize; ++row) {
143  rowNorm[row] = laneId < numWarps ?
144  smem[row * numWarps + laneId] : Math<T>::zero();
145  }
146 
147 #pragma unroll
148  for (int row = 0; row < RowTileSize; ++row) {
149  rowNorm[row] = warpReduceAllSum(rowNorm[row]);
150  }
151 
152  // Write out answer
153  if (laneId == 0) {
154 #pragma unroll
155  for (int row = 0; row < RowTileSize; ++row) {
156  int outCol = rowStart + row;
157 
158  if (lastRowTile) {
159  if (outCol < output.getSize(0)) {
160  output[outCol] =
161  NormSquared ? rowNorm[row] :
162  ConvertTo<T>::to(
163  sqrtf(ConvertTo<float>::to(rowNorm[row])));
164  }
165  } else {
166  output[outCol] =
167  NormSquared ? rowNorm[row] :
168  ConvertTo<T>::to(
169  sqrtf(ConvertTo<float>::to(rowNorm[row])));
170  }
171  }
172  }
173  }
174 }
175 
176 template <typename T, typename TVec, typename int64_t>
177 void runL2Norm(Tensor<T, 2, true, int64_t>& input,
178  Tensor<T, 1, true, int64_t>& output,
179  bool normSquared,
180  cudaStream_t stream) {
181  FAISS_ASSERT(input.getSize(0) == output.getSize(0));
182 
183  int64_t maxThreads = (int64_t) getMaxThreadsCurrentDevice();
184  constexpr int rowTileSize = 8;
185 
186 #define RUN_L2(TYPE_T, TYPE_TVEC, INPUT) \
187  do { \
188  if (normLoop) { \
189  if (normSquared) { \
190  l2Norm<TYPE_T, TYPE_TVEC, int64_t, rowTileSize, true, true> \
191  <<<grid, block, smem, stream>>>(INPUT, output); \
192  } else { \
193  l2Norm<TYPE_T, TYPE_TVEC, int64_t, rowTileSize, true, false> \
194  <<<grid, block, smem, stream>>>(INPUT, output); \
195  } \
196  } else { \
197  if (normSquared) { \
198  l2Norm<TYPE_T, TYPE_TVEC, int64_t, rowTileSize, false, true> \
199  <<<grid, block, smem, stream>>>(INPUT, output); \
200  } else { \
201  l2Norm<TYPE_T, TYPE_TVEC, int64_t, rowTileSize, false, false> \
202  <<<grid, block, smem, stream>>>(INPUT, output); \
203  } \
204  } \
205  } while (0)
206 
207  if (input.template canCastResize<TVec>()) {
208  // Can load using the vectorized type
209  auto inputV = input.template castResize<TVec>();
210 
211  auto dim = inputV.getSize(1);
212  bool normLoop = dim > maxThreads;
213  auto numThreads = min(dim, maxThreads);
214 
215  auto grid = dim3(utils::divUp(inputV.getSize(0), rowTileSize));
216  auto block = dim3(numThreads);
217 
218  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
219 
220  RUN_L2(T, TVec, inputV);
221  } else {
222  // Can't load using the vectorized type
223 
224  auto dim = input.getSize(1);
225  bool normLoop = dim > maxThreads;
226  auto numThreads = min(dim, maxThreads);
227 
228  auto grid = dim3(utils::divUp(input.getSize(0), rowTileSize));
229  auto block = dim3(numThreads);
230 
231  auto smem = sizeof(T) * rowTileSize * utils::divUp(numThreads, kWarpSize);
232 
233  RUN_L2(T, T, input);
234  }
235 
236 #undef RUN_L2
237 
238  CUDA_TEST_ERROR();
239 }
240 
241 void runL2Norm(Tensor<float, 2, true>& input,
242  Tensor<float, 1, true>& output,
243  bool normSquared,
244  cudaStream_t stream) {
245  if (input.canUseIndexType<int>()) {
246  runL2Norm<float, float4, int>(input, output, normSquared, stream);
247  } else {
248  auto inputCast = input.castIndexType<long>();
249  auto outputCast = output.castIndexType<long>();
250  runL2Norm<float, float4, long>(inputCast, outputCast, normSquared, stream);
251  }
252 }
253 
254 #ifdef FAISS_USE_FLOAT16
255 void runL2Norm(Tensor<half, 2, true>& input,
256  Tensor<half, 1, true>& output,
257  bool normSquared,
258  cudaStream_t stream) {
259  if (input.canUseIndexType<int>()) {
260  runL2Norm<half, half2, int>(input, output, normSquared, stream);
261  } else {
262  auto inputCast = input.castIndexType<long>();
263  auto outputCast = output.castIndexType<long>();
264  runL2Norm<half, half2, long>(inputCast, outputCast, normSquared, stream);
265  }
266 }
267 #endif
268 
269 } } // namespace
static __device__ T reduceAdd(T v)
For a vector type, this is a horizontal add, returning sum(v_i)