Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Select.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 #pragma once
12 
13 #include "Comparators.cuh"
14 #include "DeviceDefs.cuh"
15 #include "MergeNetworkBlock.cuh"
16 #include "MergeNetworkWarp.cuh"
17 #include "PtxUtils.cuh"
18 #include "Reductions.cuh"
19 #include "ReductionOperators.cuh"
20 #include "Tensor.cuh"
21 
22 namespace faiss { namespace gpu {
23 
24 // Specialization for block-wide monotonic merges producing a merge sort
25 // since what we really want is a constexpr loop expansion
26 template <int NumWarps,
27  int NumThreads, typename K, typename V, int NumWarpQ,
28  bool Dir, typename Comp>
30 };
31 
32 template <int NumThreads, typename K, typename V, int NumWarpQ,
33  bool Dir, typename Comp>
34 struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir, Comp> {
35  static inline __device__ void merge(K* sharedK, V* sharedV) {
36  // no merge required; single warp
37  }
38 };
39 
40 template <int NumThreads, typename K, typename V, int NumWarpQ,
41  bool Dir, typename Comp>
42 struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir, Comp> {
43  static inline __device__ void merge(K* sharedK, V* sharedV) {
44  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
45  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
46  }
47 };
48 
49 template <int NumThreads, typename K, typename V, int NumWarpQ,
50  bool Dir, typename Comp>
51 struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir, Comp> {
52  static inline __device__ void merge(K* sharedK, V* sharedV) {
53  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
54  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
55  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
56  NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
57  }
58 };
59 
60 template <int NumThreads, typename K, typename V, int NumWarpQ,
61  bool Dir, typename Comp>
62 struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir, Comp> {
63  static inline __device__ void merge(K* sharedK, V* sharedV) {
64  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
65  NumWarpQ, !Dir, Comp>(sharedK, sharedV);
66  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
67  NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
68  blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
69  NumWarpQ * 4, !Dir, Comp>(sharedK, sharedV);
70  }
71 };
72 
73 // `Dir` true, produce largest values.
74 // `Dir` false, produce smallest values.
75 template <typename K,
76  typename V,
77  bool Dir,
78  typename Comp,
79  int NumWarpQ,
80  int NumThreadQ,
81  int ThreadsPerBlock>
82 struct BlockSelect {
83  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
84  static constexpr int kTotalWarpSortSize = NumWarpQ;
85 
86  __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) :
87  sharedK(smemK),
88  sharedV(smemV),
89  warpKTop(initK),
90  kMinus1(k - 1) {
91  static_assert(utils::isPowerOf2(ThreadsPerBlock),
92  "threads must be a power-of-2");
93  static_assert(utils::isPowerOf2(NumWarpQ),
94  "warp queue must be power-of-2");
95 
96  // Fill the per-thread queue keys with the default value; values
97  // can remain uninitialized
98 #pragma unroll
99  for (int i = 0; i < NumThreadQ; ++i) {
100  threadK[i] = initK;
101  threadV[i] = initV;
102  }
103 
104  int laneId = getLaneId();
105  int warpId = threadIdx.x / kWarpSize;
106  warpK = sharedK + warpId * kTotalWarpSortSize;
107  warpV = sharedV + warpId * kTotalWarpSortSize;
108 
109  // Fill warp queue (only the actual queue space is fine, not where
110  // we write the per-thread queues for merging)
111  for (int i = laneId; i < NumWarpQ; i += kWarpSize) {
112  warpK[i] = initK;
113  warpV[i] = initV;
114  }
115 
116  warpFence();
117  }
118 
119  __device__ inline void addThreadQ(K k, V v) {
120  // If we're greater or equal to the highest element, then we don't
121  // need to add. In the equal to case, the element we have is
122  // sufficient.
123  if (Dir ? Comp::gt(k, threadK[0]) : Comp::lt(k, threadK[0])) {
124  threadK[0] = k;
125  threadV[0] = v;
126 
127  // Perform in-register bubble sort of this new element
128 #pragma unroll
129  for (int i = 1; i < NumThreadQ; ++i) {
130  bool swap = Dir ? Comp::lt(threadK[i], threadK[i - 1]) :
131  Comp::gt(threadK[i], threadK[i - 1]);
132 
133  K tmpK = threadK[i];
134  threadK[i] = swap ? threadK[i - 1] : tmpK;
135  threadK[i - 1] = swap ? tmpK : threadK[i - 1];
136 
137  V tmpV = threadV[i];
138  threadV[i] = swap ? threadV[i - 1] : tmpV;
139  threadV[i - 1] = swap ? tmpV : threadV[i - 1];
140  }
141  }
142  }
143 
144  __device__ inline void checkThreadQ() {
145  // There is no need to merge queues if no thread-local queue is
146  // better than the warp-wide queue
147  bool needSort = (Dir ?
148  Comp::gt(threadK[0], warpKTop) :
149  Comp::lt(threadK[0], warpKTop));
150  if (!__any(needSort)) {
151  return;
152  }
153 
154  // This has a trailing warpFence
155  mergeWarpQ();
156 
157  // We have to beat at least this element
158  warpKTop = warpK[kMinus1];
159 
160  warpFence();
161  }
162 
163  /// This function handles sorting and merging together the
164  /// per-thread queues with the warp-wide queue, creating a sorted
165  /// list across both
166  __device__ inline void mergeWarpQ() {
167  int laneId = getLaneId();
168 
169  // Sort all of the per-thread queues
170  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
171 
172  constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
173  K warpKRegisters[kNumWarpQRegisters];
174  V warpVRegisters[kNumWarpQRegisters];
175 
176 #pragma unroll
177  for (int i = 0; i < kNumWarpQRegisters; ++i) {
178  warpKRegisters[i] = warpK[i * kWarpSize + laneId];
179  warpVRegisters[i] = warpV[i * kWarpSize + laneId];
180  }
181 
182  warpFence();
183 
184  // The warp queue is already sorted, and now that we've sorted the
185  // per-thread queue, merge both sorted lists together, producing
186  // one sorted list
187  warpMergeAnyRegisters<K, V, kNumWarpQRegisters, NumThreadQ, !Dir, Comp>(
188  warpKRegisters, warpVRegisters, threadK, threadV);
189 
190  // Write back out the warp queue
191 #pragma unroll
192  for (int i = 0; i < kNumWarpQRegisters; ++i) {
193  warpK[i * kWarpSize + laneId] = warpKRegisters[i];
194  warpV[i * kWarpSize + laneId] = warpVRegisters[i];
195  }
196 
197  warpFence();
198 
199  // Re-map the registers for our per-thread queues
200  K tmpThreadK[NumThreadQ];
201  V tmpThreadV[NumThreadQ];
202 
203 #pragma unroll
204  for (int i = 0; i < NumThreadQ; ++i) {
205  tmpThreadK[i] = threadK[i];
206  tmpThreadV[i] = threadV[i];
207  }
208 
209 #pragma unroll
210  for (int i = 0; i < NumThreadQ; ++i) {
211  // After merging, the data is in the order small -> large.
212  // We wish to reload data in our thread queues, which have the
213  // order large -> small by indexing, so it will be reverse order.
214  // However, we also wish to give every thread an equal shot at the
215  // largest elements, so we interleave the loading.
216  threadK[i] = tmpThreadK[NumThreadQ - i - 1];
217  threadV[i] = tmpThreadV[NumThreadQ - i - 1];
218  }
219  }
220 
221  /// WARNING: all threads in a warp must participate in this.
222  /// Otherwise, you must call the constituent parts separately.
223  __device__ inline void add(K k, V v) {
224  addThreadQ(k, v);
225  checkThreadQ();
226  }
227 
228  __device__ inline void reduce() {
229  // Have all warps dump and merge their queues; this will produce
230  // the final per-warp results
231  mergeWarpQ();
232 
233  // block-wide dep; thus far, all warps have been completely
234  // independent
235  __syncthreads();
236 
237  // All warp queues are contiguous in smem.
238  // Now, we have kNumWarps lists of NumWarpQ elements.
239  // This is a power of 2.
241  merge(sharedK, sharedV);
242 
243  // The block-wide merge has a trailing syncthreads
244  }
245 
246  // threadK[0] is lowest (Dir) or highest (!Dir)
247  K threadK[NumThreadQ];
248  V threadV[NumThreadQ];
249 
250  // Queues for all warps
251  K* sharedK;
252  V* sharedV;
253 
254  // Our warp's queue (points into sharedK/sharedV)
255  // warpK[0] is highest (Dir) or lowest (!Dir)
256  K* warpK;
257  V* warpV;
258 
259  // Warp queue head value cached in a register, so we needn't read
260  // shared memory as frequently
261  K warpKTop;
262 
263  // This is a cached k-1 value
264  int kMinus1;
265 };
266 
267 /// Specialization for k == 1 (NumWarpQ == 1)
268 template <typename K,
269  typename V,
270  bool Dir,
271  typename Comp,
272  int NumThreadQ,
273  int ThreadsPerBlock>
274 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
275  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
276 
277  __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) :
278  sharedK(smemK),
279  sharedV(smemV),
280  threadK(initK),
281  threadV(initV) {
282  }
283 
284  __device__ inline void addThreadQ(K k, V v) {
285  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
286  threadK = swap ? k : threadK;
287  threadV = swap ? v : threadV;
288  }
289 
290  __device__ inline void checkThreadQ() {
291  // We don't need to do anything here, since the warp doesn't
292  // cooperate until the end
293  }
294 
295  __device__ inline void add(K k, V v) {
296  addThreadQ(k, v);
297  }
298 
299  __device__ inline void reduce() {
300  // Reduce within the warp
301  Pair<K, V> pair(threadK, threadV);
302 
303  if (Dir) {
304  pair =
305  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
306  } else {
307  pair =
308  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
309  }
310 
311  // Each warp writes out a single value
312  int laneId = getLaneId();
313  int warpId = threadIdx.x / kWarpSize;
314 
315  if (laneId == 0) {
316  sharedK[warpId] = pair.k;
317  sharedV[warpId] = pair.v;
318  }
319 
320  __syncthreads();
321 
322  // We typically use this for small blocks (<= 128), just having the first
323  // thread in the block perform the reduction across warps is
324  // faster
325  if (threadIdx.x == 0) {
326  threadK = sharedK[0];
327  threadV = sharedV[0];
328 
329 #pragma unroll
330  for (int i = 1; i < kNumWarps; ++i) {
331  K k = sharedK[i];
332  V v = sharedV[i];
333 
334  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
335  threadK = swap ? k : threadK;
336  threadV = swap ? v : threadV;
337  }
338 
339  // Hopefully a thread's smem reads/writes are ordered wrt
340  // itself, so no barrier needed :)
341  sharedK[0] = threadK;
342  sharedV[0] = threadV;
343  }
344 
345  // In case other threads wish to read this value
346  __syncthreads();
347  }
348 
349  // threadK is lowest (Dir) or highest (!Dir)
350  K threadK;
351  V threadV;
352 
353  // Where we reduce in smem
354  K* sharedK;
355  V* sharedV;
356 };
357 
358 //
359 // per-warp WarpSelect
360 //
361 
362 // `Dir` true, produce largest values.
363 // `Dir` false, produce smallest values.
364 template <typename K,
365  typename V,
366  bool Dir,
367  typename Comp,
368  int NumWarpQ,
369  int NumThreadQ,
370  int ThreadsPerBlock>
371 struct WarpSelect {
372  static constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize;
373 
374  __device__ inline WarpSelect(K initK, V initV, int k) :
375  warpKTop(initK),
376  kLane((k - 1) % kWarpSize) {
377  static_assert(utils::isPowerOf2(ThreadsPerBlock),
378  "threads must be a power-of-2");
379  static_assert(utils::isPowerOf2(NumWarpQ),
380  "warp queue must be power-of-2");
381 
382  // Fill the per-thread queue keys with the default value; values
383  // can remain uninitialized
384 #pragma unroll
385  for (int i = 0; i < NumThreadQ; ++i) {
386  threadK[i] = initK;
387  threadV[i] = initV;
388  }
389 
390  // Fill the warp queue with the default value
391 #pragma unroll
392  for (int i = 0; i < kNumWarpQRegisters; ++i) {
393  warpK[i] = initK;
394  warpV[i] = initV;
395  }
396  }
397 
398  __device__ inline void addThreadQ(K k, V v) {
399  // If we're greater or equal to the highest element, then we don't
400  // need to add. In the equal to case, the element we have is
401  // sufficient.
402  if (Dir ? Comp::gt(k, threadK[0]) : Comp::lt(k, threadK[0])) {
403  threadK[0] = k;
404  threadV[0] = v;
405 
406  // Perform in-register bubble sort of this new element
407 #pragma unroll
408  for (int i = 1; i < NumThreadQ; ++i) {
409  bool swap = Dir ? Comp::lt(threadK[i], threadK[i - 1]) :
410  Comp::gt(threadK[i], threadK[i - 1]);
411 
412  K tmpK = threadK[i];
413  threadK[i] = swap ? threadK[i - 1] : tmpK;
414  threadK[i - 1] = swap ? tmpK : threadK[i - 1];
415 
416  V tmpV = threadV[i];
417  threadV[i] = swap ? threadV[i - 1] : tmpV;
418  threadV[i - 1] = swap ? tmpV : threadV[i - 1];
419  }
420  }
421  }
422 
423  __device__ inline void checkThreadQ() {
424  // There is no need to merge queues if no thread-local queue is
425  // better than the warp-wide queue
426  bool needSort = (Dir ?
427  Comp::gt(threadK[0], warpKTop) :
428  Comp::lt(threadK[0], warpKTop));
429  if (!__any(needSort)) {
430  return;
431  }
432 
433  mergeWarpQ();
434 
435  // We have to beat at least this element
436  warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
437  }
438 
439  /// This function handles sorting and merging together the
440  /// per-thread queues with the warp-wide queue, creating a sorted
441  /// list across both
442  __device__ inline void mergeWarpQ() {
443  // Sort all of the per-thread queues
444  warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
445 
446  // The warp queue is already sorted, and now that we've sorted the
447  // per-thread queue, merge both sorted lists together, producing
448  // one sorted list
449  warpMergeAnyRegisters<K, V, kNumWarpQRegisters, NumThreadQ, !Dir, Comp>(
450  warpK, warpV, threadK, threadV);
451 
452  // Re-map the registers for our per-thread queues
453  K tmpThreadK[NumThreadQ];
454  V tmpThreadV[NumThreadQ];
455 
456 #pragma unroll
457  for (int i = 0; i < NumThreadQ; ++i) {
458  tmpThreadK[i] = threadK[i];
459  tmpThreadV[i] = threadV[i];
460  }
461 
462 #pragma unroll
463  for (int i = 0; i < NumThreadQ; ++i) {
464  // After merging, the data is in the order small -> large.
465  // We wish to reload data in our thread queues, which have the
466  // order large -> small by indexing, so it will be reverse order.
467  // However, we also wish to give every thread an equal shot at the
468  // largest elements, so we interleave the loading.
469  threadK[i] = tmpThreadK[NumThreadQ - i - 1];
470  threadV[i] = tmpThreadV[NumThreadQ - i - 1];
471  }
472  }
473 
474  /// WARNING: all threads in a warp must participate in this.
475  /// Otherwise, you must call the constituent parts separately.
476  __device__ inline void add(K k, V v) {
477  addThreadQ(k, v);
478  checkThreadQ();
479  }
480 
481  __device__ inline void reduce() {
482  // Have all warps dump and merge their queues; this will produce
483  // the final per-warp results
484  mergeWarpQ();
485  }
486 
487  /// Dump final k selected values for this warp out
488  __device__ inline void writeOut(K* outK, V* outV, int k) {
489  int laneId = getLaneId();
490 
491 #pragma unroll
492  for (int i = 0; i < kNumWarpQRegisters; ++i) {
493  int idx = i * kWarpSize + laneId;
494 
495  if (idx < k) {
496  outK[idx] = warpK[i];
497  outV[idx] = warpV[i];
498  }
499  }
500  }
501 
502  // threadK[0] is lowest (Dir) or highest (!Dir)
503  K threadK[NumThreadQ];
504  V threadV[NumThreadQ];
505 
506  // warpK[0] is highest (Dir) or lowest (!Dir)
507  K warpK[kNumWarpQRegisters];
508  V warpV[kNumWarpQRegisters];
509 
510  // Warp queue head value cached in a register, so we needn't read
511  // shared memory as frequently
512  K warpKTop;
513 
514  // This is what lane we should load an approximation (>=k) to the
515  // kth element from the last register in the warp queue (i.e.,
516  // warpK[kNumWarpQRegisters - 1]).
517  int kLane;
518 };
519 
520 /// Specialization for k == 1 (NumWarpQ == 1)
521 template <typename K,
522  typename V,
523  bool Dir,
524  typename Comp,
525  int NumThreadQ,
526  int ThreadsPerBlock>
527 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
528  static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
529 
530  __device__ inline WarpSelect(K initK, V initV, int k) :
531  threadK(initK),
532  threadV(initV) {
533  }
534 
535  __device__ inline void addThreadQ(K k, V v) {
536  bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
537  threadK = swap ? k : threadK;
538  threadV = swap ? v : threadV;
539  }
540 
541  __device__ inline void checkThreadQ() {
542  // We don't need to do anything here, since the warp doesn't
543  // cooperate until the end
544  }
545 
546  __device__ inline void add(K k, V v) {
547  addThreadQ(k, v);
548  }
549 
550  __device__ inline void reduce() {
551  // Reduce within the warp
552  Pair<K, V> pair(threadK, threadV);
553 
554  if (Dir) {
555  pair =
556  warpReduceAll<Pair<K, V>, Max<Pair<K, V>>>(pair, Max<Pair<K, V>>());
557  } else {
558  pair =
559  warpReduceAll<Pair<K, V>, Min<Pair<K, V>>>(pair, Min<Pair<K, V>>());
560  }
561 
562  threadK = pair.k;
563  threadV = pair.v;
564  }
565 
566  /// Dump final k selected values for this warp out
567  __device__ inline void writeOut(K* outK, V* outV, int k) {
568  if (getLaneId() == 0) {
569  *outK = threadK;
570  *outV = threadV;
571  }
572  }
573 
574  // threadK is lowest (Dir) or highest (!Dir)
575  K threadK;
576  V threadV;
577 };
578 
579 } } // namespace
A simple pair type for CUDA device usage.
Definition: Pair.cuh:22
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:567
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
Definition: Select.cuh:488
__device__ void add(K k, V v)
Definition: Select.cuh:476
__device__ void add(K k, V v)
Definition: Select.cuh:223
__device__ void mergeWarpQ()
Definition: Select.cuh:442
__device__ void mergeWarpQ()
Definition: Select.cuh:166