11 #include "blockselect/BlockSelectImpl.cuh"
13 namespace faiss {
namespace gpu {
15 #ifdef FAISS_USE_FLOAT16
26 BLOCK_SELECT_DECL(half,
true, 1);
27 BLOCK_SELECT_DECL(half,
true, 32);
28 BLOCK_SELECT_DECL(half,
true, 64);
29 BLOCK_SELECT_DECL(half,
true, 128);
30 BLOCK_SELECT_DECL(half,
true, 256);
31 BLOCK_SELECT_DECL(half,
true, 512);
32 BLOCK_SELECT_DECL(half,
true, 1024);
34 BLOCK_SELECT_DECL(half,
false, 1);
35 BLOCK_SELECT_DECL(half,
false, 32);
36 BLOCK_SELECT_DECL(half,
false, 64);
37 BLOCK_SELECT_DECL(half,
false, 128);
38 BLOCK_SELECT_DECL(half,
false, 256);
39 BLOCK_SELECT_DECL(half,
false, 512);
40 BLOCK_SELECT_DECL(half,
false, 1024);
42 void runBlockSelect(Tensor<half, 2, true>& in,
43 Tensor<half, 2, true>& outK,
44 Tensor<int, 2, true>& outV,
45 bool dir,
int k, cudaStream_t stream) {
46 FAISS_ASSERT(k <= 1024);
50 BLOCK_SELECT_CALL(half,
true, 1);
52 BLOCK_SELECT_CALL(half,
true, 32);
54 BLOCK_SELECT_CALL(half,
true, 64);
55 }
else if (k <= 128) {
56 BLOCK_SELECT_CALL(half,
true, 128);
57 }
else if (k <= 256) {
58 BLOCK_SELECT_CALL(half,
true, 256);
59 }
else if (k <= 512) {
60 BLOCK_SELECT_CALL(half,
true, 512);
61 }
else if (k <= 1024) {
62 BLOCK_SELECT_CALL(half,
true, 1024);
66 BLOCK_SELECT_CALL(half,
false, 1);
68 BLOCK_SELECT_CALL(half,
false, 32);
70 BLOCK_SELECT_CALL(half,
false, 64);
71 }
else if (k <= 128) {
72 BLOCK_SELECT_CALL(half,
false, 128);
73 }
else if (k <= 256) {
74 BLOCK_SELECT_CALL(half,
false, 256);
75 }
else if (k <= 512) {
76 BLOCK_SELECT_CALL(half,
false, 512);
77 }
else if (k <= 1024) {
78 BLOCK_SELECT_CALL(half,
false, 1024);