11 #include "DeviceDefs.cuh"
12 #include "MergeNetworkUtils.cuh"
13 #include "PtxUtils.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
16 #include "../../FaissAssert.h"
19 namespace faiss {
namespace gpu {
22 template <
int NumThreads,
29 inline __device__
void blockMergeSmall(K* listK, V* listV) {
30 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
31 static_assert(utils::isPowerOf2(NumThreads),
32 "NumThreads must be a power-of-2");
33 static_assert(L <= NumThreads,
"merge list size must be <= NumThreads");
36 int mergeId = threadIdx.x / L;
39 int tid = threadIdx.x % L;
42 listK += 2 * L * mergeId;
43 listV += 2 * L * mergeId;
47 int pos = L - 1 - tid;
48 int stride = 2 * tid + 1;
51 K& kb = listK[pos + stride];
53 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
57 V& vb = listV[pos + stride];
63 for (
int stride = L / 2; stride > 0; stride /= 2) {
64 int pos = 2 * tid - (tid & (stride - 1));
67 K& kb = listK[pos + stride];
69 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
73 V& vb = listV[pos + stride];
81 template <
int NumThreads,
88 inline __device__
void blockMergeLarge(K* listK, V* listV) {
89 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
90 static_assert(L >= kWarpSize,
"merge list size must be >= 32");
91 static_assert(utils::isPowerOf2(NumThreads),
92 "NumThreads must be a power-of-2");
93 static_assert(L >= NumThreads,
"merge list size must be >= NumThreads");
97 constexpr
int kLoopPerThread = L / NumThreads;
102 for (
int loop = 0; loop < kLoopPerThread; ++loop) {
103 int tid = loop * NumThreads + threadIdx.x;
104 int pos = L - 1 - tid;
105 int stride = 2 * tid + 1;
108 K& kb = listK[pos + stride];
110 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
114 V& vb = listV[pos + stride];
120 constexpr
int kSecondLoopPerThread =
121 FullMerge ? kLoopPerThread : kLoopPerThread / 2;
124 for (
int stride = L / 2; stride > 0; stride /= 2) {
126 for (
int loop = 0; loop < kSecondLoopPerThread; ++loop) {
127 int tid = loop * NumThreads + threadIdx.x;
128 int pos = 2 * tid - (tid & (stride - 1));
131 K& kb = listK[pos + stride];
133 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
137 V& vb = listV[pos + stride];
147 template <
int NumThreads,
154 bool SmallerThanBlock,
160 template <
int NumThreads,
168 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true, FullMerge> {
169 static inline __device__
void merge(K* listK, V* listV) {
170 constexpr
int kNumParallelMerges = NumThreads / L;
171 constexpr
int kNumIterations = N / kNumParallelMerges;
173 static_assert(L <= NumThreads,
"list must be <= NumThreads");
174 static_assert((N < kNumParallelMerges) ||
175 (kNumIterations * kNumParallelMerges == N),
176 "improper selection of N and L");
178 if (N < kNumParallelMerges) {
180 if (threadIdx.x < N * L) {
181 blockMergeSmall<NumThreads, K, V, L, Dir, Comp, FullMerge>(
187 for (
int i = 0; i < kNumIterations; ++i) {
188 int start = i * kNumParallelMerges * 2 * L;
190 blockMergeSmall<NumThreads, K, V, L, Dir, Comp, FullMerge>(
191 listK + start, listV + start);
198 template <
int NumThreads,
206 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false, FullMerge> {
207 static inline __device__
void merge(K* listK, V* listV) {
210 for (
int i = 0; i < N; ++i) {
211 int start = i * 2 * L;
213 blockMergeLarge<NumThreads, K, V, L, Dir, Comp, FullMerge>(
214 listK + start, listV + start);
219 template <
int NumThreads,
226 bool FullMerge =
true>
227 inline __device__
void blockMerge(K* listK, V* listV) {
228 constexpr
bool kSmallerThanBlock = (L <= NumThreads);