13 #include "Comparators.cuh"
14 #include "DeviceDefs.cuh"
15 #include "MergeNetworkBlock.cuh"
16 #include "MergeNetworkWarp.cuh"
17 #include "PtxUtils.cuh"
18 #include "Reductions.cuh"
19 #include "ReductionOperators.cuh"
22 namespace faiss {
namespace gpu {
26 template <
int NumWarps,
27 int NumThreads,
typename K,
typename V,
int NumWarpQ,
28 bool Dir,
typename Comp>
32 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
33 bool Dir,
typename Comp>
35 static inline __device__
void merge(K* sharedK, V* sharedV) {
40 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
41 bool Dir,
typename Comp>
43 static inline __device__
void merge(K* sharedK, V* sharedV) {
44 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
45 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
49 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
50 bool Dir,
typename Comp>
52 static inline __device__
void merge(K* sharedK, V* sharedV) {
53 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
54 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
55 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
56 NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
60 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
61 bool Dir,
typename Comp>
63 static inline __device__
void merge(K* sharedK, V* sharedV) {
64 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
65 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
66 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
67 NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
68 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
69 NumWarpQ * 4, !Dir, Comp>(sharedK, sharedV);
83 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
84 static constexpr
int kTotalWarpSortSize = NumWarpQ;
86 __device__
inline BlockSelect(K initK, V initV, K* smemK, V* smemV,
int k) :
91 static_assert(utils::isPowerOf2(ThreadsPerBlock),
92 "threads must be a power-of-2");
93 static_assert(utils::isPowerOf2(NumWarpQ),
94 "warp queue must be power-of-2");
99 for (
int i = 0; i < NumThreadQ; ++i) {
104 int laneId = getLaneId();
105 int warpId = threadIdx.x / kWarpSize;
106 warpK = sharedK + warpId * kTotalWarpSortSize;
107 warpV = sharedV + warpId * kTotalWarpSortSize;
111 for (
int i = laneId; i < NumWarpQ; i += kWarpSize) {
119 __device__
inline void addThreadQ(K k, V v) {
123 if (Dir ? Comp::gt(k, threadK[0]) : Comp::lt(k, threadK[0])) {
129 for (
int i = 1; i < NumThreadQ; ++i) {
130 bool swap = Dir ? Comp::lt(threadK[i], threadK[i - 1]) :
131 Comp::gt(threadK[i], threadK[i - 1]);
134 threadK[i] = swap ? threadK[i - 1] : tmpK;
135 threadK[i - 1] = swap ? tmpK : threadK[i - 1];
138 threadV[i] = swap ? threadV[i - 1] : tmpV;
139 threadV[i - 1] = swap ? tmpV : threadV[i - 1];
144 __device__
inline void checkThreadQ() {
147 bool needSort = (Dir ?
148 Comp::gt(threadK[0], warpKTop) :
149 Comp::lt(threadK[0], warpKTop));
150 if (!__any(needSort)) {
158 warpKTop = warpK[kMinus1];
167 int laneId = getLaneId();
170 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
172 constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
173 K warpKRegisters[kNumWarpQRegisters];
174 V warpVRegisters[kNumWarpQRegisters];
177 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
178 warpKRegisters[i] = warpK[i * kWarpSize + laneId];
179 warpVRegisters[i] = warpV[i * kWarpSize + laneId];
187 warpMergeAnyRegisters<K, V, kNumWarpQRegisters, NumThreadQ, !Dir, Comp>(
188 warpKRegisters, warpVRegisters, threadK, threadV);
192 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
193 warpK[i * kWarpSize + laneId] = warpKRegisters[i];
194 warpV[i * kWarpSize + laneId] = warpVRegisters[i];
200 K tmpThreadK[NumThreadQ];
201 V tmpThreadV[NumThreadQ];
204 for (
int i = 0; i < NumThreadQ; ++i) {
205 tmpThreadK[i] = threadK[i];
206 tmpThreadV[i] = threadV[i];
210 for (
int i = 0; i < NumThreadQ; ++i) {
216 threadK[i] = tmpThreadK[NumThreadQ - i - 1];
217 threadV[i] = tmpThreadV[NumThreadQ - i - 1];
223 __device__
inline void add(K k, V v) {
228 __device__
inline void reduce() {
241 merge(sharedK, sharedV);
247 K threadK[NumThreadQ];
248 V threadV[NumThreadQ];
268 template <
typename K,
274 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
275 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
277 __device__
inline BlockSelect(K initK, V initV, K* smemK, V* smemV,
int k) :
284 __device__
inline void addThreadQ(K k, V v) {
285 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
286 threadK = swap ? k : threadK;
287 threadV = swap ? v : threadV;
290 __device__
inline void checkThreadQ() {
295 __device__
inline void add(K k, V v) {
299 __device__
inline void reduce() {
312 int laneId = getLaneId();
313 int warpId = threadIdx.x / kWarpSize;
316 sharedK[warpId] = pair.k;
317 sharedV[warpId] = pair.v;
325 if (threadIdx.x == 0) {
326 threadK = sharedK[0];
327 threadV = sharedV[0];
330 for (
int i = 1; i < kNumWarps; ++i) {
334 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
335 threadK = swap ? k : threadK;
336 threadV = swap ? v : threadV;
341 sharedK[0] = threadK;
342 sharedV[0] = threadV;
364 template <
typename K,
372 static constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
374 __device__
inline WarpSelect(K initK, V initV,
int k) :
376 kLane((k - 1) % kWarpSize) {
377 static_assert(utils::isPowerOf2(ThreadsPerBlock),
378 "threads must be a power-of-2");
379 static_assert(utils::isPowerOf2(NumWarpQ),
380 "warp queue must be power-of-2");
385 for (
int i = 0; i < NumThreadQ; ++i) {
392 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
398 __device__
inline void addThreadQ(K k, V v) {
402 if (Dir ? Comp::gt(k, threadK[0]) : Comp::lt(k, threadK[0])) {
408 for (
int i = 1; i < NumThreadQ; ++i) {
409 bool swap = Dir ? Comp::lt(threadK[i], threadK[i - 1]) :
410 Comp::gt(threadK[i], threadK[i - 1]);
413 threadK[i] = swap ? threadK[i - 1] : tmpK;
414 threadK[i - 1] = swap ? tmpK : threadK[i - 1];
417 threadV[i] = swap ? threadV[i - 1] : tmpV;
418 threadV[i - 1] = swap ? tmpV : threadV[i - 1];
423 __device__
inline void checkThreadQ() {
426 bool needSort = (Dir ?
427 Comp::gt(threadK[0], warpKTop) :
428 Comp::lt(threadK[0], warpKTop));
429 if (!__any(needSort)) {
436 warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
444 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
449 warpMergeAnyRegisters<K, V, kNumWarpQRegisters, NumThreadQ, !Dir, Comp>(
450 warpK, warpV, threadK, threadV);
453 K tmpThreadK[NumThreadQ];
454 V tmpThreadV[NumThreadQ];
457 for (
int i = 0; i < NumThreadQ; ++i) {
458 tmpThreadK[i] = threadK[i];
459 tmpThreadV[i] = threadV[i];
463 for (
int i = 0; i < NumThreadQ; ++i) {
469 threadK[i] = tmpThreadK[NumThreadQ - i - 1];
470 threadV[i] = tmpThreadV[NumThreadQ - i - 1];
476 __device__
inline void add(K k, V v) {
481 __device__
inline void reduce() {
488 __device__
inline void writeOut(K* outK, V* outV,
int k) {
489 int laneId = getLaneId();
492 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
493 int idx = i * kWarpSize + laneId;
496 outK[idx] = warpK[i];
497 outV[idx] = warpV[i];
503 K threadK[NumThreadQ];
504 V threadV[NumThreadQ];
507 K warpK[kNumWarpQRegisters];
508 V warpV[kNumWarpQRegisters];
521 template <
typename K,
527 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
528 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
530 __device__
inline WarpSelect(K initK, V initV,
int k) :
535 __device__
inline void addThreadQ(K k, V v) {
536 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
537 threadK = swap ? k : threadK;
538 threadV = swap ? v : threadV;
541 __device__
inline void checkThreadQ() {
546 __device__
inline void add(K k, V v) {
550 __device__
inline void reduce() {
567 __device__
inline void writeOut(K* outK, V* outV,
int k) {
568 if (getLaneId() == 0) {
A simple pair type for CUDA device usage.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void add(K k, V v)
__device__ void add(K k, V v)
__device__ void mergeWarpQ()
__device__ void mergeWarpQ()