Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MergeNetworkBlock.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include "DeviceDefs.cuh"
12 #include "MergeNetworkUtils.cuh"
13 #include "PtxUtils.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
16 #include "../../FaissAssert.h"
17 #include <cuda.h>
18 
19 namespace faiss { namespace gpu {
20 
21 // Merge pairs of lists smaller than blockDim.x (NumThreads)
22 template <int NumThreads,
23  typename K,
24  typename V,
25  int N,
26  int L,
27  bool AllThreads,
28  bool Dir,
29  typename Comp,
30  bool FullMerge>
31 inline __device__ void blockMergeSmall(K* listK, V* listV) {
32  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
33  static_assert(utils::isPowerOf2(NumThreads),
34  "NumThreads must be a power-of-2");
35  static_assert(L <= NumThreads, "merge list size must be <= NumThreads");
36 
37  // Which pair of lists we are merging
38  int mergeId = threadIdx.x / L;
39 
40  // Which thread we are within the merge
41  int tid = threadIdx.x % L;
42 
43  // listK points to a region of size N * 2 * L
44  listK += 2 * L * mergeId;
45  listV += 2 * L * mergeId;
46 
47  // It's not a bitonic merge, both lists are in the same direction,
48  // so handle the first swap assuming the second list is reversed
49  int pos = L - 1 - tid;
50  int stride = 2 * tid + 1;
51 
52  if (AllThreads || (threadIdx.x < N * L)) {
53  K ka = listK[pos];
54  K kb = listK[pos + stride];
55 
56  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
57  listK[pos] = swap ? kb : ka;
58  listK[pos + stride] = swap ? ka : kb;
59 
60  V va = listV[pos];
61  V vb = listV[pos + stride];
62  listV[pos] = swap ? vb : va;
63  listV[pos + stride] = swap ? va : vb;
64 
65  // FIXME: is this a CUDA 9 compiler bug?
66  // K& ka = listK[pos];
67  // K& kb = listK[pos + stride];
68 
69  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
70  // swap(s, ka, kb);
71 
72  // V& va = listV[pos];
73  // V& vb = listV[pos + stride];
74  // swap(s, va, vb);
75  }
76 
77  __syncthreads();
78 
79 #pragma unroll
80  for (int stride = L / 2; stride > 0; stride /= 2) {
81  int pos = 2 * tid - (tid & (stride - 1));
82 
83  if (AllThreads || (threadIdx.x < N * L)) {
84  K ka = listK[pos];
85  K kb = listK[pos + stride];
86 
87  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
88  listK[pos] = swap ? kb : ka;
89  listK[pos + stride] = swap ? ka : kb;
90 
91  V va = listV[pos];
92  V vb = listV[pos + stride];
93  listV[pos] = swap ? vb : va;
94  listV[pos + stride] = swap ? va : vb;
95 
96  // FIXME: is this a CUDA 9 compiler bug?
97  // K& ka = listK[pos];
98  // K& kb = listK[pos + stride];
99 
100  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
101  // swap(s, ka, kb);
102 
103  // V& va = listV[pos];
104  // V& vb = listV[pos + stride];
105  // swap(s, va, vb);
106  }
107 
108  __syncthreads();
109  }
110 }
111 
112 // Merge pairs of sorted lists larger than blockDim.x (NumThreads)
113 template <int NumThreads,
114  typename K,
115  typename V,
116  int L,
117  bool Dir,
118  typename Comp,
119  bool FullMerge>
120 inline __device__ void blockMergeLarge(K* listK, V* listV) {
121  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
122  static_assert(L >= kWarpSize, "merge list size must be >= 32");
123  static_assert(utils::isPowerOf2(NumThreads),
124  "NumThreads must be a power-of-2");
125  static_assert(L >= NumThreads, "merge list size must be >= NumThreads");
126 
127  // For L > NumThreads, each thread has to perform more work
128  // per each stride.
129  constexpr int kLoopPerThread = L / NumThreads;
130 
131  // It's not a bitonic merge, both lists are in the same direction,
132  // so handle the first swap assuming the second list is reversed
133 #pragma unroll
134  for (int loop = 0; loop < kLoopPerThread; ++loop) {
135  int tid = loop * NumThreads + threadIdx.x;
136  int pos = L - 1 - tid;
137  int stride = 2 * tid + 1;
138 
139  K ka = listK[pos];
140  K kb = listK[pos + stride];
141 
142  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
143  listK[pos] = swap ? kb : ka;
144  listK[pos + stride] = swap ? ka : kb;
145 
146  V va = listV[pos];
147  V vb = listV[pos + stride];
148  listV[pos] = swap ? vb : va;
149  listV[pos + stride] = swap ? va : vb;
150 
151  // FIXME: is this a CUDA 9 compiler bug?
152  // K& ka = listK[pos];
153  // K& kb = listK[pos + stride];
154 
155  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
156  // swap(s, ka, kb);
157 
158  // V& va = listV[pos];
159  // V& vb = listV[pos + stride];
160  // swap(s, va, vb);
161  }
162 
163  __syncthreads();
164 
165  constexpr int kSecondLoopPerThread =
166  FullMerge ? kLoopPerThread : kLoopPerThread / 2;
167 
168 #pragma unroll
169  for (int stride = L / 2; stride > 0; stride /= 2) {
170 #pragma unroll
171  for (int loop = 0; loop < kSecondLoopPerThread; ++loop) {
172  int tid = loop * NumThreads + threadIdx.x;
173  int pos = 2 * tid - (tid & (stride - 1));
174 
175  K ka = listK[pos];
176  K kb = listK[pos + stride];
177 
178  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
179  listK[pos] = swap ? kb : ka;
180  listK[pos + stride] = swap ? ka : kb;
181 
182  V va = listV[pos];
183  V vb = listV[pos + stride];
184  listV[pos] = swap ? vb : va;
185  listV[pos + stride] = swap ? va : vb;
186 
187  // FIXME: is this a CUDA 9 compiler bug?
188  // K& ka = listK[pos];
189  // K& kb = listK[pos + stride];
190 
191  // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
192  // swap(s, ka, kb);
193 
194  // V& va = listV[pos];
195  // V& vb = listV[pos + stride];
196  // swap(s, va, vb);
197  }
198 
199  __syncthreads();
200  }
201 }
202 
203 /// Class template to prevent static_assert from firing for
204 /// mixing smaller/larger than block cases
205 template <int NumThreads,
206  typename K,
207  typename V,
208  int N,
209  int L,
210  bool Dir,
211  typename Comp,
212  bool SmallerThanBlock,
213  bool FullMerge>
214 struct BlockMerge {
215 };
216 
217 /// Merging lists smaller than a block
218 template <int NumThreads,
219  typename K,
220  typename V,
221  int N,
222  int L,
223  bool Dir,
224  typename Comp,
225  bool FullMerge>
226 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true, FullMerge> {
227  static inline __device__ void merge(K* listK, V* listV) {
228  constexpr int kNumParallelMerges = NumThreads / L;
229  constexpr int kNumIterations = N / kNumParallelMerges;
230 
231  static_assert(L <= NumThreads, "list must be <= NumThreads");
232  static_assert((N < kNumParallelMerges) ||
233  (kNumIterations * kNumParallelMerges == N),
234  "improper selection of N and L");
235 
236  if (N < kNumParallelMerges) {
237  // We only need L threads per each list to perform the merge
238  blockMergeSmall<NumThreads, K, V, N, L, false, Dir, Comp, FullMerge>(
239  listK, listV);
240  } else {
241  // All threads participate
242 #pragma unroll
243  for (int i = 0; i < kNumIterations; ++i) {
244  int start = i * kNumParallelMerges * 2 * L;
245 
246  blockMergeSmall<NumThreads, K, V, N, L, true, Dir, Comp, FullMerge>(
247  listK + start, listV + start);
248  }
249  }
250  }
251 };
252 
253 /// Merging lists larger than a block
254 template <int NumThreads,
255  typename K,
256  typename V,
257  int N,
258  int L,
259  bool Dir,
260  typename Comp,
261  bool FullMerge>
262 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false, FullMerge> {
263  static inline __device__ void merge(K* listK, V* listV) {
264  // Each pair of lists is merged sequentially
265 #pragma unroll
266  for (int i = 0; i < N; ++i) {
267  int start = i * 2 * L;
268 
269  blockMergeLarge<NumThreads, K, V, L, Dir, Comp, FullMerge>(
270  listK + start, listV + start);
271  }
272  }
273 };
274 
275 template <int NumThreads,
276  typename K,
277  typename V,
278  int N,
279  int L,
280  bool Dir,
281  typename Comp,
282  bool FullMerge = true>
283 inline __device__ void blockMerge(K* listK, V* listV) {
284  constexpr bool kSmallerThanBlock = (L <= NumThreads);
285 
287  merge(listK, listV);
288 }
289 
290 } } // namespace