9 #include "blockselect/BlockSelectImpl.cuh"
11 namespace faiss {
namespace gpu {
13 #ifdef FAISS_USE_FLOAT16
24 BLOCK_SELECT_DECL(half,
true, 1);
25 BLOCK_SELECT_DECL(half,
true, 32);
26 BLOCK_SELECT_DECL(half,
true, 64);
27 BLOCK_SELECT_DECL(half,
true, 128);
28 BLOCK_SELECT_DECL(half,
true, 256);
29 BLOCK_SELECT_DECL(half,
true, 512);
30 BLOCK_SELECT_DECL(half,
true, 1024);
32 BLOCK_SELECT_DECL(half,
false, 1);
33 BLOCK_SELECT_DECL(half,
false, 32);
34 BLOCK_SELECT_DECL(half,
false, 64);
35 BLOCK_SELECT_DECL(half,
false, 128);
36 BLOCK_SELECT_DECL(half,
false, 256);
37 BLOCK_SELECT_DECL(half,
false, 512);
38 BLOCK_SELECT_DECL(half,
false, 1024);
40 void runBlockSelect(Tensor<half, 2, true>& in,
41 Tensor<half, 2, true>& outK,
42 Tensor<int, 2, true>& outV,
43 bool dir,
int k, cudaStream_t stream) {
44 FAISS_ASSERT(k <= 1024);
48 BLOCK_SELECT_CALL(half,
true, 1);
50 BLOCK_SELECT_CALL(half,
true, 32);
52 BLOCK_SELECT_CALL(half,
true, 64);
53 }
else if (k <= 128) {
54 BLOCK_SELECT_CALL(half,
true, 128);
55 }
else if (k <= 256) {
56 BLOCK_SELECT_CALL(half,
true, 256);
57 }
else if (k <= 512) {
58 BLOCK_SELECT_CALL(half,
true, 512);
59 }
else if (k <= 1024) {
60 BLOCK_SELECT_CALL(half,
true, 1024);
64 BLOCK_SELECT_CALL(half,
false, 1);
66 BLOCK_SELECT_CALL(half,
false, 32);
68 BLOCK_SELECT_CALL(half,
false, 64);
69 }
else if (k <= 128) {
70 BLOCK_SELECT_CALL(half,
false, 128);
71 }
else if (k <= 256) {
72 BLOCK_SELECT_CALL(half,
false, 256);
73 }
else if (k <= 512) {
74 BLOCK_SELECT_CALL(half,
false, 512);
75 }
else if (k <= 1024) {
76 BLOCK_SELECT_CALL(half,
false, 1024);
81 void runBlockSelectPair(Tensor<half, 2, true>& inK,
82 Tensor<int, 2, true>& inV,
83 Tensor<half, 2, true>& outK,
84 Tensor<int, 2, true>& outV,
85 bool dir,
int k, cudaStream_t stream) {
86 FAISS_ASSERT(k <= 1024);
90 BLOCK_SELECT_PAIR_CALL(half,
true, 1);
92 BLOCK_SELECT_PAIR_CALL(half,
true, 32);
94 BLOCK_SELECT_PAIR_CALL(half,
true, 64);
95 }
else if (k <= 128) {
96 BLOCK_SELECT_PAIR_CALL(half,
true, 128);
97 }
else if (k <= 256) {
98 BLOCK_SELECT_PAIR_CALL(half,
true, 256);
99 }
else if (k <= 512) {
100 BLOCK_SELECT_PAIR_CALL(half,
true, 512);
101 }
else if (k <= 1024) {
102 BLOCK_SELECT_PAIR_CALL(half,
true, 1024);
106 BLOCK_SELECT_PAIR_CALL(half,
false, 1);
107 }
else if (k <= 32) {
108 BLOCK_SELECT_PAIR_CALL(half,
false, 32);
109 }
else if (k <= 64) {
110 BLOCK_SELECT_PAIR_CALL(half,
false, 64);
111 }
else if (k <= 128) {
112 BLOCK_SELECT_PAIR_CALL(half,
false, 128);
113 }
else if (k <= 256) {
114 BLOCK_SELECT_PAIR_CALL(half,
false, 256);
115 }
else if (k <= 512) {
116 BLOCK_SELECT_PAIR_CALL(half,
false, 512);
117 }
else if (k <= 1024) {
118 BLOCK_SELECT_PAIR_CALL(half,
false, 1024);