Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQScanMultiPassPrecomputed.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "PQScanMultiPassPrecomputed.cuh"
12 #include "../GpuResources.h"
13 #include "PQCodeLoad.cuh"
14 #include "IVFUtils.cuh"
15 #include "../utils/ConversionOperators.cuh"
16 #include "../utils/DeviceTensor.cuh"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/Float16.cuh"
19 #include "../utils/LoadStoreOperators.cuh"
20 #include "../utils/MathOperators.cuh"
21 #include "../utils/StaticUtils.h"
22 #include <limits>
23 
24 namespace faiss { namespace gpu {
25 
26 // For precomputed codes, this calculates and loads code distances
27 // into smem
28 template <typename LookupT, typename LookupVecT>
29 inline __device__ void
30 loadPrecomputedTerm(LookupT* smem,
31  LookupT* term2Start,
32  LookupT* term3Start,
33  int numCodes) {
34  constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT);
35 
36  // We can only use vector loads if the data is guaranteed to be
37  // aligned. The codes are innermost, so if it is evenly divisible,
38  // then any slice will be aligned.
39  if (numCodes % kWordSize == 0) {
40  constexpr int kUnroll = 2;
41 
42  // Load the data by float4 for efficiency, and then handle any remainder
43  // limitVec is the number of whole vec words we can load, in terms
44  // of whole blocks performing the load
45  int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
46  limitVec *= kUnroll * blockDim.x;
47 
48  LookupVecT* smemV = (LookupVecT*) smem;
49  LookupVecT* term2StartV = (LookupVecT*) term2Start;
50  LookupVecT* term3StartV = (LookupVecT*) term3Start;
51 
52  for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
53  LookupVecT vals[kUnroll];
54 
55 #pragma unroll
56  for (int j = 0; j < kUnroll; ++j) {
57  vals[j] =
58  LoadStore<LookupVecT>::load(&term2StartV[i + j * blockDim.x]);
59  }
60 
61 #pragma unroll
62  for (int j = 0; j < kUnroll; ++j) {
63  LookupVecT q =
64  LoadStore<LookupVecT>::load(&term3StartV[i + j * blockDim.x]);
65 
66  vals[j] = Math<LookupVecT>::add(vals[j], q);
67  }
68 
69 #pragma unroll
70  for (int j = 0; j < kUnroll; ++j) {
71  LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
72  }
73  }
74 
75  // This is where we start loading the remainder that does not evenly
76  // fit into kUnroll x blockDim.x
77  int remainder = limitVec * kWordSize;
78 
79  for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
80  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
81  }
82  } else {
83  // Potential unaligned load
84  constexpr int kUnroll = 4;
85 
86  int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
87 
88  int i = threadIdx.x;
89  for (; i < limit; i += kUnroll * blockDim.x) {
90  LookupT vals[kUnroll];
91 
92 #pragma unroll
93  for (int j = 0; j < kUnroll; ++j) {
94  vals[j] = term2Start[i + j * blockDim.x];
95  }
96 
97 #pragma unroll
98  for (int j = 0; j < kUnroll; ++j) {
99  vals[j] = Math<LookupT>::add(vals[j], term3Start[i + j * blockDim.x]);
100  }
101 
102 #pragma unroll
103  for (int j = 0; j < kUnroll; ++j) {
104  smem[i + j * blockDim.x] = vals[j];
105  }
106  }
107 
108  for (; i < numCodes; i += blockDim.x) {
109  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
110  }
111  }
112 }
113 
114 template <int NumSubQuantizers, typename LookupT, typename LookupVecT>
115 __global__ void
116 pqScanPrecomputedMultiPass(Tensor<float, 2, true> queries,
117  Tensor<float, 2, true> precompTerm1,
118  Tensor<LookupT, 3, true> precompTerm2,
119  Tensor<LookupT, 3, true> precompTerm3,
120  Tensor<int, 2, true> topQueryToCentroid,
121  void** listCodes,
122  int* listLengths,
123  Tensor<int, 2, true> prefixSumOffsets,
124  Tensor<float, 1, true> distance) {
125  // precomputed term 2 + 3 storage
126  // (sub q)(code id)
127  extern __shared__ char smemTerm23[];
128  LookupT* term23 = (LookupT*) smemTerm23;
129 
130  // Each block handles a single query
131  auto queryId = blockIdx.y;
132  auto probeId = blockIdx.x;
133  auto codesPerSubQuantizer = precompTerm2.getSize(2);
134  auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer;
135 
136  // This is where we start writing out data
137  // We ensure that before the array (at offset -1), there is a 0 value
138  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
139  float* distanceOut = distance[outBase].data();
140 
141  auto listId = topQueryToCentroid[queryId][probeId];
142  // Safety guard in case NaNs in input cause no list ID to be generated
143  if (listId == -1) {
144  return;
145  }
146 
147  unsigned char* codeList = (unsigned char*) listCodes[listId];
148  int limit = listLengths[listId];
149 
150  constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
151  (NumSubQuantizers / 4);
152  unsigned int code32[kNumCode32];
153  unsigned int nextCode32[kNumCode32];
154 
155  // We double-buffer the code loading, which improves memory utilization
156  if (threadIdx.x < limit) {
157  LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
158  }
159 
160  // Load precomputed terms 1, 2, 3
161  float term1 = precompTerm1[queryId][probeId];
162  loadPrecomputedTerm<LookupT, LookupVecT>(term23,
163  precompTerm2[listId].data(),
164  precompTerm3[queryId].data(),
165  precompTermSize);
166 
167  // Prevent WAR dependencies
168  __syncthreads();
169 
170  // Each thread handles one code element in the list, with a
171  // block-wide stride
172  for (int codeIndex = threadIdx.x;
173  codeIndex < limit;
174  codeIndex += blockDim.x) {
175  // Prefetch next codes
176  if (codeIndex + blockDim.x < limit) {
177  LoadCode32<NumSubQuantizers>::load(
178  nextCode32, codeList, codeIndex + blockDim.x);
179  }
180 
181  float dist = term1;
182 
183 #pragma unroll
184  for (int word = 0; word < kNumCode32; ++word) {
185  constexpr int kBytesPerCode32 =
186  NumSubQuantizers < 4 ? NumSubQuantizers : 4;
187 
188  if (kBytesPerCode32 == 1) {
189  auto code = code32[0];
190  dist = ConvertTo<float>::to(term23[code]);
191 
192  } else {
193 #pragma unroll
194  for (int byte = 0; byte < kBytesPerCode32; ++byte) {
195  auto code = getByte(code32[word], byte * 8, 8);
196 
197  auto offset =
198  codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
199 
200  dist += ConvertTo<float>::to(term23[offset + code]);
201  }
202  }
203  }
204 
205  // Write out intermediate distance result
206  // We do not maintain indices here, in order to reduce global
207  // memory traffic. Those are recovered in the final selection step.
208  distanceOut[codeIndex] = dist;
209 
210  // Rotate buffers
211 #pragma unroll
212  for (int word = 0; word < kNumCode32; ++word) {
213  code32[word] = nextCode32[word];
214  }
215  }
216 }
217 
218 void
219 runMultiPassTile(Tensor<float, 2, true>& queries,
220  Tensor<float, 2, true>& precompTerm1,
221  NoTypeTensor<3, true>& precompTerm2,
222  NoTypeTensor<3, true>& precompTerm3,
223  Tensor<int, 2, true>& topQueryToCentroid,
224  bool useFloat16Lookup,
225  int bytesPerCode,
226  int numSubQuantizers,
227  int numSubQuantizerCodes,
228  thrust::device_vector<void*>& listCodes,
229  thrust::device_vector<void*>& listIndices,
230  IndicesOptions indicesOptions,
231  thrust::device_vector<int>& listLengths,
232  Tensor<char, 1, true>& thrustMem,
233  Tensor<int, 2, true>& prefixSumOffsets,
234  Tensor<float, 1, true>& allDistances,
235  Tensor<float, 3, true>& heapDistances,
236  Tensor<int, 3, true>& heapIndices,
237  int k,
238  Tensor<float, 2, true>& outDistances,
239  Tensor<long, 2, true>& outIndices,
240  cudaStream_t stream) {
241  // Calculate offset lengths, so we know where to write out
242  // intermediate results
243  runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
244  thrustMem, stream);
245 
246  // Convert all codes to a distance, and write out (distance,
247  // index) values for all intermediate results
248  {
249  auto kThreadsPerBlock = 256;
250 
251  auto grid = dim3(topQueryToCentroid.getSize(1),
252  topQueryToCentroid.getSize(0));
253  auto block = dim3(kThreadsPerBlock);
254 
255  // pq precomputed terms (2 + 3)
256  auto smem = sizeof(float);
257 #ifdef FAISS_USE_FLOAT16
258  if (useFloat16Lookup) {
259  smem = sizeof(half);
260  }
261 #endif
262  smem *= numSubQuantizers * numSubQuantizerCodes;
263  FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
264 
265 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
266  do { \
267  auto precompTerm2T = precompTerm2.toTensor<LOOKUP_T>(); \
268  auto precompTerm3T = precompTerm3.toTensor<LOOKUP_T>(); \
269  \
270  pqScanPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
271  <<<grid, block, smem, stream>>>( \
272  queries, \
273  precompTerm1, \
274  precompTerm2T, \
275  precompTerm3T, \
276  topQueryToCentroid, \
277  listCodes.data().get(), \
278  listLengths.data().get(), \
279  prefixSumOffsets, \
280  allDistances); \
281  } while (0)
282 
283 #ifdef FAISS_USE_FLOAT16
284 #define RUN_PQ(NUM_SUB_Q) \
285  do { \
286  if (useFloat16Lookup) { \
287  RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
288  } else { \
289  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
290  } \
291  } while (0)
292 #else
293 #define RUN_PQ(NUM_SUB_Q) \
294  do { \
295  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
296  } while (0)
297 #endif // FAISS_USE_FLOAT16
298 
299  switch (bytesPerCode) {
300  case 1:
301  RUN_PQ(1);
302  break;
303  case 2:
304  RUN_PQ(2);
305  break;
306  case 3:
307  RUN_PQ(3);
308  break;
309  case 4:
310  RUN_PQ(4);
311  break;
312  case 8:
313  RUN_PQ(8);
314  break;
315  case 12:
316  RUN_PQ(12);
317  break;
318  case 16:
319  RUN_PQ(16);
320  break;
321  case 20:
322  RUN_PQ(20);
323  break;
324  case 24:
325  RUN_PQ(24);
326  break;
327  case 28:
328  RUN_PQ(28);
329  break;
330  case 32:
331  RUN_PQ(32);
332  break;
333  case 40:
334  RUN_PQ(40);
335  break;
336  case 48:
337  RUN_PQ(48);
338  break;
339  case 56:
340  RUN_PQ(56);
341  break;
342  case 64:
343  RUN_PQ(64);
344  break;
345  case 96:
346  RUN_PQ(96);
347  break;
348  default:
349  FAISS_ASSERT(false);
350  break;
351  }
352 
353  CUDA_TEST_ERROR();
354 
355 #undef RUN_PQ
356 #undef RUN_PQ_OPT
357  }
358 
359  // k-select the output in chunks, to increase parallelism
360  runPass1SelectLists(prefixSumOffsets,
361  allDistances,
362  topQueryToCentroid.getSize(1),
363  k,
364  false, // L2 distance chooses smallest
365  heapDistances,
366  heapIndices,
367  stream);
368 
369  // k-select final output
370  auto flatHeapDistances = heapDistances.downcastInner<2>();
371  auto flatHeapIndices = heapIndices.downcastInner<2>();
372 
373  runPass2SelectLists(flatHeapDistances,
374  flatHeapIndices,
375  listIndices,
376  indicesOptions,
377  prefixSumOffsets,
378  topQueryToCentroid,
379  k,
380  false, // L2 distance chooses smallest
381  outDistances,
382  outIndices,
383  stream);
384 
385  CUDA_TEST_ERROR();
386 }
387 
388 void runPQScanMultiPassPrecomputed(Tensor<float, 2, true>& queries,
389  Tensor<float, 2, true>& precompTerm1,
390  NoTypeTensor<3, true>& precompTerm2,
391  NoTypeTensor<3, true>& precompTerm3,
392  Tensor<int, 2, true>& topQueryToCentroid,
393  bool useFloat16Lookup,
394  int bytesPerCode,
395  int numSubQuantizers,
396  int numSubQuantizerCodes,
397  thrust::device_vector<void*>& listCodes,
398  thrust::device_vector<void*>& listIndices,
399  IndicesOptions indicesOptions,
400  thrust::device_vector<int>& listLengths,
401  int maxListLength,
402  int k,
403  // output
404  Tensor<float, 2, true>& outDistances,
405  // output
406  Tensor<long, 2, true>& outIndices,
407  GpuResources* res) {
408  constexpr int kMinQueryTileSize = 8;
409  constexpr int kMaxQueryTileSize = 128;
410  constexpr int kThrustMemSize = 16384;
411 
412  int nprobe = topQueryToCentroid.getSize(1);
413 
414  auto& mem = res->getMemoryManagerCurrentDevice();
415  auto stream = res->getDefaultStreamCurrentDevice();
416 
417  // Make a reservation for Thrust to do its dirty work (global memory
418  // cross-block reduction space); hopefully this is large enough.
419  DeviceTensor<char, 1, true> thrustMem1(
420  mem, {kThrustMemSize}, stream);
421  DeviceTensor<char, 1, true> thrustMem2(
422  mem, {kThrustMemSize}, stream);
423  DeviceTensor<char, 1, true>* thrustMem[2] =
424  {&thrustMem1, &thrustMem2};
425 
426  // How much temporary storage is available?
427  // If possible, we'd like to fit within the space available.
428  size_t sizeAvailable = mem.getSizeAvailable();
429 
430  // We run two passes of heap selection
431  // This is the size of the first-level heap passes
432  constexpr int kNProbeSplit = 8;
433  int pass2Chunks = std::min(nprobe, kNProbeSplit);
434 
435  size_t sizeForFirstSelectPass =
436  pass2Chunks * k * (sizeof(float) + sizeof(int));
437 
438  // How much temporary storage we need per each query
439  size_t sizePerQuery =
440  2 * // # streams
441  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
442  nprobe * maxListLength * sizeof(float) + // allDistances
443  sizeForFirstSelectPass);
444 
445  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
446 
447  if (queryTileSize < kMinQueryTileSize) {
448  queryTileSize = kMinQueryTileSize;
449  } else if (queryTileSize > kMaxQueryTileSize) {
450  queryTileSize = kMaxQueryTileSize;
451  }
452 
453  // FIXME: we should adjust queryTileSize to deal with this, since
454  // indexing is in int32
455  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <=
456  std::numeric_limits<int>::max());
457 
458  // Temporary memory buffers
459  // Make sure there is space prior to the start which will be 0, and
460  // will handle the boundary condition without branches
461  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
462  mem, {queryTileSize * nprobe + 1}, stream);
463  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
464  mem, {queryTileSize * nprobe + 1}, stream);
465 
466  DeviceTensor<int, 2, true> prefixSumOffsets1(
467  prefixSumOffsetSpace1[1].data(),
468  {queryTileSize, nprobe});
469  DeviceTensor<int, 2, true> prefixSumOffsets2(
470  prefixSumOffsetSpace2[1].data(),
471  {queryTileSize, nprobe});
472  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
473  {&prefixSumOffsets1, &prefixSumOffsets2};
474 
475  // Make sure the element before prefixSumOffsets is 0, since we
476  // depend upon simple, boundary-less indexing to get proper results
477  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
478  0,
479  sizeof(int),
480  stream));
481  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
482  0,
483  sizeof(int),
484  stream));
485 
486  DeviceTensor<float, 1, true> allDistances1(
487  mem, {queryTileSize * nprobe * maxListLength}, stream);
488  DeviceTensor<float, 1, true> allDistances2(
489  mem, {queryTileSize * nprobe * maxListLength}, stream);
490  DeviceTensor<float, 1, true>* allDistances[2] =
491  {&allDistances1, &allDistances2};
492 
493  DeviceTensor<float, 3, true> heapDistances1(
494  mem, {queryTileSize, pass2Chunks, k}, stream);
495  DeviceTensor<float, 3, true> heapDistances2(
496  mem, {queryTileSize, pass2Chunks, k}, stream);
497  DeviceTensor<float, 3, true>* heapDistances[2] =
498  {&heapDistances1, &heapDistances2};
499 
500  DeviceTensor<int, 3, true> heapIndices1(
501  mem, {queryTileSize, pass2Chunks, k}, stream);
502  DeviceTensor<int, 3, true> heapIndices2(
503  mem, {queryTileSize, pass2Chunks, k}, stream);
504  DeviceTensor<int, 3, true>* heapIndices[2] =
505  {&heapIndices1, &heapIndices2};
506 
507  auto streams = res->getAlternateStreamsCurrentDevice();
508  streamWait(streams, {stream});
509 
510  int curStream = 0;
511 
512  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
513  int numQueriesInTile =
514  std::min(queryTileSize, queries.getSize(0) - query);
515 
516  auto prefixSumOffsetsView =
517  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
518 
519  auto coarseIndicesView =
520  topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
521  auto queryView =
522  queries.narrowOutermost(query, numQueriesInTile);
523  auto term1View =
524  precompTerm1.narrowOutermost(query, numQueriesInTile);
525  auto term3View =
526  precompTerm3.narrowOutermost(query, numQueriesInTile);
527 
528  auto heapDistancesView =
529  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
530  auto heapIndicesView =
531  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
532 
533  auto outDistanceView =
534  outDistances.narrowOutermost(query, numQueriesInTile);
535  auto outIndicesView =
536  outIndices.narrowOutermost(query, numQueriesInTile);
537 
538  runMultiPassTile(queryView,
539  term1View,
540  precompTerm2,
541  term3View,
542  coarseIndicesView,
543  useFloat16Lookup,
544  bytesPerCode,
545  numSubQuantizers,
546  numSubQuantizerCodes,
547  listCodes,
548  listIndices,
549  indicesOptions,
550  listLengths,
551  *thrustMem[curStream],
552  prefixSumOffsetsView,
553  *allDistances[curStream],
554  heapDistancesView,
555  heapIndicesView,
556  k,
557  outDistanceView,
558  outIndicesView,
559  streams[curStream]);
560 
561  curStream = (curStream + 1) % 2;
562  }
563 
564  streamWait({stream}, streams);
565 }
566 
567 } } // namespace