10 #include "blockselect/BlockSelectImpl.cuh"
12 namespace faiss {
namespace gpu {
23 BLOCK_SELECT_DECL(
float,
true, 1);
24 BLOCK_SELECT_DECL(
float,
true, 32);
25 BLOCK_SELECT_DECL(
float,
true, 64);
26 BLOCK_SELECT_DECL(
float,
true, 128);
27 BLOCK_SELECT_DECL(
float,
true, 256);
28 BLOCK_SELECT_DECL(
float,
true, 512);
29 BLOCK_SELECT_DECL(
float,
true, 1024);
31 BLOCK_SELECT_DECL(
float,
false, 1);
32 BLOCK_SELECT_DECL(
float,
false, 32);
33 BLOCK_SELECT_DECL(
float,
false, 64);
34 BLOCK_SELECT_DECL(
float,
false, 128);
35 BLOCK_SELECT_DECL(
float,
false, 256);
36 BLOCK_SELECT_DECL(
float,
false, 512);
37 BLOCK_SELECT_DECL(
float,
false, 1024);
39 void runBlockSelect(Tensor<float, 2, true>& in,
40 Tensor<float, 2, true>& outK,
41 Tensor<int, 2, true>& outV,
42 bool dir,
int k, cudaStream_t stream) {
43 FAISS_ASSERT(k <= 1024);
47 BLOCK_SELECT_CALL(
float,
true, 1);
49 BLOCK_SELECT_CALL(
float,
true, 32);
51 BLOCK_SELECT_CALL(
float,
true, 64);
52 }
else if (k <= 128) {
53 BLOCK_SELECT_CALL(
float,
true, 128);
54 }
else if (k <= 256) {
55 BLOCK_SELECT_CALL(
float,
true, 256);
56 }
else if (k <= 512) {
57 BLOCK_SELECT_CALL(
float,
true, 512);
58 }
else if (k <= 1024) {
59 BLOCK_SELECT_CALL(
float,
true, 1024);
63 BLOCK_SELECT_CALL(
float,
false, 1);
65 BLOCK_SELECT_CALL(
float,
false, 32);
67 BLOCK_SELECT_CALL(
float,
false, 64);
68 }
else if (k <= 128) {
69 BLOCK_SELECT_CALL(
float,
false, 128);
70 }
else if (k <= 256) {
71 BLOCK_SELECT_CALL(
float,
false, 256);
72 }
else if (k <= 512) {
73 BLOCK_SELECT_CALL(
float,
false, 512);
74 }
else if (k <= 1024) {
75 BLOCK_SELECT_CALL(
float,
false, 1024);
80 void runBlockSelectPair(Tensor<float, 2, true>& inK,
81 Tensor<int, 2, true>& inV,
82 Tensor<float, 2, true>& outK,
83 Tensor<int, 2, true>& outV,
84 bool dir,
int k, cudaStream_t stream) {
85 FAISS_ASSERT(k <= 1024);
89 BLOCK_SELECT_PAIR_CALL(
float,
true, 1);
91 BLOCK_SELECT_PAIR_CALL(
float,
true, 32);
93 BLOCK_SELECT_PAIR_CALL(
float,
true, 64);
94 }
else if (k <= 128) {
95 BLOCK_SELECT_PAIR_CALL(
float,
true, 128);
96 }
else if (k <= 256) {
97 BLOCK_SELECT_PAIR_CALL(
float,
true, 256);
98 }
else if (k <= 512) {
99 BLOCK_SELECT_PAIR_CALL(
float,
true, 512);
100 }
else if (k <= 1024) {
101 BLOCK_SELECT_PAIR_CALL(
float,
true, 1024);
105 BLOCK_SELECT_PAIR_CALL(
float,
false, 1);
106 }
else if (k <= 32) {
107 BLOCK_SELECT_PAIR_CALL(
float,
false, 32);
108 }
else if (k <= 64) {
109 BLOCK_SELECT_PAIR_CALL(
float,
false, 64);
110 }
else if (k <= 128) {
111 BLOCK_SELECT_PAIR_CALL(
float,
false, 128);
112 }
else if (k <= 256) {
113 BLOCK_SELECT_PAIR_CALL(
float,
false, 256);
114 }
else if (k <= 512) {
115 BLOCK_SELECT_PAIR_CALL(
float,
false, 512);
116 }
else if (k <= 1024) {
117 BLOCK_SELECT_PAIR_CALL(
float,
false, 1024);