Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MergeNetworkWarp.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 #pragma once
11 
12 #include "DeviceDefs.cuh"
13 #include "MergeNetworkUtils.cuh"
14 #include "PtxUtils.cuh"
15 #include "StaticUtils.h"
16 #include "WarpShuffles.cuh"
17 
18 namespace faiss { namespace gpu {
19 
20 //
21 // This file contains functions to:
22 //
23 // -perform bitonic merges on pairs of sorted lists, held in
24 // registers. Each list contains N * kWarpSize (multiple of 32)
25 // elements for some N.
26 // The bitonic merge is implemented for arbitrary sizes;
27 // sorted list A of size N1 * kWarpSize registers
28 // sorted list B of size N2 * kWarpSize registers =>
29 // sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2
30 // are >= 1 and don't have to be powers of 2.
31 //
32 // -perform bitonic sorts on a set of N * kWarpSize key/value pairs
33 // held in registers, by using the above bitonic merge as a
34 // primitive.
35 // N can be an arbitrary N >= 1; i.e., the bitonic sort here supports
36 // odd sizes and doesn't require the input to be a power of 2.
37 //
38 // The sort or merge network is completely statically instantiated via
39 // template specialization / expansion and constexpr, and it uses warp
40 // shuffles to exchange values between warp lanes.
41 //
42 // A note about comparsions:
43 //
44 // For a sorting network of keys only, we only need one
45 // comparison (a < b). However, what we really need to know is
46 // if one lane chooses to exchange a value, then the
47 // corresponding lane should also do the exchange.
48 // Thus, if one just uses the negation !(x < y) in the higher
49 // lane, this will also include the case where (x == y). Thus, one
50 // lane in fact performs an exchange and the other doesn't, but
51 // because the only value being exchanged is equivalent, nothing has
52 // changed.
53 // So, you can get away with just one comparison and its negation.
54 //
55 // If we're sorting keys and values, where equivalent keys can
56 // exist, then this is a problem, since we want to treat (x, v1)
57 // as not equivalent to (x, v2).
58 //
59 // To remedy this, you can either compare with a lexicographic
60 // ordering (a.k < b.k || (a.k == b.k && a.v < b.v)), which since
61 // we're predicating all of the choices results in 3 comparisons
62 // being executed, or we can invert the selection so that there is no
63 // middle choice of equality; the other lane will likewise
64 // check that (b.k > a.k) (the higher lane has the values
65 // swapped). Then, the first lane swaps if and only if the
66 // second lane swaps; if both lanes have equivalent keys, no
67 // swap will be performed. This results in only two comparisons
68 // being executed.
69 //
70 // If you don't consider values as well, then this does not produce a
71 // consistent ordering among (k, v) pairs with equivalent keys but
72 // different values; for us, we don't really care about ordering or
73 // stability here.
74 //
75 // I have tried both re-arranging the order in the higher lane to get
76 // away with one comparison or adding the value to the check; both
77 // result in greater register consumption or lower speed than just
78 // perfoming both < and > comparisons with the variables, so I just
79 // stick with this.
80 
81 // This function merges kWarpSize / 2L lists in parallel using warp
82 // shuffles.
83 // It works on at most size-16 lists, as we need 32 threads for this
84 // shuffle merge.
85 //
86 // If IsBitonic is false, the first stage is reversed, so we don't
87 // need to sort directionally. It's still technically a bitonic sort.
88 template <typename K, typename V, int L,
89  bool Dir, typename Comp, bool IsBitonic>
90 inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
91  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
92  static_assert(L <= kWarpSize / 2, "merge list size must be <= 16");
93 
94  int laneId = getLaneId();
95 
96  if (!IsBitonic) {
97  // Reverse the first comparison stage.
98  // For example, merging a list of size 8 has the exchanges:
99  // 0 <-> 15, 1 <-> 14, ...
100  K otherK = shfl_xor(k, 2 * L - 1);
101  V otherV = shfl_xor(v, 2 * L - 1);
102 
103  // Whether we are the lesser thread in the exchange
104  bool small = !(laneId & L);
105 
106  if (Dir) {
107  // See the comment above how performing both of these
108  // comparisons in the warp seems to win out over the
109  // alternatives in practice
110  bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
111  assign(s, k, otherK);
112  assign(s, v, otherV);
113 
114  } else {
115  bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
116  assign(s, k, otherK);
117  assign(s, v, otherV);
118  }
119  }
120 
121 #pragma unroll
122  for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
123  K otherK = shfl_xor(k, stride);
124  V otherV = shfl_xor(v, stride);
125 
126  // Whether we are the lesser thread in the exchange
127  bool small = !(laneId & stride);
128 
129  if (Dir) {
130  bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
131  assign(s, k, otherK);
132  assign(s, v, otherV);
133 
134  } else {
135  bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
136  assign(s, k, otherK);
137  assign(s, v, otherV);
138  }
139  }
140 }
141 
142 // Template for performing a bitonic merge of an arbitrary set of
143 // registers
144 template <typename K, typename V, int N,
145  bool Dir, typename Comp, bool Low, bool Pow2>
147 };
148 
149 //
150 // Power-of-2 merge specialization
151 //
152 
153 // All merges eventually call this
154 template <typename K, typename V, bool Dir, typename Comp, bool Low>
155 struct BitonicMergeStep<K, V, 1, Dir, Comp, Low, true> {
156  static inline __device__ void merge(K k[1], V v[1]) {
157  // Use warp shuffles
158  warpBitonicMergeLE16<K, V, 16, Dir, Comp, true>(k[0], v[0]);
159  }
160 };
161 
162 template <typename K, typename V, int N, bool Dir, typename Comp, bool Low>
163 struct BitonicMergeStep<K, V, N, Dir, Comp, Low, true> {
164  static inline __device__ void merge(K k[N], V v[N]) {
165  static_assert(utils::isPowerOf2(N), "must be power of 2");
166  static_assert(N > 1, "must be N > 1");
167 
168 #pragma unroll
169  for (int i = 0; i < N / 2; ++i) {
170  K& ka = k[i];
171  V& va = v[i];
172 
173  K& kb = k[i + N / 2];
174  V& vb = v[i + N / 2];
175 
176  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
177  swap(s, ka, kb);
178  swap(s, va, vb);
179  }
180 
181  {
182  K newK[N / 2];
183  V newV[N / 2];
184 
185 #pragma unroll
186  for (int i = 0; i < N / 2; ++i) {
187  newK[i] = k[i];
188  newV[i] = v[i];
189  }
190 
191  BitonicMergeStep<K, V, N / 2, Dir, Comp, true, true>::merge(newK, newV);
192 
193 #pragma unroll
194  for (int i = 0; i < N / 2; ++i) {
195  k[i] = newK[i];
196  v[i] = newV[i];
197  }
198  }
199 
200  {
201  K newK[N / 2];
202  V newV[N / 2];
203 
204 #pragma unroll
205  for (int i = 0; i < N / 2; ++i) {
206  newK[i] = k[i + N / 2];
207  newV[i] = v[i + N / 2];
208  }
209 
210  BitonicMergeStep<K, V, N / 2, Dir, Comp, false, true>::merge(newK, newV);
211 
212 #pragma unroll
213  for (int i = 0; i < N / 2; ++i) {
214  k[i + N / 2] = newK[i];
215  v[i + N / 2] = newV[i];
216  }
217  }
218  }
219 };
220 
221 //
222 // Non-power-of-2 merge specialization
223 //
224 
225 // Low recursion
226 template <typename K, typename V, int N, bool Dir, typename Comp>
227 struct BitonicMergeStep<K, V, N, Dir, Comp, true, false> {
228  static inline __device__ void merge(K k[N], V v[N]) {
229  static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
230  static_assert(N >= 3, "must be N >= 3");
231 
232  constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
233 
234 #pragma unroll
235  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
236  K& ka = k[i];
237  V& va = v[i];
238 
239  K& kb = k[i + kNextHighestPowerOf2 / 2];
240  V& vb = v[i + kNextHighestPowerOf2 / 2];
241 
242  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
243  swap(s, ka, kb);
244  swap(s, va, vb);
245  }
246 
247  constexpr int kLowSize = N - kNextHighestPowerOf2 / 2;
248  constexpr int kHighSize = kNextHighestPowerOf2 / 2;
249  {
250  K newK[kLowSize];
251  V newV[kLowSize];
252 
253 #pragma unroll
254  for (int i = 0; i < kLowSize; ++i) {
255  newK[i] = k[i];
256  newV[i] = v[i];
257  }
258 
259  constexpr bool kLowIsPowerOf2 =
260  utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
261  // FIXME: compiler doesn't like this expression? compiler bug?
262 // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
263  BitonicMergeStep<K, V, kLowSize, Dir, Comp,
264  true, // low
265  kLowIsPowerOf2>::merge(newK, newV);
266 
267 #pragma unroll
268  for (int i = 0; i < kLowSize; ++i) {
269  k[i] = newK[i];
270  v[i] = newV[i];
271  }
272  }
273 
274  {
275  K newK[kHighSize];
276  V newV[kHighSize];
277 
278 #pragma unroll
279  for (int i = 0; i < kHighSize; ++i) {
280  newK[i] = k[i + kLowSize];
281  newV[i] = v[i + kLowSize];
282  }
283 
284  constexpr bool kHighIsPowerOf2 =
285  utils::isPowerOf2(kNextHighestPowerOf2 / 2);
286  // FIXME: compiler doesn't like this expression? compiler bug?
287 // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
288  BitonicMergeStep<K, V, kHighSize, Dir, Comp,
289  false, // high
290  kHighIsPowerOf2>::merge(newK, newV);
291 
292 #pragma unroll
293  for (int i = 0; i < kHighSize; ++i) {
294  k[i + kLowSize] = newK[i];
295  v[i + kLowSize] = newV[i];
296  }
297  }
298  }
299 };
300 
301 // High recursion
302 template <typename K, typename V, int N, bool Dir, typename Comp>
303 struct BitonicMergeStep<K, V, N, Dir, Comp, false, false> {
304  static inline __device__ void merge(K k[N], V v[N]) {
305  static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
306  static_assert(N >= 3, "must be N >= 3");
307 
308  constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
309 
310 #pragma unroll
311  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
312  K& ka = k[i];
313  V& va = v[i];
314 
315  K& kb = k[i + kNextHighestPowerOf2 / 2];
316  V& vb = v[i + kNextHighestPowerOf2 / 2];
317 
318  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
319  swap(s, ka, kb);
320  swap(s, va, vb);
321  }
322 
323  constexpr int kLowSize = kNextHighestPowerOf2 / 2;
324  constexpr int kHighSize = N - kNextHighestPowerOf2 / 2;
325  {
326  K newK[kLowSize];
327  V newV[kLowSize];
328 
329 #pragma unroll
330  for (int i = 0; i < kLowSize; ++i) {
331  newK[i] = k[i];
332  newV[i] = v[i];
333  }
334 
335  constexpr bool kLowIsPowerOf2 =
336  utils::isPowerOf2(kNextHighestPowerOf2 / 2);
337  // FIXME: compiler doesn't like this expression? compiler bug?
338 // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
339  BitonicMergeStep<K, V, kLowSize, Dir, Comp,
340  true, // low
341  kLowIsPowerOf2>::merge(newK, newV);
342 
343 #pragma unroll
344  for (int i = 0; i < kLowSize; ++i) {
345  k[i] = newK[i];
346  v[i] = newV[i];
347  }
348  }
349 
350  {
351  K newK[kHighSize];
352  V newV[kHighSize];
353 
354 #pragma unroll
355  for (int i = 0; i < kHighSize; ++i) {
356  newK[i] = k[i + kLowSize];
357  newV[i] = v[i + kLowSize];
358  }
359 
360  constexpr bool kHighIsPowerOf2 =
361  utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
362  // FIXME: compiler doesn't like this expression? compiler bug?
363 // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
364  BitonicMergeStep<K, V, kHighSize, Dir, Comp,
365  false, // high
366  kHighIsPowerOf2>::merge(newK, newV);
367 
368 #pragma unroll
369  for (int i = 0; i < kHighSize; ++i) {
370  k[i + kLowSize] = newK[i];
371  v[i + kLowSize] = newV[i];
372  }
373  }
374  }
375 };
376 
377 /// Merges two sets of registers across the warp of any size;
378 /// i.e., merges a sorted k/v list of size kWarpSize * N1 with a
379 /// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any
380 /// value >= 1
381 template <typename K,
382  typename V,
383  int N1,
384  int N2,
385  bool Dir,
386  typename Comp,
387  bool FullMerge = true>
388 inline __device__ void warpMergeAnyRegisters(K k1[N1], V v1[N1],
389  K k2[N2], V v2[N2]) {
390  constexpr int kSmallestN = N1 < N2 ? N1 : N2;
391 
392 #pragma unroll
393  for (int i = 0; i < kSmallestN; ++i) {
394  K& ka = k1[N1 - 1 - i];
395  V& va = v1[N1 - 1 - i];
396 
397  K& kb = k2[i];
398  V& vb = v2[i];
399 
400  K otherKa;
401  V otherVa;
402 
403  if (FullMerge) {
404  // We need the other values
405  otherKa = shfl_xor(ka, kWarpSize - 1);
406  otherVa = shfl_xor(va, kWarpSize - 1);
407  }
408 
409  K otherKb = shfl_xor(kb, kWarpSize - 1);
410  V otherVb = shfl_xor(vb, kWarpSize - 1);
411 
412  // ka is always first in the list, so we needn't use our lane
413  // in this comparison
414  bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
415  assign(swapa, ka, otherKb);
416  assign(swapa, va, otherVb);
417 
418  // kb is always second in the list, so we needn't use our lane
419  // in this comparison
420  if (FullMerge) {
421  bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
422  assign(swapb, kb, otherKa);
423  assign(swapb, vb, otherVa);
424 
425  } else {
426  // We don't care about updating elements in the second list
427  }
428  }
429 
430  BitonicMergeStep<K, V, N1, Dir, Comp,
431  true, utils::isPowerOf2(N1)>::merge(k1, v1);
432  if (FullMerge) {
433  // Only if we care about N2 do we need to bother merging it fully
434  BitonicMergeStep<K, V, N2, Dir, Comp,
435  false, utils::isPowerOf2(N2)>::merge(k2, v2);
436  }
437 }
438 
439 // Recursive template that uses the above bitonic merge to perform a
440 // bitonic sort
441 template <typename K, typename V, int N, bool Dir, typename Comp>
443  static inline __device__ void sort(K k[N], V v[N]) {
444  static_assert(N > 1, "did not hit specialized case");
445 
446  // Sort recursively
447  constexpr int kSizeA = N / 2;
448  constexpr int kSizeB = N - kSizeA;
449 
450  K aK[kSizeA];
451  V aV[kSizeA];
452 
453 #pragma unroll
454  for (int i = 0; i < kSizeA; ++i) {
455  aK[i] = k[i];
456  aV[i] = v[i];
457  }
458 
460 
461  K bK[kSizeB];
462  V bV[kSizeB];
463 
464 #pragma unroll
465  for (int i = 0; i < kSizeB; ++i) {
466  bK[i] = k[i + kSizeA];
467  bV[i] = v[i + kSizeA];
468  }
469 
471 
472  // Merge halves
473  warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir, Comp>(aK, aV, bK, bV);
474 
475 #pragma unroll
476  for (int i = 0; i < kSizeA; ++i) {
477  k[i] = aK[i];
478  v[i] = aV[i];
479  }
480 
481 #pragma unroll
482  for (int i = 0; i < kSizeB; ++i) {
483  k[i + kSizeA] = bK[i];
484  v[i + kSizeA] = bV[i];
485  }
486  }
487 };
488 
489 // Single warp (N == 1) sorting specialization
490 template <typename K, typename V, bool Dir, typename Comp>
491 struct BitonicSortStep<K, V, 1, Dir, Comp> {
492  static inline __device__ void sort(K k[1], V v[1]) {
493  // Update this code if this changes
494  // should go from 1 -> kWarpSize in multiples of 2
495  static_assert(kWarpSize == 32, "unexpected warp size");
496 
497  warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]);
498  warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]);
499  warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]);
500  warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]);
501  warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]);
502  }
503 };
504 
505 /// Sort a list of kWarpSize * N elements in registers, where N is an
506 /// arbitrary >= 1
507 template <typename K, typename V, int N, bool Dir, typename Comp>
508 inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) {
510 }
511 
512 } } // namespace