Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MergeNetworkWarp.cuh
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "DeviceDefs.cuh"
11 #include "MergeNetworkUtils.cuh"
12 #include "PtxUtils.cuh"
13 #include "StaticUtils.h"
14 #include "WarpShuffles.cuh"
15 
16 namespace faiss { namespace gpu {
17 
18 //
19 // This file contains functions to:
20 //
21 // -perform bitonic merges on pairs of sorted lists, held in
22 // registers. Each list contains N * kWarpSize (multiple of 32)
23 // elements for some N.
24 // The bitonic merge is implemented for arbitrary sizes;
25 // sorted list A of size N1 * kWarpSize registers
26 // sorted list B of size N2 * kWarpSize registers =>
27 // sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2
28 // are >= 1 and don't have to be powers of 2.
29 //
30 // -perform bitonic sorts on a set of N * kWarpSize key/value pairs
31 // held in registers, by using the above bitonic merge as a
32 // primitive.
33 // N can be an arbitrary N >= 1; i.e., the bitonic sort here supports
34 // odd sizes and doesn't require the input to be a power of 2.
35 //
36 // The sort or merge network is completely statically instantiated via
37 // template specialization / expansion and constexpr, and it uses warp
38 // shuffles to exchange values between warp lanes.
39 //
40 // A note about comparsions:
41 //
42 // For a sorting network of keys only, we only need one
43 // comparison (a < b). However, what we really need to know is
44 // if one lane chooses to exchange a value, then the
45 // corresponding lane should also do the exchange.
46 // Thus, if one just uses the negation !(x < y) in the higher
47 // lane, this will also include the case where (x == y). Thus, one
48 // lane in fact performs an exchange and the other doesn't, but
49 // because the only value being exchanged is equivalent, nothing has
50 // changed.
51 // So, you can get away with just one comparison and its negation.
52 //
53 // If we're sorting keys and values, where equivalent keys can
54 // exist, then this is a problem, since we want to treat (x, v1)
55 // as not equivalent to (x, v2).
56 //
57 // To remedy this, you can either compare with a lexicographic
58 // ordering (a.k < b.k || (a.k == b.k && a.v < b.v)), which since
59 // we're predicating all of the choices results in 3 comparisons
60 // being executed, or we can invert the selection so that there is no
61 // middle choice of equality; the other lane will likewise
62 // check that (b.k > a.k) (the higher lane has the values
63 // swapped). Then, the first lane swaps if and only if the
64 // second lane swaps; if both lanes have equivalent keys, no
65 // swap will be performed. This results in only two comparisons
66 // being executed.
67 //
68 // If you don't consider values as well, then this does not produce a
69 // consistent ordering among (k, v) pairs with equivalent keys but
70 // different values; for us, we don't really care about ordering or
71 // stability here.
72 //
73 // I have tried both re-arranging the order in the higher lane to get
74 // away with one comparison or adding the value to the check; both
75 // result in greater register consumption or lower speed than just
76 // perfoming both < and > comparisons with the variables, so I just
77 // stick with this.
78 
79 // This function merges kWarpSize / 2L lists in parallel using warp
80 // shuffles.
81 // It works on at most size-16 lists, as we need 32 threads for this
82 // shuffle merge.
83 //
84 // If IsBitonic is false, the first stage is reversed, so we don't
85 // need to sort directionally. It's still technically a bitonic sort.
86 template <typename K, typename V, int L,
87  bool Dir, typename Comp, bool IsBitonic>
88 inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
89  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
90  static_assert(L <= kWarpSize / 2, "merge list size must be <= 16");
91 
92  int laneId = getLaneId();
93 
94  if (!IsBitonic) {
95  // Reverse the first comparison stage.
96  // For example, merging a list of size 8 has the exchanges:
97  // 0 <-> 15, 1 <-> 14, ...
98  K otherK = shfl_xor(k, 2 * L - 1);
99  V otherV = shfl_xor(v, 2 * L - 1);
100 
101  // Whether we are the lesser thread in the exchange
102  bool small = !(laneId & L);
103 
104  if (Dir) {
105  // See the comment above how performing both of these
106  // comparisons in the warp seems to win out over the
107  // alternatives in practice
108  bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
109  assign(s, k, otherK);
110  assign(s, v, otherV);
111 
112  } else {
113  bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
114  assign(s, k, otherK);
115  assign(s, v, otherV);
116  }
117  }
118 
119 #pragma unroll
120  for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
121  K otherK = shfl_xor(k, stride);
122  V otherV = shfl_xor(v, stride);
123 
124  // Whether we are the lesser thread in the exchange
125  bool small = !(laneId & stride);
126 
127  if (Dir) {
128  bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
129  assign(s, k, otherK);
130  assign(s, v, otherV);
131 
132  } else {
133  bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
134  assign(s, k, otherK);
135  assign(s, v, otherV);
136  }
137  }
138 }
139 
140 // Template for performing a bitonic merge of an arbitrary set of
141 // registers
142 template <typename K, typename V, int N,
143  bool Dir, typename Comp, bool Low, bool Pow2>
145 };
146 
147 //
148 // Power-of-2 merge specialization
149 //
150 
151 // All merges eventually call this
152 template <typename K, typename V, bool Dir, typename Comp, bool Low>
153 struct BitonicMergeStep<K, V, 1, Dir, Comp, Low, true> {
154  static inline __device__ void merge(K k[1], V v[1]) {
155  // Use warp shuffles
156  warpBitonicMergeLE16<K, V, 16, Dir, Comp, true>(k[0], v[0]);
157  }
158 };
159 
160 template <typename K, typename V, int N, bool Dir, typename Comp, bool Low>
161 struct BitonicMergeStep<K, V, N, Dir, Comp, Low, true> {
162  static inline __device__ void merge(K k[N], V v[N]) {
163  static_assert(utils::isPowerOf2(N), "must be power of 2");
164  static_assert(N > 1, "must be N > 1");
165 
166 #pragma unroll
167  for (int i = 0; i < N / 2; ++i) {
168  K& ka = k[i];
169  V& va = v[i];
170 
171  K& kb = k[i + N / 2];
172  V& vb = v[i + N / 2];
173 
174  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
175  swap(s, ka, kb);
176  swap(s, va, vb);
177  }
178 
179  {
180  K newK[N / 2];
181  V newV[N / 2];
182 
183 #pragma unroll
184  for (int i = 0; i < N / 2; ++i) {
185  newK[i] = k[i];
186  newV[i] = v[i];
187  }
188 
189  BitonicMergeStep<K, V, N / 2, Dir, Comp, true, true>::merge(newK, newV);
190 
191 #pragma unroll
192  for (int i = 0; i < N / 2; ++i) {
193  k[i] = newK[i];
194  v[i] = newV[i];
195  }
196  }
197 
198  {
199  K newK[N / 2];
200  V newV[N / 2];
201 
202 #pragma unroll
203  for (int i = 0; i < N / 2; ++i) {
204  newK[i] = k[i + N / 2];
205  newV[i] = v[i + N / 2];
206  }
207 
208  BitonicMergeStep<K, V, N / 2, Dir, Comp, false, true>::merge(newK, newV);
209 
210 #pragma unroll
211  for (int i = 0; i < N / 2; ++i) {
212  k[i + N / 2] = newK[i];
213  v[i + N / 2] = newV[i];
214  }
215  }
216  }
217 };
218 
219 //
220 // Non-power-of-2 merge specialization
221 //
222 
223 // Low recursion
224 template <typename K, typename V, int N, bool Dir, typename Comp>
225 struct BitonicMergeStep<K, V, N, Dir, Comp, true, false> {
226  static inline __device__ void merge(K k[N], V v[N]) {
227  static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
228  static_assert(N >= 3, "must be N >= 3");
229 
230  constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
231 
232 #pragma unroll
233  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
234  K& ka = k[i];
235  V& va = v[i];
236 
237  K& kb = k[i + kNextHighestPowerOf2 / 2];
238  V& vb = v[i + kNextHighestPowerOf2 / 2];
239 
240  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
241  swap(s, ka, kb);
242  swap(s, va, vb);
243  }
244 
245  constexpr int kLowSize = N - kNextHighestPowerOf2 / 2;
246  constexpr int kHighSize = kNextHighestPowerOf2 / 2;
247  {
248  K newK[kLowSize];
249  V newV[kLowSize];
250 
251 #pragma unroll
252  for (int i = 0; i < kLowSize; ++i) {
253  newK[i] = k[i];
254  newV[i] = v[i];
255  }
256 
257  constexpr bool kLowIsPowerOf2 =
258  utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
259  // FIXME: compiler doesn't like this expression? compiler bug?
260 // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
261  BitonicMergeStep<K, V, kLowSize, Dir, Comp,
262  true, // low
263  kLowIsPowerOf2>::merge(newK, newV);
264 
265 #pragma unroll
266  for (int i = 0; i < kLowSize; ++i) {
267  k[i] = newK[i];
268  v[i] = newV[i];
269  }
270  }
271 
272  {
273  K newK[kHighSize];
274  V newV[kHighSize];
275 
276 #pragma unroll
277  for (int i = 0; i < kHighSize; ++i) {
278  newK[i] = k[i + kLowSize];
279  newV[i] = v[i + kLowSize];
280  }
281 
282  constexpr bool kHighIsPowerOf2 =
283  utils::isPowerOf2(kNextHighestPowerOf2 / 2);
284  // FIXME: compiler doesn't like this expression? compiler bug?
285 // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
286  BitonicMergeStep<K, V, kHighSize, Dir, Comp,
287  false, // high
288  kHighIsPowerOf2>::merge(newK, newV);
289 
290 #pragma unroll
291  for (int i = 0; i < kHighSize; ++i) {
292  k[i + kLowSize] = newK[i];
293  v[i + kLowSize] = newV[i];
294  }
295  }
296  }
297 };
298 
299 // High recursion
300 template <typename K, typename V, int N, bool Dir, typename Comp>
301 struct BitonicMergeStep<K, V, N, Dir, Comp, false, false> {
302  static inline __device__ void merge(K k[N], V v[N]) {
303  static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
304  static_assert(N >= 3, "must be N >= 3");
305 
306  constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
307 
308 #pragma unroll
309  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
310  K& ka = k[i];
311  V& va = v[i];
312 
313  K& kb = k[i + kNextHighestPowerOf2 / 2];
314  V& vb = v[i + kNextHighestPowerOf2 / 2];
315 
316  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
317  swap(s, ka, kb);
318  swap(s, va, vb);
319  }
320 
321  constexpr int kLowSize = kNextHighestPowerOf2 / 2;
322  constexpr int kHighSize = N - kNextHighestPowerOf2 / 2;
323  {
324  K newK[kLowSize];
325  V newV[kLowSize];
326 
327 #pragma unroll
328  for (int i = 0; i < kLowSize; ++i) {
329  newK[i] = k[i];
330  newV[i] = v[i];
331  }
332 
333  constexpr bool kLowIsPowerOf2 =
334  utils::isPowerOf2(kNextHighestPowerOf2 / 2);
335  // FIXME: compiler doesn't like this expression? compiler bug?
336 // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
337  BitonicMergeStep<K, V, kLowSize, Dir, Comp,
338  true, // low
339  kLowIsPowerOf2>::merge(newK, newV);
340 
341 #pragma unroll
342  for (int i = 0; i < kLowSize; ++i) {
343  k[i] = newK[i];
344  v[i] = newV[i];
345  }
346  }
347 
348  {
349  K newK[kHighSize];
350  V newV[kHighSize];
351 
352 #pragma unroll
353  for (int i = 0; i < kHighSize; ++i) {
354  newK[i] = k[i + kLowSize];
355  newV[i] = v[i + kLowSize];
356  }
357 
358  constexpr bool kHighIsPowerOf2 =
359  utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
360  // FIXME: compiler doesn't like this expression? compiler bug?
361 // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
362  BitonicMergeStep<K, V, kHighSize, Dir, Comp,
363  false, // high
364  kHighIsPowerOf2>::merge(newK, newV);
365 
366 #pragma unroll
367  for (int i = 0; i < kHighSize; ++i) {
368  k[i + kLowSize] = newK[i];
369  v[i + kLowSize] = newV[i];
370  }
371  }
372  }
373 };
374 
375 /// Merges two sets of registers across the warp of any size;
376 /// i.e., merges a sorted k/v list of size kWarpSize * N1 with a
377 /// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any
378 /// value >= 1
379 template <typename K,
380  typename V,
381  int N1,
382  int N2,
383  bool Dir,
384  typename Comp,
385  bool FullMerge = true>
386 inline __device__ void warpMergeAnyRegisters(K k1[N1], V v1[N1],
387  K k2[N2], V v2[N2]) {
388  constexpr int kSmallestN = N1 < N2 ? N1 : N2;
389 
390 #pragma unroll
391  for (int i = 0; i < kSmallestN; ++i) {
392  K& ka = k1[N1 - 1 - i];
393  V& va = v1[N1 - 1 - i];
394 
395  K& kb = k2[i];
396  V& vb = v2[i];
397 
398  K otherKa;
399  V otherVa;
400 
401  if (FullMerge) {
402  // We need the other values
403  otherKa = shfl_xor(ka, kWarpSize - 1);
404  otherVa = shfl_xor(va, kWarpSize - 1);
405  }
406 
407  K otherKb = shfl_xor(kb, kWarpSize - 1);
408  V otherVb = shfl_xor(vb, kWarpSize - 1);
409 
410  // ka is always first in the list, so we needn't use our lane
411  // in this comparison
412  bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
413  assign(swapa, ka, otherKb);
414  assign(swapa, va, otherVb);
415 
416  // kb is always second in the list, so we needn't use our lane
417  // in this comparison
418  if (FullMerge) {
419  bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
420  assign(swapb, kb, otherKa);
421  assign(swapb, vb, otherVa);
422 
423  } else {
424  // We don't care about updating elements in the second list
425  }
426  }
427 
428  BitonicMergeStep<K, V, N1, Dir, Comp,
429  true, utils::isPowerOf2(N1)>::merge(k1, v1);
430  if (FullMerge) {
431  // Only if we care about N2 do we need to bother merging it fully
432  BitonicMergeStep<K, V, N2, Dir, Comp,
433  false, utils::isPowerOf2(N2)>::merge(k2, v2);
434  }
435 }
436 
437 // Recursive template that uses the above bitonic merge to perform a
438 // bitonic sort
439 template <typename K, typename V, int N, bool Dir, typename Comp>
441  static inline __device__ void sort(K k[N], V v[N]) {
442  static_assert(N > 1, "did not hit specialized case");
443 
444  // Sort recursively
445  constexpr int kSizeA = N / 2;
446  constexpr int kSizeB = N - kSizeA;
447 
448  K aK[kSizeA];
449  V aV[kSizeA];
450 
451 #pragma unroll
452  for (int i = 0; i < kSizeA; ++i) {
453  aK[i] = k[i];
454  aV[i] = v[i];
455  }
456 
458 
459  K bK[kSizeB];
460  V bV[kSizeB];
461 
462 #pragma unroll
463  for (int i = 0; i < kSizeB; ++i) {
464  bK[i] = k[i + kSizeA];
465  bV[i] = v[i + kSizeA];
466  }
467 
469 
470  // Merge halves
471  warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir, Comp>(aK, aV, bK, bV);
472 
473 #pragma unroll
474  for (int i = 0; i < kSizeA; ++i) {
475  k[i] = aK[i];
476  v[i] = aV[i];
477  }
478 
479 #pragma unroll
480  for (int i = 0; i < kSizeB; ++i) {
481  k[i + kSizeA] = bK[i];
482  v[i + kSizeA] = bV[i];
483  }
484  }
485 };
486 
487 // Single warp (N == 1) sorting specialization
488 template <typename K, typename V, bool Dir, typename Comp>
489 struct BitonicSortStep<K, V, 1, Dir, Comp> {
490  static inline __device__ void sort(K k[1], V v[1]) {
491  // Update this code if this changes
492  // should go from 1 -> kWarpSize in multiples of 2
493  static_assert(kWarpSize == 32, "unexpected warp size");
494 
495  warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]);
496  warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]);
497  warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]);
498  warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]);
499  warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]);
500  }
501 };
502 
503 /// Sort a list of kWarpSize * N elements in registers, where N is an
504 /// arbitrary >= 1
505 template <typename K, typename V, int N, bool Dir, typename Comp>
506 inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) {
508 }
509 
510 } } // namespace