12 #include "DeviceDefs.cuh"
13 #include "PtxUtils.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
16 #include "../../FaissAssert.h"
19 namespace faiss {
namespace gpu {
22 template <
int NumThreads,
typename K,
typename V,
int L,
23 bool Dir,
typename Comp>
24 inline __device__
void blockMergeSmall(K* listK, V* listV) {
25 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
26 static_assert(utils::isPowerOf2(NumThreads),
27 "NumThreads must be a power-of-2");
28 static_assert(L <= NumThreads,
"merge list size must be <= NumThreads");
31 int mergeId = threadIdx.x / L;
34 int tid = threadIdx.x % L;
37 listK += 2 * L * mergeId;
38 listV += 2 * L * mergeId;
42 int pos = L - 1 - tid;
43 int stride = 2 * tid + 1;
46 K kb = listK[pos + stride];
48 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
49 listK[pos] = swap ? kb : ka;
50 listK[pos + stride] = swap ? ka : kb;
53 V vb = listV[pos + stride];
54 listV[pos] = swap ? vb : va;
55 listV[pos + stride] = swap ? va : vb;
60 for (
int stride = L / 2; stride > 0; stride /= 2) {
61 int pos = 2 * tid - (tid & (stride - 1));
64 K kb = listK[pos + stride];
66 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
67 listK[pos] = swap ? kb : ka;
68 listK[pos + stride] = swap ? ka : kb;
71 V vb = listV[pos + stride];
72 listV[pos] = swap ? vb : va;
73 listV[pos + stride] = swap ? va : vb;
80 template <
int NumThreads,
typename K,
typename V,
int L,
81 bool Dir,
typename Comp>
82 inline __device__
void blockMergeLarge(K* listK, V* listV) {
83 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
84 static_assert(L >= kWarpSize,
"merge list size must be >= 32");
85 static_assert(utils::isPowerOf2(NumThreads),
86 "NumThreads must be a power-of-2");
87 static_assert(L >= NumThreads,
"merge list size must be >= NumThreads");
91 constexpr
int kLoopPerThread = L / NumThreads;
96 for (
int loop = 0; loop < kLoopPerThread; ++loop) {
97 int tid = loop * NumThreads + threadIdx.x;
98 int pos = L - 1 - tid;
99 int stride = 2 * tid + 1;
102 K kb = listK[pos + stride];
104 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
105 listK[pos] = swap ? kb : ka;
106 listK[pos + stride] = swap ? ka : kb;
109 V vb = listV[pos + stride];
110 listV[pos] = swap ? vb : va;
111 listV[pos + stride] = swap ? va : vb;
117 for (
int stride = L / 2; stride > 0; stride /= 2) {
119 for (
int loop = 0; loop < kLoopPerThread; ++loop) {
120 int tid = loop * NumThreads + threadIdx.x;
121 int pos = 2 * tid - (tid & (stride - 1));
124 K kb = listK[pos + stride];
126 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
127 listK[pos] = swap ? kb : ka;
128 listK[pos + stride] = swap ? ka : kb;
131 V vb = listV[pos + stride];
132 listV[pos] = swap ? vb : va;
133 listV[pos + stride] = swap ? va : vb;
142 template <
int NumThreads,
143 typename K,
typename V,
int N,
int L,
144 bool Dir,
typename Comp,
bool SmallerThanBlock>
149 template <
int NumThreads,
typename K,
typename V,
int N,
int L,
150 bool Dir,
typename Comp>
152 static inline __device__
void merge(K* listK, V* listV) {
153 constexpr
int kNumParallelMerges = NumThreads / L;
154 constexpr
int kNumIterations = N / kNumParallelMerges;
156 static_assert(L <= NumThreads,
"list must be <= NumThreads");
157 static_assert((N < kNumParallelMerges) ||
158 (kNumIterations * kNumParallelMerges == N),
159 "improper selection of N and L");
161 if (N < kNumParallelMerges) {
163 if (threadIdx.x < N * L) {
164 blockMergeSmall<NumThreads, K, V, L, Dir, Comp>(listK, listV);
169 for (
int i = 0; i < kNumIterations; ++i) {
170 int start = i * kNumParallelMerges * 2 * L;
172 blockMergeSmall<NumThreads, K, V, L, Dir, Comp>(listK + start,
180 template <
int NumThreads,
typename K,
typename V,
int N,
int L,
181 bool Dir,
typename Comp>
182 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false> {
183 static inline __device__
void merge(K* listK, V* listV) {
186 for (
int i = 0; i < N; ++i) {
187 int start = i * 2 * L;
189 blockMergeLarge<NumThreads, K, V, L, Dir, Comp>(listK + start,
195 template <
int NumThreads,
typename K,
typename V,
int N,
int L,
196 bool Dir,
typename Comp>
197 inline __device__
void blockMerge(K* listK, V* listV) {
198 constexpr
bool kSmallerThanBlock = (L <= NumThreads);