Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Select.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #pragma once
11 
12 #include "Comparators.cuh"
13 #include "DeviceDefs.cuh"
14 #include "MergeNetworkBlock.cuh"
15 #include "MergeNetworkWarp.cuh"
16 #include "PtxUtils.cuh"
17 #include "Reductions.cuh"
18 #include "ReductionOperators.cuh"
19 #include "Tensor.cuh"
20 
21 namespace faiss { namespace gpu {
22 
23 // Specialization for block-wide monotonic merges producing a merge sort
24 // since what we really want is a constexpr loop expansion
25 template <int NumWarps,
26  int NumThreads, typename K, typename V, int NumWarpQ,
27  bool Dir, typename Comp>
29 };
30 
31 template <int NumThreads, typename K, typename V, int NumWarpQ,
32  bool Dir, typename Comp>
33 struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir, Comp> {
34  static inline __device__ void merge(K* sharedK, V* sharedV) {
35  // no merge required; single warp
36  }
37 };
38 
39 template <int NumThreads, typename K, typename V, int NumWarpQ,
40  bool Dir, typename Comp>
41 struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir, Comp> {
42  static inline __device__ void merge(K* sharedK, V* sharedV) {
43  // Final merge doesn't need to fully merge the second list
44  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
45  NumWarpQ, !Dir, Comp, false>(sharedK, sharedV);
46  }
47 };
48 
49 template <int NumThreads, typename K, typename V, int NumWarpQ,
50  bool Dir, typename Comp>
51 struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir, Comp> {
52  static inline __device__ void merge(K* sharedK, V* sharedV) {
53  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
54  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
55  // Final merge doesn't need to fully merge the second list
56  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
57  NumWarpQ * 2, !Dir, Comp, false>(sharedK, sharedV);
58  }
59 };
60 
61 template <int NumThreads, typename K, typename V, int NumWarpQ,
62  bool Dir, typename Comp>
63 struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir, Comp> {
64  static inline __device__ void merge(K* sharedK, V* sharedV) {
65  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
66  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
67  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
68  NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
69  // Final merge doesn't need to fully merge the second list
70  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
71  NumWarpQ * 4, !Dir, Comp, false>(sharedK, sharedV);
72  }
73 };
74 
75 // `Dir` true, produce largest values.
76 // `Dir` false, produce smallest values.
77 template <typename K,
78  typename V,
79  bool Dir,
80  typename Comp,
81  int NumWarpQ,
82  int NumThreadQ,
83  int ThreadsPerBlock>
84 struct BlockSelect {
85  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
86  static constexpr int kTotalWarpSortSize = NumWarpQ;
87 
88  __device__ inline BlockSelect(K initKVal,
89  V initVVal,
90  K* smemK,
91  V* smemV,
92  int k) :
93  initK(initKVal),
94  initV(initVVal),
95  numVals(0),
96  warpKTop(initKVal),
97  sharedK(smemK),
98  sharedV(smemV),
99  kMinus1(k - 1) {
100  static_assert(utils::isPowerOf2(ThreadsPerBlock),
101  "threads must be a power-of-2");
102  static_assert(utils::isPowerOf2(NumWarpQ),
103  "warp queue must be power-of-2");
104 
105  // Fill the per-thread queue keys with the default value
106 #pragma unroll
107  for (int i = 0; i < NumThreadQ; ++i) {
108  threadK[i] = initK;
109  threadV[i] = initV;
110  }
111 
112  int laneId = getLaneId();
113  int warpId = threadIdx.x / kWarpSize;
114  warpK = sharedK + warpId * kTotalWarpSortSize;
115  warpV = sharedV + warpId * kTotalWarpSortSize;
116 
117  // Fill warp queue (only the actual queue space is fine, not where
118  // we write the per-thread queues for merging)
119  for (int i = laneId; i < NumWarpQ; i += kWarpSize) {
120  warpK[i] = initK;
121  warpV[i] = initV;
122  }
123 
124  warpFence();
125  }
126 
127  __device__ inline void addThreadQ(K k, V v) {
128  if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
129  // Rotate right
130 #pragma unroll
131  for (int i = NumThreadQ - 1; i > 0; --i) {
132  threadK[i] = threadK[i - 1];
133  threadV[i] = threadV[i - 1];
134  }
135 
136  threadK[0] = k;
137  threadV[0] = v;
138  ++numVals;
139  }
140  }
141 
142  __device__ inline void checkThreadQ() {
143  bool needSort = (numVals == NumThreadQ);
144 
145 #if CUDA_VERSION >= 9000
146  needSort = __any_sync(0xffffffff, needSort);
147 #else
148  needSort = __any(needSort);
149 #endif
150 
151  if (!needSort) {
152  // no lanes have triggered a sort
153  return;
154  }
155 
156  // This has a trailing warpFence
157  mergeWarpQ();
158 
159  // Any top-k elements have been merged into the warp queue; we're
160  // free to reset the thread queues
161  numVals = 0;
162 
163 #pragma unroll
164  for (int i = 0; i < NumThreadQ; ++i) {
165  threadK[i] = initK;
166  threadV[i] = initV;
167  }
168 
169  // We have to beat at least this element
170  warpKTop = warpK[kMinus1];
171 
172  warpFence();
173  }
174 
175  /// This function handles sorting and merging together the
176  /// per-thread queues with the warp-wide queue, creating a sorted
177  /// list across both
178  __device__ inline void mergeWarpQ() {
179  int laneId = getLaneId();
180 
181  // Sort all of the per-thread queues
182  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
183 
184  constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
185  K warpKRegisters[kNumWarpQRegisters];
186  V warpVRegisters[kNumWarpQRegisters];
187 
188 #pragma unroll
189  for (int i = 0; i < kNumWarpQRegisters; ++i) {
190  warpKRegisters[i] = warpK[i * kWarpSize + laneId];
191  warpVRegisters[i] = warpV[i * kWarpSize + laneId];
192  }
193 
194  warpFence();
195 
196  // The warp queue is already sorted, and now that we've sorted the
197  // per-thread queue, merge both sorted lists together, producing
198  // one sorted list
199  warpMergeAnyRegisters<K, V,
200  kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>(
201  warpKRegisters, warpVRegisters, threadK, threadV);
202 
203  // Write back out the warp queue
204 #pragma unroll
205  for (int i = 0; i < kNumWarpQRegisters; ++i) {
206  warpK[i * kWarpSize + laneId] = warpKRegisters[i];
207  warpV[i * kWarpSize + laneId] = warpVRegisters[i];
208  }
209 
210  warpFence();
211  }
212 
213  /// WARNING: all threads in a warp must participate in this.
214  /// Otherwise, you must call the constituent parts separately.
215  __device__ inline void add(K k, V v) {
216  addThreadQ(k, v);
217  checkThreadQ();
218  }
219 
220  __device__ inline void reduce() {
221  // Have all warps dump and merge their queues; this will produce
222  // the final per-warp results
223  mergeWarpQ();
224 
225  // block-wide dep; thus far, all warps have been completely
226  // independent
227  __syncthreads();
228 
229  // All warp queues are contiguous in smem.
230  // Now, we have kNumWarps lists of NumWarpQ elements.
231  // This is a power of 2.
233  merge(sharedK, sharedV);
234 
235  // The block-wide merge has a trailing syncthreads
236  }
237 
238  // Default element key
239  const K initK;
240 
241  // Default element value
242  const V initV;
243 
244  // Number of valid elements in our thread queue
245  int numVals;
246 
247  // The k-th highest (Dir) or lowest (!Dir) element
248  K warpKTop;
249 
250  // Thread queue values
251  K threadK[NumThreadQ];
252  V threadV[NumThreadQ];
253 
254  // Queues for all warps
255  K* sharedK;
256  V* sharedV;
257 
258  // Our warp's queue (points into sharedK/sharedV)
259  // warpK[0] is highest (Dir) or lowest (!Dir)
260  K* warpK;
261  V* warpV;
262 
263  // This is a cached k-1 value
264  int kMinus1;
265 };
266 
267 /// Specialization for k == 1 (NumWarpQ == 1)
268 template <typename K,
269  typename V,
270  bool Dir,
271  typename Comp,
272  int NumThreadQ,
273  int ThreadsPerBlock>
274 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
275  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
276 
277  __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) :
278  sharedK(smemK),
279  sharedV(smemV),
280  threadK(initK),
281  threadV(initV) {
282  }
283 
284  __device__ inline void addThreadQ(K k, V v) {
285  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
286  threadK = swap ? k : threadK;
287  threadV = swap ? v : threadV;
288  }
289 
290  __device__ inline void checkThreadQ() {
291  // We don't need to do anything here, since the warp doesn't
292  // cooperate until the end
293  }
294 
295  __device__ inline void add(K k, V v) {
296  addThreadQ(k, v);
297  }
298 
299  __device__ inline void reduce() {
300  // Reduce within the warp
301  Pair<K, V> pair(threadK, threadV);
302 
303  if (Dir) {
304  pair =
305  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
306  } else {
307  pair =
308  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
309  }
310 
311  // Each warp writes out a single value
312  int laneId = getLaneId();
313  int warpId = threadIdx.x / kWarpSize;
314 
315  if (laneId == 0) {
316  sharedK[warpId] = pair.k;
317  sharedV[warpId] = pair.v;
318  }
319 
320  __syncthreads();
321 
322  // We typically use this for small blocks (<= 128), just having the first
323  // thread in the block perform the reduction across warps is
324  // faster
325  if (threadIdx.x == 0) {
326  threadK = sharedK[0];
327  threadV = sharedV[0];
328 
329 #pragma unroll
330  for (int i = 1; i < kNumWarps; ++i) {
331  K k = sharedK[i];
332  V v = sharedV[i];
333 
334  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
335  threadK = swap ? k : threadK;
336  threadV = swap ? v : threadV;
337  }
338 
339  // Hopefully a thread's smem reads/writes are ordered wrt
340  // itself, so no barrier needed :)
341  sharedK[0] = threadK;
342  sharedV[0] = threadV;
343  }
344 
345  // In case other threads wish to read this value
346  __syncthreads();
347  }
348 
349  // threadK is lowest (Dir) or highest (!Dir)
350  K threadK;
351  V threadV;
352 
353  // Where we reduce in smem
354  K* sharedK;
355  V* sharedV;
356 };
357 
358 //
359 // per-warp WarpSelect
360 //
361 
362 // `Dir` true, produce largest values.
363 // `Dir` false, produce smallest values.
364 template <typename K,
365  typename V,
366  bool Dir,
367  typename Comp,
368  int NumWarpQ,
369  int NumThreadQ,
370  int ThreadsPerBlock>
371 struct WarpSelect {
372  static constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
373 
374  __device__ inline WarpSelect(K initKVal, V initVVal, int k) :
375  initK(initKVal),
376  initV(initVVal),
377  numVals(0),
378  warpKTop(initKVal),
379  kLane((k - 1) % kWarpSize) {
380  static_assert(utils::isPowerOf2(ThreadsPerBlock),
381  "threads must be a power-of-2");
382  static_assert(utils::isPowerOf2(NumWarpQ),
383  "warp queue must be power-of-2");
384 
385  // Fill the per-thread queue keys with the default value
386 #pragma unroll
387  for (int i = 0; i < NumThreadQ; ++i) {
388  threadK[i] = initK;
389  threadV[i] = initV;
390  }
391 
392  // Fill the warp queue with the default value
393 #pragma unroll
394  for (int i = 0; i < kNumWarpQRegisters; ++i) {
395  warpK[i] = initK;
396  warpV[i] = initV;
397  }
398  }
399 
400  __device__ inline void addThreadQ(K k, V v) {
401  if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
402  // Rotate right
403 #pragma unroll
404  for (int i = NumThreadQ - 1; i > 0; --i) {
405  threadK[i] = threadK[i - 1];
406  threadV[i] = threadV[i - 1];
407  }
408 
409  threadK[0] = k;
410  threadV[0] = v;
411  ++numVals;
412  }
413  }
414 
415  __device__ inline void checkThreadQ() {
416  bool needSort = (numVals == NumThreadQ);
417 
418 #if CUDA_VERSION >= 9000
419  needSort = __any_sync(0xffffffff, needSort);
420 #else
421  needSort = __any(needSort);
422 #endif
423 
424  if (!needSort) {
425  // no lanes have triggered a sort
426  return;
427  }
428 
429  mergeWarpQ();
430 
431  // Any top-k elements have been merged into the warp queue; we're
432  // free to reset the thread queues
433  numVals = 0;
434 
435 #pragma unroll
436  for (int i = 0; i < NumThreadQ; ++i) {
437  threadK[i] = initK;
438  threadV[i] = initV;
439  }
440 
441  // We have to beat at least this element
442  warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
443  }
444 
445  /// This function handles sorting and merging together the
446  /// per-thread queues with the warp-wide queue, creating a sorted
447  /// list across both
448  __device__ inline void mergeWarpQ() {
449  // Sort all of the per-thread queues
450  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
451 
452  // The warp queue is already sorted, and now that we've sorted the
453  // per-thread queue, merge both sorted lists together, producing
454  // one sorted list
455  warpMergeAnyRegisters<K, V,
456  kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>(
457  warpK, warpV, threadK, threadV);
458  }
459 
460  /// WARNING: all threads in a warp must participate in this.
461  /// Otherwise, you must call the constituent parts separately.
462  __device__ inline void add(K k, V v) {
463  addThreadQ(k, v);
464  checkThreadQ();
465  }
466 
467  __device__ inline void reduce() {
468  // Have all warps dump and merge their queues; this will produce
469  // the final per-warp results
470  mergeWarpQ();
471  }
472 
473  /// Dump final k selected values for this warp out
474  __device__ inline void writeOut(K* outK, V* outV, int k) {
475  int laneId = getLaneId();
476 
477 #pragma unroll
478  for (int i = 0; i < kNumWarpQRegisters; ++i) {
479  int idx = i * kWarpSize + laneId;
480 
481  if (idx < k) {
482  outK[idx] = warpK[i];
483  outV[idx] = warpV[i];
484  }
485  }
486  }
487 
488  // Default element key
489  const K initK;
490 
491  // Default element value
492  const V initV;
493 
494  // Number of valid elements in our thread queue
495  int numVals;
496 
497  // The k-th highest (Dir) or lowest (!Dir) element
498  K warpKTop;
499 
500  // Thread queue values
501  K threadK[NumThreadQ];
502  V threadV[NumThreadQ];
503 
504  // warpK[0] is highest (Dir) or lowest (!Dir)
505  K warpK[kNumWarpQRegisters];
506  V warpV[kNumWarpQRegisters];
507 
508  // This is what lane we should load an approximation (>=k) to the
509  // kth element from the last register in the warp queue (i.e.,
510  // warpK[kNumWarpQRegisters - 1]).
511  int kLane;
512 };
513 
514 /// Specialization for k == 1 (NumWarpQ == 1)
515 template <typename K,
516  typename V,
517  bool Dir,
518  typename Comp,
519  int NumThreadQ,
520  int ThreadsPerBlock>
521 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
522  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
523 
524  __device__ inline WarpSelect(K initK, V initV, int k) :
525  threadK(initK),
526  threadV(initV) {
527  }
528 
529  __device__ inline void addThreadQ(K k, V v) {
530  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
531  threadK = swap ? k : threadK;
532  threadV = swap ? v : threadV;
533  }
534 
535  __device__ inline void checkThreadQ() {
536  // We don't need to do anything here, since the warp doesn't
537  // cooperate until the end
538  }
539 
540  __device__ inline void add(K k, V v) {
541  addThreadQ(k, v);
542  }
543 
544  __device__ inline void reduce() {
545  // Reduce within the warp
546  Pair<K, V> pair(threadK, threadV);
547 
548  if (Dir) {
549  pair =
550  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
551  } else {
552  pair =
553  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
554  }
555 
556  threadK = pair.k;
557  threadV = pair.v;
558  }
559 
560  /// Dump final k selected values for this warp out
561  __device__ inline void writeOut(K* outK, V* outV, int k) {
562  if (getLaneId() == 0) {
563  *outK = threadK;
564  *outV = threadV;
565  }
566  }
567 
568  // threadK is lowest (Dir) or highest (!Dir)
569  K threadK;
570  V threadV;
571 };
572 
573 } } // namespace
A simple pair type for CUDA device usage.
Definition: Pair.cuh:21
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:561
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:474
__device__ void add(K k, V v)
Definition: Select.cuh:462
__device__ void add(K k, V v)
Definition: Select.cuh:215
__device__ void mergeWarpQ()
Definition: Select.cuh:448
__device__ void mergeWarpQ()
Definition: Select.cuh:178