Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MergeNetworkWarp.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 #pragma once
12 
13 #include "DeviceDefs.cuh"
14 #include "PtxUtils.cuh"
15 #include "StaticUtils.h"
16 #include "WarpShuffles.cuh"
17 
18 namespace faiss { namespace gpu {
19 
20 //
21 // This file contains functions to:
22 //
23 // -perform bitonic merges on pairs of sorted lists, held in
24 // registers. Each list contains N * kWarpSize (multiple of 32)
25 // elements for some N.
26 // The bitonic merge is implemented for arbitrary sizes;
27 // sorted list A of size N1 * kWarpSize registers
28 // sorted list B of size N2 * kWarpSize registers =>
29 // sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2
30 // are >= 1 and don't have to be powers of 2.
31 //
32 // -perform bitonic sorts on a set of N * kWarpSize key/value pairs
33 // held in registers, by using the above bitonic merge as a
34 // primitive.
35 // N can be an arbitrary N >= 1; i.e., the bitonic sort here supports
36 // odd sizes and doesn't require the input to be a power of 2.
37 //
38 // The sort or merge network is completely statically instantiated via
39 // template specialization / expansion and constexpr, and it uses warp
40 // shuffles to exchange values between warp lanes.
41 //
42 // A note about comparsions:
43 //
44 // For a sorting network of keys only, we only need one
45 // comparison (a < b). However, what we really need to know is
46 // if one lane chooses to exchange a value, then the
47 // corresponding lane should also do the exchange.
48 // Thus, if one just uses the negation !(x < y) in the higher
49 // lane, this will also include the case where (x == y). Thus, one
50 // lane in fact performs an exchange and the other doesn't, but
51 // because the only value being exchanged is equivalent, nothing has
52 // changed.
53 // So, you can get away with just one comparison and its negation.
54 //
55 // If we're sorting keys and values, where equivalent keys can
56 // exist, then this is a problem, since we want to treat (x, v1)
57 // as not equivalent to (x, v2).
58 //
59 // To remedy this, you can either compare with a lexicographic
60 // ordering (a.k < b.k || (a.k == b.k && a.v < b.v)), which since
61 // we're predicating all of the choices results in 3 comparisons
62 // being executed, or we can invert the selection so that there is no
63 // middle choice of equality; the other lane will likewise
64 // check that (b.k > a.k) (the higher lane has the values
65 // swapped). Then, the first lane swaps if and only if the
66 // second lane swaps; if both lanes have equivalent keys, no
67 // swap will be performed. This results in only two comparisons
68 // being executed.
69 //
70 // If you don't consider values as well, then this does not produce a
71 // consistent ordering among (k, v) pairs with equivalent keys but
72 // different values; for us, we don't really care about ordering or
73 // stability here.
74 //
75 // I have tried both re-arranging the order in the higher lane to get
76 // away with one comparison or adding the value to the check; both
77 // result in greater register consumption or lower speed than just
78 // perfoming both < and > comparisons with the variables, so I just
79 // stick with this.
80 
81 template <typename T>
82 inline __device__ void swap(bool swap, T& x, T& y) {
83  T tmp = x;
84  x = swap ? y : x;
85  y = swap ? tmp : y;
86 }
87 
88 // This function merges kWarpSize / 2L lists in parallel using warp
89 // shuffles.
90 // It works on at most size-16 lists, as we need 32 threads for this
91 // shuffle merge.
92 //
93 // If IsBitonic is false, the first stage is reversed, so we don't
94 // need to sort directionally. It's still technically a bitonic sort.
95 template <typename K, typename V, int L,
96  bool Dir, typename Comp, bool IsBitonic>
97 inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
98  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
99  static_assert(L <= kWarpSize / 2, "merge list size must be <= 16");
100 
101  int laneId = getLaneId();
102 
103  if (!IsBitonic) {
104  // Reverse the first comparison stage.
105  // For example, merging a list of size 8 has the exchanges:
106  // 0 <-> 15, 1 <-> 14, ...
107  K otherK = shfl_xor(k, 2 * L - 1);
108  V otherV = shfl_xor(v, 2 * L - 1);
109 
110  // Whether we are the lesser thread in the exchange
111  bool small = !(laneId & L);
112 
113  if (Dir) {
114  // See the comment above how performing both of these
115  // comparisons in the warp seems to win out over the
116  // alternatives in practice
117  bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
118  swap(s, k, otherK);
119  swap(s, v, otherV);
120 
121  } else {
122  bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
123  swap(s, k, otherK);
124  swap(s, v, otherV);
125  }
126  }
127 
128 #pragma unroll
129  for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
130  K otherK = shfl_xor(k, stride);
131  V otherV = shfl_xor(v, stride);
132 
133  // Whether we are the lesser thread in the exchange
134  bool small = !(laneId & stride);
135 
136  if (Dir) {
137  bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
138  swap(s, k, otherK);
139  swap(s, v, otherV);
140 
141  } else {
142  bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
143  swap(s, k, otherK);
144  swap(s, v, otherV);
145  }
146  }
147 }
148 
149 // Template for performing a bitonic merge of an arbitrary set of
150 // registers
151 template <typename K, typename V, int N,
152  bool Dir, typename Comp, bool Low, bool Pow2>
154 };
155 
156 //
157 // Power-of-2 merge specialization
158 //
159 
160 // All merges eventually call this
161 template <typename K, typename V, bool Dir, typename Comp, bool Low>
162 struct BitonicMergeStep<K, V, 1, Dir, Comp, Low, true> {
163  static inline __device__ void merge(K k[1], V v[1]) {
164  // Use warp shuffles
165  warpBitonicMergeLE16<K, V, 16, Dir, Comp, true>(k[0], v[0]);
166  }
167 };
168 
169 template <typename K, typename V, int N, bool Dir, typename Comp, bool Low>
170 struct BitonicMergeStep<K, V, N, Dir, Comp, Low, true> {
171  static inline __device__ void merge(K k[N], V v[N]) {
172  static_assert(utils::isPowerOf2(N), "must be power of 2");
173  static_assert(N > 1, "must be N > 1");
174 
175 #pragma unroll
176  for (int i = 0; i < N / 2; ++i) {
177  K& ka = k[i];
178  V& va = v[i];
179 
180  K& kb = k[i + N / 2];
181  V& vb = v[i + N / 2];
182 
183  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
184  swap(s, ka, kb);
185  swap(s, va, vb);
186  }
187 
188  {
189  K newK[N / 2];
190  V newV[N / 2];
191 
192 #pragma unroll
193  for (int i = 0; i < N / 2; ++i) {
194  newK[i] = k[i];
195  newV[i] = v[i];
196  }
197 
198  BitonicMergeStep<K, V, N / 2, Dir, Comp, true, true>::merge(newK, newV);
199 
200 #pragma unroll
201  for (int i = 0; i < N / 2; ++i) {
202  k[i] = newK[i];
203  v[i] = newV[i];
204  }
205  }
206 
207  {
208  K newK[N / 2];
209  V newV[N / 2];
210 
211 #pragma unroll
212  for (int i = 0; i < N / 2; ++i) {
213  newK[i] = k[i + N / 2];
214  newV[i] = v[i + N / 2];
215  }
216 
217  BitonicMergeStep<K, V, N / 2, Dir, Comp, false, true>::merge(newK, newV);
218 
219 #pragma unroll
220  for (int i = 0; i < N / 2; ++i) {
221  k[i + N / 2] = newK[i];
222  v[i + N / 2] = newV[i];
223  }
224  }
225  }
226 };
227 
228 //
229 // Non-power-of-2 merge specialization
230 //
231 
232 // Low recursion
233 template <typename K, typename V, int N, bool Dir, typename Comp>
234 struct BitonicMergeStep<K, V, N, Dir, Comp, true, false> {
235  static inline __device__ void merge(K k[N], V v[N]) {
236  static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
237  static_assert(N >= 3, "must be N >= 3");
238 
239  constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
240 
241 #pragma unroll
242  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
243  K& ka = k[i];
244  V& va = v[i];
245 
246  K& kb = k[i + kNextHighestPowerOf2 / 2];
247  V& vb = v[i + kNextHighestPowerOf2 / 2];
248 
249  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
250  swap(s, ka, kb);
251  swap(s, va, vb);
252  }
253 
254  constexpr int kLowSize = N - kNextHighestPowerOf2 / 2;
255  constexpr int kHighSize = kNextHighestPowerOf2 / 2;
256  {
257  K newK[kLowSize];
258  V newV[kLowSize];
259 
260 #pragma unroll
261  for (int i = 0; i < kLowSize; ++i) {
262  newK[i] = k[i];
263  newV[i] = v[i];
264  }
265 
266  constexpr bool kLowIsPowerOf2 =
267  utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
268  // FIXME: compiler doesn't like this expression? compiler bug?
269 // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
270  BitonicMergeStep<K, V, kLowSize, Dir, Comp,
271  true, // low
272  kLowIsPowerOf2>::merge(newK, newV);
273 
274 #pragma unroll
275  for (int i = 0; i < kLowSize; ++i) {
276  k[i] = newK[i];
277  v[i] = newV[i];
278  }
279  }
280 
281  {
282  K newK[kHighSize];
283  V newV[kHighSize];
284 
285 #pragma unroll
286  for (int i = 0; i < kHighSize; ++i) {
287  newK[i] = k[i + kLowSize];
288  newV[i] = v[i + kLowSize];
289  }
290 
291  constexpr bool kHighIsPowerOf2 =
292  utils::isPowerOf2(kNextHighestPowerOf2 / 2);
293  // FIXME: compiler doesn't like this expression? compiler bug?
294 // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
295  BitonicMergeStep<K, V, kHighSize, Dir, Comp,
296  false, // high
297  kHighIsPowerOf2>::merge(newK, newV);
298 
299 #pragma unroll
300  for (int i = 0; i < kHighSize; ++i) {
301  k[i + kLowSize] = newK[i];
302  v[i + kLowSize] = newV[i];
303  }
304  }
305  }
306 };
307 
308 // High recursion
309 template <typename K, typename V, int N, bool Dir, typename Comp>
310 struct BitonicMergeStep<K, V, N, Dir, Comp, false, false> {
311  static inline __device__ void merge(K k[N], V v[N]) {
312  static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
313  static_assert(N >= 3, "must be N >= 3");
314 
315  constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
316 
317 #pragma unroll
318  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
319  K& ka = k[i];
320  V& va = v[i];
321 
322  K& kb = k[i + kNextHighestPowerOf2 / 2];
323  V& vb = v[i + kNextHighestPowerOf2 / 2];
324 
325  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
326  swap(s, ka, kb);
327  swap(s, va, vb);
328  }
329 
330  constexpr int kLowSize = kNextHighestPowerOf2 / 2;
331  constexpr int kHighSize = N - kNextHighestPowerOf2 / 2;
332  {
333  K newK[kLowSize];
334  V newV[kLowSize];
335 
336 #pragma unroll
337  for (int i = 0; i < kLowSize; ++i) {
338  newK[i] = k[i];
339  newV[i] = v[i];
340  }
341 
342  constexpr bool kLowIsPowerOf2 =
343  utils::isPowerOf2(kNextHighestPowerOf2 / 2);
344  // FIXME: compiler doesn't like this expression? compiler bug?
345 // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
346  BitonicMergeStep<K, V, kLowSize, Dir, Comp,
347  true, // low
348  kLowIsPowerOf2>::merge(newK, newV);
349 
350 #pragma unroll
351  for (int i = 0; i < kLowSize; ++i) {
352  k[i] = newK[i];
353  v[i] = newV[i];
354  }
355  }
356 
357  {
358  K newK[kHighSize];
359  V newV[kHighSize];
360 
361 #pragma unroll
362  for (int i = 0; i < kHighSize; ++i) {
363  newK[i] = k[i + kLowSize];
364  newV[i] = v[i + kLowSize];
365  }
366 
367  constexpr bool kHighIsPowerOf2 =
368  utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
369  // FIXME: compiler doesn't like this expression? compiler bug?
370 // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
371  BitonicMergeStep<K, V, kHighSize, Dir, Comp,
372  false, // high
373  kHighIsPowerOf2>::merge(newK, newV);
374 
375 #pragma unroll
376  for (int i = 0; i < kHighSize; ++i) {
377  k[i + kLowSize] = newK[i];
378  v[i + kLowSize] = newV[i];
379  }
380  }
381  }
382 };
383 
384 /// Merges two sets of registers across the warp of any size;
385 /// i.e., merges a sorted k/v list of size kWarpSize * N1 with a
386 /// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any
387 /// value >= 1
388 template <typename K, typename V, int N1, int N2, bool Dir, typename Comp>
389 inline __device__ void warpMergeAnyRegisters(K k1[N1], V v1[N1],
390  K k2[N2], V v2[N2]) {
391  constexpr int kSmallestN = N1 < N2 ? N1 : N2;
392 
393 #pragma unroll
394  for (int i = 0; i < kSmallestN; ++i) {
395  K& ka = k1[N1 - 1 - i];
396  V& va = v1[N1 - 1 - i];
397 
398  K& kb = k2[i];
399  V& vb = v2[i];
400 
401  K otherKa = shfl_xor(ka, kWarpSize - 1);
402  V otherVa = shfl_xor(va, kWarpSize - 1);
403 
404  K otherKb = shfl_xor(kb, kWarpSize - 1);
405  V otherVb = shfl_xor(vb, kWarpSize - 1);
406 
407  // ka is always first in the list, so we needn't use our lane
408  // in this comparison
409  bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
410  swap(swapa, ka, otherKb);
411  swap(swapa, va, otherVb);
412 
413  // kb is always second in the list, so we needn't use our lane
414  // in this comparison
415  bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
416  swap(swapb, kb, otherKa);
417  swap(swapb, vb, otherVa);
418  }
419 
420  BitonicMergeStep<K, V, N1, Dir, Comp,
421  true, utils::isPowerOf2(N1)>::merge(k1, v1);
422  BitonicMergeStep<K, V, N2, Dir, Comp,
423  false, utils::isPowerOf2(N2)>::merge(k2, v2);
424 }
425 
426 // Recursive template that uses the above bitonic merge to perform a
427 // bitonic sort
428 template <typename K, typename V, int N, bool Dir, typename Comp>
430  static inline __device__ void sort(K k[N], V v[N]) {
431  static_assert(N > 1, "did not hit specialized case");
432 
433  // Sort recursively
434  constexpr int kSizeA = N / 2;
435  constexpr int kSizeB = N - kSizeA;
436 
437  K aK[kSizeA];
438  V aV[kSizeA];
439 
440 #pragma unroll
441  for (int i = 0; i < kSizeA; ++i) {
442  aK[i] = k[i];
443  aV[i] = v[i];
444  }
445 
447 
448  K bK[kSizeB];
449  V bV[kSizeB];
450 
451 #pragma unroll
452  for (int i = 0; i < kSizeB; ++i) {
453  bK[i] = k[i + kSizeA];
454  bV[i] = v[i + kSizeA];
455  }
456 
458 
459  // Merge halves
460  warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir, Comp>(aK, aV, bK, bV);
461 
462 #pragma unroll
463  for (int i = 0; i < kSizeA; ++i) {
464  k[i] = aK[i];
465  v[i] = aV[i];
466  }
467 
468 #pragma unroll
469  for (int i = 0; i < kSizeB; ++i) {
470  k[i + kSizeA] = bK[i];
471  v[i + kSizeA] = bV[i];
472  }
473  }
474 };
475 
476 // Single warp (N == 1) sorting specialization
477 template <typename K, typename V, bool Dir, typename Comp>
478 struct BitonicSortStep<K, V, 1, Dir, Comp> {
479  static inline __device__ void sort(K k[1], V v[1]) {
480  // Update this code if this changes
481  // should go from 1 -> kWarpSize in multiples of 2
482  static_assert(kWarpSize == 32, "unexpected warp size");
483 
484  warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]);
485  warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]);
486  warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]);
487  warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]);
488  warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]);
489  }
490 };
491 
492 /// Sort a list of kWarpSize * N elements in registers, where N is an
493 /// arbitrary >= 1
494 template <typename K, typename V, int N, bool Dir, typename Comp>
495 inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) {
497 }
498 
499 } } // namespace