13 #include "Float16.cuh"
16 namespace faiss {
namespace gpu {
24 __global__
void warpSelect(Tensor<K, 2, true> in,
25 Tensor<K, 2, true> outK,
26 Tensor<IndexType, 2, true> outV,
30 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
32 WarpSelect<K, IndexType, Dir, Comparator<K>,
33 NumWarpQ, NumThreadQ, ThreadsPerBlock>
34 heap(initK, initV, k);
36 int warpId = threadIdx.x / kWarpSize;
37 int row = blockIdx.x * kNumWarps + warpId;
39 if (row >= in.getSize(0)) {
43 K* inStart = in[row].data();
47 int limit = utils::roundDown(in.getSize(1), kWarpSize);
49 for (; i < limit; i += kWarpSize) {
50 heap.add(inStart[i], (IndexType) i);
54 if (i < in.getSize(1)) {
55 heap.addThreadQ(inStart[i], (IndexType) i);
59 heap.writeOut(outK[row].data(),
63 void runWarpSelect(Tensor<float, 2, true>& in,
64 Tensor<float, 2, true>& outKeys,
65 Tensor<int, 2, true>& outIndices,
66 bool dir,
int k, cudaStream_t stream);
68 #ifdef FAISS_USE_FLOAT16
69 void runWarpSelect(Tensor<half, 2, true>& in,
70 Tensor<half, 2, true>& outKeys,
71 Tensor<int, 2, true>& outIndices,
72 bool dir,
int k, cudaStream_t stream);