11 #include "blockselect/BlockSelectImpl.cuh"
13 namespace faiss {
namespace gpu {
24 BLOCK_SELECT_DECL(
float,
true, 1);
25 BLOCK_SELECT_DECL(
float,
true, 32);
26 BLOCK_SELECT_DECL(
float,
true, 64);
27 BLOCK_SELECT_DECL(
float,
true, 128);
28 BLOCK_SELECT_DECL(
float,
true, 256);
29 BLOCK_SELECT_DECL(
float,
true, 512);
30 BLOCK_SELECT_DECL(
float,
true, 1024);
32 BLOCK_SELECT_DECL(
float,
false, 1);
33 BLOCK_SELECT_DECL(
float,
false, 32);
34 BLOCK_SELECT_DECL(
float,
false, 64);
35 BLOCK_SELECT_DECL(
float,
false, 128);
36 BLOCK_SELECT_DECL(
float,
false, 256);
37 BLOCK_SELECT_DECL(
float,
false, 512);
38 BLOCK_SELECT_DECL(
float,
false, 1024);
40 void runBlockSelect(Tensor<float, 2, true>& in,
41 Tensor<float, 2, true>& outK,
42 Tensor<int, 2, true>& outV,
43 bool dir,
int k, cudaStream_t stream) {
44 FAISS_ASSERT(k <= 1024);
48 BLOCK_SELECT_CALL(
float,
true, 1);
50 BLOCK_SELECT_CALL(
float,
true, 32);
52 BLOCK_SELECT_CALL(
float,
true, 64);
53 }
else if (k <= 128) {
54 BLOCK_SELECT_CALL(
float,
true, 128);
55 }
else if (k <= 256) {
56 BLOCK_SELECT_CALL(
float,
true, 256);
57 }
else if (k <= 512) {
58 BLOCK_SELECT_CALL(
float,
true, 512);
59 }
else if (k <= 1024) {
60 BLOCK_SELECT_CALL(
float,
true, 1024);
64 BLOCK_SELECT_CALL(
float,
false, 1);
66 BLOCK_SELECT_CALL(
float,
false, 32);
68 BLOCK_SELECT_CALL(
float,
false, 64);
69 }
else if (k <= 128) {
70 BLOCK_SELECT_CALL(
float,
false, 128);
71 }
else if (k <= 256) {
72 BLOCK_SELECT_CALL(
float,
false, 256);
73 }
else if (k <= 512) {
74 BLOCK_SELECT_CALL(
float,
false, 512);
75 }
else if (k <= 1024) {
76 BLOCK_SELECT_CALL(
float,
false, 1024);