Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MergeNetworkBlock.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include "DeviceDefs.cuh"
12 #include "MergeNetworkUtils.cuh"
13 #include "PtxUtils.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
16 #include "../../FaissAssert.h"
17 #include <cuda.h>
18 
19 namespace faiss { namespace gpu {
20 
21 // Merge pairs of lists smaller than blockDim.x (NumThreads)
22 template <int NumThreads,
23  typename K,
24  typename V,
25  int L,
26  bool Dir,
27  typename Comp,
28  bool FullMerge>
29 inline __device__ void blockMergeSmall(K* listK, V* listV) {
30  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
31  static_assert(utils::isPowerOf2(NumThreads),
32  "NumThreads must be a power-of-2");
33  static_assert(L <= NumThreads, "merge list size must be <= NumThreads");
34 
35  // Which pair of lists we are merging
36  int mergeId = threadIdx.x / L;
37 
38  // Which thread we are within the merge
39  int tid = threadIdx.x % L;
40 
41  // listK points to a region of size N * 2 * L
42  listK += 2 * L * mergeId;
43  listV += 2 * L * mergeId;
44 
45  // It's not a bitonic merge, both lists are in the same direction,
46  // so handle the first swap assuming the second list is reversed
47  int pos = L - 1 - tid;
48  int stride = 2 * tid + 1;
49 
50  K& ka = listK[pos];
51  K& kb = listK[pos + stride];
52 
53  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
54  swap(s, ka, kb);
55 
56  V& va = listV[pos];
57  V& vb = listV[pos + stride];
58  swap(s, va, vb);
59 
60  __syncthreads();
61 
62 #pragma unroll
63  for (int stride = L / 2; stride > 0; stride /= 2) {
64  int pos = 2 * tid - (tid & (stride - 1));
65 
66  K& ka = listK[pos];
67  K& kb = listK[pos + stride];
68 
69  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
70  swap(s, ka, kb);
71 
72  V& va = listV[pos];
73  V& vb = listV[pos + stride];
74  swap(s, va, vb);
75 
76  __syncthreads();
77  }
78 }
79 
80 // Merge pairs of sorted lists larger than blockDim.x (NumThreads)
81 template <int NumThreads,
82  typename K,
83  typename V,
84  int L,
85  bool Dir,
86  typename Comp,
87  bool FullMerge>
88 inline __device__ void blockMergeLarge(K* listK, V* listV) {
89  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
90  static_assert(L >= kWarpSize, "merge list size must be >= 32");
91  static_assert(utils::isPowerOf2(NumThreads),
92  "NumThreads must be a power-of-2");
93  static_assert(L >= NumThreads, "merge list size must be >= NumThreads");
94 
95  // For L > NumThreads, each thread has to perform more work
96  // per each stride.
97  constexpr int kLoopPerThread = L / NumThreads;
98 
99  // It's not a bitonic merge, both lists are in the same direction,
100  // so handle the first swap assuming the second list is reversed
101 #pragma unroll
102  for (int loop = 0; loop < kLoopPerThread; ++loop) {
103  int tid = loop * NumThreads + threadIdx.x;
104  int pos = L - 1 - tid;
105  int stride = 2 * tid + 1;
106 
107  K& ka = listK[pos];
108  K& kb = listK[pos + stride];
109 
110  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
111  swap(s, ka, kb);
112 
113  V& va = listV[pos];
114  V& vb = listV[pos + stride];
115  swap(s, va, vb);
116  }
117 
118  __syncthreads();
119 
120  constexpr int kSecondLoopPerThread =
121  FullMerge ? kLoopPerThread : kLoopPerThread / 2;
122 
123 #pragma unroll
124  for (int stride = L / 2; stride > 0; stride /= 2) {
125 #pragma unroll
126  for (int loop = 0; loop < kSecondLoopPerThread; ++loop) {
127  int tid = loop * NumThreads + threadIdx.x;
128  int pos = 2 * tid - (tid & (stride - 1));
129 
130  K& ka = listK[pos];
131  K& kb = listK[pos + stride];
132 
133  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
134  swap(s, ka, kb);
135 
136  V& va = listV[pos];
137  V& vb = listV[pos + stride];
138  swap(s, va, vb);
139  }
140 
141  __syncthreads();
142  }
143 }
144 
145 /// Class template to prevent static_assert from firing for
146 /// mixing smaller/larger than block cases
147 template <int NumThreads,
148  typename K,
149  typename V,
150  int N,
151  int L,
152  bool Dir,
153  typename Comp,
154  bool SmallerThanBlock,
155  bool FullMerge>
156 struct BlockMerge {
157 };
158 
159 /// Merging lists smaller than a block
160 template <int NumThreads,
161  typename K,
162  typename V,
163  int N,
164  int L,
165  bool Dir,
166  typename Comp,
167  bool FullMerge>
168 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true, FullMerge> {
169  static inline __device__ void merge(K* listK, V* listV) {
170  constexpr int kNumParallelMerges = NumThreads / L;
171  constexpr int kNumIterations = N / kNumParallelMerges;
172 
173  static_assert(L <= NumThreads, "list must be <= NumThreads");
174  static_assert((N < kNumParallelMerges) ||
175  (kNumIterations * kNumParallelMerges == N),
176  "improper selection of N and L");
177 
178  if (N < kNumParallelMerges) {
179  // We only need L threads per each list to perform the merge
180  if (threadIdx.x < N * L) {
181  blockMergeSmall<NumThreads, K, V, L, Dir, Comp, FullMerge>(
182  listK, listV);
183  }
184  } else {
185  // All threads participate
186 #pragma unroll
187  for (int i = 0; i < kNumIterations; ++i) {
188  int start = i * kNumParallelMerges * 2 * L;
189 
190  blockMergeSmall<NumThreads, K, V, L, Dir, Comp, FullMerge>(
191  listK + start, listV + start);
192  }
193  }
194  }
195 };
196 
197 /// Merging lists larger than a block
198 template <int NumThreads,
199  typename K,
200  typename V,
201  int N,
202  int L,
203  bool Dir,
204  typename Comp,
205  bool FullMerge>
206 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false, FullMerge> {
207  static inline __device__ void merge(K* listK, V* listV) {
208  // Each pair of lists is merged sequentially
209 #pragma unroll
210  for (int i = 0; i < N; ++i) {
211  int start = i * 2 * L;
212 
213  blockMergeLarge<NumThreads, K, V, L, Dir, Comp, FullMerge>(
214  listK + start, listV + start);
215  }
216  }
217 };
218 
219 template <int NumThreads,
220  typename K,
221  typename V,
222  int N,
223  int L,
224  bool Dir,
225  typename Comp,
226  bool FullMerge = true>
227 inline __device__ void blockMerge(K* listK, V* listV) {
228  constexpr bool kSmallerThanBlock = (L <= NumThreads);
229 
231  merge(listK, listV);
232 }
233 
234 } } // namespace