Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectFloat.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include "blockselect/BlockSelectImpl.cuh"
10 
11 namespace faiss { namespace gpu {
12 
13 // warp Q to thread Q:
14 // 1, 1
15 // 32, 2
16 // 64, 3
17 // 128, 3
18 // 256, 4
19 // 512, 8
20 // 1024, 8
21 
22 BLOCK_SELECT_DECL(float, true, 1);
23 BLOCK_SELECT_DECL(float, true, 32);
24 BLOCK_SELECT_DECL(float, true, 64);
25 BLOCK_SELECT_DECL(float, true, 128);
26 BLOCK_SELECT_DECL(float, true, 256);
27 BLOCK_SELECT_DECL(float, true, 512);
28 BLOCK_SELECT_DECL(float, true, 1024);
29 
30 BLOCK_SELECT_DECL(float, false, 1);
31 BLOCK_SELECT_DECL(float, false, 32);
32 BLOCK_SELECT_DECL(float, false, 64);
33 BLOCK_SELECT_DECL(float, false, 128);
34 BLOCK_SELECT_DECL(float, false, 256);
35 BLOCK_SELECT_DECL(float, false, 512);
36 BLOCK_SELECT_DECL(float, false, 1024);
37 
38 void runBlockSelect(Tensor<float, 2, true>& in,
39  Tensor<float, 2, true>& outK,
40  Tensor<int, 2, true>& outV,
41  bool dir, int k, cudaStream_t stream) {
42  FAISS_ASSERT(k <= 1024);
43 
44  if (dir) {
45  if (k == 1) {
46  BLOCK_SELECT_CALL(float, true, 1);
47  } else if (k <= 32) {
48  BLOCK_SELECT_CALL(float, true, 32);
49  } else if (k <= 64) {
50  BLOCK_SELECT_CALL(float, true, 64);
51  } else if (k <= 128) {
52  BLOCK_SELECT_CALL(float, true, 128);
53  } else if (k <= 256) {
54  BLOCK_SELECT_CALL(float, true, 256);
55  } else if (k <= 512) {
56  BLOCK_SELECT_CALL(float, true, 512);
57  } else if (k <= 1024) {
58  BLOCK_SELECT_CALL(float, true, 1024);
59  }
60  } else {
61  if (k == 1) {
62  BLOCK_SELECT_CALL(float, false, 1);
63  } else if (k <= 32) {
64  BLOCK_SELECT_CALL(float, false, 32);
65  } else if (k <= 64) {
66  BLOCK_SELECT_CALL(float, false, 64);
67  } else if (k <= 128) {
68  BLOCK_SELECT_CALL(float, false, 128);
69  } else if (k <= 256) {
70  BLOCK_SELECT_CALL(float, false, 256);
71  } else if (k <= 512) {
72  BLOCK_SELECT_CALL(float, false, 512);
73  } else if (k <= 1024) {
74  BLOCK_SELECT_CALL(float, false, 1024);
75  }
76  }
77 }
78 
79 void runBlockSelectPair(Tensor<float, 2, true>& inK,
80  Tensor<int, 2, true>& inV,
81  Tensor<float, 2, true>& outK,
82  Tensor<int, 2, true>& outV,
83  bool dir, int k, cudaStream_t stream) {
84  FAISS_ASSERT(k <= 1024);
85 
86  if (dir) {
87  if (k == 1) {
88  BLOCK_SELECT_PAIR_CALL(float, true, 1);
89  } else if (k <= 32) {
90  BLOCK_SELECT_PAIR_CALL(float, true, 32);
91  } else if (k <= 64) {
92  BLOCK_SELECT_PAIR_CALL(float, true, 64);
93  } else if (k <= 128) {
94  BLOCK_SELECT_PAIR_CALL(float, true, 128);
95  } else if (k <= 256) {
96  BLOCK_SELECT_PAIR_CALL(float, true, 256);
97  } else if (k <= 512) {
98  BLOCK_SELECT_PAIR_CALL(float, true, 512);
99  } else if (k <= 1024) {
100  BLOCK_SELECT_PAIR_CALL(float, true, 1024);
101  }
102  } else {
103  if (k == 1) {
104  BLOCK_SELECT_PAIR_CALL(float, false, 1);
105  } else if (k <= 32) {
106  BLOCK_SELECT_PAIR_CALL(float, false, 32);
107  } else if (k <= 64) {
108  BLOCK_SELECT_PAIR_CALL(float, false, 64);
109  } else if (k <= 128) {
110  BLOCK_SELECT_PAIR_CALL(float, false, 128);
111  } else if (k <= 256) {
112  BLOCK_SELECT_PAIR_CALL(float, false, 256);
113  } else if (k <= 512) {
114  BLOCK_SELECT_PAIR_CALL(float, false, 512);
115  } else if (k <= 1024) {
116  BLOCK_SELECT_PAIR_CALL(float, false, 1024);
117  }
118  }
119 }
120 
121 } } // namespace