Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQScanMultiPassPrecomputed.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "PQScanMultiPassPrecomputed.cuh"
10 #include "../GpuResources.h"
11 #include "PQCodeLoad.cuh"
12 #include "IVFUtils.cuh"
13 #include "../utils/ConversionOperators.cuh"
14 #include "../utils/DeviceTensor.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/Float16.cuh"
17 #include "../utils/LoadStoreOperators.cuh"
18 #include "../utils/MathOperators.cuh"
19 #include "../utils/StaticUtils.h"
20 #include <limits>
21 
22 namespace faiss { namespace gpu {
23 
24 // For precomputed codes, this calculates and loads code distances
25 // into smem
26 template <typename LookupT, typename LookupVecT>
27 inline __device__ void
28 loadPrecomputedTerm(LookupT* smem,
29  LookupT* term2Start,
30  LookupT* term3Start,
31  int numCodes) {
32  constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT);
33 
34  // We can only use vector loads if the data is guaranteed to be
35  // aligned. The codes are innermost, so if it is evenly divisible,
36  // then any slice will be aligned.
37  if (numCodes % kWordSize == 0) {
38  constexpr int kUnroll = 2;
39 
40  // Load the data by float4 for efficiency, and then handle any remainder
41  // limitVec is the number of whole vec words we can load, in terms
42  // of whole blocks performing the load
43  int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
44  limitVec *= kUnroll * blockDim.x;
45 
46  LookupVecT* smemV = (LookupVecT*) smem;
47  LookupVecT* term2StartV = (LookupVecT*) term2Start;
48  LookupVecT* term3StartV = (LookupVecT*) term3Start;
49 
50  for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
51  LookupVecT vals[kUnroll];
52 
53 #pragma unroll
54  for (int j = 0; j < kUnroll; ++j) {
55  vals[j] =
56  LoadStore<LookupVecT>::load(&term2StartV[i + j * blockDim.x]);
57  }
58 
59 #pragma unroll
60  for (int j = 0; j < kUnroll; ++j) {
61  LookupVecT q =
62  LoadStore<LookupVecT>::load(&term3StartV[i + j * blockDim.x]);
63 
64  vals[j] = Math<LookupVecT>::add(vals[j], q);
65  }
66 
67 #pragma unroll
68  for (int j = 0; j < kUnroll; ++j) {
69  LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
70  }
71  }
72 
73  // This is where we start loading the remainder that does not evenly
74  // fit into kUnroll x blockDim.x
75  int remainder = limitVec * kWordSize;
76 
77  for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
78  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
79  }
80  } else {
81  // Potential unaligned load
82  constexpr int kUnroll = 4;
83 
84  int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
85 
86  int i = threadIdx.x;
87  for (; i < limit; i += kUnroll * blockDim.x) {
88  LookupT vals[kUnroll];
89 
90 #pragma unroll
91  for (int j = 0; j < kUnroll; ++j) {
92  vals[j] = term2Start[i + j * blockDim.x];
93  }
94 
95 #pragma unroll
96  for (int j = 0; j < kUnroll; ++j) {
97  vals[j] = Math<LookupT>::add(vals[j], term3Start[i + j * blockDim.x]);
98  }
99 
100 #pragma unroll
101  for (int j = 0; j < kUnroll; ++j) {
102  smem[i + j * blockDim.x] = vals[j];
103  }
104  }
105 
106  for (; i < numCodes; i += blockDim.x) {
107  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
108  }
109  }
110 }
111 
112 template <int NumSubQuantizers, typename LookupT, typename LookupVecT>
113 __global__ void
114 pqScanPrecomputedMultiPass(Tensor<float, 2, true> queries,
115  Tensor<float, 2, true> precompTerm1,
116  Tensor<LookupT, 3, true> precompTerm2,
117  Tensor<LookupT, 3, true> precompTerm3,
118  Tensor<int, 2, true> topQueryToCentroid,
119  void** listCodes,
120  int* listLengths,
121  Tensor<int, 2, true> prefixSumOffsets,
122  Tensor<float, 1, true> distance) {
123  // precomputed term 2 + 3 storage
124  // (sub q)(code id)
125  extern __shared__ char smemTerm23[];
126  LookupT* term23 = (LookupT*) smemTerm23;
127 
128  // Each block handles a single query
129  auto queryId = blockIdx.y;
130  auto probeId = blockIdx.x;
131  auto codesPerSubQuantizer = precompTerm2.getSize(2);
132  auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer;
133 
134  // This is where we start writing out data
135  // We ensure that before the array (at offset -1), there is a 0 value
136  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
137  float* distanceOut = distance[outBase].data();
138 
139  auto listId = topQueryToCentroid[queryId][probeId];
140  // Safety guard in case NaNs in input cause no list ID to be generated
141  if (listId == -1) {
142  return;
143  }
144 
145  unsigned char* codeList = (unsigned char*) listCodes[listId];
146  int limit = listLengths[listId];
147 
148  constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
149  (NumSubQuantizers / 4);
150  unsigned int code32[kNumCode32];
151  unsigned int nextCode32[kNumCode32];
152 
153  // We double-buffer the code loading, which improves memory utilization
154  if (threadIdx.x < limit) {
155  LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
156  }
157 
158  // Load precomputed terms 1, 2, 3
159  float term1 = precompTerm1[queryId][probeId];
160  loadPrecomputedTerm<LookupT, LookupVecT>(term23,
161  precompTerm2[listId].data(),
162  precompTerm3[queryId].data(),
163  precompTermSize);
164 
165  // Prevent WAR dependencies
166  __syncthreads();
167 
168  // Each thread handles one code element in the list, with a
169  // block-wide stride
170  for (int codeIndex = threadIdx.x;
171  codeIndex < limit;
172  codeIndex += blockDim.x) {
173  // Prefetch next codes
174  if (codeIndex + blockDim.x < limit) {
175  LoadCode32<NumSubQuantizers>::load(
176  nextCode32, codeList, codeIndex + blockDim.x);
177  }
178 
179  float dist = term1;
180 
181 #pragma unroll
182  for (int word = 0; word < kNumCode32; ++word) {
183  constexpr int kBytesPerCode32 =
184  NumSubQuantizers < 4 ? NumSubQuantizers : 4;
185 
186  if (kBytesPerCode32 == 1) {
187  auto code = code32[0];
188  dist = ConvertTo<float>::to(term23[code]);
189 
190  } else {
191 #pragma unroll
192  for (int byte = 0; byte < kBytesPerCode32; ++byte) {
193  auto code = getByte(code32[word], byte * 8, 8);
194 
195  auto offset =
196  codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
197 
198  dist += ConvertTo<float>::to(term23[offset + code]);
199  }
200  }
201  }
202 
203  // Write out intermediate distance result
204  // We do not maintain indices here, in order to reduce global
205  // memory traffic. Those are recovered in the final selection step.
206  distanceOut[codeIndex] = dist;
207 
208  // Rotate buffers
209 #pragma unroll
210  for (int word = 0; word < kNumCode32; ++word) {
211  code32[word] = nextCode32[word];
212  }
213  }
214 }
215 
216 void
217 runMultiPassTile(Tensor<float, 2, true>& queries,
218  Tensor<float, 2, true>& precompTerm1,
219  NoTypeTensor<3, true>& precompTerm2,
220  NoTypeTensor<3, true>& precompTerm3,
221  Tensor<int, 2, true>& topQueryToCentroid,
222  bool useFloat16Lookup,
223  int bytesPerCode,
224  int numSubQuantizers,
225  int numSubQuantizerCodes,
226  thrust::device_vector<void*>& listCodes,
227  thrust::device_vector<void*>& listIndices,
228  IndicesOptions indicesOptions,
229  thrust::device_vector<int>& listLengths,
230  Tensor<char, 1, true>& thrustMem,
231  Tensor<int, 2, true>& prefixSumOffsets,
232  Tensor<float, 1, true>& allDistances,
233  Tensor<float, 3, true>& heapDistances,
234  Tensor<int, 3, true>& heapIndices,
235  int k,
236  Tensor<float, 2, true>& outDistances,
237  Tensor<long, 2, true>& outIndices,
238  cudaStream_t stream) {
239  // Calculate offset lengths, so we know where to write out
240  // intermediate results
241  runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
242  thrustMem, stream);
243 
244  // Convert all codes to a distance, and write out (distance,
245  // index) values for all intermediate results
246  {
247  auto kThreadsPerBlock = 256;
248 
249  auto grid = dim3(topQueryToCentroid.getSize(1),
250  topQueryToCentroid.getSize(0));
251  auto block = dim3(kThreadsPerBlock);
252 
253  // pq precomputed terms (2 + 3)
254  auto smem = sizeof(float);
255 #ifdef FAISS_USE_FLOAT16
256  if (useFloat16Lookup) {
257  smem = sizeof(half);
258  }
259 #endif
260  smem *= numSubQuantizers * numSubQuantizerCodes;
261  FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
262 
263 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
264  do { \
265  auto precompTerm2T = precompTerm2.toTensor<LOOKUP_T>(); \
266  auto precompTerm3T = precompTerm3.toTensor<LOOKUP_T>(); \
267  \
268  pqScanPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
269  <<<grid, block, smem, stream>>>( \
270  queries, \
271  precompTerm1, \
272  precompTerm2T, \
273  precompTerm3T, \
274  topQueryToCentroid, \
275  listCodes.data().get(), \
276  listLengths.data().get(), \
277  prefixSumOffsets, \
278  allDistances); \
279  } while (0)
280 
281 #ifdef FAISS_USE_FLOAT16
282 #define RUN_PQ(NUM_SUB_Q) \
283  do { \
284  if (useFloat16Lookup) { \
285  RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
286  } else { \
287  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
288  } \
289  } while (0)
290 #else
291 #define RUN_PQ(NUM_SUB_Q) \
292  do { \
293  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
294  } while (0)
295 #endif // FAISS_USE_FLOAT16
296 
297  switch (bytesPerCode) {
298  case 1:
299  RUN_PQ(1);
300  break;
301  case 2:
302  RUN_PQ(2);
303  break;
304  case 3:
305  RUN_PQ(3);
306  break;
307  case 4:
308  RUN_PQ(4);
309  break;
310  case 8:
311  RUN_PQ(8);
312  break;
313  case 12:
314  RUN_PQ(12);
315  break;
316  case 16:
317  RUN_PQ(16);
318  break;
319  case 20:
320  RUN_PQ(20);
321  break;
322  case 24:
323  RUN_PQ(24);
324  break;
325  case 28:
326  RUN_PQ(28);
327  break;
328  case 32:
329  RUN_PQ(32);
330  break;
331  case 40:
332  RUN_PQ(40);
333  break;
334  case 48:
335  RUN_PQ(48);
336  break;
337  case 56:
338  RUN_PQ(56);
339  break;
340  case 64:
341  RUN_PQ(64);
342  break;
343  case 96:
344  RUN_PQ(96);
345  break;
346  default:
347  FAISS_ASSERT(false);
348  break;
349  }
350 
351  CUDA_TEST_ERROR();
352 
353 #undef RUN_PQ
354 #undef RUN_PQ_OPT
355  }
356 
357  // k-select the output in chunks, to increase parallelism
358  runPass1SelectLists(prefixSumOffsets,
359  allDistances,
360  topQueryToCentroid.getSize(1),
361  k,
362  false, // L2 distance chooses smallest
363  heapDistances,
364  heapIndices,
365  stream);
366 
367  // k-select final output
368  auto flatHeapDistances = heapDistances.downcastInner<2>();
369  auto flatHeapIndices = heapIndices.downcastInner<2>();
370 
371  runPass2SelectLists(flatHeapDistances,
372  flatHeapIndices,
373  listIndices,
374  indicesOptions,
375  prefixSumOffsets,
376  topQueryToCentroid,
377  k,
378  false, // L2 distance chooses smallest
379  outDistances,
380  outIndices,
381  stream);
382 
383  CUDA_TEST_ERROR();
384 }
385 
386 void runPQScanMultiPassPrecomputed(Tensor<float, 2, true>& queries,
387  Tensor<float, 2, true>& precompTerm1,
388  NoTypeTensor<3, true>& precompTerm2,
389  NoTypeTensor<3, true>& precompTerm3,
390  Tensor<int, 2, true>& topQueryToCentroid,
391  bool useFloat16Lookup,
392  int bytesPerCode,
393  int numSubQuantizers,
394  int numSubQuantizerCodes,
395  thrust::device_vector<void*>& listCodes,
396  thrust::device_vector<void*>& listIndices,
397  IndicesOptions indicesOptions,
398  thrust::device_vector<int>& listLengths,
399  int maxListLength,
400  int k,
401  // output
402  Tensor<float, 2, true>& outDistances,
403  // output
404  Tensor<long, 2, true>& outIndices,
405  GpuResources* res) {
406  constexpr int kMinQueryTileSize = 8;
407  constexpr int kMaxQueryTileSize = 128;
408  constexpr int kThrustMemSize = 16384;
409 
410  int nprobe = topQueryToCentroid.getSize(1);
411 
412  auto& mem = res->getMemoryManagerCurrentDevice();
413  auto stream = res->getDefaultStreamCurrentDevice();
414 
415  // Make a reservation for Thrust to do its dirty work (global memory
416  // cross-block reduction space); hopefully this is large enough.
417  DeviceTensor<char, 1, true> thrustMem1(
418  mem, {kThrustMemSize}, stream);
419  DeviceTensor<char, 1, true> thrustMem2(
420  mem, {kThrustMemSize}, stream);
421  DeviceTensor<char, 1, true>* thrustMem[2] =
422  {&thrustMem1, &thrustMem2};
423 
424  // How much temporary storage is available?
425  // If possible, we'd like to fit within the space available.
426  size_t sizeAvailable = mem.getSizeAvailable();
427 
428  // We run two passes of heap selection
429  // This is the size of the first-level heap passes
430  constexpr int kNProbeSplit = 8;
431  int pass2Chunks = std::min(nprobe, kNProbeSplit);
432 
433  size_t sizeForFirstSelectPass =
434  pass2Chunks * k * (sizeof(float) + sizeof(int));
435 
436  // How much temporary storage we need per each query
437  size_t sizePerQuery =
438  2 * // # streams
439  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
440  nprobe * maxListLength * sizeof(float) + // allDistances
441  sizeForFirstSelectPass);
442 
443  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
444 
445  if (queryTileSize < kMinQueryTileSize) {
446  queryTileSize = kMinQueryTileSize;
447  } else if (queryTileSize > kMaxQueryTileSize) {
448  queryTileSize = kMaxQueryTileSize;
449  }
450 
451  // FIXME: we should adjust queryTileSize to deal with this, since
452  // indexing is in int32
453  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <=
454  std::numeric_limits<int>::max());
455 
456  // Temporary memory buffers
457  // Make sure there is space prior to the start which will be 0, and
458  // will handle the boundary condition without branches
459  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
460  mem, {queryTileSize * nprobe + 1}, stream);
461  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
462  mem, {queryTileSize * nprobe + 1}, stream);
463 
464  DeviceTensor<int, 2, true> prefixSumOffsets1(
465  prefixSumOffsetSpace1[1].data(),
466  {queryTileSize, nprobe});
467  DeviceTensor<int, 2, true> prefixSumOffsets2(
468  prefixSumOffsetSpace2[1].data(),
469  {queryTileSize, nprobe});
470  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
471  {&prefixSumOffsets1, &prefixSumOffsets2};
472 
473  // Make sure the element before prefixSumOffsets is 0, since we
474  // depend upon simple, boundary-less indexing to get proper results
475  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
476  0,
477  sizeof(int),
478  stream));
479  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
480  0,
481  sizeof(int),
482  stream));
483 
484  DeviceTensor<float, 1, true> allDistances1(
485  mem, {queryTileSize * nprobe * maxListLength}, stream);
486  DeviceTensor<float, 1, true> allDistances2(
487  mem, {queryTileSize * nprobe * maxListLength}, stream);
488  DeviceTensor<float, 1, true>* allDistances[2] =
489  {&allDistances1, &allDistances2};
490 
491  DeviceTensor<float, 3, true> heapDistances1(
492  mem, {queryTileSize, pass2Chunks, k}, stream);
493  DeviceTensor<float, 3, true> heapDistances2(
494  mem, {queryTileSize, pass2Chunks, k}, stream);
495  DeviceTensor<float, 3, true>* heapDistances[2] =
496  {&heapDistances1, &heapDistances2};
497 
498  DeviceTensor<int, 3, true> heapIndices1(
499  mem, {queryTileSize, pass2Chunks, k}, stream);
500  DeviceTensor<int, 3, true> heapIndices2(
501  mem, {queryTileSize, pass2Chunks, k}, stream);
502  DeviceTensor<int, 3, true>* heapIndices[2] =
503  {&heapIndices1, &heapIndices2};
504 
505  auto streams = res->getAlternateStreamsCurrentDevice();
506  streamWait(streams, {stream});
507 
508  int curStream = 0;
509 
510  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
511  int numQueriesInTile =
512  std::min(queryTileSize, queries.getSize(0) - query);
513 
514  auto prefixSumOffsetsView =
515  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
516 
517  auto coarseIndicesView =
518  topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
519  auto queryView =
520  queries.narrowOutermost(query, numQueriesInTile);
521  auto term1View =
522  precompTerm1.narrowOutermost(query, numQueriesInTile);
523  auto term3View =
524  precompTerm3.narrowOutermost(query, numQueriesInTile);
525 
526  auto heapDistancesView =
527  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
528  auto heapIndicesView =
529  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
530 
531  auto outDistanceView =
532  outDistances.narrowOutermost(query, numQueriesInTile);
533  auto outIndicesView =
534  outIndices.narrowOutermost(query, numQueriesInTile);
535 
536  runMultiPassTile(queryView,
537  term1View,
538  precompTerm2,
539  term3View,
540  coarseIndicesView,
541  useFloat16Lookup,
542  bytesPerCode,
543  numSubQuantizers,
544  numSubQuantizerCodes,
545  listCodes,
546  listIndices,
547  indicesOptions,
548  listLengths,
549  *thrustMem[curStream],
550  prefixSumOffsetsView,
551  *allDistances[curStream],
552  heapDistancesView,
553  heapIndicesView,
554  k,
555  outDistanceView,
556  outIndicesView,
557  streams[curStream]);
558 
559  curStream = (curStream + 1) % 2;
560  }
561 
562  streamWait({stream}, streams);
563 }
564 
565 } } // namespace