10 #include "DeviceDefs.cuh"
11 #include "MergeNetworkUtils.cuh"
12 #include "PtxUtils.cuh"
13 #include "StaticUtils.h"
14 #include "WarpShuffles.cuh"
15 #include "../../FaissAssert.h"
18 namespace faiss {
namespace gpu {
21 template <
int NumThreads,
30 inline __device__
void blockMergeSmall(K* listK, V* listV) {
31 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
32 static_assert(utils::isPowerOf2(NumThreads),
33 "NumThreads must be a power-of-2");
34 static_assert(L <= NumThreads,
"merge list size must be <= NumThreads");
37 int mergeId = threadIdx.x / L;
40 int tid = threadIdx.x % L;
43 listK += 2 * L * mergeId;
44 listV += 2 * L * mergeId;
48 int pos = L - 1 - tid;
49 int stride = 2 * tid + 1;
51 if (AllThreads || (threadIdx.x < N * L)) {
53 K kb = listK[pos + stride];
55 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
56 listK[pos] = swap ? kb : ka;
57 listK[pos + stride] = swap ? ka : kb;
60 V vb = listV[pos + stride];
61 listV[pos] = swap ? vb : va;
62 listV[pos + stride] = swap ? va : vb;
79 for (
int stride = L / 2; stride > 0; stride /= 2) {
80 int pos = 2 * tid - (tid & (stride - 1));
82 if (AllThreads || (threadIdx.x < N * L)) {
84 K kb = listK[pos + stride];
86 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
87 listK[pos] = swap ? kb : ka;
88 listK[pos + stride] = swap ? ka : kb;
91 V vb = listV[pos + stride];
92 listV[pos] = swap ? vb : va;
93 listV[pos + stride] = swap ? va : vb;
112 template <
int NumThreads,
119 inline __device__
void blockMergeLarge(K* listK, V* listV) {
120 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
121 static_assert(L >= kWarpSize,
"merge list size must be >= 32");
122 static_assert(utils::isPowerOf2(NumThreads),
123 "NumThreads must be a power-of-2");
124 static_assert(L >= NumThreads,
"merge list size must be >= NumThreads");
128 constexpr
int kLoopPerThread = L / NumThreads;
133 for (
int loop = 0; loop < kLoopPerThread; ++loop) {
134 int tid = loop * NumThreads + threadIdx.x;
135 int pos = L - 1 - tid;
136 int stride = 2 * tid + 1;
139 K kb = listK[pos + stride];
141 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
142 listK[pos] = swap ? kb : ka;
143 listK[pos + stride] = swap ? ka : kb;
146 V vb = listV[pos + stride];
147 listV[pos] = swap ? vb : va;
148 listV[pos + stride] = swap ? va : vb;
164 constexpr
int kSecondLoopPerThread =
165 FullMerge ? kLoopPerThread : kLoopPerThread / 2;
168 for (
int stride = L / 2; stride > 0; stride /= 2) {
170 for (
int loop = 0; loop < kSecondLoopPerThread; ++loop) {
171 int tid = loop * NumThreads + threadIdx.x;
172 int pos = 2 * tid - (tid & (stride - 1));
175 K kb = listK[pos + stride];
177 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
178 listK[pos] = swap ? kb : ka;
179 listK[pos + stride] = swap ? ka : kb;
182 V vb = listV[pos + stride];
183 listV[pos] = swap ? vb : va;
184 listV[pos + stride] = swap ? va : vb;
204 template <
int NumThreads,
211 bool SmallerThanBlock,
217 template <
int NumThreads,
225 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true, FullMerge> {
226 static inline __device__
void merge(K* listK, V* listV) {
227 constexpr
int kNumParallelMerges = NumThreads / L;
228 constexpr
int kNumIterations = N / kNumParallelMerges;
230 static_assert(L <= NumThreads,
"list must be <= NumThreads");
231 static_assert((N < kNumParallelMerges) ||
232 (kNumIterations * kNumParallelMerges == N),
233 "improper selection of N and L");
235 if (N < kNumParallelMerges) {
237 blockMergeSmall<NumThreads, K, V, N, L, false, Dir, Comp, FullMerge>(
242 for (
int i = 0; i < kNumIterations; ++i) {
243 int start = i * kNumParallelMerges * 2 * L;
245 blockMergeSmall<NumThreads, K, V, N, L, true, Dir, Comp, FullMerge>(
246 listK + start, listV + start);
253 template <
int NumThreads,
261 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false, FullMerge> {
262 static inline __device__
void merge(K* listK, V* listV) {
265 for (
int i = 0; i < N; ++i) {
266 int start = i * 2 * L;
268 blockMergeLarge<NumThreads, K, V, L, Dir, Comp, FullMerge>(
269 listK + start, listV + start);
274 template <
int NumThreads,
281 bool FullMerge =
true>
282 inline __device__
void blockMerge(K* listK, V* listV) {
283 constexpr
bool kSmallerThanBlock = (L <= NumThreads);