11 #include "DeviceDefs.cuh"
12 #include "MergeNetworkUtils.cuh"
13 #include "PtxUtils.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
16 #include "../../FaissAssert.h"
19 namespace faiss {
namespace gpu {
22 template <
int NumThreads,
31 inline __device__
void blockMergeSmall(K* listK, V* listV) {
32 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
33 static_assert(utils::isPowerOf2(NumThreads),
34 "NumThreads must be a power-of-2");
35 static_assert(L <= NumThreads,
"merge list size must be <= NumThreads");
38 int mergeId = threadIdx.x / L;
41 int tid = threadIdx.x % L;
44 listK += 2 * L * mergeId;
45 listV += 2 * L * mergeId;
49 int pos = L - 1 - tid;
50 int stride = 2 * tid + 1;
52 if (AllThreads || (threadIdx.x < N * L)) {
54 K kb = listK[pos + stride];
56 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
57 listK[pos] = swap ? kb : ka;
58 listK[pos + stride] = swap ? ka : kb;
61 V vb = listV[pos + stride];
62 listV[pos] = swap ? vb : va;
63 listV[pos + stride] = swap ? va : vb;
80 for (
int stride = L / 2; stride > 0; stride /= 2) {
81 int pos = 2 * tid - (tid & (stride - 1));
83 if (AllThreads || (threadIdx.x < N * L)) {
85 K kb = listK[pos + stride];
87 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
88 listK[pos] = swap ? kb : ka;
89 listK[pos + stride] = swap ? ka : kb;
92 V vb = listV[pos + stride];
93 listV[pos] = swap ? vb : va;
94 listV[pos + stride] = swap ? va : vb;
113 template <
int NumThreads,
120 inline __device__
void blockMergeLarge(K* listK, V* listV) {
121 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
122 static_assert(L >= kWarpSize,
"merge list size must be >= 32");
123 static_assert(utils::isPowerOf2(NumThreads),
124 "NumThreads must be a power-of-2");
125 static_assert(L >= NumThreads,
"merge list size must be >= NumThreads");
129 constexpr
int kLoopPerThread = L / NumThreads;
134 for (
int loop = 0; loop < kLoopPerThread; ++loop) {
135 int tid = loop * NumThreads + threadIdx.x;
136 int pos = L - 1 - tid;
137 int stride = 2 * tid + 1;
140 K kb = listK[pos + stride];
142 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
143 listK[pos] = swap ? kb : ka;
144 listK[pos + stride] = swap ? ka : kb;
147 V vb = listV[pos + stride];
148 listV[pos] = swap ? vb : va;
149 listV[pos + stride] = swap ? va : vb;
165 constexpr
int kSecondLoopPerThread =
166 FullMerge ? kLoopPerThread : kLoopPerThread / 2;
169 for (
int stride = L / 2; stride > 0; stride /= 2) {
171 for (
int loop = 0; loop < kSecondLoopPerThread; ++loop) {
172 int tid = loop * NumThreads + threadIdx.x;
173 int pos = 2 * tid - (tid & (stride - 1));
176 K kb = listK[pos + stride];
178 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
179 listK[pos] = swap ? kb : ka;
180 listK[pos + stride] = swap ? ka : kb;
183 V vb = listV[pos + stride];
184 listV[pos] = swap ? vb : va;
185 listV[pos + stride] = swap ? va : vb;
205 template <
int NumThreads,
212 bool SmallerThanBlock,
218 template <
int NumThreads,
226 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true, FullMerge> {
227 static inline __device__
void merge(K* listK, V* listV) {
228 constexpr
int kNumParallelMerges = NumThreads / L;
229 constexpr
int kNumIterations = N / kNumParallelMerges;
231 static_assert(L <= NumThreads,
"list must be <= NumThreads");
232 static_assert((N < kNumParallelMerges) ||
233 (kNumIterations * kNumParallelMerges == N),
234 "improper selection of N and L");
236 if (N < kNumParallelMerges) {
238 blockMergeSmall<NumThreads, K, V, N, L, false, Dir, Comp, FullMerge>(
243 for (
int i = 0; i < kNumIterations; ++i) {
244 int start = i * kNumParallelMerges * 2 * L;
246 blockMergeSmall<NumThreads, K, V, N, L, true, Dir, Comp, FullMerge>(
247 listK + start, listV + start);
254 template <
int NumThreads,
262 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false, FullMerge> {
263 static inline __device__
void merge(K* listK, V* listV) {
266 for (
int i = 0; i < N; ++i) {
267 int start = i * 2 * L;
269 blockMergeLarge<NumThreads, K, V, L, Dir, Comp, FullMerge>(
270 listK + start, listV + start);
275 template <
int NumThreads,
282 bool FullMerge =
true>
283 inline __device__
void blockMerge(K* listK, V* listV) {
284 constexpr
bool kSmallerThanBlock = (L <= NumThreads);