10 #include "DeviceDefs.cuh"
11 #include "MergeNetworkUtils.cuh"
12 #include "PtxUtils.cuh"
13 #include "StaticUtils.h"
14 #include "WarpShuffles.cuh"
16 namespace faiss {
namespace gpu {
86 template <
typename K,
typename V,
int L,
87 bool Dir,
typename Comp,
bool IsBitonic>
88 inline __device__
void warpBitonicMergeLE16(K& k, V& v) {
89 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
90 static_assert(L <= kWarpSize / 2,
"merge list size must be <= 16");
92 int laneId = getLaneId();
98 K otherK = shfl_xor(k, 2 * L - 1);
99 V otherV = shfl_xor(v, 2 * L - 1);
102 bool small = !(laneId & L);
108 bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
109 assign(s, k, otherK);
110 assign(s, v, otherV);
113 bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
114 assign(s, k, otherK);
115 assign(s, v, otherV);
120 for (
int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
121 K otherK = shfl_xor(k, stride);
122 V otherV = shfl_xor(v, stride);
125 bool small = !(laneId & stride);
128 bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
129 assign(s, k, otherK);
130 assign(s, v, otherV);
133 bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
134 assign(s, k, otherK);
135 assign(s, v, otherV);
142 template <
typename K,
typename V,
int N,
143 bool Dir,
typename Comp,
bool Low,
bool Pow2>
152 template <
typename K,
typename V,
bool Dir,
typename Comp,
bool Low>
154 static inline __device__
void merge(K k[1], V v[1]) {
156 warpBitonicMergeLE16<K, V, 16, Dir, Comp, true>(k[0], v[0]);
160 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp,
bool Low>
162 static inline __device__
void merge(K k[N], V v[N]) {
163 static_assert(utils::isPowerOf2(N),
"must be power of 2");
164 static_assert(N > 1,
"must be N > 1");
167 for (
int i = 0; i < N / 2; ++i) {
171 K& kb = k[i + N / 2];
172 V& vb = v[i + N / 2];
174 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
184 for (
int i = 0; i < N / 2; ++i) {
192 for (
int i = 0; i < N / 2; ++i) {
203 for (
int i = 0; i < N / 2; ++i) {
204 newK[i] = k[i + N / 2];
205 newV[i] = v[i + N / 2];
211 for (
int i = 0; i < N / 2; ++i) {
212 k[i + N / 2] = newK[i];
213 v[i + N / 2] = newV[i];
224 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
226 static inline __device__
void merge(K k[N], V v[N]) {
227 static_assert(!utils::isPowerOf2(N),
"must be non-power-of-2");
228 static_assert(N >= 3,
"must be N >= 3");
230 constexpr
int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
233 for (
int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
237 K& kb = k[i + kNextHighestPowerOf2 / 2];
238 V& vb = v[i + kNextHighestPowerOf2 / 2];
240 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
245 constexpr
int kLowSize = N - kNextHighestPowerOf2 / 2;
246 constexpr
int kHighSize = kNextHighestPowerOf2 / 2;
252 for (
int i = 0; i < kLowSize; ++i) {
257 constexpr
bool kLowIsPowerOf2 =
258 utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
263 kLowIsPowerOf2>::merge(newK, newV);
266 for (
int i = 0; i < kLowSize; ++i) {
277 for (
int i = 0; i < kHighSize; ++i) {
278 newK[i] = k[i + kLowSize];
279 newV[i] = v[i + kLowSize];
282 constexpr
bool kHighIsPowerOf2 =
283 utils::isPowerOf2(kNextHighestPowerOf2 / 2);
288 kHighIsPowerOf2>::merge(newK, newV);
291 for (
int i = 0; i < kHighSize; ++i) {
292 k[i + kLowSize] = newK[i];
293 v[i + kLowSize] = newV[i];
300 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
302 static inline __device__
void merge(K k[N], V v[N]) {
303 static_assert(!utils::isPowerOf2(N),
"must be non-power-of-2");
304 static_assert(N >= 3,
"must be N >= 3");
306 constexpr
int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
309 for (
int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
313 K& kb = k[i + kNextHighestPowerOf2 / 2];
314 V& vb = v[i + kNextHighestPowerOf2 / 2];
316 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
321 constexpr
int kLowSize = kNextHighestPowerOf2 / 2;
322 constexpr
int kHighSize = N - kNextHighestPowerOf2 / 2;
328 for (
int i = 0; i < kLowSize; ++i) {
333 constexpr
bool kLowIsPowerOf2 =
334 utils::isPowerOf2(kNextHighestPowerOf2 / 2);
339 kLowIsPowerOf2>::merge(newK, newV);
342 for (
int i = 0; i < kLowSize; ++i) {
353 for (
int i = 0; i < kHighSize; ++i) {
354 newK[i] = k[i + kLowSize];
355 newV[i] = v[i + kLowSize];
358 constexpr
bool kHighIsPowerOf2 =
359 utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
364 kHighIsPowerOf2>::merge(newK, newV);
367 for (
int i = 0; i < kHighSize; ++i) {
368 k[i + kLowSize] = newK[i];
369 v[i + kLowSize] = newV[i];
379 template <
typename K,
385 bool FullMerge =
true>
386 inline __device__
void warpMergeAnyRegisters(K k1[N1], V v1[N1],
387 K k2[N2], V v2[N2]) {
388 constexpr
int kSmallestN = N1 < N2 ? N1 : N2;
391 for (
int i = 0; i < kSmallestN; ++i) {
392 K& ka = k1[N1 - 1 - i];
393 V& va = v1[N1 - 1 - i];
403 otherKa = shfl_xor(ka, kWarpSize - 1);
404 otherVa = shfl_xor(va, kWarpSize - 1);
407 K otherKb = shfl_xor(kb, kWarpSize - 1);
408 V otherVb = shfl_xor(vb, kWarpSize - 1);
412 bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
413 assign(swapa, ka, otherKb);
414 assign(swapa, va, otherVb);
419 bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
420 assign(swapb, kb, otherKa);
421 assign(swapb, vb, otherVa);
428 BitonicMergeStep<K, V, N1, Dir, Comp,
429 true, utils::isPowerOf2(N1)>::merge(k1, v1);
432 BitonicMergeStep<K, V, N2, Dir, Comp,
433 false, utils::isPowerOf2(N2)>::merge(k2, v2);
439 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
441 static inline __device__
void sort(K k[N], V v[N]) {
442 static_assert(N > 1,
"did not hit specialized case");
445 constexpr
int kSizeA = N / 2;
446 constexpr
int kSizeB = N - kSizeA;
452 for (
int i = 0; i < kSizeA; ++i) {
463 for (
int i = 0; i < kSizeB; ++i) {
464 bK[i] = k[i + kSizeA];
465 bV[i] = v[i + kSizeA];
471 warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir, Comp>(aK, aV, bK, bV);
474 for (
int i = 0; i < kSizeA; ++i) {
480 for (
int i = 0; i < kSizeB; ++i) {
481 k[i + kSizeA] = bK[i];
482 v[i + kSizeA] = bV[i];
488 template <
typename K,
typename V,
bool Dir,
typename Comp>
490 static inline __device__
void sort(K k[1], V v[1]) {
493 static_assert(kWarpSize == 32,
"unexpected warp size");
495 warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]);
496 warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]);
497 warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]);
498 warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]);
499 warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]);
505 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
506 inline __device__
void warpSortAnyRegisters(K k[N], V v[N]) {