11 #include "Comparators.cuh"
12 #include "DeviceDefs.cuh"
13 #include "MergeNetworkBlock.cuh"
14 #include "MergeNetworkWarp.cuh"
15 #include "PtxUtils.cuh"
16 #include "Reductions.cuh"
17 #include "ReductionOperators.cuh"
20 namespace faiss {
namespace gpu {
24 template <
int NumWarps,
25 int NumThreads,
typename K,
typename V,
int NumWarpQ,
26 bool Dir,
typename Comp>
30 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
31 bool Dir,
typename Comp>
33 static inline __device__
void merge(K* sharedK, V* sharedV) {
38 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
39 bool Dir,
typename Comp>
41 static inline __device__
void merge(K* sharedK, V* sharedV) {
43 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
44 NumWarpQ, !Dir, Comp,
false>(sharedK, sharedV);
48 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
49 bool Dir,
typename Comp>
51 static inline __device__
void merge(K* sharedK, V* sharedV) {
52 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
53 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
55 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
56 NumWarpQ * 2, !Dir, Comp,
false>(sharedK, sharedV);
60 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
61 bool Dir,
typename Comp>
63 static inline __device__
void merge(K* sharedK, V* sharedV) {
64 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
65 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
66 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
67 NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
69 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
70 NumWarpQ * 4, !Dir, Comp,
false>(sharedK, sharedV);
84 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
85 static constexpr
int kTotalWarpSortSize = NumWarpQ;
99 static_assert(utils::isPowerOf2(ThreadsPerBlock),
100 "threads must be a power-of-2");
101 static_assert(utils::isPowerOf2(NumWarpQ),
102 "warp queue must be power-of-2");
106 for (
int i = 0; i < NumThreadQ; ++i) {
111 int laneId = getLaneId();
112 int warpId = threadIdx.x / kWarpSize;
113 warpK = sharedK + warpId * kTotalWarpSortSize;
114 warpV = sharedV + warpId * kTotalWarpSortSize;
118 for (
int i = laneId; i < NumWarpQ; i += kWarpSize) {
126 __device__
inline void addThreadQ(K k, V v) {
127 if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
130 for (
int i = NumThreadQ - 1; i > 0; --i) {
131 threadK[i] = threadK[i - 1];
132 threadV[i] = threadV[i - 1];
141 __device__
inline void checkThreadQ() {
142 bool needSort = (numVals == NumThreadQ);
144 #if CUDA_VERSION >= 9000
145 needSort = __any_sync(0xffffffff, needSort);
147 needSort = __any(needSort);
163 for (
int i = 0; i < NumThreadQ; ++i) {
169 warpKTop = warpK[kMinus1];
178 int laneId = getLaneId();
181 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
183 constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
184 K warpKRegisters[kNumWarpQRegisters];
185 V warpVRegisters[kNumWarpQRegisters];
188 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
189 warpKRegisters[i] = warpK[i * kWarpSize + laneId];
190 warpVRegisters[i] = warpV[i * kWarpSize + laneId];
198 warpMergeAnyRegisters<K, V,
199 kNumWarpQRegisters, NumThreadQ, !Dir, Comp,
false>(
200 warpKRegisters, warpVRegisters, threadK, threadV);
204 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
205 warpK[i * kWarpSize + laneId] = warpKRegisters[i];
206 warpV[i * kWarpSize + laneId] = warpVRegisters[i];
214 __device__
inline void add(K k, V v) {
219 __device__
inline void reduce() {
232 merge(sharedK, sharedV);
250 K threadK[NumThreadQ];
251 V threadV[NumThreadQ];
267 template <
typename K,
273 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
274 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
276 __device__
inline BlockSelect(K initK, V initV, K* smemK, V* smemV,
int k) :
283 __device__
inline void addThreadQ(K k, V v) {
284 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
285 threadK = swap ? k : threadK;
286 threadV = swap ? v : threadV;
289 __device__
inline void checkThreadQ() {
294 __device__
inline void add(K k, V v) {
298 __device__
inline void reduce() {
311 int laneId = getLaneId();
312 int warpId = threadIdx.x / kWarpSize;
315 sharedK[warpId] = pair.k;
316 sharedV[warpId] = pair.v;
324 if (threadIdx.x == 0) {
325 threadK = sharedK[0];
326 threadV = sharedV[0];
329 for (
int i = 1; i < kNumWarps; ++i) {
333 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
334 threadK = swap ? k : threadK;
335 threadV = swap ? v : threadV;
340 sharedK[0] = threadK;
341 sharedV[0] = threadV;
363 template <
typename K,
371 static constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
373 __device__
inline WarpSelect(K initKVal, V initVVal,
int k) :
378 kLane((k - 1) % kWarpSize) {
379 static_assert(utils::isPowerOf2(ThreadsPerBlock),
380 "threads must be a power-of-2");
381 static_assert(utils::isPowerOf2(NumWarpQ),
382 "warp queue must be power-of-2");
386 for (
int i = 0; i < NumThreadQ; ++i) {
393 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
399 __device__
inline void addThreadQ(K k, V v) {
400 if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
403 for (
int i = NumThreadQ - 1; i > 0; --i) {
404 threadK[i] = threadK[i - 1];
405 threadV[i] = threadV[i - 1];
414 __device__
inline void checkThreadQ() {
415 bool needSort = (numVals == NumThreadQ);
417 #if CUDA_VERSION >= 9000
418 needSort = __any_sync(0xffffffff, needSort);
420 needSort = __any(needSort);
435 for (
int i = 0; i < NumThreadQ; ++i) {
441 warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
449 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
454 warpMergeAnyRegisters<K, V,
455 kNumWarpQRegisters, NumThreadQ, !Dir, Comp,
false>(
456 warpK, warpV, threadK, threadV);
461 __device__
inline void add(K k, V v) {
466 __device__
inline void reduce() {
473 __device__
inline void writeOut(K* outK, V* outV,
int k) {
474 int laneId = getLaneId();
477 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
478 int idx = i * kWarpSize + laneId;
481 outK[idx] = warpK[i];
482 outV[idx] = warpV[i];
500 K threadK[NumThreadQ];
501 V threadV[NumThreadQ];
504 K warpK[kNumWarpQRegisters];
505 V warpV[kNumWarpQRegisters];
514 template <
typename K,
520 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
521 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
523 __device__
inline WarpSelect(K initK, V initV,
int k) :
528 __device__
inline void addThreadQ(K k, V v) {
529 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
530 threadK = swap ? k : threadK;
531 threadV = swap ? v : threadV;
534 __device__
inline void checkThreadQ() {
539 __device__
inline void add(K k, V v) {
543 __device__
inline void reduce() {
560 __device__
inline void writeOut(K* outK, V* outV,
int k) {
561 if (getLaneId() == 0) {
A simple pair type for CUDA device usage.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void add(K k, V v)
__device__ void add(K k, V v)
__device__ void mergeWarpQ()
__device__ void mergeWarpQ()