Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
MergeNetworkWarp.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include "DeviceDefs.cuh"
12 #include "MergeNetworkUtils.cuh"
13 #include "PtxUtils.cuh"
14 #include "StaticUtils.h"
15 #include "WarpShuffles.cuh"
16 
17 namespace faiss { namespace gpu {
18 
19 //
20 // This file contains functions to:
21 //
22 // -perform bitonic merges on pairs of sorted lists, held in
23 // registers. Each list contains N * kWarpSize (multiple of 32)
24 // elements for some N.
25 // The bitonic merge is implemented for arbitrary sizes;
26 // sorted list A of size N1 * kWarpSize registers
27 // sorted list B of size N2 * kWarpSize registers =>
28 // sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2
29 // are >= 1 and don't have to be powers of 2.
30 //
31 // -perform bitonic sorts on a set of N * kWarpSize key/value pairs
32 // held in registers, by using the above bitonic merge as a
33 // primitive.
34 // N can be an arbitrary N >= 1; i.e., the bitonic sort here supports
35 // odd sizes and doesn't require the input to be a power of 2.
36 //
37 // The sort or merge network is completely statically instantiated via
38 // template specialization / expansion and constexpr, and it uses warp
39 // shuffles to exchange values between warp lanes.
40 //
41 // A note about comparsions:
42 //
43 // For a sorting network of keys only, we only need one
44 // comparison (a < b). However, what we really need to know is
45 // if one lane chooses to exchange a value, then the
46 // corresponding lane should also do the exchange.
47 // Thus, if one just uses the negation !(x < y) in the higher
48 // lane, this will also include the case where (x == y). Thus, one
49 // lane in fact performs an exchange and the other doesn't, but
50 // because the only value being exchanged is equivalent, nothing has
51 // changed.
52 // So, you can get away with just one comparison and its negation.
53 //
54 // If we're sorting keys and values, where equivalent keys can
55 // exist, then this is a problem, since we want to treat (x, v1)
56 // as not equivalent to (x, v2).
57 //
58 // To remedy this, you can either compare with a lexicographic
59 // ordering (a.k < b.k || (a.k == b.k && a.v < b.v)), which since
60 // we're predicating all of the choices results in 3 comparisons
61 // being executed, or we can invert the selection so that there is no
62 // middle choice of equality; the other lane will likewise
63 // check that (b.k > a.k) (the higher lane has the values
64 // swapped). Then, the first lane swaps if and only if the
65 // second lane swaps; if both lanes have equivalent keys, no
66 // swap will be performed. This results in only two comparisons
67 // being executed.
68 //
69 // If you don't consider values as well, then this does not produce a
70 // consistent ordering among (k, v) pairs with equivalent keys but
71 // different values; for us, we don't really care about ordering or
72 // stability here.
73 //
74 // I have tried both re-arranging the order in the higher lane to get
75 // away with one comparison or adding the value to the check; both
76 // result in greater register consumption or lower speed than just
77 // perfoming both < and > comparisons with the variables, so I just
78 // stick with this.
79 
80 // This function merges kWarpSize / 2L lists in parallel using warp
81 // shuffles.
82 // It works on at most size-16 lists, as we need 32 threads for this
83 // shuffle merge.
84 //
85 // If IsBitonic is false, the first stage is reversed, so we don't
86 // need to sort directionally. It's still technically a bitonic sort.
87 template <typename K, typename V, int L,
88  bool Dir, typename Comp, bool IsBitonic>
89 inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
90  static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
91  static_assert(L <= kWarpSize / 2, "merge list size must be <= 16");
92 
93  int laneId = getLaneId();
94 
95  if (!IsBitonic) {
96  // Reverse the first comparison stage.
97  // For example, merging a list of size 8 has the exchanges:
98  // 0 <-> 15, 1 <-> 14, ...
99  K otherK = shfl_xor(k, 2 * L - 1);
100  V otherV = shfl_xor(v, 2 * L - 1);
101 
102  // Whether we are the lesser thread in the exchange
103  bool small = !(laneId & L);
104 
105  if (Dir) {
106  // See the comment above how performing both of these
107  // comparisons in the warp seems to win out over the
108  // alternatives in practice
109  bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
110  assign(s, k, otherK);
111  assign(s, v, otherV);
112 
113  } else {
114  bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
115  assign(s, k, otherK);
116  assign(s, v, otherV);
117  }
118  }
119 
120 #pragma unroll
121  for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) {
122  K otherK = shfl_xor(k, stride);
123  V otherV = shfl_xor(v, stride);
124 
125  // Whether we are the lesser thread in the exchange
126  bool small = !(laneId & stride);
127 
128  if (Dir) {
129  bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK);
130  assign(s, k, otherK);
131  assign(s, v, otherV);
132 
133  } else {
134  bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK);
135  assign(s, k, otherK);
136  assign(s, v, otherV);
137  }
138  }
139 }
140 
141 // Template for performing a bitonic merge of an arbitrary set of
142 // registers
143 template <typename K, typename V, int N,
144  bool Dir, typename Comp, bool Low, bool Pow2>
146 };
147 
148 //
149 // Power-of-2 merge specialization
150 //
151 
152 // All merges eventually call this
153 template <typename K, typename V, bool Dir, typename Comp, bool Low>
154 struct BitonicMergeStep<K, V, 1, Dir, Comp, Low, true> {
155  static inline __device__ void merge(K k[1], V v[1]) {
156  // Use warp shuffles
157  warpBitonicMergeLE16<K, V, 16, Dir, Comp, true>(k[0], v[0]);
158  }
159 };
160 
161 template <typename K, typename V, int N, bool Dir, typename Comp, bool Low>
162 struct BitonicMergeStep<K, V, N, Dir, Comp, Low, true> {
163  static inline __device__ void merge(K k[N], V v[N]) {
164  static_assert(utils::isPowerOf2(N), "must be power of 2");
165  static_assert(N > 1, "must be N > 1");
166 
167 #pragma unroll
168  for (int i = 0; i < N / 2; ++i) {
169  K& ka = k[i];
170  V& va = v[i];
171 
172  K& kb = k[i + N / 2];
173  V& vb = v[i + N / 2];
174 
175  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
176  swap(s, ka, kb);
177  swap(s, va, vb);
178  }
179 
180  {
181  K newK[N / 2];
182  V newV[N / 2];
183 
184 #pragma unroll
185  for (int i = 0; i < N / 2; ++i) {
186  newK[i] = k[i];
187  newV[i] = v[i];
188  }
189 
190  BitonicMergeStep<K, V, N / 2, Dir, Comp, true, true>::merge(newK, newV);
191 
192 #pragma unroll
193  for (int i = 0; i < N / 2; ++i) {
194  k[i] = newK[i];
195  v[i] = newV[i];
196  }
197  }
198 
199  {
200  K newK[N / 2];
201  V newV[N / 2];
202 
203 #pragma unroll
204  for (int i = 0; i < N / 2; ++i) {
205  newK[i] = k[i + N / 2];
206  newV[i] = v[i + N / 2];
207  }
208 
209  BitonicMergeStep<K, V, N / 2, Dir, Comp, false, true>::merge(newK, newV);
210 
211 #pragma unroll
212  for (int i = 0; i < N / 2; ++i) {
213  k[i + N / 2] = newK[i];
214  v[i + N / 2] = newV[i];
215  }
216  }
217  }
218 };
219 
220 //
221 // Non-power-of-2 merge specialization
222 //
223 
224 // Low recursion
225 template <typename K, typename V, int N, bool Dir, typename Comp>
226 struct BitonicMergeStep<K, V, N, Dir, Comp, true, false> {
227  static inline __device__ void merge(K k[N], V v[N]) {
228  static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
229  static_assert(N >= 3, "must be N >= 3");
230 
231  constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
232 
233 #pragma unroll
234  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
235  K& ka = k[i];
236  V& va = v[i];
237 
238  K& kb = k[i + kNextHighestPowerOf2 / 2];
239  V& vb = v[i + kNextHighestPowerOf2 / 2];
240 
241  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
242  swap(s, ka, kb);
243  swap(s, va, vb);
244  }
245 
246  constexpr int kLowSize = N - kNextHighestPowerOf2 / 2;
247  constexpr int kHighSize = kNextHighestPowerOf2 / 2;
248  {
249  K newK[kLowSize];
250  V newV[kLowSize];
251 
252 #pragma unroll
253  for (int i = 0; i < kLowSize; ++i) {
254  newK[i] = k[i];
255  newV[i] = v[i];
256  }
257 
258  constexpr bool kLowIsPowerOf2 =
259  utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
260  // FIXME: compiler doesn't like this expression? compiler bug?
261 // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
262  BitonicMergeStep<K, V, kLowSize, Dir, Comp,
263  true, // low
264  kLowIsPowerOf2>::merge(newK, newV);
265 
266 #pragma unroll
267  for (int i = 0; i < kLowSize; ++i) {
268  k[i] = newK[i];
269  v[i] = newV[i];
270  }
271  }
272 
273  {
274  K newK[kHighSize];
275  V newV[kHighSize];
276 
277 #pragma unroll
278  for (int i = 0; i < kHighSize; ++i) {
279  newK[i] = k[i + kLowSize];
280  newV[i] = v[i + kLowSize];
281  }
282 
283  constexpr bool kHighIsPowerOf2 =
284  utils::isPowerOf2(kNextHighestPowerOf2 / 2);
285  // FIXME: compiler doesn't like this expression? compiler bug?
286 // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
287  BitonicMergeStep<K, V, kHighSize, Dir, Comp,
288  false, // high
289  kHighIsPowerOf2>::merge(newK, newV);
290 
291 #pragma unroll
292  for (int i = 0; i < kHighSize; ++i) {
293  k[i + kLowSize] = newK[i];
294  v[i + kLowSize] = newV[i];
295  }
296  }
297  }
298 };
299 
300 // High recursion
301 template <typename K, typename V, int N, bool Dir, typename Comp>
302 struct BitonicMergeStep<K, V, N, Dir, Comp, false, false> {
303  static inline __device__ void merge(K k[N], V v[N]) {
304  static_assert(!utils::isPowerOf2(N), "must be non-power-of-2");
305  static_assert(N >= 3, "must be N >= 3");
306 
307  constexpr int kNextHighestPowerOf2 = utils::nextHighestPowerOf2(N);
308 
309 #pragma unroll
310  for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) {
311  K& ka = k[i];
312  V& va = v[i];
313 
314  K& kb = k[i + kNextHighestPowerOf2 / 2];
315  V& vb = v[i + kNextHighestPowerOf2 / 2];
316 
317  bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
318  swap(s, ka, kb);
319  swap(s, va, vb);
320  }
321 
322  constexpr int kLowSize = kNextHighestPowerOf2 / 2;
323  constexpr int kHighSize = N - kNextHighestPowerOf2 / 2;
324  {
325  K newK[kLowSize];
326  V newV[kLowSize];
327 
328 #pragma unroll
329  for (int i = 0; i < kLowSize; ++i) {
330  newK[i] = k[i];
331  newV[i] = v[i];
332  }
333 
334  constexpr bool kLowIsPowerOf2 =
335  utils::isPowerOf2(kNextHighestPowerOf2 / 2);
336  // FIXME: compiler doesn't like this expression? compiler bug?
337 // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize);
338  BitonicMergeStep<K, V, kLowSize, Dir, Comp,
339  true, // low
340  kLowIsPowerOf2>::merge(newK, newV);
341 
342 #pragma unroll
343  for (int i = 0; i < kLowSize; ++i) {
344  k[i] = newK[i];
345  v[i] = newV[i];
346  }
347  }
348 
349  {
350  K newK[kHighSize];
351  V newV[kHighSize];
352 
353 #pragma unroll
354  for (int i = 0; i < kHighSize; ++i) {
355  newK[i] = k[i + kLowSize];
356  newV[i] = v[i + kLowSize];
357  }
358 
359  constexpr bool kHighIsPowerOf2 =
360  utils::isPowerOf2(N - kNextHighestPowerOf2 / 2);
361  // FIXME: compiler doesn't like this expression? compiler bug?
362 // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize);
363  BitonicMergeStep<K, V, kHighSize, Dir, Comp,
364  false, // high
365  kHighIsPowerOf2>::merge(newK, newV);
366 
367 #pragma unroll
368  for (int i = 0; i < kHighSize; ++i) {
369  k[i + kLowSize] = newK[i];
370  v[i + kLowSize] = newV[i];
371  }
372  }
373  }
374 };
375 
376 /// Merges two sets of registers across the warp of any size;
377 /// i.e., merges a sorted k/v list of size kWarpSize * N1 with a
378 /// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any
379 /// value >= 1
380 template <typename K,
381  typename V,
382  int N1,
383  int N2,
384  bool Dir,
385  typename Comp,
386  bool FullMerge = true>
387 inline __device__ void warpMergeAnyRegisters(K k1[N1], V v1[N1],
388  K k2[N2], V v2[N2]) {
389  constexpr int kSmallestN = N1 < N2 ? N1 : N2;
390 
391 #pragma unroll
392  for (int i = 0; i < kSmallestN; ++i) {
393  K& ka = k1[N1 - 1 - i];
394  V& va = v1[N1 - 1 - i];
395 
396  K& kb = k2[i];
397  V& vb = v2[i];
398 
399  K otherKa;
400  V otherVa;
401 
402  if (FullMerge) {
403  // We need the other values
404  otherKa = shfl_xor(ka, kWarpSize - 1);
405  otherVa = shfl_xor(va, kWarpSize - 1);
406  }
407 
408  K otherKb = shfl_xor(kb, kWarpSize - 1);
409  V otherVb = shfl_xor(vb, kWarpSize - 1);
410 
411  // ka is always first in the list, so we needn't use our lane
412  // in this comparison
413  bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb);
414  assign(swapa, ka, otherKb);
415  assign(swapa, va, otherVb);
416 
417  // kb is always second in the list, so we needn't use our lane
418  // in this comparison
419  if (FullMerge) {
420  bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa);
421  assign(swapb, kb, otherKa);
422  assign(swapb, vb, otherVa);
423 
424  } else {
425  // We don't care about updating elements in the second list
426  }
427  }
428 
429  BitonicMergeStep<K, V, N1, Dir, Comp,
430  true, utils::isPowerOf2(N1)>::merge(k1, v1);
431  if (FullMerge) {
432  // Only if we care about N2 do we need to bother merging it fully
433  BitonicMergeStep<K, V, N2, Dir, Comp,
434  false, utils::isPowerOf2(N2)>::merge(k2, v2);
435  }
436 }
437 
438 // Recursive template that uses the above bitonic merge to perform a
439 // bitonic sort
440 template <typename K, typename V, int N, bool Dir, typename Comp>
442  static inline __device__ void sort(K k[N], V v[N]) {
443  static_assert(N > 1, "did not hit specialized case");
444 
445  // Sort recursively
446  constexpr int kSizeA = N / 2;
447  constexpr int kSizeB = N - kSizeA;
448 
449  K aK[kSizeA];
450  V aV[kSizeA];
451 
452 #pragma unroll
453  for (int i = 0; i < kSizeA; ++i) {
454  aK[i] = k[i];
455  aV[i] = v[i];
456  }
457 
459 
460  K bK[kSizeB];
461  V bV[kSizeB];
462 
463 #pragma unroll
464  for (int i = 0; i < kSizeB; ++i) {
465  bK[i] = k[i + kSizeA];
466  bV[i] = v[i + kSizeA];
467  }
468 
470 
471  // Merge halves
472  warpMergeAnyRegisters<K, V, kSizeA, kSizeB, Dir, Comp>(aK, aV, bK, bV);
473 
474 #pragma unroll
475  for (int i = 0; i < kSizeA; ++i) {
476  k[i] = aK[i];
477  v[i] = aV[i];
478  }
479 
480 #pragma unroll
481  for (int i = 0; i < kSizeB; ++i) {
482  k[i + kSizeA] = bK[i];
483  v[i + kSizeA] = bV[i];
484  }
485  }
486 };
487 
488 // Single warp (N == 1) sorting specialization
489 template <typename K, typename V, bool Dir, typename Comp>
490 struct BitonicSortStep<K, V, 1, Dir, Comp> {
491  static inline __device__ void sort(K k[1], V v[1]) {
492  // Update this code if this changes
493  // should go from 1 -> kWarpSize in multiples of 2
494  static_assert(kWarpSize == 32, "unexpected warp size");
495 
496  warpBitonicMergeLE16<K, V, 1, Dir, Comp, false>(k[0], v[0]);
497  warpBitonicMergeLE16<K, V, 2, Dir, Comp, false>(k[0], v[0]);
498  warpBitonicMergeLE16<K, V, 4, Dir, Comp, false>(k[0], v[0]);
499  warpBitonicMergeLE16<K, V, 8, Dir, Comp, false>(k[0], v[0]);
500  warpBitonicMergeLE16<K, V, 16, Dir, Comp, false>(k[0], v[0]);
501  }
502 };
503 
504 /// Sort a list of kWarpSize * N elements in registers, where N is an
505 /// arbitrary >= 1
506 template <typename K, typename V, int N, bool Dir, typename Comp>
507 inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) {
509 }
510 
511 } } // namespace