Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
L2Select.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "L2Select.cuh"
12 #include "../../FaissAssert.h"
13 
14 #include "../utils/DeviceUtils.h"
15 #include "../utils/MathOperators.cuh"
16 #include "../utils/Pair.cuh"
17 #include "../utils/Reductions.cuh"
18 #include "../utils/Select.cuh"
19 #include "../utils/Tensor.cuh"
20 #include "../utils/StaticUtils.h"
21 
22 namespace faiss { namespace gpu {
23 
24 // L2 + select kernel for k == 1, implements re-use of ||c||^2
25 template <typename T, int kRowsPerBlock, int kBlockSize>
26 __global__ void l2SelectMin1(Tensor<T, 2, true> productDistances,
27  Tensor<T, 1, true> centroidDistances,
28  Tensor<T, 2, true> outDistances,
29  Tensor<int, 2, true> outIndices) {
30  // Each block handles kRowsPerBlock rows of the distances (results)
31  Pair<T, int> threadMin[kRowsPerBlock];
32  __shared__ Pair<T, int> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
33 
34  T distance[kRowsPerBlock];
35 
36 #pragma unroll
37  for (int i = 0; i < kRowsPerBlock; ++i) {
38  threadMin[i].k = Limits<T>::getMax();
39  threadMin[i].v = -1;
40  }
41 
42  // blockIdx.x: which chunk of rows we are responsible for updating
43  int rowStart = blockIdx.x * kRowsPerBlock;
44 
45  // FIXME: if we have exact multiples, don't need this
46  bool endRow = (blockIdx.x == gridDim.x - 1);
47 
48  if (endRow) {
49  if (productDistances.getSize(0) % kRowsPerBlock == 0) {
50  endRow = false;
51  }
52  }
53 
54  if (endRow) {
55  for (int row = rowStart; row < productDistances.getSize(0); ++row) {
56  for (int col = threadIdx.x; col < productDistances.getSize(1);
57  col += blockDim.x) {
58  distance[0] = Math<T>::add(centroidDistances[col],
59  productDistances[row][col]);
60 
61  if (Math<T>::lt(distance[0], threadMin[0].k)) {
62  threadMin[0].k = distance[0];
63  threadMin[0].v = col;
64  }
65  }
66 
67  // Reduce within the block
68  threadMin[0] =
69  blockReduceAll<Pair<T, int>, Min<Pair<T, int> >, false, false>(
70  threadMin[0], Min<Pair<T, int> >(), blockMin);
71 
72  if (threadIdx.x == 0) {
73  outDistances[row][0] = threadMin[0].k;
74  outIndices[row][0] = threadMin[0].v;
75  }
76 
77  // so we can use the shared memory again
78  __syncthreads();
79 
80  threadMin[0].k = Limits<T>::getMax();
81  threadMin[0].v = -1;
82  }
83  } else {
84  for (int col = threadIdx.x; col < productDistances.getSize(1);
85  col += blockDim.x) {
86  T centroidDistance = centroidDistances[col];
87 
88 #pragma unroll
89  for (int row = 0; row < kRowsPerBlock; ++row) {
90  distance[row] = productDistances[rowStart + row][col];
91  }
92 
93 #pragma unroll
94  for (int row = 0; row < kRowsPerBlock; ++row) {
95  distance[row] = Math<T>::add(distance[row], centroidDistance);
96  }
97 
98 #pragma unroll
99  for (int row = 0; row < kRowsPerBlock; ++row) {
100  if (Math<T>::lt(distance[row], threadMin[row].k)) {
101  threadMin[row].k = distance[row];
102  threadMin[row].v = col;
103  }
104  }
105  }
106 
107  // Reduce within the block
108  blockReduceAll<kRowsPerBlock, Pair<T, int>, Min<Pair<T, int> >, false, false>(
109  threadMin, Min<Pair<T, int> >(), blockMin);
110 
111  if (threadIdx.x == 0) {
112 #pragma unroll
113  for (int row = 0; row < kRowsPerBlock; ++row) {
114  outDistances[rowStart + row][0] = threadMin[row].k;
115  outIndices[rowStart + row][0] = threadMin[row].v;
116  }
117  }
118  }
119 }
120 
121 // L2 + select kernel for k > 1, no re-use of ||c||^2
122 template <typename T, int NumWarpQ, int NumThreadQ, int ThreadsPerBlock>
123 __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
124  Tensor<T, 1, true> centroidDistances,
125  Tensor<T, 2, true> outDistances,
126  Tensor<int, 2, true> outIndices,
127  int k, T initK) {
128  // Each block handles a single row of the distances (results)
129  constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
130 
131  __shared__ T smemK[kNumWarps * NumWarpQ];
132  __shared__ int smemV[kNumWarps * NumWarpQ];
133 
134  BlockSelect<T, int, false, Comparator<T>,
135  NumWarpQ, NumThreadQ, ThreadsPerBlock>
136  heap(initK, -1, smemK, smemV, k);
137 
138  int row = blockIdx.x;
139 
140  // Whole warps must participate in the selection
141  int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
142  int i = threadIdx.x;
143 
144  for (; i < limit; i += blockDim.x) {
145  T v = Math<T>::add(centroidDistances[i],
146  productDistances[row][i]);
147  heap.add(v, i);
148  }
149 
150  if (i < productDistances.getSize(1)) {
151  T v = Math<T>::add(centroidDistances[i],
152  productDistances[row][i]);
153  heap.addThreadQ(v, i);
154  }
155 
156  heap.reduce();
157  for (int i = threadIdx.x; i < k; i += blockDim.x) {
158  outDistances[row][i] = smemK[i];
159  outIndices[row][i] = smemV[i];
160  }
161 }
162 
163 // FIXME: no TVec specialization
164 template <typename T>
165 void runL2SelectMin(Tensor<T, 2, true>& productDistances,
166  Tensor<T, 1, true>& centroidDistances,
167  Tensor<T, 2, true>& outDistances,
168  Tensor<int, 2, true>& outIndices,
169  int k,
170  cudaStream_t stream) {
171  FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0));
172  FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0));
173  FAISS_ASSERT(centroidDistances.getSize(0) == productDistances.getSize(1));
174  FAISS_ASSERT(outDistances.getSize(1) == k);
175  FAISS_ASSERT(outIndices.getSize(1) == k);
176  FAISS_ASSERT(k <= 1024);
177 
178  if (k == 1) {
179  constexpr int kThreadsPerBlock = 256;
180  constexpr int kRowsPerBlock = 8;
181 
182  auto block = dim3(kThreadsPerBlock);
183  auto grid = dim3(utils::divUp(outDistances.getSize(0), kRowsPerBlock));
184 
185  l2SelectMin1<T, kRowsPerBlock, kThreadsPerBlock>
186  <<<grid, block, 0, stream>>>(productDistances, centroidDistances,
187  outDistances, outIndices);
188  } else {
189  constexpr int kThreadsPerBlock = 128;
190 
191  auto block = dim3(kThreadsPerBlock);
192  auto grid = dim3(outDistances.getSize(0));
193 
194 #define RUN_L2_SELECT(NUM_WARP_Q, NUM_THREAD_Q) \
195  do { \
196  l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, kThreadsPerBlock> \
197  <<<grid, block, 0, stream>>>(productDistances, centroidDistances, \
198  outDistances, outIndices, \
199  k, Limits<T>::getMax()); \
200  } while (0)
201 
202  if (k <= 32) {
203  RUN_L2_SELECT(32, 2);
204  } else if (k <= 64) {
205  RUN_L2_SELECT(64, 3);
206  } else if (k <= 128) {
207  RUN_L2_SELECT(128, 3);
208  } else if (k <= 256) {
209  RUN_L2_SELECT(256, 4);
210  } else if (k <= 512) {
211  RUN_L2_SELECT(512, 8);
212  } else if (k <= 1024) {
213  RUN_L2_SELECT(1024, 8);
214  } else {
215  FAISS_ASSERT(false);
216  }
217  }
218 
219  CUDA_TEST_ERROR();
220 }
221 
222 void runL2SelectMin(Tensor<float, 2, true>& productDistances,
223  Tensor<float, 1, true>& centroidDistances,
224  Tensor<float, 2, true>& outDistances,
225  Tensor<int, 2, true>& outIndices,
226  int k,
227  cudaStream_t stream) {
228  runL2SelectMin<float>(productDistances,
229  centroidDistances,
230  outDistances,
231  outIndices,
232  k,
233  stream);
234 }
235 
236 #ifdef FAISS_USE_FLOAT16
237 void runL2SelectMin(Tensor<half, 2, true>& productDistances,
238  Tensor<half, 1, true>& centroidDistances,
239  Tensor<half, 2, true>& outDistances,
240  Tensor<int, 2, true>& outIndices,
241  int k,
242  cudaStream_t stream) {
243  runL2SelectMin<half>(productDistances,
244  centroidDistances,
245  outDistances,
246  outIndices,
247  k,
248  stream);
249 }
250 #endif
251 
252 } } // namespace