Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MergeNetworkBlock.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 #pragma once
11 
12 #include "DeviceDefs.cuh"
13 #include "PtxUtils.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
16 #include "../../FaissAssert.h"
17 #include <cuda.h>
18 
19 namespace faiss { namespace gpu {
20 
21 // Merge pairs of lists smaller than blockDim.x (NumThreads)
22 template <int NumThreads, typename K, typename V, int L,
23  bool Dir, typename Comp>
24 inline __device__ void blockMergeSmall(K* listK, V* listV) {
25  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
26  static_assert(utils::isPowerOf2(NumThreads),
27  "NumThreads must be a power-of-2");
28  static_assert(L <= NumThreads, "merge list size must be <= NumThreads");
29 
30  // Which pair of lists we are merging
31  int mergeId = threadIdx.x / L;
32 
33  // Which thread we are within the merge
34  int tid = threadIdx.x % L;
35 
36  // listK points to a region of size N * 2 * L
37  listK += 2 * L * mergeId;
38  listV += 2 * L * mergeId;
39 
40  // It's not a bitonic merge, both lists are in the same direction,
41  // so handle the first swap assuming the second list is reversed
42  int pos = L - 1 - tid;
43  int stride = 2 * tid + 1;
44 
45  K ka = listK[pos];
46  K kb = listK[pos + stride];
47 
48  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
49  listK[pos] = swap ? kb : ka;
50  listK[pos + stride] = swap ? ka : kb;
51 
52  V va = listV[pos];
53  V vb = listV[pos + stride];
54  listV[pos] = swap ? vb : va;
55  listV[pos + stride] = swap ? va : vb;
56 
57  __syncthreads();
58 
59 #pragma unroll
60  for (int stride = L / 2; stride > 0; stride /= 2) {
61  int pos = 2 * tid - (tid & (stride - 1));
62 
63  K ka = listK[pos];
64  K kb = listK[pos + stride];
65 
66  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
67  listK[pos] = swap ? kb : ka;
68  listK[pos + stride] = swap ? ka : kb;
69 
70  V va = listV[pos];
71  V vb = listV[pos + stride];
72  listV[pos] = swap ? vb : va;
73  listV[pos + stride] = swap ? va : vb;
74 
75  __syncthreads();
76  }
77 }
78 
79 // Merge pairs of sorted lists larger than blockDim.x (NumThreads)
80 template <int NumThreads, typename K, typename V, int L,
81  bool Dir, typename Comp>
82 inline __device__ void blockMergeLarge(K* listK, V* listV) {
83  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
84  static_assert(L >= kWarpSize, "merge list size must be >= 32");
85  static_assert(utils::isPowerOf2(NumThreads),
86  "NumThreads must be a power-of-2");
87  static_assert(L >= NumThreads, "merge list size must be >= NumThreads");
88 
89  // For L > NumThreads, each thread has to perform more work
90  // per each stride.
91  constexpr int kLoopPerThread = L / NumThreads;
92 
93  // It's not a bitonic merge, both lists are in the same direction,
94  // so handle the first swap assuming the second list is reversed
95 #pragma unroll
96  for (int loop = 0; loop < kLoopPerThread; ++loop) {
97  int tid = loop * NumThreads + threadIdx.x;
98  int pos = L - 1 - tid;
99  int stride = 2 * tid + 1;
100 
101  K ka = listK[pos];
102  K kb = listK[pos + stride];
103 
104  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
105  listK[pos] = swap ? kb : ka;
106  listK[pos + stride] = swap ? ka : kb;
107 
108  V va = listV[pos];
109  V vb = listV[pos + stride];
110  listV[pos] = swap ? vb : va;
111  listV[pos + stride] = swap ? va : vb;
112  }
113 
114  __syncthreads();
115 
116 #pragma unroll
117  for (int stride = L / 2; stride > 0; stride /= 2) {
118 #pragma unroll
119  for (int loop = 0; loop < kLoopPerThread; ++loop) {
120  int tid = loop * NumThreads + threadIdx.x;
121  int pos = 2 * tid - (tid & (stride - 1));
122 
123  K ka = listK[pos];
124  K kb = listK[pos + stride];
125 
126  bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
127  listK[pos] = swap ? kb : ka;
128  listK[pos + stride] = swap ? ka : kb;
129 
130  V va = listV[pos];
131  V vb = listV[pos + stride];
132  listV[pos] = swap ? vb : va;
133  listV[pos + stride] = swap ? va : vb;
134  }
135 
136  __syncthreads();
137  }
138 }
139 
140 /// Class template to prevent static_assert from firing for
141 /// mixing smaller/larger than block cases
142 template <int NumThreads,
143  typename K, typename V, int N, int L,
144  bool Dir, typename Comp, bool SmallerThanBlock>
145 struct BlockMerge {
146 };
147 
148 /// Merging lists smaller than a block
149 template <int NumThreads, typename K, typename V, int N, int L,
150  bool Dir, typename Comp>
151 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true> {
152  static inline __device__ void merge(K* listK, V* listV) {
153  constexpr int kNumParallelMerges = NumThreads / L;
154  constexpr int kNumIterations = N / kNumParallelMerges;
155 
156  static_assert(L <= NumThreads, "list must be <= NumThreads");
157  static_assert((N < kNumParallelMerges) ||
158  (kNumIterations * kNumParallelMerges == N),
159  "improper selection of N and L");
160 
161  if (N < kNumParallelMerges) {
162  // We only need L threads per each list to perform the merge
163  if (threadIdx.x < N * L) {
164  blockMergeSmall<NumThreads, K, V, L, Dir, Comp>(listK, listV);
165  }
166  } else {
167  // All threads participate
168 #pragma unroll
169  for (int i = 0; i < kNumIterations; ++i) {
170  int start = i * kNumParallelMerges * 2 * L;
171 
172  blockMergeSmall<NumThreads, K, V, L, Dir, Comp>(listK + start,
173  listV + start);
174  }
175  }
176  }
177 };
178 
179 /// Merging lists larger than a block
180 template <int NumThreads, typename K, typename V, int N, int L,
181  bool Dir, typename Comp>
182 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false> {
183  static inline __device__ void merge(K* listK, V* listV) {
184  // Each pair of lists is merged sequentially
185 #pragma unroll
186  for (int i = 0; i < N; ++i) {
187  int start = i * 2 * L;
188 
189  blockMergeLarge<NumThreads, K, V, L, Dir, Comp>(listK + start,
190  listV + start);
191  }
192  }
193 };
194 
195 template <int NumThreads, typename K, typename V, int N, int L,
196  bool Dir, typename Comp>
197 inline __device__ void blockMerge(K* listK, V* listV) {
198  constexpr bool kSmallerThanBlock = (L <= NumThreads);
199 
201  merge(listK, listV);
202 }
203 
204 } } // namespace