10 #include "Float16.cuh"
13 namespace faiss {
namespace gpu {
21 __global__
void blockSelect(Tensor<K, 2, true> in,
22 Tensor<K, 2, true> outK,
23 Tensor<IndexType, 2, true> outV,
27 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
29 __shared__ K smemK[kNumWarps * NumWarpQ];
30 __shared__ IndexType smemV[kNumWarps * NumWarpQ];
32 BlockSelect<K, IndexType, Dir, Comparator<K>,
33 NumWarpQ, NumThreadQ, ThreadsPerBlock>
34 heap(initK, initV, smemK, smemV, k);
40 K* inStart = in[row][i].data();
43 int limit = utils::roundDown(in.getSize(1), kWarpSize);
45 for (; i < limit; i += ThreadsPerBlock) {
46 heap.add(*inStart, (IndexType) i);
47 inStart += ThreadsPerBlock;
51 if (i < in.getSize(1)) {
52 heap.addThreadQ(*inStart, (IndexType) i);
57 for (
int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
58 outK[row][i] = smemK[i];
59 outV[row][i] = smemV[i];
69 __global__
void blockSelectPair(Tensor<K, 2, true> inK,
70 Tensor<IndexType, 2, true> inV,
71 Tensor<K, 2, true> outK,
72 Tensor<IndexType, 2, true> outV,
76 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
78 __shared__ K smemK[kNumWarps * NumWarpQ];
79 __shared__ IndexType smemV[kNumWarps * NumWarpQ];
81 BlockSelect<K, IndexType, Dir, Comparator<K>,
82 NumWarpQ, NumThreadQ, ThreadsPerBlock>
83 heap(initK, initV, smemK, smemV, k);
89 K* inKStart = inK[row][i].data();
90 IndexType* inVStart = inV[row][i].data();
93 int limit = utils::roundDown(inK.getSize(1), kWarpSize);
95 for (; i < limit; i += ThreadsPerBlock) {
96 heap.add(*inKStart, *inVStart);
97 inKStart += ThreadsPerBlock;
98 inVStart += ThreadsPerBlock;
102 if (i < inK.getSize(1)) {
103 heap.addThreadQ(*inKStart, *inVStart);
108 for (
int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
109 outK[row][i] = smemK[i];
110 outV[row][i] = smemV[i];
114 void runBlockSelect(Tensor<float, 2, true>& in,
115 Tensor<float, 2, true>& outKeys,
116 Tensor<int, 2, true>& outIndices,
117 bool dir,
int k, cudaStream_t stream);
119 void runBlockSelectPair(Tensor<float, 2, true>& inKeys,
120 Tensor<int, 2, true>& inIndices,
121 Tensor<float, 2, true>& outKeys,
122 Tensor<int, 2, true>& outIndices,
123 bool dir,
int k, cudaStream_t stream);
125 #ifdef FAISS_USE_FLOAT16
126 void runBlockSelect(Tensor<half, 2, true>& in,
127 Tensor<half, 2, true>& outKeys,
128 Tensor<int, 2, true>& outIndices,
129 bool dir,
int k, cudaStream_t stream);
131 void runBlockSelectPair(Tensor<half, 2, true>& inKeys,
132 Tensor<int, 2, true>& inIndices,
133 Tensor<half, 2, true>& outKeys,
134 Tensor<int, 2, true>& outIndices,
135 bool dir,
int k, cudaStream_t stream);