Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQScanMultiPassPrecomputed.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "PQScanMultiPassPrecomputed.cuh"
13 #include "../GpuResources.h"
14 #include "PQCodeLoad.cuh"
15 #include "IVFUtils.cuh"
16 #include "../utils/ConversionOperators.cuh"
17 #include "../utils/DeviceTensor.cuh"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Float16.cuh"
20 #include "../utils/LoadStoreOperators.cuh"
21 #include "../utils/MathOperators.cuh"
22 #include "../utils/StaticUtils.h"
23 #include <limits>
24 
25 namespace faiss { namespace gpu {
26 
27 // For precomputed codes, this calculates and loads code distances
28 // into smem
29 template <typename LookupT, typename LookupVecT>
30 inline __device__ void
31 loadPrecomputedTerm(LookupT* smem,
32  LookupT* term2Start,
33  LookupT* term3Start,
34  int numCodes) {
35  constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT);
36 
37  // We can only use vector loads if the data is guaranteed to be
38  // aligned. The codes are innermost, so if it is evenly divisible,
39  // then any slice will be aligned.
40  if (numCodes % kWordSize == 0) {
41  constexpr int kUnroll = 2;
42 
43  // Load the data by float4 for efficiency, and then handle any remainder
44  // limitVec is the number of whole vec words we can load, in terms
45  // of whole blocks performing the load
46  int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
47  limitVec *= kUnroll * blockDim.x;
48 
49  LookupVecT* smemV = (LookupVecT*) smem;
50  LookupVecT* term2StartV = (LookupVecT*) term2Start;
51  LookupVecT* term3StartV = (LookupVecT*) term3Start;
52 
53  for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
54  LookupVecT vals[kUnroll];
55 
56 #pragma unroll
57  for (int j = 0; j < kUnroll; ++j) {
58  vals[j] =
59  LoadStore<LookupVecT>::load(&term2StartV[i + j * blockDim.x]);
60  }
61 
62 #pragma unroll
63  for (int j = 0; j < kUnroll; ++j) {
64  LookupVecT q =
65  LoadStore<LookupVecT>::load(&term3StartV[i + j * blockDim.x]);
66 
67  vals[j] = Math<LookupVecT>::add(vals[j], q);
68  }
69 
70 #pragma unroll
71  for (int j = 0; j < kUnroll; ++j) {
72  LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
73  }
74  }
75 
76  // This is where we start loading the remainder that does not evenly
77  // fit into kUnroll x blockDim.x
78  int remainder = limitVec * kWordSize;
79 
80  for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
81  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
82  }
83  } else {
84  // Potential unaligned load
85  constexpr int kUnroll = 4;
86 
87  int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
88 
89  int i = threadIdx.x;
90  for (; i < limit; i += kUnroll * blockDim.x) {
91  LookupT vals[kUnroll];
92 
93 #pragma unroll
94  for (int j = 0; j < kUnroll; ++j) {
95  vals[j] = term2Start[i + j * blockDim.x];
96  }
97 
98 #pragma unroll
99  for (int j = 0; j < kUnroll; ++j) {
100  vals[j] = Math<LookupT>::add(vals[j], term3Start[i + j * blockDim.x]);
101  }
102 
103 #pragma unroll
104  for (int j = 0; j < kUnroll; ++j) {
105  smem[i + j * blockDim.x] = vals[j];
106  }
107  }
108 
109  for (; i < numCodes; i += blockDim.x) {
110  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
111  }
112  }
113 }
114 
115 template <int NumSubQuantizers, typename LookupT, typename LookupVecT>
116 __global__ void
117 pqScanPrecomputedMultiPass(Tensor<float, 2, true> queries,
118  Tensor<float, 2, true> precompTerm1,
119  Tensor<LookupT, 3, true> precompTerm2,
120  Tensor<LookupT, 3, true> precompTerm3,
121  Tensor<int, 2, true> topQueryToCentroid,
122  void** listCodes,
123  int* listLengths,
124  Tensor<int, 2, true> prefixSumOffsets,
125  Tensor<float, 1, true> distance) {
126  // precomputed term 2 + 3 storage
127  // (sub q)(code id)
128  extern __shared__ char smemTerm23[];
129  LookupT* term23 = (LookupT*) smemTerm23;
130 
131  // Each block handles a single query
132  auto queryId = blockIdx.y;
133  auto probeId = blockIdx.x;
134  auto codesPerSubQuantizer = precompTerm2.getSize(2);
135  auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer;
136 
137  // This is where we start writing out data
138  // We ensure that before the array (at offset -1), there is a 0 value
139  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
140  float* distanceOut = distance[outBase].data();
141 
142  auto listId = topQueryToCentroid[queryId][probeId];
143  // Safety guard in case NaNs in input cause no list ID to be generated
144  if (listId == -1) {
145  return;
146  }
147 
148  unsigned char* codeList = (unsigned char*) listCodes[listId];
149  int limit = listLengths[listId];
150 
151  constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
152  (NumSubQuantizers / 4);
153  unsigned int code32[kNumCode32];
154  unsigned int nextCode32[kNumCode32];
155 
156  // We double-buffer the code loading, which improves memory utilization
157  if (threadIdx.x < limit) {
158  LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
159  }
160 
161  // Load precomputed terms 1, 2, 3
162  float term1 = precompTerm1[queryId][probeId];
163  loadPrecomputedTerm<LookupT, LookupVecT>(term23,
164  precompTerm2[listId].data(),
165  precompTerm3[queryId].data(),
166  precompTermSize);
167 
168  // Prevent WAR dependencies
169  __syncthreads();
170 
171  // Each thread handles one code element in the list, with a
172  // block-wide stride
173  for (int codeIndex = threadIdx.x;
174  codeIndex < limit;
175  codeIndex += blockDim.x) {
176  // Prefetch next codes
177  if (codeIndex + blockDim.x < limit) {
178  LoadCode32<NumSubQuantizers>::load(
179  nextCode32, codeList, codeIndex + blockDim.x);
180  }
181 
182  float dist = term1;
183 
184 #pragma unroll
185  for (int word = 0; word < kNumCode32; ++word) {
186  constexpr int kBytesPerCode32 =
187  NumSubQuantizers < 4 ? NumSubQuantizers : 4;
188 
189  if (kBytesPerCode32 == 1) {
190  auto code = code32[0];
191  dist = ConvertTo<float>::to(term23[code]);
192 
193  } else {
194 #pragma unroll
195  for (int byte = 0; byte < kBytesPerCode32; ++byte) {
196  auto code = getByte(code32[word], byte * 8, 8);
197 
198  auto offset =
199  codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
200 
201  dist += ConvertTo<float>::to(term23[offset + code]);
202  }
203  }
204  }
205 
206  // Write out intermediate distance result
207  // We do not maintain indices here, in order to reduce global
208  // memory traffic. Those are recovered in the final selection step.
209  distanceOut[codeIndex] = dist;
210 
211  // Rotate buffers
212 #pragma unroll
213  for (int word = 0; word < kNumCode32; ++word) {
214  code32[word] = nextCode32[word];
215  }
216  }
217 }
218 
219 void
220 runMultiPassTile(Tensor<float, 2, true>& queries,
221  Tensor<float, 2, true>& precompTerm1,
222  NoTypeTensor<3, true>& precompTerm2,
223  NoTypeTensor<3, true>& precompTerm3,
224  Tensor<int, 2, true>& topQueryToCentroid,
225  bool useFloat16Lookup,
226  int bytesPerCode,
227  int numSubQuantizers,
228  int numSubQuantizerCodes,
229  thrust::device_vector<void*>& listCodes,
230  thrust::device_vector<void*>& listIndices,
231  IndicesOptions indicesOptions,
232  thrust::device_vector<int>& listLengths,
233  Tensor<char, 1, true>& thrustMem,
234  Tensor<int, 2, true>& prefixSumOffsets,
235  Tensor<float, 1, true>& allDistances,
236  Tensor<float, 3, true>& heapDistances,
237  Tensor<int, 3, true>& heapIndices,
238  int k,
239  Tensor<float, 2, true>& outDistances,
240  Tensor<long, 2, true>& outIndices,
241  cudaStream_t stream) {
242  // Calculate offset lengths, so we know where to write out
243  // intermediate results
244  runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
245  thrustMem, stream);
246 
247  // Convert all codes to a distance, and write out (distance,
248  // index) values for all intermediate results
249  {
250  auto kThreadsPerBlock = 256;
251 
252  auto grid = dim3(topQueryToCentroid.getSize(1),
253  topQueryToCentroid.getSize(0));
254  auto block = dim3(kThreadsPerBlock);
255 
256  // pq precomputed terms (2 + 3)
257  auto smem = sizeof(float);
258 #ifdef FAISS_USE_FLOAT16
259  if (useFloat16Lookup) {
260  smem = sizeof(half);
261  }
262 #endif
263  smem *= numSubQuantizers * numSubQuantizerCodes;
264  FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
265 
266 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
267  do { \
268  auto precompTerm2T = precompTerm2.toTensor<LOOKUP_T>(); \
269  auto precompTerm3T = precompTerm3.toTensor<LOOKUP_T>(); \
270  \
271  pqScanPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
272  <<<grid, block, smem, stream>>>( \
273  queries, \
274  precompTerm1, \
275  precompTerm2T, \
276  precompTerm3T, \
277  topQueryToCentroid, \
278  listCodes.data().get(), \
279  listLengths.data().get(), \
280  prefixSumOffsets, \
281  allDistances); \
282  } while (0)
283 
284 #ifdef FAISS_USE_FLOAT16
285 #define RUN_PQ(NUM_SUB_Q) \
286  do { \
287  if (useFloat16Lookup) { \
288  RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
289  } else { \
290  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
291  } \
292  } while (0)
293 #else
294 #define RUN_PQ(NUM_SUB_Q) \
295  do { \
296  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
297  } while (0)
298 #endif // FAISS_USE_FLOAT16
299 
300  switch (bytesPerCode) {
301  case 1:
302  RUN_PQ(1);
303  break;
304  case 2:
305  RUN_PQ(2);
306  break;
307  case 3:
308  RUN_PQ(3);
309  break;
310  case 4:
311  RUN_PQ(4);
312  break;
313  case 8:
314  RUN_PQ(8);
315  break;
316  case 12:
317  RUN_PQ(12);
318  break;
319  case 16:
320  RUN_PQ(16);
321  break;
322  case 20:
323  RUN_PQ(20);
324  break;
325  case 24:
326  RUN_PQ(24);
327  break;
328  case 28:
329  RUN_PQ(28);
330  break;
331  case 32:
332  RUN_PQ(32);
333  break;
334  case 40:
335  RUN_PQ(40);
336  break;
337  case 48:
338  RUN_PQ(48);
339  break;
340  case 56:
341  RUN_PQ(56);
342  break;
343  case 64:
344  RUN_PQ(64);
345  break;
346  default:
347  FAISS_ASSERT(false);
348  break;
349  }
350 
351 #undef RUN_PQ
352 #undef RUN_PQ_OPT
353  }
354 
355  // k-select the output in chunks, to increase parallelism
356  runPass1SelectLists(prefixSumOffsets,
357  allDistances,
358  topQueryToCentroid.getSize(1),
359  k,
360  false, // L2 distance chooses smallest
361  heapDistances,
362  heapIndices,
363  stream);
364 
365  // k-select final output
366  auto flatHeapDistances = heapDistances.downcastInner<2>();
367  auto flatHeapIndices = heapIndices.downcastInner<2>();
368 
369  runPass2SelectLists(flatHeapDistances,
370  flatHeapIndices,
371  listIndices,
372  indicesOptions,
373  prefixSumOffsets,
374  topQueryToCentroid,
375  k,
376  false, // L2 distance chooses smallest
377  outDistances,
378  outIndices,
379  stream);
380 
381  CUDA_VERIFY(cudaGetLastError());
382 }
383 
384 void runPQScanMultiPassPrecomputed(Tensor<float, 2, true>& queries,
385  Tensor<float, 2, true>& precompTerm1,
386  NoTypeTensor<3, true>& precompTerm2,
387  NoTypeTensor<3, true>& precompTerm3,
388  Tensor<int, 2, true>& topQueryToCentroid,
389  bool useFloat16Lookup,
390  int bytesPerCode,
391  int numSubQuantizers,
392  int numSubQuantizerCodes,
393  thrust::device_vector<void*>& listCodes,
394  thrust::device_vector<void*>& listIndices,
395  IndicesOptions indicesOptions,
396  thrust::device_vector<int>& listLengths,
397  int maxListLength,
398  int k,
399  // output
400  Tensor<float, 2, true>& outDistances,
401  // output
402  Tensor<long, 2, true>& outIndices,
403  GpuResources* res) {
404  constexpr int kMinQueryTileSize = 8;
405  constexpr int kMaxQueryTileSize = 128;
406  constexpr int kThrustMemSize = 16384;
407 
408  int nprobe = topQueryToCentroid.getSize(1);
409 
410  auto& mem = res->getMemoryManagerCurrentDevice();
411  auto stream = res->getDefaultStreamCurrentDevice();
412 
413  // Make a reservation for Thrust to do its dirty work (global memory
414  // cross-block reduction space); hopefully this is large enough.
415  DeviceTensor<char, 1, true> thrustMem1(
416  mem, {kThrustMemSize}, stream);
417  DeviceTensor<char, 1, true> thrustMem2(
418  mem, {kThrustMemSize}, stream);
419  DeviceTensor<char, 1, true>* thrustMem[2] =
420  {&thrustMem1, &thrustMem2};
421 
422  // How much temporary storage is available?
423  // If possible, we'd like to fit within the space available.
424  size_t sizeAvailable = mem.getSizeAvailable();
425 
426  // We run two passes of heap selection
427  // This is the size of the first-level heap passes
428  constexpr int kNProbeSplit = 8;
429  int pass2Chunks = std::min(nprobe, kNProbeSplit);
430 
431  size_t sizeForFirstSelectPass =
432  pass2Chunks * k * (sizeof(float) + sizeof(int));
433 
434  // How much temporary storage we need per each query
435  size_t sizePerQuery =
436  2 * // # streams
437  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
438  nprobe * maxListLength * sizeof(float) + // allDistances
439  sizeForFirstSelectPass);
440 
441  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
442 
443  if (queryTileSize < kMinQueryTileSize) {
444  queryTileSize = kMinQueryTileSize;
445  } else if (queryTileSize > kMaxQueryTileSize) {
446  queryTileSize = kMaxQueryTileSize;
447  }
448 
449  // FIXME: we should adjust queryTileSize to deal with this, since
450  // indexing is in int32
451  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <=
452  std::numeric_limits<int>::max());
453 
454  // Temporary memory buffers
455  // Make sure there is space prior to the start which will be 0, and
456  // will handle the boundary condition without branches
457  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
458  mem, {queryTileSize * nprobe + 1}, stream);
459  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
460  mem, {queryTileSize * nprobe + 1}, stream);
461 
462  DeviceTensor<int, 2, true> prefixSumOffsets1(
463  prefixSumOffsetSpace1[1].data(),
464  {queryTileSize, nprobe});
465  DeviceTensor<int, 2, true> prefixSumOffsets2(
466  prefixSumOffsetSpace2[1].data(),
467  {queryTileSize, nprobe});
468  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
469  {&prefixSumOffsets1, &prefixSumOffsets2};
470 
471  // Make sure the element before prefixSumOffsets is 0, since we
472  // depend upon simple, boundary-less indexing to get proper results
473  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
474  0,
475  sizeof(int),
476  stream));
477  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
478  0,
479  sizeof(int),
480  stream));
481 
482  DeviceTensor<float, 1, true> allDistances1(
483  mem, {queryTileSize * nprobe * maxListLength}, stream);
484  DeviceTensor<float, 1, true> allDistances2(
485  mem, {queryTileSize * nprobe * maxListLength}, stream);
486  DeviceTensor<float, 1, true>* allDistances[2] =
487  {&allDistances1, &allDistances2};
488 
489  DeviceTensor<float, 3, true> heapDistances1(
490  mem, {queryTileSize, pass2Chunks, k}, stream);
491  DeviceTensor<float, 3, true> heapDistances2(
492  mem, {queryTileSize, pass2Chunks, k}, stream);
493  DeviceTensor<float, 3, true>* heapDistances[2] =
494  {&heapDistances1, &heapDistances2};
495 
496  DeviceTensor<int, 3, true> heapIndices1(
497  mem, {queryTileSize, pass2Chunks, k}, stream);
498  DeviceTensor<int, 3, true> heapIndices2(
499  mem, {queryTileSize, pass2Chunks, k}, stream);
500  DeviceTensor<int, 3, true>* heapIndices[2] =
501  {&heapIndices1, &heapIndices2};
502 
503  auto streams = res->getAlternateStreamsCurrentDevice();
504  streamWait(streams, {stream});
505 
506  int curStream = 0;
507 
508  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
509  int numQueriesInTile =
510  std::min(queryTileSize, queries.getSize(0) - query);
511 
512  auto prefixSumOffsetsView =
513  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
514 
515  auto coarseIndicesView =
516  topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
517  auto queryView =
518  queries.narrowOutermost(query, numQueriesInTile);
519  auto term1View =
520  precompTerm1.narrowOutermost(query, numQueriesInTile);
521  auto term3View =
522  precompTerm3.narrowOutermost(query, numQueriesInTile);
523 
524  auto heapDistancesView =
525  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
526  auto heapIndicesView =
527  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
528 
529  auto outDistanceView =
530  outDistances.narrowOutermost(query, numQueriesInTile);
531  auto outIndicesView =
532  outIndices.narrowOutermost(query, numQueriesInTile);
533 
534  runMultiPassTile(queryView,
535  term1View,
536  precompTerm2,
537  term3View,
538  coarseIndicesView,
539  useFloat16Lookup,
540  bytesPerCode,
541  numSubQuantizers,
542  numSubQuantizerCodes,
543  listCodes,
544  listIndices,
545  indicesOptions,
546  listLengths,
547  *thrustMem[curStream],
548  prefixSumOffsetsView,
549  *allDistances[curStream],
550  heapDistancesView,
551  heapIndicesView,
552  k,
553  outDistanceView,
554  outIndicesView,
555  streams[curStream]);
556 
557  curStream = (curStream + 1) % 2;
558  }
559 
560  streamWait({stream}, streams);
561 }
562 
563 } } // namespace