Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
L2Select.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "L2Select.cuh"
13 #include "../../FaissAssert.h"
14 
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/MathOperators.cuh"
17 #include "../utils/Pair.cuh"
18 #include "../utils/Reductions.cuh"
19 #include "../utils/Select.cuh"
20 #include "../utils/Tensor.cuh"
21 #include "../utils/StaticUtils.h"
22 
23 namespace faiss { namespace gpu {
24 
25 // L2 + select kernel for k == 1, implements re-use of ||c||^2
26 template <typename T, int kRowsPerBlock, int kBlockSize>
27 __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
28  Tensor<T, 1, true> centroidDistances,
29  Tensor<T, 2, true> outDistances,
30  Tensor<int, 2, true> outIndices) {
31  // Each block handles kRowsPerBlock rows of the distances (results)
32  Pair<T, int> threadMin[kRowsPerBlock];
33  __shared__ Pair<T, int> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
34 
35  T distance[kRowsPerBlock];
36 
37 #pragma unroll
38  for (int i = 0; i < kRowsPerBlock; ++i) {
39  threadMin[i].k = Limits<T>::getMax();
40  threadMin[i].v = -1;
41  }
42 
43  // blockIdx.x: which chunk of rows we are responsible for updating
44  int rowStart = blockIdx.x * kRowsPerBlock;
45 
46  // FIXME: if we have exact multiples, don't need this
47  bool endRow = (blockIdx.x == gridDim.x - 1);
48 
49  if (endRow) {
50  if (productDistances.getSize(0) % kRowsPerBlock == 0) {
51  endRow = false;
52  }
53  }
54 
55  if (endRow) {
56  for (int row = rowStart; row < productDistances.getSize(0); ++row) {
57  for (int col = threadIdx.x; col < productDistances.getSize(1);
58  col += blockDim.x) {
59  distance[0] = Math<T>::add(centroidDistances[col],
60  productDistances[row][col]);
61 
62  if (Math<T>::lt(distance[0], threadMin[0].k)) {
63  threadMin[0].k = distance[0];
64  threadMin[0].v = col;
65  }
66  }
67 
68  // Reduce within the block
69  threadMin[0] =
70  blockReduceAll<Pair<T, int>, Min<Pair<T, int> >, false, false>(
71  threadMin[0], Min<Pair<T, int> >(), blockMin);
72 
73  if (threadIdx.x == 0) {
74  outDistances[row][0] = threadMin[0].k;
75  outIndices[row][0] = threadMin[0].v;
76  }
77 
78  // so we can use the shared memory again
79  __syncthreads();
80 
81  threadMin[0].k = Limits<T>::getMax();
82  threadMin[0].v = -1;
83  }
84  } else {
85  for (int col = threadIdx.x; col < productDistances.getSize(1);
86  col += blockDim.x) {
87  T centroidDistance = centroidDistances[col];
88 
89 #pragma unroll
90  for (int row = 0; row < kRowsPerBlock; ++row) {
91  distance[row] = productDistances[rowStart + row][col];
92  }
93 
94 #pragma unroll
95  for (int row = 0; row < kRowsPerBlock; ++row) {
96  distance[row] = Math<T>::add(distance[row], centroidDistance);
97  }
98 
99 #pragma unroll
100  for (int row = 0; row < kRowsPerBlock; ++row) {
101  if (Math<T>::lt(distance[row], threadMin[row].k)) {
102  threadMin[row].k = distance[row];
103  threadMin[row].v = col;
104  }
105  }
106  }
107 
108  // Reduce within the block
109  blockReduceAll<kRowsPerBlock, Pair<T, int>, Min<Pair<T, int> >, false, false>(
110  threadMin, Min<Pair<T, int> >(), blockMin);
111 
112  if (threadIdx.x == 0) {
113 #pragma unroll
114  for (int row = 0; row < kRowsPerBlock; ++row) {
115  outDistances[rowStart + row][0] = threadMin[row].k;
116  outIndices[rowStart + row][0] = threadMin[row].v;
117  }
118  }
119  }
120 }
121 
122 // L2 + select kernel for k > 1, no re-use of ||c||^2
123 template <typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
124 __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
125  Tensor<T, 1, true> centroidDistances,
126  Tensor<T, 2, true> outDistances,
127  Tensor<int, 2, true> outIndices,
128  int k, T initK) {
129  // Each block handles a single row of the distances (results)
130  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
131 
132  __shared__ T smemK[kNumWarps * NumWarpQ];
133  __shared__ int smemV[kNumWarps * NumWarpQ];
134 
135  BlockSelect<T, int, false, Comparator<T>,
136  NumWarpQ, NumThreadQ, ThreadsPerBlock>
137  heap(initK, -1, smemK, smemV, k);
138 
139  int row = blockIdx.x;
140 
141  // Whole warps must participate in the selection
142  int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
143  int i = threadIdx.x;
144 
145  for (; i < limit; i += blockDim.x) {
146  T v = Math<T>::add(centroidDistances[i],
147  productDistances[row][i]);
148  heap.add(v, i);
149  }
150 
151  if (i < productDistances.getSize(1)) {
152  T v = Math<T>::add(centroidDistances[i],
153  productDistances[row][i]);
154  heap.addThreadQ(v, i);
155  }
156 
157  heap.reduce();
158  for (int i = threadIdx.x; i < k; i += blockDim.x) {
159  outDistances[row][i] = smemK[i];
160  outIndices[row][i] = smemV[i];
161  }
162 }
163 
164 // FIXME: no TVec specialization
165 template <typename T>
166 void runL2SelectMin(Tensor<T, 2, true>& productDistances,
167  Tensor<T, 1, true>& centroidDistances,
168  Tensor<T, 2, true>& outDistances,
169  Tensor<int, 2, true>& outIndices,
170  int k,
171  cudaStream_t stream) {
172  FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0));
173  FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0));
174  FAISS_ASSERT(centroidDistances.getSize(0) == productDistances.getSize(1));
175  FAISS_ASSERT(outDistances.getSize(1) == k);
176  FAISS_ASSERT(outIndices.getSize(1) == k);
177  FAISS_ASSERT(k <= 1024);
178 
179  if (k == 1) {
180  constexpr int kThreadsPerBlock = 256;
181  constexpr int kRowsPerBlock = 8;
182 
183  auto block = dim3(kThreadsPerBlock);
184  auto grid = dim3(utils::divUp(outDistances.getSize(0), kRowsPerBlock));
185 
186  l2SelectMin1<T, kRowsPerBlock, kThreadsPerBlock>
187  <<<grid, block, 0, stream>>>(productDistances, centroidDistances,
188  outDistances, outIndices);
189  } else {
190  constexpr int kThreadsPerBlock = 128;
191 
192  auto block = dim3(kThreadsPerBlock);
193  auto grid = dim3(outDistances.getSize(0));
194 
195 #define RUN_L2_SELECT(NUM_WARP_Q, NUM_THREAD_Q) \
196  do { \
197  l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, kThreadsPerBlock> \
198  <<<grid, block, 0, stream>>>(productDistances, centroidDistances, \
199  outDistances, outIndices, \
200  k, Limits<T>::getMax()); \
201  } while (0)
202 
203  if (k <= 32) {
204  RUN_L2_SELECT(32, 2);
205  } else if (k <= 64) {
206  RUN_L2_SELECT(64, 3);
207  } else if (k <= 128) {
208  RUN_L2_SELECT(128, 3);
209  } else if (k <= 256) {
210  RUN_L2_SELECT(256, 4);
211  } else if (k <= 512) {
212  RUN_L2_SELECT(512, 8);
213  } else if (k <= 1024) {
214  RUN_L2_SELECT(1024, 8);
215  } else {
216  FAISS_ASSERT(false);
217  }
218  }
219 
220  CUDA_VERIFY(cudaGetLastError());
221 }
222 
223 void runL2SelectMin(Tensor<float, 2, true>& productDistances,
224  Tensor<float, 1, true>& centroidDistances,
225  Tensor<float, 2, true>& outDistances,
226  Tensor<int, 2, true>& outIndices,
227  int k,
228  cudaStream_t stream) {
229  runL2SelectMin<float>(productDistances,
230  centroidDistances,
231  outDistances,
232  outIndices,
233  k,
234  stream);
235 }
236 
237 #ifdef FAISS_USE_FLOAT16
238 void runL2SelectMin(Tensor<half, 2, true>& productDistances,
239  Tensor<half, 1, true>& centroidDistances,
240  Tensor<half, 2, true>& outDistances,
241  Tensor<int, 2, true>& outIndices,
242  int k,
243  cudaStream_t stream) {
244  runL2SelectMin<half>(productDistances,
245  centroidDistances,
246  outDistances,
247  outIndices,
248  k,
249  stream);
250 }
251 #endif
252 
253 } } // namespace