10 #include "blockselect/BlockSelectImpl.cuh"
12 namespace faiss {
namespace gpu {
23 BLOCK_SELECT_DECL(
float,
true, 1);
24 BLOCK_SELECT_DECL(
float,
true, 32);
25 BLOCK_SELECT_DECL(
float,
true, 64);
26 BLOCK_SELECT_DECL(
float,
true, 128);
27 BLOCK_SELECT_DECL(
float,
true, 256);
28 BLOCK_SELECT_DECL(
float,
true, 512);
29 BLOCK_SELECT_DECL(
float,
true, 1024);
31 BLOCK_SELECT_DECL(
float,
false, 1);
32 BLOCK_SELECT_DECL(
float,
false, 32);
33 BLOCK_SELECT_DECL(
float,
false, 64);
34 BLOCK_SELECT_DECL(
float,
false, 128);
35 BLOCK_SELECT_DECL(
float,
false, 256);
36 BLOCK_SELECT_DECL(
float,
false, 512);
37 BLOCK_SELECT_DECL(
float,
false, 1024);
39 void runBlockSelect(Tensor<float, 2, true>& in,
40 Tensor<float, 2, true>& outK,
41 Tensor<int, 2, true>& outV,
42 bool dir,
int k, cudaStream_t stream) {
43 FAISS_ASSERT(k <= 1024);
47 BLOCK_SELECT_CALL(
float,
true, 1);
49 BLOCK_SELECT_CALL(
float,
true, 32);
51 BLOCK_SELECT_CALL(
float,
true, 64);
52 }
else if (k <= 128) {
53 BLOCK_SELECT_CALL(
float,
true, 128);
54 }
else if (k <= 256) {
55 BLOCK_SELECT_CALL(
float,
true, 256);
56 }
else if (k <= 512) {
57 BLOCK_SELECT_CALL(
float,
true, 512);
58 }
else if (k <= 1024) {
59 BLOCK_SELECT_CALL(
float,
true, 1024);
63 BLOCK_SELECT_CALL(
float,
false, 1);
65 BLOCK_SELECT_CALL(
float,
false, 32);
67 BLOCK_SELECT_CALL(
float,
false, 64);
68 }
else if (k <= 128) {
69 BLOCK_SELECT_CALL(
float,
false, 128);
70 }
else if (k <= 256) {
71 BLOCK_SELECT_CALL(
float,
false, 256);
72 }
else if (k <= 512) {
73 BLOCK_SELECT_CALL(
float,
false, 512);
74 }
else if (k <= 1024) {
75 BLOCK_SELECT_CALL(
float,
false, 1024);