9 #include "blockselect/BlockSelectImpl.cuh"
11 namespace faiss {
namespace gpu {
22 BLOCK_SELECT_DECL(
float,
true, 1);
23 BLOCK_SELECT_DECL(
float,
true, 32);
24 BLOCK_SELECT_DECL(
float,
true, 64);
25 BLOCK_SELECT_DECL(
float,
true, 128);
26 BLOCK_SELECT_DECL(
float,
true, 256);
27 BLOCK_SELECT_DECL(
float,
true, 512);
28 BLOCK_SELECT_DECL(
float,
true, 1024);
30 BLOCK_SELECT_DECL(
float,
false, 1);
31 BLOCK_SELECT_DECL(
float,
false, 32);
32 BLOCK_SELECT_DECL(
float,
false, 64);
33 BLOCK_SELECT_DECL(
float,
false, 128);
34 BLOCK_SELECT_DECL(
float,
false, 256);
35 BLOCK_SELECT_DECL(
float,
false, 512);
36 BLOCK_SELECT_DECL(
float,
false, 1024);
38 void runBlockSelect(Tensor<float, 2, true>& in,
39 Tensor<float, 2, true>& outK,
40 Tensor<int, 2, true>& outV,
41 bool dir,
int k, cudaStream_t stream) {
42 FAISS_ASSERT(k <= 1024);
46 BLOCK_SELECT_CALL(
float,
true, 1);
48 BLOCK_SELECT_CALL(
float,
true, 32);
50 BLOCK_SELECT_CALL(
float,
true, 64);
51 }
else if (k <= 128) {
52 BLOCK_SELECT_CALL(
float,
true, 128);
53 }
else if (k <= 256) {
54 BLOCK_SELECT_CALL(
float,
true, 256);
55 }
else if (k <= 512) {
56 BLOCK_SELECT_CALL(
float,
true, 512);
57 }
else if (k <= 1024) {
58 BLOCK_SELECT_CALL(
float,
true, 1024);
62 BLOCK_SELECT_CALL(
float,
false, 1);
64 BLOCK_SELECT_CALL(
float,
false, 32);
66 BLOCK_SELECT_CALL(
float,
false, 64);
67 }
else if (k <= 128) {
68 BLOCK_SELECT_CALL(
float,
false, 128);
69 }
else if (k <= 256) {
70 BLOCK_SELECT_CALL(
float,
false, 256);
71 }
else if (k <= 512) {
72 BLOCK_SELECT_CALL(
float,
false, 512);
73 }
else if (k <= 1024) {
74 BLOCK_SELECT_CALL(
float,
false, 1024);
79 void runBlockSelectPair(Tensor<float, 2, true>& inK,
80 Tensor<int, 2, true>& inV,
81 Tensor<float, 2, true>& outK,
82 Tensor<int, 2, true>& outV,
83 bool dir,
int k, cudaStream_t stream) {
84 FAISS_ASSERT(k <= 1024);
88 BLOCK_SELECT_PAIR_CALL(
float,
true, 1);
90 BLOCK_SELECT_PAIR_CALL(
float,
true, 32);
92 BLOCK_SELECT_PAIR_CALL(
float,
true, 64);
93 }
else if (k <= 128) {
94 BLOCK_SELECT_PAIR_CALL(
float,
true, 128);
95 }
else if (k <= 256) {
96 BLOCK_SELECT_PAIR_CALL(
float,
true, 256);
97 }
else if (k <= 512) {
98 BLOCK_SELECT_PAIR_CALL(
float,
true, 512);
99 }
else if (k <= 1024) {
100 BLOCK_SELECT_PAIR_CALL(
float,
true, 1024);
104 BLOCK_SELECT_PAIR_CALL(
float,
false, 1);
105 }
else if (k <= 32) {
106 BLOCK_SELECT_PAIR_CALL(
float,
false, 32);
107 }
else if (k <= 64) {
108 BLOCK_SELECT_PAIR_CALL(
float,
false, 64);
109 }
else if (k <= 128) {
110 BLOCK_SELECT_PAIR_CALL(
float,
false, 128);
111 }
else if (k <= 256) {
112 BLOCK_SELECT_PAIR_CALL(
float,
false, 256);
113 }
else if (k <= 512) {
114 BLOCK_SELECT_PAIR_CALL(
float,
false, 512);
115 }
else if (k <= 1024) {
116 BLOCK_SELECT_PAIR_CALL(
float,
false, 1024);