12 #include "Comparators.cuh"
13 #include "DeviceDefs.cuh"
14 #include "MergeNetworkBlock.cuh"
15 #include "MergeNetworkWarp.cuh"
16 #include "PtxUtils.cuh"
17 #include "Reductions.cuh"
18 #include "ReductionOperators.cuh"
21 namespace faiss {
namespace gpu {
25 template <
int NumWarps,
26 int NumThreads,
typename K,
typename V,
int NumWarpQ,
27 bool Dir,
typename Comp>
31 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
32 bool Dir,
typename Comp>
34 static inline __device__
void merge(K* sharedK, V* sharedV) {
39 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
40 bool Dir,
typename Comp>
42 static inline __device__
void merge(K* sharedK, V* sharedV) {
44 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
45 NumWarpQ, !Dir, Comp,
false>(sharedK, sharedV);
49 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
50 bool Dir,
typename Comp>
52 static inline __device__
void merge(K* sharedK, V* sharedV) {
53 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
54 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
56 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
57 NumWarpQ * 2, !Dir, Comp,
false>(sharedK, sharedV);
61 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
62 bool Dir,
typename Comp>
64 static inline __device__
void merge(K* sharedK, V* sharedV) {
65 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
66 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
67 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
68 NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
70 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
71 NumWarpQ * 4, !Dir, Comp,
false>(sharedK, sharedV);
85 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
86 static constexpr
int kTotalWarpSortSize = NumWarpQ;
100 static_assert(utils::isPowerOf2(ThreadsPerBlock),
101 "threads must be a power-of-2");
102 static_assert(utils::isPowerOf2(NumWarpQ),
103 "warp queue must be power-of-2");
107 for (
int i = 0; i < NumThreadQ; ++i) {
112 int laneId = getLaneId();
113 int warpId = threadIdx.x / kWarpSize;
114 warpK = sharedK + warpId * kTotalWarpSortSize;
115 warpV = sharedV + warpId * kTotalWarpSortSize;
119 for (
int i = laneId; i < NumWarpQ; i += kWarpSize) {
127 __device__
inline void addThreadQ(K k, V v) {
128 if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
131 for (
int i = NumThreadQ - 1; i > 0; --i) {
132 threadK[i] = threadK[i - 1];
133 threadV[i] = threadV[i - 1];
142 __device__
inline void checkThreadQ() {
143 bool needSort = (numVals == NumThreadQ);
145 #if CUDA_VERSION >= 9000
146 needSort = __any_sync(0xffffffff, needSort);
148 needSort = __any(needSort);
164 for (
int i = 0; i < NumThreadQ; ++i) {
170 warpKTop = warpK[kMinus1];
179 int laneId = getLaneId();
182 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
184 constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
185 K warpKRegisters[kNumWarpQRegisters];
186 V warpVRegisters[kNumWarpQRegisters];
189 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
190 warpKRegisters[i] = warpK[i * kWarpSize + laneId];
191 warpVRegisters[i] = warpV[i * kWarpSize + laneId];
199 warpMergeAnyRegisters<K, V,
200 kNumWarpQRegisters, NumThreadQ, !Dir, Comp,
false>(
201 warpKRegisters, warpVRegisters, threadK, threadV);
205 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
206 warpK[i * kWarpSize + laneId] = warpKRegisters[i];
207 warpV[i * kWarpSize + laneId] = warpVRegisters[i];
215 __device__
inline void add(K k, V v) {
220 __device__
inline void reduce() {
233 merge(sharedK, sharedV);
251 K threadK[NumThreadQ];
252 V threadV[NumThreadQ];
268 template <
typename K,
274 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
275 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
277 __device__
inline BlockSelect(K initK, V initV, K* smemK, V* smemV,
int k) :
284 __device__
inline void addThreadQ(K k, V v) {
285 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
286 threadK = swap ? k : threadK;
287 threadV = swap ? v : threadV;
290 __device__
inline void checkThreadQ() {
295 __device__
inline void add(K k, V v) {
299 __device__
inline void reduce() {
312 int laneId = getLaneId();
313 int warpId = threadIdx.x / kWarpSize;
316 sharedK[warpId] = pair.k;
317 sharedV[warpId] = pair.v;
325 if (threadIdx.x == 0) {
326 threadK = sharedK[0];
327 threadV = sharedV[0];
330 for (
int i = 1; i < kNumWarps; ++i) {
334 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
335 threadK = swap ? k : threadK;
336 threadV = swap ? v : threadV;
341 sharedK[0] = threadK;
342 sharedV[0] = threadV;
364 template <
typename K,
372 static constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
374 __device__
inline WarpSelect(K initKVal, V initVVal,
int k) :
379 kLane((k - 1) % kWarpSize) {
380 static_assert(utils::isPowerOf2(ThreadsPerBlock),
381 "threads must be a power-of-2");
382 static_assert(utils::isPowerOf2(NumWarpQ),
383 "warp queue must be power-of-2");
387 for (
int i = 0; i < NumThreadQ; ++i) {
394 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
400 __device__
inline void addThreadQ(K k, V v) {
401 if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
404 for (
int i = NumThreadQ - 1; i > 0; --i) {
405 threadK[i] = threadK[i - 1];
406 threadV[i] = threadV[i - 1];
415 __device__
inline void checkThreadQ() {
416 bool needSort = (numVals == NumThreadQ);
418 #if CUDA_VERSION >= 9000
419 needSort = __any_sync(0xffffffff, needSort);
421 needSort = __any(needSort);
436 for (
int i = 0; i < NumThreadQ; ++i) {
442 warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
450 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
455 warpMergeAnyRegisters<K, V,
456 kNumWarpQRegisters, NumThreadQ, !Dir, Comp,
false>(
457 warpK, warpV, threadK, threadV);
462 __device__
inline void add(K k, V v) {
467 __device__
inline void reduce() {
474 __device__
inline void writeOut(K* outK, V* outV,
int k) {
475 int laneId = getLaneId();
478 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
479 int idx = i * kWarpSize + laneId;
482 outK[idx] = warpK[i];
483 outV[idx] = warpV[i];
501 K threadK[NumThreadQ];
502 V threadV[NumThreadQ];
505 K warpK[kNumWarpQRegisters];
506 V warpV[kNumWarpQRegisters];
515 template <
typename K,
521 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
522 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
524 __device__
inline WarpSelect(K initK, V initV,
int k) :
529 __device__
inline void addThreadQ(K k, V v) {
530 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
531 threadK = swap ? k : threadK;
532 threadV = swap ? v : threadV;
535 __device__
inline void checkThreadQ() {
540 __device__
inline void add(K k, V v) {
544 __device__
inline void reduce() {
561 __device__
inline void writeOut(K* outK, V* outV,
int k) {
562 if (getLaneId() == 0) {
A simple pair type for CUDA device usage.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void add(K k, V v)
__device__ void add(K k, V v)
__device__ void mergeWarpQ()
__device__ void mergeWarpQ()