11 #include "DeviceDefs.cuh"
12 #include "MergeNetworkUtils.cuh"
13 #include "PtxUtils.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
17 namespace faiss {
namespace gpu {
87 template <
typename K,
typename V,
int L,
88 bool Dir,
typename Comp,
bool IsBitonic>
89 inline __device__
void warpBitonicMergeLE16(K& k, V& v) {
90 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
91 static_assert(L <= kWarpSize / 2,
"merge list size must be <= 16");
93 int laneId = getLaneId();
99 K otherK = shfl_xor(k, 2 * L - 1);
100 V otherV = shfl_xor(v, 2 * L - 1);
103 bool small = !(laneId & L);
109 bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
110 assign(s, k, otherK);
111 assign(s, v, otherV);
114 bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
115 assign(s, k, otherK);
116 assign(s, v, otherV);
121 for (
int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
122 K otherK = shfl_xor(k, stride);
123 V otherV = shfl_xor(v, stride);
126 bool small = !(laneId & stride);
129 bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
130 assign(s, k, otherK);
131 assign(s, v, otherV);
134 bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
135 assign(s, k, otherK);
136 assign(s, v, otherV);
143 template <
typename K,
typename V,
int N,
144 bool Dir,
typename Comp,
bool Low,
bool Pow2>
153 template <
typename K,
typename V,
bool Dir,
typename Comp,
bool Low>
155 static inline __device__
void merge(K k[1], V v[1]) {
157 warpBitonicMergeLE16<K, V, 16, Dir, Comp, true>(k[0], v[0]);
161 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp,
bool Low>
163 static inline __device__
void merge(K k[N], V v[N]) {
164 static_assert(utils::isPowerOf2(N),
"must be power of 2");
165 static_assert(N > 1,
"must be N > 1");
168 for (
int i = 0; i < N / 2; ++i) {
172 K& kb = k[i + N / 2];
173 V& vb = v[i + N / 2];
175 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
185 for (
int i = 0; i < N / 2; ++i) {
193 for (
int i = 0; i < N / 2; ++i) {
204 for (
int i = 0; i < N / 2; ++i) {
205 newK[i] = k[i + N / 2];
206 newV[i] = v[i + N / 2];
212 for (
int i = 0; i < N / 2; ++i) {
213 k[i + N / 2] = newK[i];
214 v[i + N / 2] = newV[i];
225 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
227 static inline __device__
void merge(K k[N], V v[N]) {
228 static_assert(!utils::isPowerOf2(N),
"must be non-power-of-2");
229 static_assert(N >= 3,
"must be N >= 3");
231 constexpr
int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
234 for (
int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
238 K& kb = k[i + kNextHighestPowerOf2 / 2];
239 V& vb = v[i + kNextHighestPowerOf2 / 2];
241 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
246 constexpr
int kLowSize = N - kNextHighestPowerOf2 / 2;
247 constexpr
int kHighSize = kNextHighestPowerOf2 / 2;
253 for (
int i = 0; i < kLowSize; ++i) {
258 constexpr
bool kLowIsPowerOf2 =
259 utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
264 kLowIsPowerOf2>::merge(newK, newV);
267 for (
int i = 0; i < kLowSize; ++i) {
278 for (
int i = 0; i < kHighSize; ++i) {
279 newK[i] = k[i + kLowSize];
280 newV[i] = v[i + kLowSize];
283 constexpr
bool kHighIsPowerOf2 =
284 utils::isPowerOf2(kNextHighestPowerOf2 / 2);
289 kHighIsPowerOf2>::merge(newK, newV);
292 for (
int i = 0; i < kHighSize; ++i) {
293 k[i + kLowSize] = newK[i];
294 v[i + kLowSize] = newV[i];
301 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
303 static inline __device__
void merge(K k[N], V v[N]) {
304 static_assert(!utils::isPowerOf2(N),
"must be non-power-of-2");
305 static_assert(N >= 3,
"must be N >= 3");
307 constexpr
int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
310 for (
int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
314 K& kb = k[i + kNextHighestPowerOf2 / 2];
315 V& vb = v[i + kNextHighestPowerOf2 / 2];
317 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
322 constexpr
int kLowSize = kNextHighestPowerOf2 / 2;
323 constexpr
int kHighSize = N - kNextHighestPowerOf2 / 2;
329 for (
int i = 0; i < kLowSize; ++i) {
334 constexpr
bool kLowIsPowerOf2 =
335 utils::isPowerOf2(kNextHighestPowerOf2 / 2);
340 kLowIsPowerOf2>::merge(newK, newV);
343 for (
int i = 0; i < kLowSize; ++i) {
354 for (
int i = 0; i < kHighSize; ++i) {
355 newK[i] = k[i + kLowSize];
356 newV[i] = v[i + kLowSize];
359 constexpr
bool kHighIsPowerOf2 =
360 utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
365 kHighIsPowerOf2>::merge(newK, newV);
368 for (
int i = 0; i < kHighSize; ++i) {
369 k[i + kLowSize] = newK[i];
370 v[i + kLowSize] = newV[i];
380 template <
typename K,
386 bool FullMerge =
true>
387 inline __device__
void warpMergeAnyRegisters(K k1[N1], V v1[N1],
388 K k2[N2], V v2[N2]) {
389 constexpr
int kSmallestN = N1 < N2 ? N1 : N2;
392 for (
int i = 0; i < kSmallestN; ++i) {
393 K& ka = k1[N1 - 1 - i];
394 V& va = v1[N1 - 1 - i];
404 otherKa = shfl_xor(ka, kWarpSize - 1);
405 otherVa = shfl_xor(va, kWarpSize - 1);
408 K otherKb = shfl_xor(kb, kWarpSize - 1);
409 V otherVb = shfl_xor(vb, kWarpSize - 1);
413 bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
414 assign(swapa, ka, otherKb);
415 assign(swapa, va, otherVb);
420 bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
421 assign(swapb, kb, otherKa);
422 assign(swapb, vb, otherVa);
429 BitonicMergeStep<K, V, N1, Dir, Comp,
430 true, utils::isPowerOf2(N1)>::merge(k1, v1);
433 BitonicMergeStep<K, V, N2, Dir, Comp,
434 false, utils::isPowerOf2(N2)>::merge(k2, v2);
440 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
442 static inline __device__
void sort(K k[N], V v[N]) {
443 static_assert(N > 1,
"did not hit specialized case");
446 constexpr
int kSizeA = N / 2;
447 constexpr
int kSizeB = N - kSizeA;
453 for (
int i = 0; i < kSizeA; ++i) {
464 for (
int i = 0; i < kSizeB; ++i) {
465 bK[i] = k[i + kSizeA];
466 bV[i] = v[i + kSizeA];
472 warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir, Comp>(aK, aV, bK, bV);
475 for (
int i = 0; i < kSizeA; ++i) {
481 for (
int i = 0; i < kSizeB; ++i) {
482 k[i + kSizeA] = bK[i];
483 v[i + kSizeA] = bV[i];
489 template <
typename K,
typename V,
bool Dir,
typename Comp>
491 static inline __device__
void sort(K k[1], V v[1]) {
494 static_assert(kWarpSize == 32,
"unexpected warp size");
496 warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]);
497 warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]);
498 warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]);
499 warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]);
500 warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]);
506 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
507 inline __device__
void warpSortAnyRegisters(K k[N], V v[N]) {