10 #include "Float16.cuh"
13 namespace faiss {
namespace gpu {
21 __global__
void warpSelect(Tensor<K, 2, true> in,
22 Tensor<K, 2, true> outK,
23 Tensor<IndexType, 2, true> outV,
27 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
29 WarpSelect<K, IndexType, Dir, Comparator<K>,
30 NumWarpQ, NumThreadQ, ThreadsPerBlock>
31 heap(initK, initV, k);
33 int warpId = threadIdx.x / kWarpSize;
34 int row = blockIdx.x * kNumWarps + warpId;
36 if (row >= in.getSize(0)) {
41 K* inStart = in[row][i].data();
44 int limit = utils::roundDown(in.getSize(1), kWarpSize);
46 for (; i < limit; i += kWarpSize) {
47 heap.add(*inStart, (IndexType) i);
52 if (i < in.getSize(1)) {
53 heap.addThreadQ(*inStart, (IndexType) i);
57 heap.writeOut(outK[row].data(),
61 void runWarpSelect(Tensor<float, 2, true>& in,
62 Tensor<float, 2, true>& outKeys,
63 Tensor<int, 2, true>& outIndices,
64 bool dir,
int k, cudaStream_t stream);
66 #ifdef FAISS_USE_FLOAT16
67 void runWarpSelect(Tensor<half, 2, true>& in,
68 Tensor<half, 2, true>& outKeys,
69 Tensor<int, 2, true>& outIndices,
70 bool dir,
int k, cudaStream_t stream);