faiss/gpu/utils/MergeNetworkWarp.cuh

513 lines
15 KiB
Plaintext

/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the CC-by-NC license found in the
* LICENSE file in the root directory of this source tree.
*/
// Copyright 2004-present Facebook. All Rights Reserved.
#pragma once
#include "DeviceDefs.cuh"
#include "MergeNetworkUtils.cuh"
#include "PtxUtils.cuh"
#include "StaticUtils.h"
#include "WarpShuffles.cuh"
namespace faiss { namespace gpu {
//
// This file contains functions to:
//
// -perform bitonic merges on pairs of sorted lists, held in
// registers. Each list contains N * kWarpSize (multiple of 32)
// elements for some N.
// The bitonic merge is implemented for arbitrary sizes;
// sorted list A of size N1 * kWarpSize registers
// sorted list B of size N2 * kWarpSize registers =>
// sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2
// are >= 1 and don't have to be powers of 2.
//
// -perform bitonic sorts on a set of N * kWarpSize key/value pairs
// held in registers, by using the above bitonic merge as a
// primitive.
// N can be an arbitrary N >= 1; i.e., the bitonic sort here supports
// odd sizes and doesn't require the input to be a power of 2.
//
// The sort or merge network is completely statically instantiated via
// template specialization / expansion and constexpr, and it uses warp
// shuffles to exchange values between warp lanes.
//
// A note about comparsions:
//
// For a sorting network of keys only, we only need one
// comparison (a < b). However, what we really need to know is
// if one lane chooses to exchange a value, then the
// corresponding lane should also do the exchange.
// Thus, if one just uses the negation !(x < y) in the higher
// lane, this will also include the case where (x == y). Thus, one
// lane in fact performs an exchange and the other doesn't, but
// because the only value being exchanged is equivalent, nothing has
// changed.
// So, you can get away with just one comparison and its negation.
//
// If we're sorting keys and values, where equivalent keys can
// exist, then this is a problem, since we want to treat (x, v1)
// as not equivalent to (x, v2).
//
// To remedy this, you can either compare with a lexicographic
// ordering (a.k < b.k || (a.k == b.k && a.v < b.v)), which since
// we're predicating all of the choices results in 3 comparisons
// being executed, or we can invert the selection so that there is no
// middle choice of equality; the other lane will likewise
// check that (b.k > a.k) (the higher lane has the values
// swapped). Then, the first lane swaps if and only if the
// second lane swaps; if both lanes have equivalent keys, no
// swap will be performed. This results in only two comparisons
// being executed.
//
// If you don't consider values as well, then this does not produce a
// consistent ordering among (k, v) pairs with equivalent keys but
// different values; for us, we don't really care about ordering or
// stability here.
//
// I have tried both re-arranging the order in the higher lane to get
// away with one comparison or adding the value to the check; both
// result in greater register consumption or lower speed than just
// perfoming both < and > comparisons with the variables, so I just
// stick with this.
// This function merges kWarpSize / 2L lists in parallel using warp
// shuffles.
// It works on at most size-16 lists, as we need 32 threads for this
// shuffle merge.
//
// If IsBitonic is false, the first stage is reversed, so we don't
// need to sort directionally. It's still technically a bitonic sort.
template <typename K, typename V, int L,
bool Dir, typename Comp, bool IsBitonic>
inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
static_assert(L <= kWarpSize / 2, "merge list size must be <= 16");
int laneId = getLaneId();
if (!IsBitonic) {
// Reverse the first comparison stage.
// For example, merging a list of size 8 has the exchanges:
// 0 <-> 15, 1 <-> 14, ...
K otherK = shfl_xor(k, 2 * L - 1);
V otherV = shfl_xor(v, 2 * L - 1);
// Whether we are the lesser thread in the exchange
bool small = !(laneId & L);
if (Dir) {
// See the comment above how performing both of these
// comparisons in the warp seems to win out over the
// alternatives in practice
bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
assign(s, k, otherK);
assign(s, v, otherV);
} else {
bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
assign(s, k, otherK);
assign(s, v, otherV);
}
}
#pragma unroll
for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
K otherK = shfl_xor(k, stride);
V otherV = shfl_xor(v, stride);
// Whether we are the lesser thread in the exchange
bool small = !(laneId & stride);
if (Dir) {
bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
assign(s, k, otherK);
assign(s, v, otherV);
} else {
bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
assign(s, k, otherK);
assign(s, v, otherV);
}
}
}
// Template for performing a bitonic merge of an arbitrary set of
// registers
template <typename K, typename V, int N,
bool Dir, typename Comp, bool Low, bool Pow2>
struct BitonicMergeStep {
};
//
// Power-of-2 merge specialization
//
// All merges eventually call this
template <typename K, typename V, bool Dir, typename Comp, bool Low>
struct BitonicMergeStep<K, V, 1, Dir, Comp, Low, true> {
static inline __device__ void merge(K k[1], V v[1]) {
// Use warp shuffles
warpBitonicMergeLE16<K, V, 16, Dir, Comp, true>(k[0], v[0]);
}
};
template <typename K, typename V, int N, bool Dir, typename Comp, bool Low>
struct BitonicMergeStep<K, V, N, Dir, Comp, Low, true> {
static inline __device__ void merge(K k[N], V v[N]) {
static_assert(utils::isPowerOf2(N), "must be power of 2");
static_assert(N > 1, "must be N > 1");
#pragma unroll
for (int i = 0; i < N / 2; ++i) {
K& ka = k[i];
V& va = v[i];
K& kb = k[i + N / 2];
V& vb = v[i + N / 2];
bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
swap(s, ka, kb);
swap(s, va, vb);
}
{
K newK[N / 2];
V newV[N / 2];
#pragma unroll
for (int i = 0; i < N / 2; ++i) {
newK[i] = k[i];
newV[i] = v[i];
}
BitonicMergeStep<K, V, N / 2, Dir, Comp, true, true>::merge(newK, newV);
#pragma unroll
for (int i = 0; i < N / 2; ++i) {
k[i] = newK[i];
v[i] = newV[i];
}
}
{
K newK[N / 2];
V newV[N / 2];
#pragma unroll
for (int i = 0; i < N / 2; ++i) {
newK[i] = k[i + N / 2];
newV[i] = v[i + N / 2];
}
BitonicMergeStep<K, V, N / 2, Dir, Comp, false, true>::merge(newK, newV);
#pragma unroll
for (int i = 0; i < N / 2; ++i) {
k[i + N / 2] = newK[i];
v[i + N / 2] = newV[i];
}
}
}
};
//
// Non-power-of-2 merge specialization
//
// Low recursion
template <typename K, typename V, int N, bool Dir, typename Comp>
struct BitonicMergeStep<K, V, N, Dir, Comp, true, false> {
static inline __device__ void merge(K k[N], V v[N]) {
static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
static_assert(N >= 3, "must be N >= 3");
constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
#pragma unroll
for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
K& ka = k[i];
V& va = v[i];
K& kb = k[i + kNextHighestPowerOf2 / 2];
V& vb = v[i + kNextHighestPowerOf2 / 2];
bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
swap(s, ka, kb);
swap(s, va, vb);
}
constexpr int kLowSize = N - kNextHighestPowerOf2 / 2;
constexpr int kHighSize = kNextHighestPowerOf2 / 2;
{
K newK[kLowSize];
V newV[kLowSize];
#pragma unroll
for (int i = 0; i < kLowSize; ++i) {
newK[i] = k[i];
newV[i] = v[i];
}
constexpr bool kLowIsPowerOf2 =
utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
// FIXME: compiler doesn't like this expression? compiler bug?
// constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
BitonicMergeStep<K, V, kLowSize, Dir, Comp,
true, // low
kLowIsPowerOf2>::merge(newK, newV);
#pragma unroll
for (int i = 0; i < kLowSize; ++i) {
k[i] = newK[i];
v[i] = newV[i];
}
}
{
K newK[kHighSize];
V newV[kHighSize];
#pragma unroll
for (int i = 0; i < kHighSize; ++i) {
newK[i] = k[i + kLowSize];
newV[i] = v[i + kLowSize];
}
constexpr bool kHighIsPowerOf2 =
utils::isPowerOf2(kNextHighestPowerOf2 / 2);
// FIXME: compiler doesn't like this expression? compiler bug?
// constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
BitonicMergeStep<K, V, kHighSize, Dir, Comp,
false, // high
kHighIsPowerOf2>::merge(newK, newV);
#pragma unroll
for (int i = 0; i < kHighSize; ++i) {
k[i + kLowSize] = newK[i];
v[i + kLowSize] = newV[i];
}
}
}
};
// High recursion
template <typename K, typename V, int N, bool Dir, typename Comp>
struct BitonicMergeStep<K, V, N, Dir, Comp, false, false> {
static inline __device__ void merge(K k[N], V v[N]) {
static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
static_assert(N >= 3, "must be N >= 3");
constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
#pragma unroll
for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
K& ka = k[i];
V& va = v[i];
K& kb = k[i + kNextHighestPowerOf2 / 2];
V& vb = v[i + kNextHighestPowerOf2 / 2];
bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
swap(s, ka, kb);
swap(s, va, vb);
}
constexpr int kLowSize = kNextHighestPowerOf2 / 2;
constexpr int kHighSize = N - kNextHighestPowerOf2 / 2;
{
K newK[kLowSize];
V newV[kLowSize];
#pragma unroll
for (int i = 0; i < kLowSize; ++i) {
newK[i] = k[i];
newV[i] = v[i];
}
constexpr bool kLowIsPowerOf2 =
utils::isPowerOf2(kNextHighestPowerOf2 / 2);
// FIXME: compiler doesn't like this expression? compiler bug?
// constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
BitonicMergeStep<K, V, kLowSize, Dir, Comp,
true, // low
kLowIsPowerOf2>::merge(newK, newV);
#pragma unroll
for (int i = 0; i < kLowSize; ++i) {
k[i] = newK[i];
v[i] = newV[i];
}
}
{
K newK[kHighSize];
V newV[kHighSize];
#pragma unroll
for (int i = 0; i < kHighSize; ++i) {
newK[i] = k[i + kLowSize];
newV[i] = v[i + kLowSize];
}
constexpr bool kHighIsPowerOf2 =
utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
// FIXME: compiler doesn't like this expression? compiler bug?
// constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
BitonicMergeStep<K, V, kHighSize, Dir, Comp,
false, // high
kHighIsPowerOf2>::merge(newK, newV);
#pragma unroll
for (int i = 0; i < kHighSize; ++i) {
k[i + kLowSize] = newK[i];
v[i + kLowSize] = newV[i];
}
}
}
};
/// Merges two sets of registers across the warp of any size;
/// i.e., merges a sorted k/v list of size kWarpSize * N1 with a
/// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any
/// value >= 1
template <typename K,
typename V,
int N1,
int N2,
bool Dir,
typename Comp,
bool FullMerge = true>
inline __device__ void warpMergeAnyRegisters(K k1[N1], V v1[N1],
K k2[N2], V v2[N2]) {
constexpr int kSmallestN = N1 < N2 ? N1 : N2;
#pragma unroll
for (int i = 0; i < kSmallestN; ++i) {
K& ka = k1[N1 - 1 - i];
V& va = v1[N1 - 1 - i];
K& kb = k2[i];
V& vb = v2[i];
K otherKa;
V otherVa;
if (FullMerge) {
// We need the other values
otherKa = shfl_xor(ka, kWarpSize - 1);
otherVa = shfl_xor(va, kWarpSize - 1);
}
K otherKb = shfl_xor(kb, kWarpSize - 1);
V otherVb = shfl_xor(vb, kWarpSize - 1);
// ka is always first in the list, so we needn't use our lane
// in this comparison
bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
assign(swapa, ka, otherKb);
assign(swapa, va, otherVb);
// kb is always second in the list, so we needn't use our lane
// in this comparison
if (FullMerge) {
bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
assign(swapb, kb, otherKa);
assign(swapb, vb, otherVa);
} else {
// We don't care about updating elements in the second list
}
}
BitonicMergeStep<K, V, N1, Dir, Comp,
true, utils::isPowerOf2(N1)>::merge(k1, v1);
if (FullMerge) {
// Only if we care about N2 do we need to bother merging it fully
BitonicMergeStep<K, V, N2, Dir, Comp,
false, utils::isPowerOf2(N2)>::merge(k2, v2);
}
}
// Recursive template that uses the above bitonic merge to perform a
// bitonic sort
template <typename K, typename V, int N, bool Dir, typename Comp>
struct BitonicSortStep {
static inline __device__ void sort(K k[N], V v[N]) {
static_assert(N > 1, "did not hit specialized case");
// Sort recursively
constexpr int kSizeA = N / 2;
constexpr int kSizeB = N - kSizeA;
K aK[kSizeA];
V aV[kSizeA];
#pragma unroll
for (int i = 0; i < kSizeA; ++i) {
aK[i] = k[i];
aV[i] = v[i];
}
BitonicSortStep<K, V, kSizeA, Dir, Comp>::sort(aK, aV);
K bK[kSizeB];
V bV[kSizeB];
#pragma unroll
for (int i = 0; i < kSizeB; ++i) {
bK[i] = k[i + kSizeA];
bV[i] = v[i + kSizeA];
}
BitonicSortStep<K, V, kSizeB, Dir, Comp>::sort(bK, bV);
// Merge halves
warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir, Comp>(aK, aV, bK, bV);
#pragma unroll
for (int i = 0; i < kSizeA; ++i) {
k[i] = aK[i];
v[i] = aV[i];
}
#pragma unroll
for (int i = 0; i < kSizeB; ++i) {
k[i + kSizeA] = bK[i];
v[i + kSizeA] = bV[i];
}
}
};
// Single warp (N == 1) sorting specialization
template <typename K, typename V, bool Dir, typename Comp>
struct BitonicSortStep<K, V, 1, Dir, Comp> {
static inline __device__ void sort(K k[1], V v[1]) {
// Update this code if this changes
// should go from 1 -> kWarpSize in multiples of 2
static_assert(kWarpSize == 32, "unexpected warp size");
warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]);
warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]);
warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]);
warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]);
warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]);
}
};
/// Sort a list of kWarpSize * N elements in registers, where N is an
/// arbitrary >= 1
template <typename K, typename V, int N, bool Dir, typename Comp>
inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) {
BitonicSortStep<K, V, N, Dir, Comp>::sort(k, v);
}
} } // namespace