Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQScanMultiPassPrecomputed.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "PQScanMultiPassPrecomputed.cuh"
13 #include "../GpuResources.h"
14 #include "PQCodeLoad.cuh"
15 #include "IVFUtils.cuh"
16 #include "../utils/ConversionOperators.cuh"
17 #include "../utils/DeviceTensor.cuh"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Float16.cuh"
20 #include "../utils/LoadStoreOperators.cuh"
21 #include "../utils/MathOperators.cuh"
22 #include "../utils/StaticUtils.h"
23 #include <limits>
24 
25 namespace faiss { namespace gpu {
26 
27 // For precomputed codes, this calculates and loads code distances
28 // into smem
29 template <typename LookupT, typename LookupVecT>
30 inline __device__ void
31 loadPrecomputedTerm(LookupT* smem,
32  LookupT* term2Start,
33  LookupT* term3Start,
34  int numCodes) {
35  constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT);
36 
37  // We can only use vector loads if the data is guaranteed to be
38  // aligned. The codes are innermost, so if it is evenly divisible,
39  // then any slice will be aligned.
40  if (numCodes % kWordSize == 0) {
41  constexpr int kUnroll = 2;
42 
43  // Load the data by float4 for efficiency, and then handle any remainder
44  // limitVec is the number of whole vec words we can load, in terms
45  // of whole blocks performing the load
46  int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
47  limitVec *= kUnroll * blockDim.x;
48 
49  LookupVecT* smemV = (LookupVecT*) smem;
50  LookupVecT* term2StartV = (LookupVecT*) term2Start;
51  LookupVecT* term3StartV = (LookupVecT*) term3Start;
52 
53  for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
54  LookupVecT vals[kUnroll];
55 
56 #pragma unroll
57  for (int j = 0; j < kUnroll; ++j) {
58  vals[j] =
59  LoadStore<LookupVecT>::load(&term2StartV[i + j * blockDim.x]);
60  }
61 
62 #pragma unroll
63  for (int j = 0; j < kUnroll; ++j) {
64  LookupVecT q =
65  LoadStore<LookupVecT>::load(&term3StartV[i + j * blockDim.x]);
66 
67  vals[j] = Math<LookupVecT>::add(vals[j], q);
68  }
69 
70 #pragma unroll
71  for (int j = 0; j < kUnroll; ++j) {
72  LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
73  }
74  }
75 
76  // This is where we start loading the remainder that does not evenly
77  // fit into kUnroll x blockDim.x
78  int remainder = limitVec * kWordSize;
79 
80  for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
81  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
82  }
83  } else {
84  // Potential unaligned load
85  constexpr int kUnroll = 4;
86 
87  int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
88 
89  int i = threadIdx.x;
90  for (; i < limit; i += kUnroll * blockDim.x) {
91  LookupT vals[kUnroll];
92 
93 #pragma unroll
94  for (int j = 0; j < kUnroll; ++j) {
95  vals[j] = term2Start[i + j * blockDim.x];
96  }
97 
98 #pragma unroll
99  for (int j = 0; j < kUnroll; ++j) {
100  vals[j] = Math<LookupT>::add(vals[j], term3Start[i + j * blockDim.x]);
101  }
102 
103 #pragma unroll
104  for (int j = 0; j < kUnroll; ++j) {
105  smem[i + j * blockDim.x] = vals[j];
106  }
107  }
108 
109  for (; i < numCodes; i += blockDim.x) {
110  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
111  }
112  }
113 }
114 
115 template <int NumSubQuantizers, typename LookupT, typename LookupVecT>
116 __global__ void
117 pqScanPrecomputedMultiPass(Tensor<float, 2, true> queries,
118  Tensor<float, 2, true> precompTerm1,
119  Tensor<LookupT, 3, true> precompTerm2,
120  Tensor<LookupT, 3, true> precompTerm3,
121  Tensor<int, 2, true> topQueryToCentroid,
122  void** listCodes,
123  int* listLengths,
124  Tensor<int, 2, true> prefixSumOffsets,
125  Tensor<float, 1, true> distance) {
126  // precomputed term 2 + 3 storage
127  // (sub q)(code id)
128  extern __shared__ char smemTerm23[];
129  LookupT* term23 = (LookupT*) smemTerm23;
130 
131  // Each block handles a single query
132  auto queryId = blockIdx.y;
133  auto probeId = blockIdx.x;
134  auto codesPerSubQuantizer = precompTerm2.getSize(2);
135  auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer;
136 
137  // This is where we start writing out data
138  // We ensure that before the array (at offset -1), there is a 0 value
139  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
140  float* distanceOut = distance[outBase].data();
141 
142  auto listId = topQueryToCentroid[queryId][probeId];
143  // Safety guard in case NaNs in input cause no list ID to be generated
144  if (listId == -1) {
145  return;
146  }
147 
148  unsigned char* codeList = (unsigned char*) listCodes[listId];
149  int limit = listLengths[listId];
150 
151  constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
152  (NumSubQuantizers / 4);
153  unsigned int code32[kNumCode32];
154  unsigned int nextCode32[kNumCode32];
155 
156  // We double-buffer the code loading, which improves memory utilization
157  if (threadIdx.x < limit) {
158  LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
159  }
160 
161  // Load precomputed terms 1, 2, 3
162  float term1 = precompTerm1[queryId][probeId];
163  loadPrecomputedTerm<LookupT, LookupVecT>(term23,
164  precompTerm2[listId].data(),
165  precompTerm3[queryId].data(),
166  precompTermSize);
167 
168  // Prevent WAR dependencies
169  __syncthreads();
170 
171  // Each thread handles one code element in the list, with a
172  // block-wide stride
173  for (int codeIndex = threadIdx.x;
174  codeIndex < limit;
175  codeIndex += blockDim.x) {
176  // Prefetch next codes
177  if (codeIndex + blockDim.x < limit) {
178  LoadCode32<NumSubQuantizers>::load(
179  nextCode32, codeList, codeIndex + blockDim.x);
180  }
181 
182  float dist = term1;
183 
184 #pragma unroll
185  for (int word = 0; word < kNumCode32; ++word) {
186  constexpr int kBytesPerCode32 =
187  NumSubQuantizers < 4 ? NumSubQuantizers : 4;
188 
189  if (kBytesPerCode32 == 1) {
190  auto code = code32[0];
191  dist = ConvertTo<float>::to(term23[code]);
192 
193  } else {
194 #pragma unroll
195  for (int byte = 0; byte < kBytesPerCode32; ++byte) {
196  auto code = getByte(code32[word], byte * 8, 8);
197 
198  auto offset =
199  codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
200 
201  dist += ConvertTo<float>::to(term23[offset + code]);
202  }
203  }
204  }
205 
206  // Write out intermediate distance result
207  // We do not maintain indices here, in order to reduce global
208  // memory traffic. Those are recovered in the final selection step.
209  distanceOut[codeIndex] = dist;
210 
211  // Rotate buffers
212 #pragma unroll
213  for (int word = 0; word < kNumCode32; ++word) {
214  code32[word] = nextCode32[word];
215  }
216  }
217 }
218 
219 void
220 runMultiPassTile(Tensor<float, 2, true>& queries,
221  Tensor<float, 2, true>& precompTerm1,
222  NoTypeTensor<3, true>& precompTerm2,
223  NoTypeTensor<3, true>& precompTerm3,
224  Tensor<int, 2, true>& topQueryToCentroid,
225  bool useFloat16Lookup,
226  int bytesPerCode,
227  int numSubQuantizers,
228  int numSubQuantizerCodes,
229  thrust::device_vector<void*>& listCodes,
230  thrust::device_vector<void*>& listIndices,
231  IndicesOptions indicesOptions,
232  thrust::device_vector<int>& listLengths,
233  Tensor<char, 1, true>& thrustMem,
234  Tensor<int, 2, true>& prefixSumOffsets,
235  Tensor<float, 1, true>& allDistances,
236  Tensor<float, 3, true>& heapDistances,
237  Tensor<int, 3, true>& heapIndices,
238  int k,
239  Tensor<float, 2, true>& outDistances,
240  Tensor<long, 2, true>& outIndices,
241  cudaStream_t stream) {
242  // Calculate offset lengths, so we know where to write out
243  // intermediate results
244  runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
245  thrustMem, stream);
246 
247  // Convert all codes to a distance, and write out (distance,
248  // index) values for all intermediate results
249  {
250  auto kThreadsPerBlock = 256;
251 
252  auto grid = dim3(topQueryToCentroid.getSize(1),
253  topQueryToCentroid.getSize(0));
254  auto block = dim3(kThreadsPerBlock);
255 
256  // pq precomputed terms (2 + 3)
257  auto smem = sizeof(float);
258 #ifdef FAISS_USE_FLOAT16
259  if (useFloat16Lookup) {
260  smem = sizeof(half);
261  }
262 #endif
263  smem *= numSubQuantizers * numSubQuantizerCodes;
264  FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
265 
266 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
267  do { \
268  auto precompTerm2T = precompTerm2.toTensor<LOOKUP_T>(); \
269  auto precompTerm3T = precompTerm3.toTensor<LOOKUP_T>(); \
270  \
271  pqScanPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
272  <<<grid, block, smem, stream>>>( \
273  queries, \
274  precompTerm1, \
275  precompTerm2T, \
276  precompTerm3T, \
277  topQueryToCentroid, \
278  listCodes.data().get(), \
279  listLengths.data().get(), \
280  prefixSumOffsets, \
281  allDistances); \
282  } while (0)
283 
284 #ifdef FAISS_USE_FLOAT16
285 #define RUN_PQ(NUM_SUB_Q) \
286  do { \
287  if (useFloat16Lookup) { \
288  RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
289  } else { \
290  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
291  } \
292  } while (0)
293 #else
294 #define RUN_PQ(NUM_SUB_Q) \
295  do { \
296  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
297  } while (0)
298 #endif // FAISS_USE_FLOAT16
299 
300  switch (bytesPerCode) {
301  case 1:
302  RUN_PQ(1);
303  break;
304  case 2:
305  RUN_PQ(2);
306  break;
307  case 3:
308  RUN_PQ(3);
309  break;
310  case 4:
311  RUN_PQ(4);
312  break;
313  case 8:
314  RUN_PQ(8);
315  break;
316  case 12:
317  RUN_PQ(12);
318  break;
319  case 16:
320  RUN_PQ(16);
321  break;
322  case 20:
323  RUN_PQ(20);
324  break;
325  case 24:
326  RUN_PQ(24);
327  break;
328  case 28:
329  RUN_PQ(28);
330  break;
331  case 32:
332  RUN_PQ(32);
333  break;
334  case 40:
335  RUN_PQ(40);
336  break;
337  case 48:
338  RUN_PQ(48);
339  break;
340  case 56:
341  RUN_PQ(56);
342  break;
343  case 64:
344  RUN_PQ(64);
345  break;
346  case 96:
347  RUN_PQ(96);
348  break;
349  default:
350  FAISS_ASSERT(false);
351  break;
352  }
353 
354 #undef RUN_PQ
355 #undef RUN_PQ_OPT
356  }
357 
358  // k-select the output in chunks, to increase parallelism
359  runPass1SelectLists(prefixSumOffsets,
360  allDistances,
361  topQueryToCentroid.getSize(1),
362  k,
363  false, // L2 distance chooses smallest
364  heapDistances,
365  heapIndices,
366  stream);
367 
368  // k-select final output
369  auto flatHeapDistances = heapDistances.downcastInner<2>();
370  auto flatHeapIndices = heapIndices.downcastInner<2>();
371 
372  runPass2SelectLists(flatHeapDistances,
373  flatHeapIndices,
374  listIndices,
375  indicesOptions,
376  prefixSumOffsets,
377  topQueryToCentroid,
378  k,
379  false, // L2 distance chooses smallest
380  outDistances,
381  outIndices,
382  stream);
383 
384  CUDA_VERIFY(cudaGetLastError());
385 }
386 
387 void runPQScanMultiPassPrecomputed(Tensor<float, 2, true>& queries,
388  Tensor<float, 2, true>& precompTerm1,
389  NoTypeTensor<3, true>& precompTerm2,
390  NoTypeTensor<3, true>& precompTerm3,
391  Tensor<int, 2, true>& topQueryToCentroid,
392  bool useFloat16Lookup,
393  int bytesPerCode,
394  int numSubQuantizers,
395  int numSubQuantizerCodes,
396  thrust::device_vector<void*>& listCodes,
397  thrust::device_vector<void*>& listIndices,
398  IndicesOptions indicesOptions,
399  thrust::device_vector<int>& listLengths,
400  int maxListLength,
401  int k,
402  // output
403  Tensor<float, 2, true>& outDistances,
404  // output
405  Tensor<long, 2, true>& outIndices,
406  GpuResources* res) {
407  constexpr int kMinQueryTileSize = 8;
408  constexpr int kMaxQueryTileSize = 128;
409  constexpr int kThrustMemSize = 16384;
410 
411  int nprobe = topQueryToCentroid.getSize(1);
412 
413  auto& mem = res->getMemoryManagerCurrentDevice();
414  auto stream = res->getDefaultStreamCurrentDevice();
415 
416  // Make a reservation for Thrust to do its dirty work (global memory
417  // cross-block reduction space); hopefully this is large enough.
418  DeviceTensor<char, 1, true> thrustMem1(
419  mem, {kThrustMemSize}, stream);
420  DeviceTensor<char, 1, true> thrustMem2(
421  mem, {kThrustMemSize}, stream);
422  DeviceTensor<char, 1, true>* thrustMem[2] =
423  {&thrustMem1, &thrustMem2};
424 
425  // How much temporary storage is available?
426  // If possible, we'd like to fit within the space available.
427  size_t sizeAvailable = mem.getSizeAvailable();
428 
429  // We run two passes of heap selection
430  // This is the size of the first-level heap passes
431  constexpr int kNProbeSplit = 8;
432  int pass2Chunks = std::min(nprobe, kNProbeSplit);
433 
434  size_t sizeForFirstSelectPass =
435  pass2Chunks * k * (sizeof(float) + sizeof(int));
436 
437  // How much temporary storage we need per each query
438  size_t sizePerQuery =
439  2 * // # streams
440  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
441  nprobe * maxListLength * sizeof(float) + // allDistances
442  sizeForFirstSelectPass);
443 
444  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
445 
446  if (queryTileSize < kMinQueryTileSize) {
447  queryTileSize = kMinQueryTileSize;
448  } else if (queryTileSize > kMaxQueryTileSize) {
449  queryTileSize = kMaxQueryTileSize;
450  }
451 
452  // FIXME: we should adjust queryTileSize to deal with this, since
453  // indexing is in int32
454  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <=
455  std::numeric_limits<int>::max());
456 
457  // Temporary memory buffers
458  // Make sure there is space prior to the start which will be 0, and
459  // will handle the boundary condition without branches
460  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
461  mem, {queryTileSize * nprobe + 1}, stream);
462  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
463  mem, {queryTileSize * nprobe + 1}, stream);
464 
465  DeviceTensor<int, 2, true> prefixSumOffsets1(
466  prefixSumOffsetSpace1[1].data(),
467  {queryTileSize, nprobe});
468  DeviceTensor<int, 2, true> prefixSumOffsets2(
469  prefixSumOffsetSpace2[1].data(),
470  {queryTileSize, nprobe});
471  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
472  {&prefixSumOffsets1, &prefixSumOffsets2};
473 
474  // Make sure the element before prefixSumOffsets is 0, since we
475  // depend upon simple, boundary-less indexing to get proper results
476  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
477  0,
478  sizeof(int),
479  stream));
480  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
481  0,
482  sizeof(int),
483  stream));
484 
485  DeviceTensor<float, 1, true> allDistances1(
486  mem, {queryTileSize * nprobe * maxListLength}, stream);
487  DeviceTensor<float, 1, true> allDistances2(
488  mem, {queryTileSize * nprobe * maxListLength}, stream);
489  DeviceTensor<float, 1, true>* allDistances[2] =
490  {&allDistances1, &allDistances2};
491 
492  DeviceTensor<float, 3, true> heapDistances1(
493  mem, {queryTileSize, pass2Chunks, k}, stream);
494  DeviceTensor<float, 3, true> heapDistances2(
495  mem, {queryTileSize, pass2Chunks, k}, stream);
496  DeviceTensor<float, 3, true>* heapDistances[2] =
497  {&heapDistances1, &heapDistances2};
498 
499  DeviceTensor<int, 3, true> heapIndices1(
500  mem, {queryTileSize, pass2Chunks, k}, stream);
501  DeviceTensor<int, 3, true> heapIndices2(
502  mem, {queryTileSize, pass2Chunks, k}, stream);
503  DeviceTensor<int, 3, true>* heapIndices[2] =
504  {&heapIndices1, &heapIndices2};
505 
506  auto streams = res->getAlternateStreamsCurrentDevice();
507  streamWait(streams, {stream});
508 
509  int curStream = 0;
510 
511  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
512  int numQueriesInTile =
513  std::min(queryTileSize, queries.getSize(0) - query);
514 
515  auto prefixSumOffsetsView =
516  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
517 
518  auto coarseIndicesView =
519  topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
520  auto queryView =
521  queries.narrowOutermost(query, numQueriesInTile);
522  auto term1View =
523  precompTerm1.narrowOutermost(query, numQueriesInTile);
524  auto term3View =
525  precompTerm3.narrowOutermost(query, numQueriesInTile);
526 
527  auto heapDistancesView =
528  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
529  auto heapIndicesView =
530  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
531 
532  auto outDistanceView =
533  outDistances.narrowOutermost(query, numQueriesInTile);
534  auto outIndicesView =
535  outIndices.narrowOutermost(query, numQueriesInTile);
536 
537  runMultiPassTile(queryView,
538  term1View,
539  precompTerm2,
540  term3View,
541  coarseIndicesView,
542  useFloat16Lookup,
543  bytesPerCode,
544  numSubQuantizers,
545  numSubQuantizerCodes,
546  listCodes,
547  listIndices,
548  indicesOptions,
549  listLengths,
550  *thrustMem[curStream],
551  prefixSumOffsetsView,
552  *allDistances[curStream],
553  heapDistancesView,
554  heapIndicesView,
555  k,
556  outDistanceView,
557  outIndicesView,
558  streams[curStream]);
559 
560  curStream = (curStream + 1) % 2;
561  }
562 
563  streamWait({stream}, streams);
564 }
565 
566 } } // namespace