12 #include "DeviceDefs.cuh"
13 #include "MergeNetworkUtils.cuh"
14 #include "PtxUtils.cuh"
15 #include "StaticUtils.h"
16 #include "WarpShuffles.cuh"
18 namespace faiss {
namespace gpu {
88 template <
typename K,
typename V,
int L,
89 bool Dir,
typename Comp,
bool IsBitonic>
90 inline __device__
void warpBitonicMergeLE16(K& k, V& v) {
91 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
92 static_assert(L <= kWarpSize / 2,
"merge list size must be <= 16");
94 int laneId = getLaneId();
100 K otherK = shfl_xor(k, 2 * L - 1);
101 V otherV = shfl_xor(v, 2 * L - 1);
104 bool small = !(laneId & L);
110 bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
111 assign(s, k, otherK);
112 assign(s, v, otherV);
115 bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
116 assign(s, k, otherK);
117 assign(s, v, otherV);
122 for (
int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
123 K otherK = shfl_xor(k, stride);
124 V otherV = shfl_xor(v, stride);
127 bool small = !(laneId & stride);
130 bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
131 assign(s, k, otherK);
132 assign(s, v, otherV);
135 bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
136 assign(s, k, otherK);
137 assign(s, v, otherV);
144 template <
typename K,
typename V,
int N,
145 bool Dir,
typename Comp,
bool Low,
bool Pow2>
154 template <
typename K,
typename V,
bool Dir,
typename Comp,
bool Low>
156 static inline __device__
void merge(K k[1], V v[1]) {
158 warpBitonicMergeLE16<K, V, 16, Dir, Comp, true>(k[0], v[0]);
162 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp,
bool Low>
164 static inline __device__
void merge(K k[N], V v[N]) {
165 static_assert(utils::isPowerOf2(N),
"must be power of 2");
166 static_assert(N > 1,
"must be N > 1");
169 for (
int i = 0; i < N / 2; ++i) {
173 K& kb = k[i + N / 2];
174 V& vb = v[i + N / 2];
176 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
186 for (
int i = 0; i < N / 2; ++i) {
194 for (
int i = 0; i < N / 2; ++i) {
205 for (
int i = 0; i < N / 2; ++i) {
206 newK[i] = k[i + N / 2];
207 newV[i] = v[i + N / 2];
213 for (
int i = 0; i < N / 2; ++i) {
214 k[i + N / 2] = newK[i];
215 v[i + N / 2] = newV[i];
226 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
228 static inline __device__
void merge(K k[N], V v[N]) {
229 static_assert(!utils::isPowerOf2(N),
"must be non-power-of-2");
230 static_assert(N >= 3,
"must be N >= 3");
232 constexpr
int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
235 for (
int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
239 K& kb = k[i + kNextHighestPowerOf2 / 2];
240 V& vb = v[i + kNextHighestPowerOf2 / 2];
242 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
247 constexpr
int kLowSize = N - kNextHighestPowerOf2 / 2;
248 constexpr
int kHighSize = kNextHighestPowerOf2 / 2;
254 for (
int i = 0; i < kLowSize; ++i) {
259 constexpr
bool kLowIsPowerOf2 =
260 utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
265 kLowIsPowerOf2>::merge(newK, newV);
268 for (
int i = 0; i < kLowSize; ++i) {
279 for (
int i = 0; i < kHighSize; ++i) {
280 newK[i] = k[i + kLowSize];
281 newV[i] = v[i + kLowSize];
284 constexpr
bool kHighIsPowerOf2 =
285 utils::isPowerOf2(kNextHighestPowerOf2 / 2);
290 kHighIsPowerOf2>::merge(newK, newV);
293 for (
int i = 0; i < kHighSize; ++i) {
294 k[i + kLowSize] = newK[i];
295 v[i + kLowSize] = newV[i];
302 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
304 static inline __device__
void merge(K k[N], V v[N]) {
305 static_assert(!utils::isPowerOf2(N),
"must be non-power-of-2");
306 static_assert(N >= 3,
"must be N >= 3");
308 constexpr
int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
311 for (
int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
315 K& kb = k[i + kNextHighestPowerOf2 / 2];
316 V& vb = v[i + kNextHighestPowerOf2 / 2];
318 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
323 constexpr
int kLowSize = kNextHighestPowerOf2 / 2;
324 constexpr
int kHighSize = N - kNextHighestPowerOf2 / 2;
330 for (
int i = 0; i < kLowSize; ++i) {
335 constexpr
bool kLowIsPowerOf2 =
336 utils::isPowerOf2(kNextHighestPowerOf2 / 2);
341 kLowIsPowerOf2>::merge(newK, newV);
344 for (
int i = 0; i < kLowSize; ++i) {
355 for (
int i = 0; i < kHighSize; ++i) {
356 newK[i] = k[i + kLowSize];
357 newV[i] = v[i + kLowSize];
360 constexpr
bool kHighIsPowerOf2 =
361 utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
366 kHighIsPowerOf2>::merge(newK, newV);
369 for (
int i = 0; i < kHighSize; ++i) {
370 k[i + kLowSize] = newK[i];
371 v[i + kLowSize] = newV[i];
381 template <
typename K,
387 bool FullMerge =
true>
388 inline __device__
void warpMergeAnyRegisters(K k1[N1], V v1[N1],
389 K k2[N2], V v2[N2]) {
390 constexpr
int kSmallestN = N1 < N2 ? N1 : N2;
393 for (
int i = 0; i < kSmallestN; ++i) {
394 K& ka = k1[N1 - 1 - i];
395 V& va = v1[N1 - 1 - i];
405 otherKa = shfl_xor(ka, kWarpSize - 1);
406 otherVa = shfl_xor(va, kWarpSize - 1);
409 K otherKb = shfl_xor(kb, kWarpSize - 1);
410 V otherVb = shfl_xor(vb, kWarpSize - 1);
414 bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
415 assign(swapa, ka, otherKb);
416 assign(swapa, va, otherVb);
421 bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
422 assign(swapb, kb, otherKa);
423 assign(swapb, vb, otherVa);
430 BitonicMergeStep<K, V, N1, Dir, Comp,
431 true, utils::isPowerOf2(N1)>::merge(k1, v1);
434 BitonicMergeStep<K, V, N2, Dir, Comp,
435 false, utils::isPowerOf2(N2)>::merge(k2, v2);
441 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
443 static inline __device__
void sort(K k[N], V v[N]) {
444 static_assert(N > 1,
"did not hit specialized case");
447 constexpr
int kSizeA = N / 2;
448 constexpr
int kSizeB = N - kSizeA;
454 for (
int i = 0; i < kSizeA; ++i) {
465 for (
int i = 0; i < kSizeB; ++i) {
466 bK[i] = k[i + kSizeA];
467 bV[i] = v[i + kSizeA];
473 warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir, Comp>(aK, aV, bK, bV);
476 for (
int i = 0; i < kSizeA; ++i) {
482 for (
int i = 0; i < kSizeB; ++i) {
483 k[i + kSizeA] = bK[i];
484 v[i + kSizeA] = bV[i];
490 template <
typename K,
typename V,
bool Dir,
typename Comp>
492 static inline __device__
void sort(K k[1], V v[1]) {
495 static_assert(kWarpSize == 32,
"unexpected warp size");
497 warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]);
498 warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]);
499 warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]);
500 warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]);
501 warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]);
507 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
508 inline __device__
void warpSortAnyRegisters(K k[N], V v[N]) {