Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Select.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include "Comparators.cuh"
12 #include "DeviceDefs.cuh"
13 #include "MergeNetworkBlock.cuh"
14 #include "MergeNetworkWarp.cuh"
15 #include "PtxUtils.cuh"
16 #include "Reductions.cuh"
17 #include "ReductionOperators.cuh"
18 #include "Tensor.cuh"
19 
20 namespace faiss { namespace gpu {
21 
22 // Specialization for block-wide monotonic merges producing a merge sort
23 // since what we really want is a constexpr loop expansion
24 template <int NumWarps,
25  int NumThreads, typename K, typename V, int NumWarpQ,
26  bool Dir, typename Comp>
28 };
29 
30 template <int NumThreads, typename K, typename V, int NumWarpQ,
31  bool Dir, typename Comp>
32 struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir, Comp> {
33  static inline __device__ void merge(K* sharedK, V* sharedV) {
34  // no merge required; single warp
35  }
36 };
37 
38 template <int NumThreads, typename K, typename V, int NumWarpQ,
39  bool Dir, typename Comp>
40 struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir, Comp> {
41  static inline __device__ void merge(K* sharedK, V* sharedV) {
42  // Final merge doesn't need to fully merge the second list
43  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
44  NumWarpQ, !Dir, Comp, false>(sharedK, sharedV);
45  }
46 };
47 
48 template <int NumThreads, typename K, typename V, int NumWarpQ,
49  bool Dir, typename Comp>
50 struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir, Comp> {
51  static inline __device__ void merge(K* sharedK, V* sharedV) {
52  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
53  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
54  // Final merge doesn't need to fully merge the second list
55  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
56  NumWarpQ * 2, !Dir, Comp, false>(sharedK, sharedV);
57  }
58 };
59 
60 template <int NumThreads, typename K, typename V, int NumWarpQ,
61  bool Dir, typename Comp>
62 struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir, Comp> {
63  static inline __device__ void merge(K* sharedK, V* sharedV) {
64  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
65  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
66  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
67  NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
68  // Final merge doesn't need to fully merge the second list
69  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
70  NumWarpQ * 4, !Dir, Comp, false>(sharedK, sharedV);
71  }
72 };
73 
74 // `Dir` true, produce largest values.
75 // `Dir` false, produce smallest values.
76 template <typename K,
77  typename V,
78  bool Dir,
79  typename Comp,
80  int NumWarpQ,
81  int NumThreadQ,
82  int ThreadsPerBlock>
83 struct BlockSelect {
84  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
85  static constexpr int kTotalWarpSortSize = NumWarpQ;
86 
87  __device__ inline BlockSelect(K initKVal,
88  V initVVal,
89  K* smemK,
90  V* smemV,
91  int k) :
92  initK(initKVal),
93  initV(initVVal),
94  numVals(0),
95  warpKTop(initKVal),
96  sharedK(smemK),
97  sharedV(smemV),
98  kMinus1(k - 1) {
99  static_assert(utils::isPowerOf2(ThreadsPerBlock),
100  "threads must be a power-of-2");
101  static_assert(utils::isPowerOf2(NumWarpQ),
102  "warp queue must be power-of-2");
103 
104  // Fill the per-thread queue keys with the default value
105 #pragma unroll
106  for (int i = 0; i < NumThreadQ; ++i) {
107  threadK[i] = initK;
108  threadV[i] = initV;
109  }
110 
111  int laneId = getLaneId();
112  int warpId = threadIdx.x / kWarpSize;
113  warpK = sharedK + warpId * kTotalWarpSortSize;
114  warpV = sharedV + warpId * kTotalWarpSortSize;
115 
116  // Fill warp queue (only the actual queue space is fine, not where
117  // we write the per-thread queues for merging)
118  for (int i = laneId; i < NumWarpQ; i += kWarpSize) {
119  warpK[i] = initK;
120  warpV[i] = initV;
121  }
122 
123  warpFence();
124  }
125 
126  __device__ inline void addThreadQ(K k, V v) {
127  if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
128  // Rotate right
129 #pragma unroll
130  for (int i = NumThreadQ - 1; i > 0; --i) {
131  threadK[i] = threadK[i - 1];
132  threadV[i] = threadV[i - 1];
133  }
134 
135  threadK[0] = k;
136  threadV[0] = v;
137  ++numVals;
138  }
139  }
140 
141  __device__ inline void checkThreadQ() {
142  bool needSort = (numVals == NumThreadQ);
143 
144 #if CUDA_VERSION >= 9000
145  needSort = __any_sync(0xffffffff, needSort);
146 #else
147  needSort = __any(needSort);
148 #endif
149 
150  if (!needSort) {
151  // no lanes have triggered a sort
152  return;
153  }
154 
155  // This has a trailing warpFence
156  mergeWarpQ();
157 
158  // Any top-k elements have been merged into the warp queue; we're
159  // free to reset the thread queues
160  numVals = 0;
161 
162 #pragma unroll
163  for (int i = 0; i < NumThreadQ; ++i) {
164  threadK[i] = initK;
165  threadV[i] = initV;
166  }
167 
168  // We have to beat at least this element
169  warpKTop = warpK[kMinus1];
170 
171  warpFence();
172  }
173 
174  /// This function handles sorting and merging together the
175  /// per-thread queues with the warp-wide queue, creating a sorted
176  /// list across both
177  __device__ inline void mergeWarpQ() {
178  int laneId = getLaneId();
179 
180  // Sort all of the per-thread queues
181  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
182 
183  constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
184  K warpKRegisters[kNumWarpQRegisters];
185  V warpVRegisters[kNumWarpQRegisters];
186 
187 #pragma unroll
188  for (int i = 0; i < kNumWarpQRegisters; ++i) {
189  warpKRegisters[i] = warpK[i * kWarpSize + laneId];
190  warpVRegisters[i] = warpV[i * kWarpSize + laneId];
191  }
192 
193  warpFence();
194 
195  // The warp queue is already sorted, and now that we've sorted the
196  // per-thread queue, merge both sorted lists together, producing
197  // one sorted list
198  warpMergeAnyRegisters<K, V,
199  kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>(
200  warpKRegisters, warpVRegisters, threadK, threadV);
201 
202  // Write back out the warp queue
203 #pragma unroll
204  for (int i = 0; i < kNumWarpQRegisters; ++i) {
205  warpK[i * kWarpSize + laneId] = warpKRegisters[i];
206  warpV[i * kWarpSize + laneId] = warpVRegisters[i];
207  }
208 
209  warpFence();
210  }
211 
212  /// WARNING: all threads in a warp must participate in this.
213  /// Otherwise, you must call the constituent parts separately.
214  __device__ inline void add(K k, V v) {
215  addThreadQ(k, v);
216  checkThreadQ();
217  }
218 
219  __device__ inline void reduce() {
220  // Have all warps dump and merge their queues; this will produce
221  // the final per-warp results
222  mergeWarpQ();
223 
224  // block-wide dep; thus far, all warps have been completely
225  // independent
226  __syncthreads();
227 
228  // All warp queues are contiguous in smem.
229  // Now, we have kNumWarps lists of NumWarpQ elements.
230  // This is a power of 2.
232  merge(sharedK, sharedV);
233 
234  // The block-wide merge has a trailing syncthreads
235  }
236 
237  // Default element key
238  const K initK;
239 
240  // Default element value
241  const V initV;
242 
243  // Number of valid elements in our thread queue
244  int numVals;
245 
246  // The k-th highest (Dir) or lowest (!Dir) element
247  K warpKTop;
248 
249  // Thread queue values
250  K threadK[NumThreadQ];
251  V threadV[NumThreadQ];
252 
253  // Queues for all warps
254  K* sharedK;
255  V* sharedV;
256 
257  // Our warp's queue (points into sharedK/sharedV)
258  // warpK[0] is highest (Dir) or lowest (!Dir)
259  K* warpK;
260  V* warpV;
261 
262  // This is a cached k-1 value
263  int kMinus1;
264 };
265 
266 /// Specialization for k == 1 (NumWarpQ == 1)
267 template <typename K,
268  typename V,
269  bool Dir,
270  typename Comp,
271  int NumThreadQ,
272  int ThreadsPerBlock>
273 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
274  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
275 
276  __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) :
277  sharedK(smemK),
278  sharedV(smemV),
279  threadK(initK),
280  threadV(initV) {
281  }
282 
283  __device__ inline void addThreadQ(K k, V v) {
284  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
285  threadK = swap ? k : threadK;
286  threadV = swap ? v : threadV;
287  }
288 
289  __device__ inline void checkThreadQ() {
290  // We don't need to do anything here, since the warp doesn't
291  // cooperate until the end
292  }
293 
294  __device__ inline void add(K k, V v) {
295  addThreadQ(k, v);
296  }
297 
298  __device__ inline void reduce() {
299  // Reduce within the warp
300  Pair<K, V> pair(threadK, threadV);
301 
302  if (Dir) {
303  pair =
304  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
305  } else {
306  pair =
307  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
308  }
309 
310  // Each warp writes out a single value
311  int laneId = getLaneId();
312  int warpId = threadIdx.x / kWarpSize;
313 
314  if (laneId == 0) {
315  sharedK[warpId] = pair.k;
316  sharedV[warpId] = pair.v;
317  }
318 
319  __syncthreads();
320 
321  // We typically use this for small blocks (<= 128), just having the first
322  // thread in the block perform the reduction across warps is
323  // faster
324  if (threadIdx.x == 0) {
325  threadK = sharedK[0];
326  threadV = sharedV[0];
327 
328 #pragma unroll
329  for (int i = 1; i < kNumWarps; ++i) {
330  K k = sharedK[i];
331  V v = sharedV[i];
332 
333  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
334  threadK = swap ? k : threadK;
335  threadV = swap ? v : threadV;
336  }
337 
338  // Hopefully a thread's smem reads/writes are ordered wrt
339  // itself, so no barrier needed :)
340  sharedK[0] = threadK;
341  sharedV[0] = threadV;
342  }
343 
344  // In case other threads wish to read this value
345  __syncthreads();
346  }
347 
348  // threadK is lowest (Dir) or highest (!Dir)
349  K threadK;
350  V threadV;
351 
352  // Where we reduce in smem
353  K* sharedK;
354  V* sharedV;
355 };
356 
357 //
358 // per-warp WarpSelect
359 //
360 
361 // `Dir` true, produce largest values.
362 // `Dir` false, produce smallest values.
363 template <typename K,
364  typename V,
365  bool Dir,
366  typename Comp,
367  int NumWarpQ,
368  int NumThreadQ,
369  int ThreadsPerBlock>
370 struct WarpSelect {
371  static constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
372 
373  __device__ inline WarpSelect(K initKVal, V initVVal, int k) :
374  initK(initKVal),
375  initV(initVVal),
376  numVals(0),
377  warpKTop(initKVal),
378  kLane((k - 1) % kWarpSize) {
379  static_assert(utils::isPowerOf2(ThreadsPerBlock),
380  "threads must be a power-of-2");
381  static_assert(utils::isPowerOf2(NumWarpQ),
382  "warp queue must be power-of-2");
383 
384  // Fill the per-thread queue keys with the default value
385 #pragma unroll
386  for (int i = 0; i < NumThreadQ; ++i) {
387  threadK[i] = initK;
388  threadV[i] = initV;
389  }
390 
391  // Fill the warp queue with the default value
392 #pragma unroll
393  for (int i = 0; i < kNumWarpQRegisters; ++i) {
394  warpK[i] = initK;
395  warpV[i] = initV;
396  }
397  }
398 
399  __device__ inline void addThreadQ(K k, V v) {
400  if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
401  // Rotate right
402 #pragma unroll
403  for (int i = NumThreadQ - 1; i > 0; --i) {
404  threadK[i] = threadK[i - 1];
405  threadV[i] = threadV[i - 1];
406  }
407 
408  threadK[0] = k;
409  threadV[0] = v;
410  ++numVals;
411  }
412  }
413 
414  __device__ inline void checkThreadQ() {
415  bool needSort = (numVals == NumThreadQ);
416 
417 #if CUDA_VERSION >= 9000
418  needSort = __any_sync(0xffffffff, needSort);
419 #else
420  needSort = __any(needSort);
421 #endif
422 
423  if (!needSort) {
424  // no lanes have triggered a sort
425  return;
426  }
427 
428  mergeWarpQ();
429 
430  // Any top-k elements have been merged into the warp queue; we're
431  // free to reset the thread queues
432  numVals = 0;
433 
434 #pragma unroll
435  for (int i = 0; i < NumThreadQ; ++i) {
436  threadK[i] = initK;
437  threadV[i] = initV;
438  }
439 
440  // We have to beat at least this element
441  warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
442  }
443 
444  /// This function handles sorting and merging together the
445  /// per-thread queues with the warp-wide queue, creating a sorted
446  /// list across both
447  __device__ inline void mergeWarpQ() {
448  // Sort all of the per-thread queues
449  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
450 
451  // The warp queue is already sorted, and now that we've sorted the
452  // per-thread queue, merge both sorted lists together, producing
453  // one sorted list
454  warpMergeAnyRegisters<K, V,
455  kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>(
456  warpK, warpV, threadK, threadV);
457  }
458 
459  /// WARNING: all threads in a warp must participate in this.
460  /// Otherwise, you must call the constituent parts separately.
461  __device__ inline void add(K k, V v) {
462  addThreadQ(k, v);
463  checkThreadQ();
464  }
465 
466  __device__ inline void reduce() {
467  // Have all warps dump and merge their queues; this will produce
468  // the final per-warp results
469  mergeWarpQ();
470  }
471 
472  /// Dump final k selected values for this warp out
473  __device__ inline void writeOut(K* outK, V* outV, int k) {
474  int laneId = getLaneId();
475 
476 #pragma unroll
477  for (int i = 0; i < kNumWarpQRegisters; ++i) {
478  int idx = i * kWarpSize + laneId;
479 
480  if (idx < k) {
481  outK[idx] = warpK[i];
482  outV[idx] = warpV[i];
483  }
484  }
485  }
486 
487  // Default element key
488  const K initK;
489 
490  // Default element value
491  const V initV;
492 
493  // Number of valid elements in our thread queue
494  int numVals;
495 
496  // The k-th highest (Dir) or lowest (!Dir) element
497  K warpKTop;
498 
499  // Thread queue values
500  K threadK[NumThreadQ];
501  V threadV[NumThreadQ];
502 
503  // warpK[0] is highest (Dir) or lowest (!Dir)
504  K warpK[kNumWarpQRegisters];
505  V warpV[kNumWarpQRegisters];
506 
507  // This is what lane we should load an approximation (>=k) to the
508  // kth element from the last register in the warp queue (i.e.,
509  // warpK[kNumWarpQRegisters - 1]).
510  int kLane;
511 };
512 
513 /// Specialization for k == 1 (NumWarpQ == 1)
514 template <typename K,
515  typename V,
516  bool Dir,
517  typename Comp,
518  int NumThreadQ,
519  int ThreadsPerBlock>
520 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
521  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
522 
523  __device__ inline WarpSelect(K initK, V initV, int k) :
524  threadK(initK),
525  threadV(initV) {
526  }
527 
528  __device__ inline void addThreadQ(K k, V v) {
529  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
530  threadK = swap ? k : threadK;
531  threadV = swap ? v : threadV;
532  }
533 
534  __device__ inline void checkThreadQ() {
535  // We don't need to do anything here, since the warp doesn't
536  // cooperate until the end
537  }
538 
539  __device__ inline void add(K k, V v) {
540  addThreadQ(k, v);
541  }
542 
543  __device__ inline void reduce() {
544  // Reduce within the warp
545  Pair<K, V> pair(threadK, threadV);
546 
547  if (Dir) {
548  pair =
549  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
550  } else {
551  pair =
552  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
553  }
554 
555  threadK = pair.k;
556  threadV = pair.v;
557  }
558 
559  /// Dump final k selected values for this warp out
560  __device__ inline void writeOut(K* outK, V* outV, int k) {
561  if (getLaneId() == 0) {
562  *outK = threadK;
563  *outV = threadV;
564  }
565  }
566 
567  // threadK is lowest (Dir) or highest (!Dir)
568  K threadK;
569  V threadV;
570 };
571 
572 } } // namespace
A simple pair type for CUDA device usage.
Definition: Pair.cuh:20
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:560
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:473
__device__ void add(K k, V v)
Definition: Select.cuh:461
__device__ void add(K k, V v)
Definition: Select.cuh:214
__device__ void mergeWarpQ()
Definition: Select.cuh:447
__device__ void mergeWarpQ()
Definition: Select.cuh:177