12 #include "Float16.cuh"
15 namespace faiss {
namespace gpu {
23 __global__
void blockSelect(Tensor<K, 2, true> in,
24 Tensor<K, 2, true> outK,
25 Tensor<IndexType, 2, true> outV,
29 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
31 __shared__ K smemK[kNumWarps * NumWarpQ];
32 __shared__ IndexType smemV[kNumWarps * NumWarpQ];
34 BlockSelect<K, IndexType, Dir, Comparator<K>,
35 NumWarpQ, NumThreadQ, ThreadsPerBlock>
36 heap(initK, initV, smemK, smemV, k);
42 K* inStart = in[row][i].data();
45 int limit = utils::roundDown(in.getSize(1), kWarpSize);
47 for (; i < limit; i += ThreadsPerBlock) {
48 heap.add(*inStart, (IndexType) i);
49 inStart += ThreadsPerBlock;
53 if (i < in.getSize(1)) {
54 heap.addThreadQ(*inStart, (IndexType) i);
59 for (
int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
60 outK[row][i] = smemK[i];
61 outV[row][i] = smemV[i];
71 __global__
void blockSelectPair(Tensor<K, 2, true> inK,
72 Tensor<IndexType, 2, true> inV,
73 Tensor<K, 2, true> outK,
74 Tensor<IndexType, 2, true> outV,
78 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
80 __shared__ K smemK[kNumWarps * NumWarpQ];
81 __shared__ IndexType smemV[kNumWarps * NumWarpQ];
83 BlockSelect<K, IndexType, Dir, Comparator<K>,
84 NumWarpQ, NumThreadQ, ThreadsPerBlock>
85 heap(initK, initV, smemK, smemV, k);
91 K* inKStart = inK[row][i].data();
92 IndexType* inVStart = inV[row][i].data();
95 int limit = utils::roundDown(inK.getSize(1), kWarpSize);
97 for (; i < limit; i += ThreadsPerBlock) {
98 heap.add(*inKStart, *inVStart);
99 inKStart += ThreadsPerBlock;
100 inVStart += ThreadsPerBlock;
104 if (i < inK.getSize(1)) {
105 heap.addThreadQ(*inKStart, *inVStart);
110 for (
int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
111 outK[row][i] = smemK[i];
112 outV[row][i] = smemV[i];
116 void runBlockSelect(Tensor<float, 2, true>& in,
117 Tensor<float, 2, true>& outKeys,
118 Tensor<int, 2, true>& outIndices,
119 bool dir,
int k, cudaStream_t stream);
121 void runBlockSelectPair(Tensor<float, 2, true>& inKeys,
122 Tensor<int, 2, true>& inIndices,
123 Tensor<float, 2, true>& outKeys,
124 Tensor<int, 2, true>& outIndices,
125 bool dir,
int k, cudaStream_t stream);
127 #ifdef FAISS_USE_FLOAT16
128 void runBlockSelect(Tensor<half, 2, true>& in,
129 Tensor<half, 2, true>& outKeys,
130 Tensor<int, 2, true>& outIndices,
131 bool dir,
int k, cudaStream_t stream);
133 void runBlockSelectPair(Tensor<half, 2, true>& inKeys,
134 Tensor<int, 2, true>& inIndices,
135 Tensor<half, 2, true>& outKeys,
136 Tensor<int, 2, true>& outIndices,
137 bool dir,
int k, cudaStream_t stream);