12 #include "Comparators.cuh"
13 #include "DeviceDefs.cuh"
14 #include "MergeNetworkBlock.cuh"
15 #include "MergeNetworkWarp.cuh"
16 #include "PtxUtils.cuh"
17 #include "Reductions.cuh"
18 #include "ReductionOperators.cuh"
21 namespace faiss {
namespace gpu {
25 template <
int NumWarps,
26 int NumThreads,
typename K,
typename V,
int NumWarpQ,
27 bool Dir,
typename Comp>
31 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
32 bool Dir,
typename Comp>
34 static inline __device__
void merge(K* sharedK, V* sharedV) {
39 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
40 bool Dir,
typename Comp>
42 static inline __device__
void merge(K* sharedK, V* sharedV) {
44 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
45 NumWarpQ, !Dir, Comp,
false>(sharedK, sharedV);
49 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
50 bool Dir,
typename Comp>
52 static inline __device__
void merge(K* sharedK, V* sharedV) {
53 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
54 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
56 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
57 NumWarpQ * 2, !Dir, Comp,
false>(sharedK, sharedV);
61 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
62 bool Dir,
typename Comp>
64 static inline __device__
void merge(K* sharedK, V* sharedV) {
65 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
66 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
67 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
68 NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
70 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
71 NumWarpQ * 4, !Dir, Comp,
false>(sharedK, sharedV);
85 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
86 static constexpr
int kTotalWarpSortSize = NumWarpQ;
100 static_assert(utils::isPowerOf2(ThreadsPerBlock),
101 "threads must be a power-of-2");
102 static_assert(utils::isPowerOf2(NumWarpQ),
103 "warp queue must be power-of-2");
107 for (
int i = 0; i < NumThreadQ; ++i) {
112 int laneId = getLaneId();
113 int warpId = threadIdx.x / kWarpSize;
114 warpK = sharedK + warpId * kTotalWarpSortSize;
115 warpV = sharedV + warpId * kTotalWarpSortSize;
119 for (
int i = laneId; i < NumWarpQ; i += kWarpSize) {
127 __device__
inline void addThreadQ(K k, V v) {
128 if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
131 for (
int i = NumThreadQ - 1; i > 0; --i) {
132 threadK[i] = threadK[i - 1];
133 threadV[i] = threadV[i - 1];
142 __device__
inline void checkThreadQ() {
143 bool needSort = (numVals == NumThreadQ);
145 if (!__any(needSort)) {
157 for (
int i = 0; i < NumThreadQ; ++i) {
163 warpKTop = warpK[kMinus1];
172 int laneId = getLaneId();
175 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
177 constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
178 K warpKRegisters[kNumWarpQRegisters];
179 V warpVRegisters[kNumWarpQRegisters];
182 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
183 warpKRegisters[i] = warpK[i * kWarpSize + laneId];
184 warpVRegisters[i] = warpV[i * kWarpSize + laneId];
192 warpMergeAnyRegisters<K, V,
193 kNumWarpQRegisters, NumThreadQ, !Dir, Comp,
false>(
194 warpKRegisters, warpVRegisters, threadK, threadV);
198 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
199 warpK[i * kWarpSize + laneId] = warpKRegisters[i];
200 warpV[i * kWarpSize + laneId] = warpVRegisters[i];
208 __device__
inline void add(K k, V v) {
213 __device__
inline void reduce() {
226 merge(sharedK, sharedV);
244 K threadK[NumThreadQ];
245 V threadV[NumThreadQ];
261 template <
typename K,
267 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
268 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
270 __device__
inline BlockSelect(K initK, V initV, K* smemK, V* smemV,
int k) :
277 __device__
inline void addThreadQ(K k, V v) {
278 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
279 threadK = swap ? k : threadK;
280 threadV = swap ? v : threadV;
283 __device__
inline void checkThreadQ() {
288 __device__
inline void add(K k, V v) {
292 __device__
inline void reduce() {
305 int laneId = getLaneId();
306 int warpId = threadIdx.x / kWarpSize;
309 sharedK[warpId] = pair.k;
310 sharedV[warpId] = pair.v;
318 if (threadIdx.x == 0) {
319 threadK = sharedK[0];
320 threadV = sharedV[0];
323 for (
int i = 1; i < kNumWarps; ++i) {
327 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
328 threadK = swap ? k : threadK;
329 threadV = swap ? v : threadV;
334 sharedK[0] = threadK;
335 sharedV[0] = threadV;
357 template <
typename K,
365 static constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
367 __device__
inline WarpSelect(K initKVal, V initVVal,
int k) :
372 kLane((k - 1) % kWarpSize) {
373 static_assert(utils::isPowerOf2(ThreadsPerBlock),
374 "threads must be a power-of-2");
375 static_assert(utils::isPowerOf2(NumWarpQ),
376 "warp queue must be power-of-2");
380 for (
int i = 0; i < NumThreadQ; ++i) {
387 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
393 __device__
inline void addThreadQ(K k, V v) {
394 if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
397 for (
int i = NumThreadQ - 1; i > 0; --i) {
398 threadK[i] = threadK[i - 1];
399 threadV[i] = threadV[i - 1];
408 __device__
inline void checkThreadQ() {
409 bool needSort = (numVals == NumThreadQ);
411 if (!__any(needSort)) {
422 for (
int i = 0; i < NumThreadQ; ++i) {
428 warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
436 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
441 warpMergeAnyRegisters<K, V,
442 kNumWarpQRegisters, NumThreadQ, !Dir, Comp,
false>(
443 warpK, warpV, threadK, threadV);
448 __device__
inline void add(K k, V v) {
453 __device__
inline void reduce() {
460 __device__
inline void writeOut(K* outK, V* outV,
int k) {
461 int laneId = getLaneId();
464 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
465 int idx = i * kWarpSize + laneId;
468 outK[idx] = warpK[i];
469 outV[idx] = warpV[i];
487 K threadK[NumThreadQ];
488 V threadV[NumThreadQ];
491 K warpK[kNumWarpQRegisters];
492 V warpV[kNumWarpQRegisters];
501 template <
typename K,
507 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
508 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
510 __device__
inline WarpSelect(K initK, V initV,
int k) :
515 __device__
inline void addThreadQ(K k, V v) {
516 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
517 threadK = swap ? k : threadK;
518 threadV = swap ? v : threadV;
521 __device__
inline void checkThreadQ() {
526 __device__
inline void add(K k, V v) {
530 __device__
inline void reduce() {
547 __device__
inline void writeOut(K* outK, V* outV,
int k) {
548 if (getLaneId() == 0) {
A simple pair type for CUDA device usage.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void add(K k, V v)
__device__ void add(K k, V v)
__device__ void mergeWarpQ()
__device__ void mergeWarpQ()