11 #include "Float16.cuh"
14 namespace faiss {
namespace gpu {
22 __global__
void blockSelect(Tensor<K, 2, true> in,
23 Tensor<K, 2, true> outK,
24 Tensor<IndexType, 2, true> outV,
28 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
30 __shared__ K smemK[kNumWarps * NumWarpQ];
31 __shared__ IndexType smemV[kNumWarps * NumWarpQ];
33 BlockSelect<K, IndexType, Dir, Comparator<K>,
34 NumWarpQ, NumThreadQ, ThreadsPerBlock>
35 heap(initK, initV, smemK, smemV, k);
41 K* inStart = in[row][i].data();
44 int limit = utils::roundDown(in.getSize(1), kWarpSize);
46 for (; i < limit; i += ThreadsPerBlock) {
47 heap.add(*inStart, (IndexType) i);
48 inStart += ThreadsPerBlock;
52 if (i < in.getSize(1)) {
53 heap.addThreadQ(*inStart, (IndexType) i);
58 for (
int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
59 outK[row][i] = smemK[i];
60 outV[row][i] = smemV[i];
70 __global__
void blockSelectPair(Tensor<K, 2, true> inK,
71 Tensor<IndexType, 2, true> inV,
72 Tensor<K, 2, true> outK,
73 Tensor<IndexType, 2, true> outV,
77 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
79 __shared__ K smemK[kNumWarps * NumWarpQ];
80 __shared__ IndexType smemV[kNumWarps * NumWarpQ];
82 BlockSelect<K, IndexType, Dir, Comparator<K>,
83 NumWarpQ, NumThreadQ, ThreadsPerBlock>
84 heap(initK, initV, smemK, smemV, k);
90 K* inKStart = inK[row][i].data();
91 IndexType* inVStart = inV[row][i].data();
94 int limit = utils::roundDown(inK.getSize(1), kWarpSize);
96 for (; i < limit; i += ThreadsPerBlock) {
97 heap.add(*inKStart, *inVStart);
98 inKStart += ThreadsPerBlock;
99 inVStart += ThreadsPerBlock;
103 if (i < inK.getSize(1)) {
104 heap.addThreadQ(*inKStart, *inVStart);
109 for (
int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
110 outK[row][i] = smemK[i];
111 outV[row][i] = smemV[i];
115 void runBlockSelect(Tensor<float, 2, true>& in,
116 Tensor<float, 2, true>& outKeys,
117 Tensor<int, 2, true>& outIndices,
118 bool dir,
int k, cudaStream_t stream);
120 void runBlockSelectPair(Tensor<float, 2, true>& inKeys,
121 Tensor<int, 2, true>& inIndices,
122 Tensor<float, 2, true>& outKeys,
123 Tensor<int, 2, true>& outIndices,
124 bool dir,
int k, cudaStream_t stream);
126 #ifdef FAISS_USE_FLOAT16
127 void runBlockSelect(Tensor<half, 2, true>& in,
128 Tensor<half, 2, true>& outKeys,
129 Tensor<int, 2, true>& outIndices,
130 bool dir,
int k, cudaStream_t stream);
132 void runBlockSelectPair(Tensor<half, 2, true>& inKeys,
133 Tensor<int, 2, true>& inIndices,
134 Tensor<half, 2, true>& outKeys,
135 Tensor<int, 2, true>& outIndices,
136 bool dir,
int k, cudaStream_t stream);