Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Select.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #pragma once
11 
12 #include "Comparators.cuh"
13 #include "DeviceDefs.cuh"
14 #include "MergeNetworkBlock.cuh"
15 #include "MergeNetworkWarp.cuh"
16 #include "PtxUtils.cuh"
17 #include "Reductions.cuh"
18 #include "ReductionOperators.cuh"
19 #include "Tensor.cuh"
20 
21 namespace faiss { namespace gpu {
22 
23 // Specialization for block-wide monotonic merges producing a merge sort
24 // since what we really want is a constexpr loop expansion
25 template <int NumWarps,
26  int NumThreads, typename K, typename V, int NumWarpQ,
27  bool Dir, typename Comp>
29 };
30 
31 template <int NumThreads, typename K, typename V, int NumWarpQ,
32  bool Dir, typename Comp>
33 struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir, Comp> {
34  static inline __device__ void merge(K* sharedK, V* sharedV) {
35  // no merge required; single warp
36  }
37 };
38 
39 template <int NumThreads, typename K, typename V, int NumWarpQ,
40  bool Dir, typename Comp>
41 struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir, Comp> {
42  static inline __device__ void merge(K* sharedK, V* sharedV) {
43  // Final merge doesn't need to fully merge the second list
44  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
45  NumWarpQ, !Dir, Comp, false>(sharedK, sharedV);
46  }
47 };
48 
49 template <int NumThreads, typename K, typename V, int NumWarpQ,
50  bool Dir, typename Comp>
51 struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir, Comp> {
52  static inline __device__ void merge(K* sharedK, V* sharedV) {
53  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
54  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
55  // Final merge doesn't need to fully merge the second list
56  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
57  NumWarpQ * 2, !Dir, Comp, false>(sharedK, sharedV);
58  }
59 };
60 
61 template <int NumThreads, typename K, typename V, int NumWarpQ,
62  bool Dir, typename Comp>
63 struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir, Comp> {
64  static inline __device__ void merge(K* sharedK, V* sharedV) {
65  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
66  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
67  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
68  NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
69  // Final merge doesn't need to fully merge the second list
70  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
71  NumWarpQ * 4, !Dir, Comp, false>(sharedK, sharedV);
72  }
73 };
74 
75 // `Dir` true, produce largest values.
76 // `Dir` false, produce smallest values.
77 template <typename K,
78  typename V,
79  bool Dir,
80  typename Comp,
81  int NumWarpQ,
82  int NumThreadQ,
83  int ThreadsPerBlock>
84 struct BlockSelect {
85  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
86  static constexpr int kTotalWarpSortSize = NumWarpQ;
87 
88  __device__ inline BlockSelect(K initKVal,
89  V initVVal,
90  K* smemK,
91  V* smemV,
92  int k) :
93  initK(initKVal),
94  initV(initVVal),
95  numVals(0),
96  warpKTop(initKVal),
97  sharedK(smemK),
98  sharedV(smemV),
99  kMinus1(k - 1) {
100  static_assert(utils::isPowerOf2(ThreadsPerBlock),
101  "threads must be a power-of-2");
102  static_assert(utils::isPowerOf2(NumWarpQ),
103  "warp queue must be power-of-2");
104 
105  // Fill the per-thread queue keys with the default value
106 #pragma unroll
107  for (int i = 0; i < NumThreadQ; ++i) {
108  threadK[i] = initK;
109  threadV[i] = initV;
110  }
111 
112  int laneId = getLaneId();
113  int warpId = threadIdx.x / kWarpSize;
114  warpK = sharedK + warpId * kTotalWarpSortSize;
115  warpV = sharedV + warpId * kTotalWarpSortSize;
116 
117  // Fill warp queue (only the actual queue space is fine, not where
118  // we write the per-thread queues for merging)
119  for (int i = laneId; i < NumWarpQ; i += kWarpSize) {
120  warpK[i] = initK;
121  warpV[i] = initV;
122  }
123 
124  warpFence();
125  }
126 
127  __device__ inline void addThreadQ(K k, V v) {
128  if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
129  // Rotate right
130 #pragma unroll
131  for (int i = NumThreadQ - 1; i > 0; --i) {
132  threadK[i] = threadK[i - 1];
133  threadV[i] = threadV[i - 1];
134  }
135 
136  threadK[0] = k;
137  threadV[0] = v;
138  ++numVals;
139  }
140  }
141 
142  __device__ inline void checkThreadQ() {
143  bool needSort = (numVals == NumThreadQ);
144 
145  if (!__any(needSort)) {
146  return;
147  }
148 
149  // This has a trailing warpFence
150  mergeWarpQ();
151 
152  // Any top-k elements have been merged into the warp queue; we're
153  // free to reset the thread queues
154  numVals = 0;
155 
156 #pragma unroll
157  for (int i = 0; i < NumThreadQ; ++i) {
158  threadK[i] = initK;
159  threadV[i] = initV;
160  }
161 
162  // We have to beat at least this element
163  warpKTop = warpK[kMinus1];
164 
165  warpFence();
166  }
167 
168  /// This function handles sorting and merging together the
169  /// per-thread queues with the warp-wide queue, creating a sorted
170  /// list across both
171  __device__ inline void mergeWarpQ() {
172  int laneId = getLaneId();
173 
174  // Sort all of the per-thread queues
175  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
176 
177  constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
178  K warpKRegisters[kNumWarpQRegisters];
179  V warpVRegisters[kNumWarpQRegisters];
180 
181 #pragma unroll
182  for (int i = 0; i < kNumWarpQRegisters; ++i) {
183  warpKRegisters[i] = warpK[i * kWarpSize + laneId];
184  warpVRegisters[i] = warpV[i * kWarpSize + laneId];
185  }
186 
187  warpFence();
188 
189  // The warp queue is already sorted, and now that we've sorted the
190  // per-thread queue, merge both sorted lists together, producing
191  // one sorted list
192  warpMergeAnyRegisters<K, V,
193  kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>(
194  warpKRegisters, warpVRegisters, threadK, threadV);
195 
196  // Write back out the warp queue
197 #pragma unroll
198  for (int i = 0; i < kNumWarpQRegisters; ++i) {
199  warpK[i * kWarpSize + laneId] = warpKRegisters[i];
200  warpV[i * kWarpSize + laneId] = warpVRegisters[i];
201  }
202 
203  warpFence();
204  }
205 
206  /// WARNING: all threads in a warp must participate in this.
207  /// Otherwise, you must call the constituent parts separately.
208  __device__ inline void add(K k, V v) {
209  addThreadQ(k, v);
210  checkThreadQ();
211  }
212 
213  __device__ inline void reduce() {
214  // Have all warps dump and merge their queues; this will produce
215  // the final per-warp results
216  mergeWarpQ();
217 
218  // block-wide dep; thus far, all warps have been completely
219  // independent
220  __syncthreads();
221 
222  // All warp queues are contiguous in smem.
223  // Now, we have kNumWarps lists of NumWarpQ elements.
224  // This is a power of 2.
226  merge(sharedK, sharedV);
227 
228  // The block-wide merge has a trailing syncthreads
229  }
230 
231  // Default element key
232  const K initK;
233 
234  // Default element value
235  const V initV;
236 
237  // Number of valid elements in our thread queue
238  int numVals;
239 
240  // The k-th highest (Dir) or lowest (!Dir) element
241  K warpKTop;
242 
243  // Thread queue values
244  K threadK[NumThreadQ];
245  V threadV[NumThreadQ];
246 
247  // Queues for all warps
248  K* sharedK;
249  V* sharedV;
250 
251  // Our warp's queue (points into sharedK/sharedV)
252  // warpK[0] is highest (Dir) or lowest (!Dir)
253  K* warpK;
254  V* warpV;
255 
256  // This is a cached k-1 value
257  int kMinus1;
258 };
259 
260 /// Specialization for k == 1 (NumWarpQ == 1)
261 template <typename K,
262  typename V,
263  bool Dir,
264  typename Comp,
265  int NumThreadQ,
266  int ThreadsPerBlock>
267 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
268  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
269 
270  __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) :
271  sharedK(smemK),
272  sharedV(smemV),
273  threadK(initK),
274  threadV(initV) {
275  }
276 
277  __device__ inline void addThreadQ(K k, V v) {
278  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
279  threadK = swap ? k : threadK;
280  threadV = swap ? v : threadV;
281  }
282 
283  __device__ inline void checkThreadQ() {
284  // We don't need to do anything here, since the warp doesn't
285  // cooperate until the end
286  }
287 
288  __device__ inline void add(K k, V v) {
289  addThreadQ(k, v);
290  }
291 
292  __device__ inline void reduce() {
293  // Reduce within the warp
294  Pair<K, V> pair(threadK, threadV);
295 
296  if (Dir) {
297  pair =
298  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
299  } else {
300  pair =
301  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
302  }
303 
304  // Each warp writes out a single value
305  int laneId = getLaneId();
306  int warpId = threadIdx.x / kWarpSize;
307 
308  if (laneId == 0) {
309  sharedK[warpId] = pair.k;
310  sharedV[warpId] = pair.v;
311  }
312 
313  __syncthreads();
314 
315  // We typically use this for small blocks (<= 128), just having the first
316  // thread in the block perform the reduction across warps is
317  // faster
318  if (threadIdx.x == 0) {
319  threadK = sharedK[0];
320  threadV = sharedV[0];
321 
322 #pragma unroll
323  for (int i = 1; i < kNumWarps; ++i) {
324  K k = sharedK[i];
325  V v = sharedV[i];
326 
327  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
328  threadK = swap ? k : threadK;
329  threadV = swap ? v : threadV;
330  }
331 
332  // Hopefully a thread's smem reads/writes are ordered wrt
333  // itself, so no barrier needed :)
334  sharedK[0] = threadK;
335  sharedV[0] = threadV;
336  }
337 
338  // In case other threads wish to read this value
339  __syncthreads();
340  }
341 
342  // threadK is lowest (Dir) or highest (!Dir)
343  K threadK;
344  V threadV;
345 
346  // Where we reduce in smem
347  K* sharedK;
348  V* sharedV;
349 };
350 
351 //
352 // per-warp WarpSelect
353 //
354 
355 // `Dir` true, produce largest values.
356 // `Dir` false, produce smallest values.
357 template <typename K,
358  typename V,
359  bool Dir,
360  typename Comp,
361  int NumWarpQ,
362  int NumThreadQ,
363  int ThreadsPerBlock>
364 struct WarpSelect {
365  static constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
366 
367  __device__ inline WarpSelect(K initKVal, V initVVal, int k) :
368  initK(initKVal),
369  initV(initVVal),
370  numVals(0),
371  warpKTop(initKVal),
372  kLane((k - 1) % kWarpSize) {
373  static_assert(utils::isPowerOf2(ThreadsPerBlock),
374  "threads must be a power-of-2");
375  static_assert(utils::isPowerOf2(NumWarpQ),
376  "warp queue must be power-of-2");
377 
378  // Fill the per-thread queue keys with the default value
379 #pragma unroll
380  for (int i = 0; i < NumThreadQ; ++i) {
381  threadK[i] = initK;
382  threadV[i] = initV;
383  }
384 
385  // Fill the warp queue with the default value
386 #pragma unroll
387  for (int i = 0; i < kNumWarpQRegisters; ++i) {
388  warpK[i] = initK;
389  warpV[i] = initV;
390  }
391  }
392 
393  __device__ inline void addThreadQ(K k, V v) {
394  if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
395  // Rotate right
396 #pragma unroll
397  for (int i = NumThreadQ - 1; i > 0; --i) {
398  threadK[i] = threadK[i - 1];
399  threadV[i] = threadV[i - 1];
400  }
401 
402  threadK[0] = k;
403  threadV[0] = v;
404  ++numVals;
405  }
406  }
407 
408  __device__ inline void checkThreadQ() {
409  bool needSort = (numVals == NumThreadQ);
410 
411  if (!__any(needSort)) {
412  return;
413  }
414 
415  mergeWarpQ();
416 
417  // Any top-k elements have been merged into the warp queue; we're
418  // free to reset the thread queues
419  numVals = 0;
420 
421 #pragma unroll
422  for (int i = 0; i < NumThreadQ; ++i) {
423  threadK[i] = initK;
424  threadV[i] = initV;
425  }
426 
427  // We have to beat at least this element
428  warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
429  }
430 
431  /// This function handles sorting and merging together the
432  /// per-thread queues with the warp-wide queue, creating a sorted
433  /// list across both
434  __device__ inline void mergeWarpQ() {
435  // Sort all of the per-thread queues
436  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
437 
438  // The warp queue is already sorted, and now that we've sorted the
439  // per-thread queue, merge both sorted lists together, producing
440  // one sorted list
441  warpMergeAnyRegisters<K, V,
442  kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>(
443  warpK, warpV, threadK, threadV);
444  }
445 
446  /// WARNING: all threads in a warp must participate in this.
447  /// Otherwise, you must call the constituent parts separately.
448  __device__ inline void add(K k, V v) {
449  addThreadQ(k, v);
450  checkThreadQ();
451  }
452 
453  __device__ inline void reduce() {
454  // Have all warps dump and merge their queues; this will produce
455  // the final per-warp results
456  mergeWarpQ();
457  }
458 
459  /// Dump final k selected values for this warp out
460  __device__ inline void writeOut(K* outK, V* outV, int k) {
461  int laneId = getLaneId();
462 
463 #pragma unroll
464  for (int i = 0; i < kNumWarpQRegisters; ++i) {
465  int idx = i * kWarpSize + laneId;
466 
467  if (idx < k) {
468  outK[idx] = warpK[i];
469  outV[idx] = warpV[i];
470  }
471  }
472  }
473 
474  // Default element key
475  const K initK;
476 
477  // Default element value
478  const V initV;
479 
480  // Number of valid elements in our thread queue
481  int numVals;
482 
483  // The k-th highest (Dir) or lowest (!Dir) element
484  K warpKTop;
485 
486  // Thread queue values
487  K threadK[NumThreadQ];
488  V threadV[NumThreadQ];
489 
490  // warpK[0] is highest (Dir) or lowest (!Dir)
491  K warpK[kNumWarpQRegisters];
492  V warpV[kNumWarpQRegisters];
493 
494  // This is what lane we should load an approximation (>=k) to the
495  // kth element from the last register in the warp queue (i.e.,
496  // warpK[kNumWarpQRegisters - 1]).
497  int kLane;
498 };
499 
500 /// Specialization for k == 1 (NumWarpQ == 1)
501 template <typename K,
502  typename V,
503  bool Dir,
504  typename Comp,
505  int NumThreadQ,
506  int ThreadsPerBlock>
507 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
508  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
509 
510  __device__ inline WarpSelect(K initK, V initV, int k) :
511  threadK(initK),
512  threadV(initV) {
513  }
514 
515  __device__ inline void addThreadQ(K k, V v) {
516  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
517  threadK = swap ? k : threadK;
518  threadV = swap ? v : threadV;
519  }
520 
521  __device__ inline void checkThreadQ() {
522  // We don't need to do anything here, since the warp doesn't
523  // cooperate until the end
524  }
525 
526  __device__ inline void add(K k, V v) {
527  addThreadQ(k, v);
528  }
529 
530  __device__ inline void reduce() {
531  // Reduce within the warp
532  Pair<K, V> pair(threadK, threadV);
533 
534  if (Dir) {
535  pair =
536  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
537  } else {
538  pair =
539  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
540  }
541 
542  threadK = pair.k;
543  threadV = pair.v;
544  }
545 
546  /// Dump final k selected values for this warp out
547  __device__ inline void writeOut(K* outK, V* outV, int k) {
548  if (getLaneId() == 0) {
549  *outK = threadK;
550  *outV = threadV;
551  }
552  }
553 
554  // threadK is lowest (Dir) or highest (!Dir)
555  K threadK;
556  V threadV;
557 };
558 
559 } } // namespace
A simple pair type for CUDA device usage.
Definition: Pair.cuh:21
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:547
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:460
__device__ void add(K k, V v)
Definition: Select.cuh:448
__device__ void add(K k, V v)
Definition: Select.cuh:208
__device__ void mergeWarpQ()
Definition: Select.cuh:434
__device__ void mergeWarpQ()
Definition: Select.cuh:171