Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
L2Select.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "L2Select.cuh"
10 #include "../../FaissAssert.h"
11 
12 #include "../utils/DeviceDefs.cuh"
13 #include "../utils/DeviceUtils.h"
14 #include "../utils/MathOperators.cuh"
15 #include "../utils/Pair.cuh"
16 #include "../utils/Reductions.cuh"
17 #include "../utils/Select.cuh"
18 #include "../utils/Tensor.cuh"
19 #include "../utils/StaticUtils.h"
20 
21 namespace faiss { namespace gpu {
22 
23 // L2 + select kernel for k == 1, implements re-use of ||c||^2
24 template <typename T, int kRowsPerBlock, int kBlockSize>
25 __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
26  Tensor<T, 1, true> centroidDistances,
27  Tensor<T, 2, true> outDistances,
28  Tensor<int, 2, true> outIndices) {
29  // Each block handles kRowsPerBlock rows of the distances (results)
30  Pair<T, int> threadMin[kRowsPerBlock];
31  __shared__ Pair<T, int> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
32 
33  T distance[kRowsPerBlock];
34 
35 #pragma unroll
36  for (int i = 0; i < kRowsPerBlock; ++i) {
37  threadMin[i].k = Limits<T>::getMax();
38  threadMin[i].v = -1;
39  }
40 
41  // blockIdx.x: which chunk of rows we are responsible for updating
42  int rowStart = blockIdx.x * kRowsPerBlock;
43 
44  // FIXME: if we have exact multiples, don't need this
45  bool endRow = (blockIdx.x == gridDim.x - 1);
46 
47  if (endRow) {
48  if (productDistances.getSize(0) % kRowsPerBlock == 0) {
49  endRow = false;
50  }
51  }
52 
53  if (endRow) {
54  for (int row = rowStart; row < productDistances.getSize(0); ++row) {
55  for (int col = threadIdx.x; col < productDistances.getSize(1);
56  col += blockDim.x) {
57  distance[0] = Math<T>::add(centroidDistances[col],
58  productDistances[row][col]);
59 
60  if (Math<T>::lt(distance[0], threadMin[0].k)) {
61  threadMin[0].k = distance[0];
62  threadMin[0].v = col;
63  }
64  }
65 
66  // Reduce within the block
67  threadMin[0] =
68  blockReduceAll<Pair<T, int>, Min<Pair<T, int> >, false, false>(
69  threadMin[0], Min<Pair<T, int> >(), blockMin);
70 
71  if (threadIdx.x == 0) {
72  outDistances[row][0] = threadMin[0].k;
73  outIndices[row][0] = threadMin[0].v;
74  }
75 
76  // so we can use the shared memory again
77  __syncthreads();
78 
79  threadMin[0].k = Limits<T>::getMax();
80  threadMin[0].v = -1;
81  }
82  } else {
83  for (int col = threadIdx.x; col < productDistances.getSize(1);
84  col += blockDim.x) {
85  T centroidDistance = centroidDistances[col];
86 
87 #pragma unroll
88  for (int row = 0; row < kRowsPerBlock; ++row) {
89  distance[row] = productDistances[rowStart + row][col];
90  }
91 
92 #pragma unroll
93  for (int row = 0; row < kRowsPerBlock; ++row) {
94  distance[row] = Math<T>::add(distance[row], centroidDistance);
95  }
96 
97 #pragma unroll
98  for (int row = 0; row < kRowsPerBlock; ++row) {
99  if (Math<T>::lt(distance[row], threadMin[row].k)) {
100  threadMin[row].k = distance[row];
101  threadMin[row].v = col;
102  }
103  }
104  }
105 
106  // Reduce within the block
107  blockReduceAll<kRowsPerBlock, Pair<T, int>, Min<Pair<T, int> >, false, false>(
108  threadMin, Min<Pair<T, int> >(), blockMin);
109 
110  if (threadIdx.x == 0) {
111 #pragma unroll
112  for (int row = 0; row < kRowsPerBlock; ++row) {
113  outDistances[rowStart + row][0] = threadMin[row].k;
114  outIndices[rowStart + row][0] = threadMin[row].v;
115  }
116  }
117  }
118 }
119 
120 // L2 + select kernel for k > 1, no re-use of ||c||^2
121 template <typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
122 __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
123  Tensor<T, 1, true> centroidDistances,
124  Tensor<T, 2, true> outDistances,
125  Tensor<int, 2, true> outIndices,
126  int k, T initK) {
127  // Each block handles a single row of the distances (results)
128  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
129 
130  __shared__ T smemK[kNumWarps * NumWarpQ];
131  __shared__ int smemV[kNumWarps * NumWarpQ];
132 
133  BlockSelect<T, int, false, Comparator<T>,
134  NumWarpQ, NumThreadQ, ThreadsPerBlock>
135  heap(initK, -1, smemK, smemV, k);
136 
137  int row = blockIdx.x;
138 
139  // Whole warps must participate in the selection
140  int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
141  int i = threadIdx.x;
142 
143  for (; i < limit; i += blockDim.x) {
144  T v = Math<T>::add(centroidDistances[i],
145  productDistances[row][i]);
146  heap.add(v, i);
147  }
148 
149  if (i < productDistances.getSize(1)) {
150  T v = Math<T>::add(centroidDistances[i],
151  productDistances[row][i]);
152  heap.addThreadQ(v, i);
153  }
154 
155  heap.reduce();
156  for (int i = threadIdx.x; i < k; i += blockDim.x) {
157  outDistances[row][i] = smemK[i];
158  outIndices[row][i] = smemV[i];
159  }
160 }
161 
162 template <typename T>
163 void runL2SelectMin(Tensor<T, 2, true>& productDistances,
164  Tensor<T, 1, true>& centroidDistances,
165  Tensor<T, 2, true>& outDistances,
166  Tensor<int, 2, true>& outIndices,
167  int k,
168  cudaStream_t stream) {
169  FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0));
170  FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0));
171  FAISS_ASSERT(centroidDistances.getSize(0) == productDistances.getSize(1));
172  FAISS_ASSERT(outDistances.getSize(1) == k);
173  FAISS_ASSERT(outIndices.getSize(1) == k);
174  FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
175 
176  if (k == 1) {
177  constexpr int kThreadsPerBlock = 256;
178  constexpr int kRowsPerBlock = 8;
179 
180  auto block = dim3(kThreadsPerBlock);
181  auto grid = dim3(utils::divUp(outDistances.getSize(0), kRowsPerBlock));
182 
183  l2SelectMin1<T, kRowsPerBlock, kThreadsPerBlock>
184  <<<grid, block, 0, stream>>>(productDistances, centroidDistances,
185  outDistances, outIndices);
186  } else {
187  auto grid = dim3(outDistances.getSize(0));
188 
189 #define RUN_L2_SELECT(BLOCK, NUM_WARP_Q, NUM_THREAD_Q) \
190  do { \
191  l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, BLOCK> \
192  <<<grid, BLOCK, 0, stream>>>(productDistances, centroidDistances, \
193  outDistances, outIndices, \
194  k, Limits<T>::getMax()); \
195  } while (0)
196 
197  // block size 128 for everything <= 1024
198  if (k <= 32) {
199  RUN_L2_SELECT(128, 32, 2);
200  } else if (k <= 64) {
201  RUN_L2_SELECT(128, 64, 3);
202  } else if (k <= 128) {
203  RUN_L2_SELECT(128, 128, 3);
204  } else if (k <= 256) {
205  RUN_L2_SELECT(128, 256, 4);
206  } else if (k <= 512) {
207  RUN_L2_SELECT(128, 512, 8);
208  } else if (k <= 1024) {
209  RUN_L2_SELECT(128, 1024, 8);
210 
211 #if GPU_MAX_SELECTION_K >= 2048
212  } else if (k <= 2048) {
213  // smaller block for less shared memory
214  RUN_L2_SELECT(64, 2048, 8);
215 #endif
216 
217  } else {
218  FAISS_ASSERT(false);
219  }
220  }
221 
222  CUDA_TEST_ERROR();
223 }
224 
225 void runL2SelectMin(Tensor<float, 2, true>& productDistances,
226  Tensor<float, 1, true>& centroidDistances,
227  Tensor<float, 2, true>& outDistances,
228  Tensor<int, 2, true>& outIndices,
229  int k,
230  cudaStream_t stream) {
231  runL2SelectMin<float>(productDistances,
232  centroidDistances,
233  outDistances,
234  outIndices,
235  k,
236  stream);
237 }
238 
239 #ifdef FAISS_USE_FLOAT16
240 void runL2SelectMin(Tensor<half, 2, true>& productDistances,
241  Tensor<half, 1, true>& centroidDistances,
242  Tensor<half, 2, true>& outDistances,
243  Tensor<int, 2, true>& outIndices,
244  int k,
245  cudaStream_t stream) {
246  runL2SelectMin<half>(productDistances,
247  centroidDistances,
248  outDistances,
249  outIndices,
250  k,
251  stream);
252 }
253 #endif
254 
255 } } // namespace