13 #include "DeviceDefs.cuh"
14 #include "PtxUtils.cuh"
15 #include "StaticUtils.h"
16 #include "WarpShuffles.cuh"
18 namespace faiss {
namespace gpu {
82 inline __device__
void swap(
bool swap, T& x, T& y) {
95 template <
typename K,
typename V,
int L,
96 bool Dir,
typename Comp,
bool IsBitonic>
97 inline __device__
void warpBitonicMergeLE16(K& k, V& v) {
98 static_assert(utils::isPowerOf2(L),
"L must be a power-of-2");
99 static_assert(L <= kWarpSize / 2,
"merge list size must be <= 16");
101 int laneId = getLaneId();
107 K otherK = shfl_xor(k, 2 * L - 1);
108 V otherV = shfl_xor(v, 2 * L - 1);
111 bool small = !(laneId & L);
117 bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
122 bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
129 for (
int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
130 K otherK = shfl_xor(k, stride);
131 V otherV = shfl_xor(v, stride);
134 bool small = !(laneId & stride);
137 bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
142 bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
151 template <
typename K,
typename V,
int N,
152 bool Dir,
typename Comp,
bool Low,
bool Pow2>
161 template <
typename K,
typename V,
bool Dir,
typename Comp,
bool Low>
163 static inline __device__
void merge(K k[1], V v[1]) {
165 warpBitonicMergeLE16<K, V, 16, Dir, Comp, true>(k[0], v[0]);
169 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp,
bool Low>
171 static inline __device__
void merge(K k[N], V v[N]) {
172 static_assert(utils::isPowerOf2(N),
"must be power of 2");
173 static_assert(N > 1,
"must be N > 1");
176 for (
int i = 0; i < N / 2; ++i) {
180 K& kb = k[i + N / 2];
181 V& vb = v[i + N / 2];
183 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
193 for (
int i = 0; i < N / 2; ++i) {
201 for (
int i = 0; i < N / 2; ++i) {
212 for (
int i = 0; i < N / 2; ++i) {
213 newK[i] = k[i + N / 2];
214 newV[i] = v[i + N / 2];
220 for (
int i = 0; i < N / 2; ++i) {
221 k[i + N / 2] = newK[i];
222 v[i + N / 2] = newV[i];
233 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
235 static inline __device__
void merge(K k[N], V v[N]) {
236 static_assert(!utils::isPowerOf2(N),
"must be non-power-of-2");
237 static_assert(N >= 3,
"must be N >= 3");
239 constexpr
int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
242 for (
int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
246 K& kb = k[i + kNextHighestPowerOf2 / 2];
247 V& vb = v[i + kNextHighestPowerOf2 / 2];
249 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
254 constexpr
int kLowSize = N - kNextHighestPowerOf2 / 2;
255 constexpr
int kHighSize = kNextHighestPowerOf2 / 2;
261 for (
int i = 0; i < kLowSize; ++i) {
266 constexpr
bool kLowIsPowerOf2 =
267 utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
272 kLowIsPowerOf2>::merge(newK, newV);
275 for (
int i = 0; i < kLowSize; ++i) {
286 for (
int i = 0; i < kHighSize; ++i) {
287 newK[i] = k[i + kLowSize];
288 newV[i] = v[i + kLowSize];
291 constexpr
bool kHighIsPowerOf2 =
292 utils::isPowerOf2(kNextHighestPowerOf2 / 2);
297 kHighIsPowerOf2>::merge(newK, newV);
300 for (
int i = 0; i < kHighSize; ++i) {
301 k[i + kLowSize] = newK[i];
302 v[i + kLowSize] = newV[i];
309 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
311 static inline __device__
void merge(K k[N], V v[N]) {
312 static_assert(!utils::isPowerOf2(N),
"must be non-power-of-2");
313 static_assert(N >= 3,
"must be N >= 3");
315 constexpr
int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
318 for (
int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
322 K& kb = k[i + kNextHighestPowerOf2 / 2];
323 V& vb = v[i + kNextHighestPowerOf2 / 2];
325 bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
330 constexpr
int kLowSize = kNextHighestPowerOf2 / 2;
331 constexpr
int kHighSize = N - kNextHighestPowerOf2 / 2;
337 for (
int i = 0; i < kLowSize; ++i) {
342 constexpr
bool kLowIsPowerOf2 =
343 utils::isPowerOf2(kNextHighestPowerOf2 / 2);
348 kLowIsPowerOf2>::merge(newK, newV);
351 for (
int i = 0; i < kLowSize; ++i) {
362 for (
int i = 0; i < kHighSize; ++i) {
363 newK[i] = k[i + kLowSize];
364 newV[i] = v[i + kLowSize];
367 constexpr
bool kHighIsPowerOf2 =
368 utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
373 kHighIsPowerOf2>::merge(newK, newV);
376 for (
int i = 0; i < kHighSize; ++i) {
377 k[i + kLowSize] = newK[i];
378 v[i + kLowSize] = newV[i];
388 template <
typename K,
typename V,
int N1,
int N2,
bool Dir,
typename Comp>
389 inline __device__
void warpMergeAnyRegisters(K k1[N1], V v1[N1],
390 K k2[N2], V v2[N2]) {
391 constexpr
int kSmallestN = N1 < N2 ? N1 : N2;
394 for (
int i = 0; i < kSmallestN; ++i) {
395 K& ka = k1[N1 - 1 - i];
396 V& va = v1[N1 - 1 - i];
401 K otherKa = shfl_xor(ka, kWarpSize - 1);
402 V otherVa = shfl_xor(va, kWarpSize - 1);
404 K otherKb = shfl_xor(kb, kWarpSize - 1);
405 V otherVb = shfl_xor(vb, kWarpSize - 1);
409 bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
410 swap(swapa, ka, otherKb);
411 swap(swapa, va, otherVb);
415 bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
416 swap(swapb, kb, otherKa);
417 swap(swapb, vb, otherVa);
420 BitonicMergeStep<K, V, N1, Dir, Comp,
421 true, utils::isPowerOf2(N1)>::merge(k1, v1);
422 BitonicMergeStep<K, V, N2, Dir, Comp,
423 false, utils::isPowerOf2(N2)>::merge(k2, v2);
428 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
430 static inline __device__
void sort(K k[N], V v[N]) {
431 static_assert(N > 1,
"did not hit specialized case");
434 constexpr
int kSizeA = N / 2;
435 constexpr
int kSizeB = N - kSizeA;
441 for (
int i = 0; i < kSizeA; ++i) {
452 for (
int i = 0; i < kSizeB; ++i) {
453 bK[i] = k[i + kSizeA];
454 bV[i] = v[i + kSizeA];
460 warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir, Comp>(aK, aV, bK, bV);
463 for (
int i = 0; i < kSizeA; ++i) {
469 for (
int i = 0; i < kSizeB; ++i) {
470 k[i + kSizeA] = bK[i];
471 v[i + kSizeA] = bV[i];
477 template <
typename K,
typename V,
bool Dir,
typename Comp>
479 static inline __device__
void sort(K k[1], V v[1]) {
482 static_assert(kWarpSize == 32,
"unexpected warp size");
484 warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]);
485 warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]);
486 warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]);
487 warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]);
488 warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]);
494 template <
typename K,
typename V,
int N,
bool Dir,
typename Comp>
495 inline __device__
void warpSortAnyRegisters(K k[N], V v[N]) {