Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BlockSelectHalf.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include "blockselect/BlockSelectImpl.cuh"
10 
11 namespace faiss { namespace gpu {
12 
13 #ifdef FAISS_USE_FLOAT16
14 
15 // warp Q to thread Q:
16 // 1, 1
17 // 32, 2
18 // 64, 3
19 // 128, 3
20 // 256, 4
21 // 512, 8
22 // 1024, 8
23 
24 BLOCK_SELECT_DECL(half, true, 1);
25 BLOCK_SELECT_DECL(half, true, 32);
26 BLOCK_SELECT_DECL(half, true, 64);
27 BLOCK_SELECT_DECL(half, true, 128);
28 BLOCK_SELECT_DECL(half, true, 256);
29 BLOCK_SELECT_DECL(half, true, 512);
30 BLOCK_SELECT_DECL(half, true, 1024);
31 
32 BLOCK_SELECT_DECL(half, false, 1);
33 BLOCK_SELECT_DECL(half, false, 32);
34 BLOCK_SELECT_DECL(half, false, 64);
35 BLOCK_SELECT_DECL(half, false, 128);
36 BLOCK_SELECT_DECL(half, false, 256);
37 BLOCK_SELECT_DECL(half, false, 512);
38 BLOCK_SELECT_DECL(half, false, 1024);
39 
40 void runBlockSelect(Tensor<half, 2, true>& in,
41  Tensor<half, 2, true>& outK,
42  Tensor<int, 2, true>& outV,
43  bool dir, int k, cudaStream_t stream) {
44  FAISS_ASSERT(k <= 1024);
45 
46  if (dir) {
47  if (k == 1) {
48  BLOCK_SELECT_CALL(half, true, 1);
49  } else if (k <= 32) {
50  BLOCK_SELECT_CALL(half, true, 32);
51  } else if (k <= 64) {
52  BLOCK_SELECT_CALL(half, true, 64);
53  } else if (k <= 128) {
54  BLOCK_SELECT_CALL(half, true, 128);
55  } else if (k <= 256) {
56  BLOCK_SELECT_CALL(half, true, 256);
57  } else if (k <= 512) {
58  BLOCK_SELECT_CALL(half, true, 512);
59  } else if (k <= 1024) {
60  BLOCK_SELECT_CALL(half, true, 1024);
61  }
62  } else {
63  if (k == 1) {
64  BLOCK_SELECT_CALL(half, false, 1);
65  } else if (k <= 32) {
66  BLOCK_SELECT_CALL(half, false, 32);
67  } else if (k <= 64) {
68  BLOCK_SELECT_CALL(half, false, 64);
69  } else if (k <= 128) {
70  BLOCK_SELECT_CALL(half, false, 128);
71  } else if (k <= 256) {
72  BLOCK_SELECT_CALL(half, false, 256);
73  } else if (k <= 512) {
74  BLOCK_SELECT_CALL(half, false, 512);
75  } else if (k <= 1024) {
76  BLOCK_SELECT_CALL(half, false, 1024);
77  }
78  }
79 }
80 
81 void runBlockSelectPair(Tensor<half, 2, true>& inK,
82  Tensor<int, 2, true>& inV,
83  Tensor<half, 2, true>& outK,
84  Tensor<int, 2, true>& outV,
85  bool dir, int k, cudaStream_t stream) {
86  FAISS_ASSERT(k <= 1024);
87 
88  if (dir) {
89  if (k == 1) {
90  BLOCK_SELECT_PAIR_CALL(half, true, 1);
91  } else if (k <= 32) {
92  BLOCK_SELECT_PAIR_CALL(half, true, 32);
93  } else if (k <= 64) {
94  BLOCK_SELECT_PAIR_CALL(half, true, 64);
95  } else if (k <= 128) {
96  BLOCK_SELECT_PAIR_CALL(half, true, 128);
97  } else if (k <= 256) {
98  BLOCK_SELECT_PAIR_CALL(half, true, 256);
99  } else if (k <= 512) {
100  BLOCK_SELECT_PAIR_CALL(half, true, 512);
101  } else if (k <= 1024) {
102  BLOCK_SELECT_PAIR_CALL(half, true, 1024);
103  }
104  } else {
105  if (k == 1) {
106  BLOCK_SELECT_PAIR_CALL(half, false, 1);
107  } else if (k <= 32) {
108  BLOCK_SELECT_PAIR_CALL(half, false, 32);
109  } else if (k <= 64) {
110  BLOCK_SELECT_PAIR_CALL(half, false, 64);
111  } else if (k <= 128) {
112  BLOCK_SELECT_PAIR_CALL(half, false, 128);
113  } else if (k <= 256) {
114  BLOCK_SELECT_PAIR_CALL(half, false, 256);
115  } else if (k <= 512) {
116  BLOCK_SELECT_PAIR_CALL(half, false, 512);
117  } else if (k <= 1024) {
118  BLOCK_SELECT_PAIR_CALL(half, false, 1024);
119  }
120  }
121 }
122 
123 #endif
124 
125 } } // namespace