Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MergeNetworkBlock.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "DeviceDefs.cuh"
11 #include "MergeNetworkUtils.cuh"
12 #include "PtxUtils.cuh"
13 #include "StaticUtils.h"
14 #include "WarpShuffles.cuh"
15 #include "../../FaissAssert.h"
16 #include <cuda.h>
17 
18 namespace faiss { namespace gpu {
19 
20 // Merge pairs of lists smaller than blockDim.x (NumThreads)
21 template <int NumThreads,
22  typename K,
23  typename V,
24  int N,
25  int L,
26  bool AllThreads,
27  bool Dir,
28  typename Comp,
29  bool FullMerge>
30 inline __device__ void blockMergeSmall(K* listK, V* listV) {
31  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
32  static_assert(utils::isPowerOf2(NumThreads),
33  "NumThreads must be a power-of-2");
34  static_assert(L <= NumThreads, "merge list size must be <= NumThreads");
35 
36  // Which pair of lists we are merging
37  int mergeId = threadIdx.x / L;
38 
39  // Which thread we are within the merge
40  int tid = threadIdx.x % L;
41 
42  // listK points to a region of size N * 2 * L
43  listK += 2 * L * mergeId;
44  listV += 2 * L * mergeId;
45 
46  // It's not a bitonic merge, both lists are in the same direction,
47  // so handle the first swap assuming the second list is reversed
48  int pos = L - 1 - tid;
49  int stride = 2 * tid + 1;
50 
51  if (AllThreads || (threadIdx.x < N * L)) {
52  K ka = listK[pos];
53  K kb = listK[pos + stride];
54 
55  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
56  listK[pos] = swap ? kb : ka;
57  listK[pos + stride] = swap ? ka : kb;
58 
59  V va = listV[pos];
60  V vb = listV[pos + stride];
61  listV[pos] = swap ? vb : va;
62  listV[pos + stride] = swap ? va : vb;
63 
64  // FIXME: is this a CUDA 9 compiler bug?
65  // K& ka = listK[pos];
66  // K& kb = listK[pos + stride];
67 
68  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
69  // swap(s, ka, kb);
70 
71  // V& va = listV[pos];
72  // V& vb = listV[pos + stride];
73  // swap(s, va, vb);
74  }
75 
76  __syncthreads();
77 
78 #pragma unroll
79  for (int stride = L / 2; stride > 0; stride /= 2) {
80  int pos = 2 * tid - (tid & (stride - 1));
81 
82  if (AllThreads || (threadIdx.x < N * L)) {
83  K ka = listK[pos];
84  K kb = listK[pos + stride];
85 
86  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
87  listK[pos] = swap ? kb : ka;
88  listK[pos + stride] = swap ? ka : kb;
89 
90  V va = listV[pos];
91  V vb = listV[pos + stride];
92  listV[pos] = swap ? vb : va;
93  listV[pos + stride] = swap ? va : vb;
94 
95  // FIXME: is this a CUDA 9 compiler bug?
96  // K& ka = listK[pos];
97  // K& kb = listK[pos + stride];
98 
99  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
100  // swap(s, ka, kb);
101 
102  // V& va = listV[pos];
103  // V& vb = listV[pos + stride];
104  // swap(s, va, vb);
105  }
106 
107  __syncthreads();
108  }
109 }
110 
111 // Merge pairs of sorted lists larger than blockDim.x (NumThreads)
112 template <int NumThreads,
113  typename K,
114  typename V,
115  int L,
116  bool Dir,
117  typename Comp,
118  bool FullMerge>
119 inline __device__ void blockMergeLarge(K* listK, V* listV) {
120  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
121  static_assert(L >= kWarpSize, "merge list size must be >= 32");
122  static_assert(utils::isPowerOf2(NumThreads),
123  "NumThreads must be a power-of-2");
124  static_assert(L >= NumThreads, "merge list size must be >= NumThreads");
125 
126  // For L > NumThreads, each thread has to perform more work
127  // per each stride.
128  constexpr int kLoopPerThread = L / NumThreads;
129 
130  // It's not a bitonic merge, both lists are in the same direction,
131  // so handle the first swap assuming the second list is reversed
132 #pragma unroll
133  for (int loop = 0; loop < kLoopPerThread; ++loop) {
134  int tid = loop * NumThreads + threadIdx.x;
135  int pos = L - 1 - tid;
136  int stride = 2 * tid + 1;
137 
138  K ka = listK[pos];
139  K kb = listK[pos + stride];
140 
141  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
142  listK[pos] = swap ? kb : ka;
143  listK[pos + stride] = swap ? ka : kb;
144 
145  V va = listV[pos];
146  V vb = listV[pos + stride];
147  listV[pos] = swap ? vb : va;
148  listV[pos + stride] = swap ? va : vb;
149 
150  // FIXME: is this a CUDA 9 compiler bug?
151  // K& ka = listK[pos];
152  // K& kb = listK[pos + stride];
153 
154  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
155  // swap(s, ka, kb);
156 
157  // V& va = listV[pos];
158  // V& vb = listV[pos + stride];
159  // swap(s, va, vb);
160  }
161 
162  __syncthreads();
163 
164  constexpr int kSecondLoopPerThread =
165  FullMerge ? kLoopPerThread : kLoopPerThread / 2;
166 
167 #pragma unroll
168  for (int stride = L / 2; stride > 0; stride /= 2) {
169 #pragma unroll
170  for (int loop = 0; loop < kSecondLoopPerThread; ++loop) {
171  int tid = loop * NumThreads + threadIdx.x;
172  int pos = 2 * tid - (tid & (stride - 1));
173 
174  K ka = listK[pos];
175  K kb = listK[pos + stride];
176 
177  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
178  listK[pos] = swap ? kb : ka;
179  listK[pos + stride] = swap ? ka : kb;
180 
181  V va = listV[pos];
182  V vb = listV[pos + stride];
183  listV[pos] = swap ? vb : va;
184  listV[pos + stride] = swap ? va : vb;
185 
186  // FIXME: is this a CUDA 9 compiler bug?
187  // K& ka = listK[pos];
188  // K& kb = listK[pos + stride];
189 
190  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
191  // swap(s, ka, kb);
192 
193  // V& va = listV[pos];
194  // V& vb = listV[pos + stride];
195  // swap(s, va, vb);
196  }
197 
198  __syncthreads();
199  }
200 }
201 
202 /// Class template to prevent static_assert from firing for
203 /// mixing smaller/larger than block cases
204 template <int NumThreads,
205  typename K,
206  typename V,
207  int N,
208  int L,
209  bool Dir,
210  typename Comp,
211  bool SmallerThanBlock,
212  bool FullMerge>
213 struct BlockMerge {
214 };
215 
216 /// Merging lists smaller than a block
217 template <int NumThreads,
218  typename K,
219  typename V,
220  int N,
221  int L,
222  bool Dir,
223  typename Comp,
224  bool FullMerge>
225 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true, FullMerge> {
226  static inline __device__ void merge(K* listK, V* listV) {
227  constexpr int kNumParallelMerges = NumThreads / L;
228  constexpr int kNumIterations = N / kNumParallelMerges;
229 
230  static_assert(L <= NumThreads, "list must be <= NumThreads");
231  static_assert((N < kNumParallelMerges) ||
232  (kNumIterations * kNumParallelMerges == N),
233  "improper selection of N and L");
234 
235  if (N < kNumParallelMerges) {
236  // We only need L threads per each list to perform the merge
237  blockMergeSmall<NumThreads, K, V, N, L, false, Dir, Comp, FullMerge>(
238  listK, listV);
239  } else {
240  // All threads participate
241 #pragma unroll
242  for (int i = 0; i < kNumIterations; ++i) {
243  int start = i * kNumParallelMerges * 2 * L;
244 
245  blockMergeSmall<NumThreads, K, V, N, L, true, Dir, Comp, FullMerge>(
246  listK + start, listV + start);
247  }
248  }
249  }
250 };
251 
252 /// Merging lists larger than a block
253 template <int NumThreads,
254  typename K,
255  typename V,
256  int N,
257  int L,
258  bool Dir,
259  typename Comp,
260  bool FullMerge>
261 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false, FullMerge> {
262  static inline __device__ void merge(K* listK, V* listV) {
263  // Each pair of lists is merged sequentially
264 #pragma unroll
265  for (int i = 0; i < N; ++i) {
266  int start = i * 2 * L;
267 
268  blockMergeLarge<NumThreads, K, V, L, Dir, Comp, FullMerge>(
269  listK + start, listV + start);
270  }
271  }
272 };
273 
274 template <int NumThreads,
275  typename K,
276  typename V,
277  int N,
278  int L,
279  bool Dir,
280  typename Comp,
281  bool FullMerge = true>
282 inline __device__ void blockMerge(K* listK, V* listV) {
283  constexpr bool kSmallerThanBlock = (L <= NumThreads);
284 
286  merge(listK, listV);
287 }
288 
289 } } // namespace