faiss/gpu/utils/MergeNetworkBlock.cuh

235 lines
6.3 KiB
Plaintext

/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the CC-by-NC license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include "DeviceDefs.cuh"
#include "MergeNetworkUtils.cuh"
#include "PtxUtils.cuh"
#include "StaticUtils.h"
#include "WarpShuffles.cuh"
#include "../../FaissAssert.h"
#include <cuda.h>
namespace faiss { namespace gpu {
// Merge pairs of lists smaller than blockDim.x (NumThreads)
template <int NumThreads,
typename K,
typename V,
int L,
bool Dir,
typename Comp,
bool FullMerge>
inline __device__ void blockMergeSmall(K* listK, V* listV) {
static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
static_assert(utils::isPowerOf2(NumThreads),
"NumThreads must be a power-of-2");
static_assert(L <= NumThreads, "merge list size must be <= NumThreads");
// Which pair of lists we are merging
int mergeId = threadIdx.x / L;
// Which thread we are within the merge
int tid = threadIdx.x % L;
// listK points to a region of size N * 2 * L
listK += 2 * L * mergeId;
listV += 2 * L * mergeId;
// It's not a bitonic merge, both lists are in the same direction,
// so handle the first swap assuming the second list is reversed
int pos = L - 1 - tid;
int stride = 2 * tid + 1;
K& ka = listK[pos];
K& kb = listK[pos + stride];
bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
swap(s, ka, kb);
V& va = listV[pos];
V& vb = listV[pos + stride];
swap(s, va, vb);
__syncthreads();
#pragma unroll
for (int stride = L / 2; stride > 0; stride /= 2) {
int pos = 2 * tid - (tid & (stride - 1));
K& ka = listK[pos];
K& kb = listK[pos + stride];
bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
swap(s, ka, kb);
V& va = listV[pos];
V& vb = listV[pos + stride];
swap(s, va, vb);
__syncthreads();
}
}
// Merge pairs of sorted lists larger than blockDim.x (NumThreads)
template <int NumThreads,
typename K,
typename V,
int L,
bool Dir,
typename Comp,
bool FullMerge>
inline __device__ void blockMergeLarge(K* listK, V* listV) {
static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
static_assert(L >= kWarpSize, "merge list size must be >= 32");
static_assert(utils::isPowerOf2(NumThreads),
"NumThreads must be a power-of-2");
static_assert(L >= NumThreads, "merge list size must be >= NumThreads");
// For L > NumThreads, each thread has to perform more work
// per each stride.
constexpr int kLoopPerThread = L / NumThreads;
// It's not a bitonic merge, both lists are in the same direction,
// so handle the first swap assuming the second list is reversed
#pragma unroll
for (int loop = 0; loop < kLoopPerThread; ++loop) {
int tid = loop * NumThreads + threadIdx.x;
int pos = L - 1 - tid;
int stride = 2 * tid + 1;
K& ka = listK[pos];
K& kb = listK[pos + stride];
bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
swap(s, ka, kb);
V& va = listV[pos];
V& vb = listV[pos + stride];
swap(s, va, vb);
}
__syncthreads();
constexpr int kSecondLoopPerThread =
FullMerge ? kLoopPerThread : kLoopPerThread / 2;
#pragma unroll
for (int stride = L / 2; stride > 0; stride /= 2) {
#pragma unroll
for (int loop = 0; loop < kSecondLoopPerThread; ++loop) {
int tid = loop * NumThreads + threadIdx.x;
int pos = 2 * tid - (tid & (stride - 1));
K& ka = listK[pos];
K& kb = listK[pos + stride];
bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
swap(s, ka, kb);
V& va = listV[pos];
V& vb = listV[pos + stride];
swap(s, va, vb);
}
__syncthreads();
}
}
/// Class template to prevent static_assert from firing for
/// mixing smaller/larger than block cases
template <int NumThreads,
typename K,
typename V,
int N,
int L,
bool Dir,
typename Comp,
bool SmallerThanBlock,
bool FullMerge>
struct BlockMerge {
};
/// Merging lists smaller than a block
template <int NumThreads,
typename K,
typename V,
int N,
int L,
bool Dir,
typename Comp,
bool FullMerge>
struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true, FullMerge> {
static inline __device__ void merge(K* listK, V* listV) {
constexpr int kNumParallelMerges = NumThreads / L;
constexpr int kNumIterations = N / kNumParallelMerges;
static_assert(L <= NumThreads, "list must be <= NumThreads");
static_assert((N < kNumParallelMerges) ||
(kNumIterations * kNumParallelMerges == N),
"improper selection of N and L");
if (N < kNumParallelMerges) {
// We only need L threads per each list to perform the merge
if (threadIdx.x < N * L) {
blockMergeSmall<NumThreads, K, V, L, Dir, Comp, FullMerge>(
listK, listV);
}
} else {
// All threads participate
#pragma unroll
for (int i = 0; i < kNumIterations; ++i) {
int start = i * kNumParallelMerges * 2 * L;
blockMergeSmall<NumThreads, K, V, L, Dir, Comp, FullMerge>(
listK + start, listV + start);
}
}
}
};
/// Merging lists larger than a block
template <int NumThreads,
typename K,
typename V,
int N,
int L,
bool Dir,
typename Comp,
bool FullMerge>
struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false, FullMerge> {
static inline __device__ void merge(K* listK, V* listV) {
// Each pair of lists is merged sequentially
#pragma unroll
for (int i = 0; i < N; ++i) {
int start = i * 2 * L;
blockMergeLarge<NumThreads, K, V, L, Dir, Comp, FullMerge>(
listK + start, listV + start);
}
}
};
template <int NumThreads,
typename K,
typename V,
int N,
int L,
bool Dir,
typename Comp,
bool FullMerge = true>
inline __device__ void blockMerge(K* listK, V* listV) {
constexpr bool kSmallerThanBlock = (L <= NumThreads);
BlockMerge<NumThreads, K, V, N, L, Dir, Comp, kSmallerThanBlock, FullMerge>::
merge(listK, listV);
}
} } // namespace