12 #include "Float16.cuh"
15 namespace faiss {
namespace gpu {
23 __global__
void blockSelect(Tensor<K, 2, true> in,
24 Tensor<K, 2, true> outK,
25 Tensor<IndexType, 2, true> outV,
29 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
31 __shared__ K smemK[kNumWarps * NumWarpQ];
32 __shared__ IndexType smemV[kNumWarps * NumWarpQ];
34 BlockSelect<K, IndexType, Dir, Comparator<K>,
35 NumWarpQ, NumThreadQ, ThreadsPerBlock>
36 heap(initK, initV, smemK, smemV, k);
42 K* inStart = in[row][i].data();
45 int limit = utils::roundDown(in.getSize(1), kWarpSize);
47 for (; i < limit; i += ThreadsPerBlock) {
48 heap.add(*inStart, (IndexType) i);
49 inStart += ThreadsPerBlock;
53 if (i < in.getSize(1)) {
54 heap.addThreadQ(*inStart, (IndexType) i);
59 for (
int i = threadIdx.x; i < k; i += ThreadsPerBlock) {
60 outK[row][i] = smemK[i];
61 outV[row][i] = smemV[i];
65 void runBlockSelect(Tensor<float, 2, true>& in,
66 Tensor<float, 2, true>& outKeys,
67 Tensor<int, 2, true>& outIndices,
68 bool dir,
int k, cudaStream_t stream);
70 #ifdef FAISS_USE_FLOAT16
71 void runBlockSelect(Tensor<half, 2, true>& in,
72 Tensor<half, 2, true>& outKeys,
73 Tensor<int, 2, true>& outIndices,
74 bool dir,
int k, cudaStream_t stream);