Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Select.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "Comparators.cuh"
11 #include "DeviceDefs.cuh"
12 #include "MergeNetworkBlock.cuh"
13 #include "MergeNetworkWarp.cuh"
14 #include "PtxUtils.cuh"
15 #include "Reductions.cuh"
16 #include "ReductionOperators.cuh"
17 #include "Tensor.cuh"
18 
19 namespace faiss { namespace gpu {
20 
21 // Specialization for block-wide monotonic merges producing a merge sort
22 // since what we really want is a constexpr loop expansion
23 template <int NumWarps,
24  int NumThreads, typename K, typename V, int NumWarpQ,
25  bool Dir, typename Comp>
27 };
28 
29 template <int NumThreads, typename K, typename V, int NumWarpQ,
30  bool Dir, typename Comp>
31 struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir, Comp> {
32  static inline __device__ void merge(K* sharedK, V* sharedV) {
33  // no merge required; single warp
34  }
35 };
36 
37 template <int NumThreads, typename K, typename V, int NumWarpQ,
38  bool Dir, typename Comp>
39 struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir, Comp> {
40  static inline __device__ void merge(K* sharedK, V* sharedV) {
41  // Final merge doesn't need to fully merge the second list
42  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
43  NumWarpQ, !Dir, Comp, false>(sharedK, sharedV);
44  }
45 };
46 
47 template <int NumThreads, typename K, typename V, int NumWarpQ,
48  bool Dir, typename Comp>
49 struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir, Comp> {
50  static inline __device__ void merge(K* sharedK, V* sharedV) {
51  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
52  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
53  // Final merge doesn't need to fully merge the second list
54  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
55  NumWarpQ * 2, !Dir, Comp, false>(sharedK, sharedV);
56  }
57 };
58 
59 template <int NumThreads, typename K, typename V, int NumWarpQ,
60  bool Dir, typename Comp>
61 struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir, Comp> {
62  static inline __device__ void merge(K* sharedK, V* sharedV) {
63  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
64  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
65  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
66  NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
67  // Final merge doesn't need to fully merge the second list
68  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
69  NumWarpQ * 4, !Dir, Comp, false>(sharedK, sharedV);
70  }
71 };
72 
73 // `Dir` true, produce largest values.
74 // `Dir` false, produce smallest values.
75 template <typename K,
76  typename V,
77  bool Dir,
78  typename Comp,
79  int NumWarpQ,
80  int NumThreadQ,
81  int ThreadsPerBlock>
82 struct BlockSelect {
83  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
84  static constexpr int kTotalWarpSortSize = NumWarpQ;
85 
86  __device__ inline BlockSelect(K initKVal,
87  V initVVal,
88  K* smemK,
89  V* smemV,
90  int k) :
91  initK(initKVal),
92  initV(initVVal),
93  numVals(0),
94  warpKTop(initKVal),
95  sharedK(smemK),
96  sharedV(smemV),
97  kMinus1(k - 1) {
98  static_assert(utils::isPowerOf2(ThreadsPerBlock),
99  "threads must be a power-of-2");
100  static_assert(utils::isPowerOf2(NumWarpQ),
101  "warp queue must be power-of-2");
102 
103  // Fill the per-thread queue keys with the default value
104 #pragma unroll
105  for (int i = 0; i < NumThreadQ; ++i) {
106  threadK[i] = initK;
107  threadV[i] = initV;
108  }
109 
110  int laneId = getLaneId();
111  int warpId = threadIdx.x / kWarpSize;
112  warpK = sharedK + warpId * kTotalWarpSortSize;
113  warpV = sharedV + warpId * kTotalWarpSortSize;
114 
115  // Fill warp queue (only the actual queue space is fine, not where
116  // we write the per-thread queues for merging)
117  for (int i = laneId; i < NumWarpQ; i += kWarpSize) {
118  warpK[i] = initK;
119  warpV[i] = initV;
120  }
121 
122  warpFence();
123  }
124 
125  __device__ inline void addThreadQ(K k, V v) {
126  if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
127  // Rotate right
128 #pragma unroll
129  for (int i = NumThreadQ - 1; i > 0; --i) {
130  threadK[i] = threadK[i - 1];
131  threadV[i] = threadV[i - 1];
132  }
133 
134  threadK[0] = k;
135  threadV[0] = v;
136  ++numVals;
137  }
138  }
139 
140  __device__ inline void checkThreadQ() {
141  bool needSort = (numVals == NumThreadQ);
142 
143 #if CUDA_VERSION >= 9000
144  needSort = __any_sync(0xffffffff, needSort);
145 #else
146  needSort = __any(needSort);
147 #endif
148 
149  if (!needSort) {
150  // no lanes have triggered a sort
151  return;
152  }
153 
154  // This has a trailing warpFence
155  mergeWarpQ();
156 
157  // Any top-k elements have been merged into the warp queue; we're
158  // free to reset the thread queues
159  numVals = 0;
160 
161 #pragma unroll
162  for (int i = 0; i < NumThreadQ; ++i) {
163  threadK[i] = initK;
164  threadV[i] = initV;
165  }
166 
167  // We have to beat at least this element
168  warpKTop = warpK[kMinus1];
169 
170  warpFence();
171  }
172 
173  /// This function handles sorting and merging together the
174  /// per-thread queues with the warp-wide queue, creating a sorted
175  /// list across both
176  __device__ inline void mergeWarpQ() {
177  int laneId = getLaneId();
178 
179  // Sort all of the per-thread queues
180  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
181 
182  constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
183  K warpKRegisters[kNumWarpQRegisters];
184  V warpVRegisters[kNumWarpQRegisters];
185 
186 #pragma unroll
187  for (int i = 0; i < kNumWarpQRegisters; ++i) {
188  warpKRegisters[i] = warpK[i * kWarpSize + laneId];
189  warpVRegisters[i] = warpV[i * kWarpSize + laneId];
190  }
191 
192  warpFence();
193 
194  // The warp queue is already sorted, and now that we've sorted the
195  // per-thread queue, merge both sorted lists together, producing
196  // one sorted list
197  warpMergeAnyRegisters<K, V,
198  kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>(
199  warpKRegisters, warpVRegisters, threadK, threadV);
200 
201  // Write back out the warp queue
202 #pragma unroll
203  for (int i = 0; i < kNumWarpQRegisters; ++i) {
204  warpK[i * kWarpSize + laneId] = warpKRegisters[i];
205  warpV[i * kWarpSize + laneId] = warpVRegisters[i];
206  }
207 
208  warpFence();
209  }
210 
211  /// WARNING: all threads in a warp must participate in this.
212  /// Otherwise, you must call the constituent parts separately.
213  __device__ inline void add(K k, V v) {
214  addThreadQ(k, v);
215  checkThreadQ();
216  }
217 
218  __device__ inline void reduce() {
219  // Have all warps dump and merge their queues; this will produce
220  // the final per-warp results
221  mergeWarpQ();
222 
223  // block-wide dep; thus far, all warps have been completely
224  // independent
225  __syncthreads();
226 
227  // All warp queues are contiguous in smem.
228  // Now, we have kNumWarps lists of NumWarpQ elements.
229  // This is a power of 2.
231  merge(sharedK, sharedV);
232 
233  // The block-wide merge has a trailing syncthreads
234  }
235 
236  // Default element key
237  const K initK;
238 
239  // Default element value
240  const V initV;
241 
242  // Number of valid elements in our thread queue
243  int numVals;
244 
245  // The k-th highest (Dir) or lowest (!Dir) element
246  K warpKTop;
247 
248  // Thread queue values
249  K threadK[NumThreadQ];
250  V threadV[NumThreadQ];
251 
252  // Queues for all warps
253  K* sharedK;
254  V* sharedV;
255 
256  // Our warp's queue (points into sharedK/sharedV)
257  // warpK[0] is highest (Dir) or lowest (!Dir)
258  K* warpK;
259  V* warpV;
260 
261  // This is a cached k-1 value
262  int kMinus1;
263 };
264 
265 /// Specialization for k == 1 (NumWarpQ == 1)
266 template <typename K,
267  typename V,
268  bool Dir,
269  typename Comp,
270  int NumThreadQ,
271  int ThreadsPerBlock>
272 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
273  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
274 
275  __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) :
276  sharedK(smemK),
277  sharedV(smemV),
278  threadK(initK),
279  threadV(initV) {
280  }
281 
282  __device__ inline void addThreadQ(K k, V v) {
283  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
284  threadK = swap ? k : threadK;
285  threadV = swap ? v : threadV;
286  }
287 
288  __device__ inline void checkThreadQ() {
289  // We don't need to do anything here, since the warp doesn't
290  // cooperate until the end
291  }
292 
293  __device__ inline void add(K k, V v) {
294  addThreadQ(k, v);
295  }
296 
297  __device__ inline void reduce() {
298  // Reduce within the warp
299  Pair<K, V> pair(threadK, threadV);
300 
301  if (Dir) {
302  pair =
303  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
304  } else {
305  pair =
306  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
307  }
308 
309  // Each warp writes out a single value
310  int laneId = getLaneId();
311  int warpId = threadIdx.x / kWarpSize;
312 
313  if (laneId == 0) {
314  sharedK[warpId] = pair.k;
315  sharedV[warpId] = pair.v;
316  }
317 
318  __syncthreads();
319 
320  // We typically use this for small blocks (<= 128), just having the first
321  // thread in the block perform the reduction across warps is
322  // faster
323  if (threadIdx.x == 0) {
324  threadK = sharedK[0];
325  threadV = sharedV[0];
326 
327 #pragma unroll
328  for (int i = 1; i < kNumWarps; ++i) {
329  K k = sharedK[i];
330  V v = sharedV[i];
331 
332  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
333  threadK = swap ? k : threadK;
334  threadV = swap ? v : threadV;
335  }
336 
337  // Hopefully a thread's smem reads/writes are ordered wrt
338  // itself, so no barrier needed :)
339  sharedK[0] = threadK;
340  sharedV[0] = threadV;
341  }
342 
343  // In case other threads wish to read this value
344  __syncthreads();
345  }
346 
347  // threadK is lowest (Dir) or highest (!Dir)
348  K threadK;
349  V threadV;
350 
351  // Where we reduce in smem
352  K* sharedK;
353  V* sharedV;
354 };
355 
356 //
357 // per-warp WarpSelect
358 //
359 
360 // `Dir` true, produce largest values.
361 // `Dir` false, produce smallest values.
362 template <typename K,
363  typename V,
364  bool Dir,
365  typename Comp,
366  int NumWarpQ,
367  int NumThreadQ,
368  int ThreadsPerBlock>
369 struct WarpSelect {
370  static constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
371 
372  __device__ inline WarpSelect(K initKVal, V initVVal, int k) :
373  initK(initKVal),
374  initV(initVVal),
375  numVals(0),
376  warpKTop(initKVal),
377  kLane((k - 1) % kWarpSize) {
378  static_assert(utils::isPowerOf2(ThreadsPerBlock),
379  "threads must be a power-of-2");
380  static_assert(utils::isPowerOf2(NumWarpQ),
381  "warp queue must be power-of-2");
382 
383  // Fill the per-thread queue keys with the default value
384 #pragma unroll
385  for (int i = 0; i < NumThreadQ; ++i) {
386  threadK[i] = initK;
387  threadV[i] = initV;
388  }
389 
390  // Fill the warp queue with the default value
391 #pragma unroll
392  for (int i = 0; i < kNumWarpQRegisters; ++i) {
393  warpK[i] = initK;
394  warpV[i] = initV;
395  }
396  }
397 
398  __device__ inline void addThreadQ(K k, V v) {
399  if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
400  // Rotate right
401 #pragma unroll
402  for (int i = NumThreadQ - 1; i > 0; --i) {
403  threadK[i] = threadK[i - 1];
404  threadV[i] = threadV[i - 1];
405  }
406 
407  threadK[0] = k;
408  threadV[0] = v;
409  ++numVals;
410  }
411  }
412 
413  __device__ inline void checkThreadQ() {
414  bool needSort = (numVals == NumThreadQ);
415 
416 #if CUDA_VERSION >= 9000
417  needSort = __any_sync(0xffffffff, needSort);
418 #else
419  needSort = __any(needSort);
420 #endif
421 
422  if (!needSort) {
423  // no lanes have triggered a sort
424  return;
425  }
426 
427  mergeWarpQ();
428 
429  // Any top-k elements have been merged into the warp queue; we're
430  // free to reset the thread queues
431  numVals = 0;
432 
433 #pragma unroll
434  for (int i = 0; i < NumThreadQ; ++i) {
435  threadK[i] = initK;
436  threadV[i] = initV;
437  }
438 
439  // We have to beat at least this element
440  warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
441  }
442 
443  /// This function handles sorting and merging together the
444  /// per-thread queues with the warp-wide queue, creating a sorted
445  /// list across both
446  __device__ inline void mergeWarpQ() {
447  // Sort all of the per-thread queues
448  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
449 
450  // The warp queue is already sorted, and now that we've sorted the
451  // per-thread queue, merge both sorted lists together, producing
452  // one sorted list
453  warpMergeAnyRegisters<K, V,
454  kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>(
455  warpK, warpV, threadK, threadV);
456  }
457 
458  /// WARNING: all threads in a warp must participate in this.
459  /// Otherwise, you must call the constituent parts separately.
460  __device__ inline void add(K k, V v) {
461  addThreadQ(k, v);
462  checkThreadQ();
463  }
464 
465  __device__ inline void reduce() {
466  // Have all warps dump and merge their queues; this will produce
467  // the final per-warp results
468  mergeWarpQ();
469  }
470 
471  /// Dump final k selected values for this warp out
472  __device__ inline void writeOut(K* outK, V* outV, int k) {
473  int laneId = getLaneId();
474 
475 #pragma unroll
476  for (int i = 0; i < kNumWarpQRegisters; ++i) {
477  int idx = i * kWarpSize + laneId;
478 
479  if (idx < k) {
480  outK[idx] = warpK[i];
481  outV[idx] = warpV[i];
482  }
483  }
484  }
485 
486  // Default element key
487  const K initK;
488 
489  // Default element value
490  const V initV;
491 
492  // Number of valid elements in our thread queue
493  int numVals;
494 
495  // The k-th highest (Dir) or lowest (!Dir) element
496  K warpKTop;
497 
498  // Thread queue values
499  K threadK[NumThreadQ];
500  V threadV[NumThreadQ];
501 
502  // warpK[0] is highest (Dir) or lowest (!Dir)
503  K warpK[kNumWarpQRegisters];
504  V warpV[kNumWarpQRegisters];
505 
506  // This is what lane we should load an approximation (>=k) to the
507  // kth element from the last register in the warp queue (i.e.,
508  // warpK[kNumWarpQRegisters - 1]).
509  int kLane;
510 };
511 
512 /// Specialization for k == 1 (NumWarpQ == 1)
513 template <typename K,
514  typename V,
515  bool Dir,
516  typename Comp,
517  int NumThreadQ,
518  int ThreadsPerBlock>
519 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
520  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
521 
522  __device__ inline WarpSelect(K initK, V initV, int k) :
523  threadK(initK),
524  threadV(initV) {
525  }
526 
527  __device__ inline void addThreadQ(K k, V v) {
528  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
529  threadK = swap ? k : threadK;
530  threadV = swap ? v : threadV;
531  }
532 
533  __device__ inline void checkThreadQ() {
534  // We don't need to do anything here, since the warp doesn't
535  // cooperate until the end
536  }
537 
538  __device__ inline void add(K k, V v) {
539  addThreadQ(k, v);
540  }
541 
542  __device__ inline void reduce() {
543  // Reduce within the warp
544  Pair<K, V> pair(threadK, threadV);
545 
546  if (Dir) {
547  pair =
548  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
549  } else {
550  pair =
551  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
552  }
553 
554  threadK = pair.k;
555  threadV = pair.v;
556  }
557 
558  /// Dump final k selected values for this warp out
559  __device__ inline void writeOut(K* outK, V* outV, int k) {
560  if (getLaneId() == 0) {
561  *outK = threadK;
562  *outV = threadV;
563  }
564  }
565 
566  // threadK is lowest (Dir) or highest (!Dir)
567  K threadK;
568  V threadV;
569 };
570 
571 } } // namespace
A simple pair type for CUDA device usage.
Definition: Pair.cuh:19
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:559
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:472
__device__ void add(K k, V v)
Definition: Select.cuh:460
__device__ void add(K k, V v)
Definition: Select.cuh:213
__device__ void mergeWarpQ()
Definition: Select.cuh:446
__device__ void mergeWarpQ()
Definition: Select.cuh:176