10 #include "Comparators.cuh"
11 #include "DeviceDefs.cuh"
12 #include "MergeNetworkBlock.cuh"
13 #include "MergeNetworkWarp.cuh"
14 #include "PtxUtils.cuh"
15 #include "Reductions.cuh"
16 #include "ReductionOperators.cuh"
19 namespace faiss {
namespace gpu {
23 template <
int NumWarps,
24 int NumThreads,
typename K,
typename V,
int NumWarpQ,
25 bool Dir,
typename Comp>
29 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
30 bool Dir,
typename Comp>
32 static inline __device__
void merge(K* sharedK, V* sharedV) {
37 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
38 bool Dir,
typename Comp>
40 static inline __device__
void merge(K* sharedK, V* sharedV) {
42 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
43 NumWarpQ, !Dir, Comp,
false>(sharedK, sharedV);
47 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
48 bool Dir,
typename Comp>
50 static inline __device__
void merge(K* sharedK, V* sharedV) {
51 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
52 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
54 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
55 NumWarpQ * 2, !Dir, Comp,
false>(sharedK, sharedV);
59 template <
int NumThreads,
typename K,
typename V,
int NumWarpQ,
60 bool Dir,
typename Comp>
62 static inline __device__
void merge(K* sharedK, V* sharedV) {
63 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 2),
64 NumWarpQ, !Dir, Comp>(sharedK, sharedV);
65 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 4),
66 NumWarpQ * 2, !Dir, Comp>(sharedK, sharedV);
68 blockMerge<NumThreads, K, V, NumThreads / (kWarpSize * 8),
69 NumWarpQ * 4, !Dir, Comp,
false>(sharedK, sharedV);
83 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
84 static constexpr
int kTotalWarpSortSize = NumWarpQ;
98 static_assert(utils::isPowerOf2(ThreadsPerBlock),
99 "threads must be a power-of-2");
100 static_assert(utils::isPowerOf2(NumWarpQ),
101 "warp queue must be power-of-2");
105 for (
int i = 0; i < NumThreadQ; ++i) {
110 int laneId = getLaneId();
111 int warpId = threadIdx.x / kWarpSize;
112 warpK = sharedK + warpId * kTotalWarpSortSize;
113 warpV = sharedV + warpId * kTotalWarpSortSize;
117 for (
int i = laneId; i < NumWarpQ; i += kWarpSize) {
125 __device__
inline void addThreadQ(K k, V v) {
126 if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
129 for (
int i = NumThreadQ - 1; i > 0; --i) {
130 threadK[i] = threadK[i - 1];
131 threadV[i] = threadV[i - 1];
140 __device__
inline void checkThreadQ() {
141 bool needSort = (numVals == NumThreadQ);
143 #if CUDA_VERSION >= 9000
144 needSort = __any_sync(0xffffffff, needSort);
146 needSort = __any(needSort);
162 for (
int i = 0; i < NumThreadQ; ++i) {
168 warpKTop = warpK[kMinus1];
177 int laneId = getLaneId();
180 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
182 constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
183 K warpKRegisters[kNumWarpQRegisters];
184 V warpVRegisters[kNumWarpQRegisters];
187 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
188 warpKRegisters[i] = warpK[i * kWarpSize + laneId];
189 warpVRegisters[i] = warpV[i * kWarpSize + laneId];
197 warpMergeAnyRegisters<K, V,
198 kNumWarpQRegisters, NumThreadQ, !Dir, Comp,
false>(
199 warpKRegisters, warpVRegisters, threadK, threadV);
203 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
204 warpK[i * kWarpSize + laneId] = warpKRegisters[i];
205 warpV[i * kWarpSize + laneId] = warpVRegisters[i];
213 __device__
inline void add(K k, V v) {
218 __device__
inline void reduce() {
231 merge(sharedK, sharedV);
249 K threadK[NumThreadQ];
250 V threadV[NumThreadQ];
266 template <
typename K,
272 struct BlockSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
273 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
275 __device__
inline BlockSelect(K initK, V initV, K* smemK, V* smemV,
int k) :
282 __device__
inline void addThreadQ(K k, V v) {
283 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
284 threadK = swap ? k : threadK;
285 threadV = swap ? v : threadV;
288 __device__
inline void checkThreadQ() {
293 __device__
inline void add(K k, V v) {
297 __device__
inline void reduce() {
310 int laneId = getLaneId();
311 int warpId = threadIdx.x / kWarpSize;
314 sharedK[warpId] = pair.k;
315 sharedV[warpId] = pair.v;
323 if (threadIdx.x == 0) {
324 threadK = sharedK[0];
325 threadV = sharedV[0];
328 for (
int i = 1; i < kNumWarps; ++i) {
332 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
333 threadK = swap ? k : threadK;
334 threadV = swap ? v : threadV;
339 sharedK[0] = threadK;
340 sharedV[0] = threadV;
362 template <
typename K,
370 static constexpr
int kNumWarpQRegisters = NumWarpQ / kWarpSize;
372 __device__
inline WarpSelect(K initKVal, V initVVal,
int k) :
377 kLane((k - 1) % kWarpSize) {
378 static_assert(utils::isPowerOf2(ThreadsPerBlock),
379 "threads must be a power-of-2");
380 static_assert(utils::isPowerOf2(NumWarpQ),
381 "warp queue must be power-of-2");
385 for (
int i = 0; i < NumThreadQ; ++i) {
392 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
398 __device__
inline void addThreadQ(K k, V v) {
399 if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) {
402 for (
int i = NumThreadQ - 1; i > 0; --i) {
403 threadK[i] = threadK[i - 1];
404 threadV[i] = threadV[i - 1];
413 __device__
inline void checkThreadQ() {
414 bool needSort = (numVals == NumThreadQ);
416 #if CUDA_VERSION >= 9000
417 needSort = __any_sync(0xffffffff, needSort);
419 needSort = __any(needSort);
434 for (
int i = 0; i < NumThreadQ; ++i) {
440 warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane);
448 warpSortAnyRegisters<K, V, NumThreadQ, !Dir, Comp>(threadK, threadV);
453 warpMergeAnyRegisters<K, V,
454 kNumWarpQRegisters, NumThreadQ, !Dir, Comp,
false>(
455 warpK, warpV, threadK, threadV);
460 __device__
inline void add(K k, V v) {
465 __device__
inline void reduce() {
472 __device__
inline void writeOut(K* outK, V* outV,
int k) {
473 int laneId = getLaneId();
476 for (
int i = 0; i < kNumWarpQRegisters; ++i) {
477 int idx = i * kWarpSize + laneId;
480 outK[idx] = warpK[i];
481 outV[idx] = warpV[i];
499 K threadK[NumThreadQ];
500 V threadV[NumThreadQ];
503 K warpK[kNumWarpQRegisters];
504 V warpV[kNumWarpQRegisters];
513 template <
typename K,
519 struct WarpSelect<K, V, Dir, Comp, 1, NumThreadQ, ThreadsPerBlock> {
520 static constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
522 __device__
inline WarpSelect(K initK, V initV,
int k) :
527 __device__
inline void addThreadQ(K k, V v) {
528 bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK);
529 threadK = swap ? k : threadK;
530 threadV = swap ? v : threadV;
533 __device__
inline void checkThreadQ() {
538 __device__
inline void add(K k, V v) {
542 __device__
inline void reduce() {
559 __device__
inline void writeOut(K* outK, V* outV,
int k) {
560 if (getLaneId() == 0) {
A simple pair type for CUDA device usage.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void writeOut(K *outK, V *outV, int k)
Dump final k selected values for this warp out.
__device__ void add(K k, V v)
__device__ void add(K k, V v)
__device__ void mergeWarpQ()
__device__ void mergeWarpQ()