Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQScanMultiPassPrecomputed.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "PQScanMultiPassPrecomputed.cuh"
11 #include "../GpuResources.h"
12 #include "PQCodeLoad.cuh"
13 #include "IVFUtils.cuh"
14 #include "../utils/ConversionOperators.cuh"
15 #include "../utils/DeviceTensor.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Float16.cuh"
18 #include "../utils/LoadStoreOperators.cuh"
19 #include "../utils/MathOperators.cuh"
20 #include "../utils/StaticUtils.h"
21 #include <limits>
22 
23 namespace faiss { namespace gpu {
24 
25 // For precomputed codes, this calculates and loads code distances
26 // into smem
27 template <typename LookupT, typename LookupVecT>
28 inline __device__ void
29 loadPrecomputedTerm(LookupT* smem,
30  LookupT* term2Start,
31  LookupT* term3Start,
32  int numCodes) {
33  constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT);
34 
35  // We can only use vector loads if the data is guaranteed to be
36  // aligned. The codes are innermost, so if it is evenly divisible,
37  // then any slice will be aligned.
38  if (numCodes % kWordSize == 0) {
39  constexpr int kUnroll = 2;
40 
41  // Load the data by float4 for efficiency, and then handle any remainder
42  // limitVec is the number of whole vec words we can load, in terms
43  // of whole blocks performing the load
44  int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
45  limitVec *= kUnroll * blockDim.x;
46 
47  LookupVecT* smemV = (LookupVecT*) smem;
48  LookupVecT* term2StartV = (LookupVecT*) term2Start;
49  LookupVecT* term3StartV = (LookupVecT*) term3Start;
50 
51  for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
52  LookupVecT vals[kUnroll];
53 
54 #pragma unroll
55  for (int j = 0; j < kUnroll; ++j) {
56  vals[j] =
57  LoadStore<LookupVecT>::load(&term2StartV[i + j * blockDim.x]);
58  }
59 
60 #pragma unroll
61  for (int j = 0; j < kUnroll; ++j) {
62  LookupVecT q =
63  LoadStore<LookupVecT>::load(&term3StartV[i + j * blockDim.x]);
64 
65  vals[j] = Math<LookupVecT>::add(vals[j], q);
66  }
67 
68 #pragma unroll
69  for (int j = 0; j < kUnroll; ++j) {
70  LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
71  }
72  }
73 
74  // This is where we start loading the remainder that does not evenly
75  // fit into kUnroll x blockDim.x
76  int remainder = limitVec * kWordSize;
77 
78  for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
79  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
80  }
81  } else {
82  // Potential unaligned load
83  constexpr int kUnroll = 4;
84 
85  int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
86 
87  int i = threadIdx.x;
88  for (; i < limit; i += kUnroll * blockDim.x) {
89  LookupT vals[kUnroll];
90 
91 #pragma unroll
92  for (int j = 0; j < kUnroll; ++j) {
93  vals[j] = term2Start[i + j * blockDim.x];
94  }
95 
96 #pragma unroll
97  for (int j = 0; j < kUnroll; ++j) {
98  vals[j] = Math<LookupT>::add(vals[j], term3Start[i + j * blockDim.x]);
99  }
100 
101 #pragma unroll
102  for (int j = 0; j < kUnroll; ++j) {
103  smem[i + j * blockDim.x] = vals[j];
104  }
105  }
106 
107  for (; i < numCodes; i += blockDim.x) {
108  smem[i] = Math<LookupT>::add(term2Start[i], term3Start[i]);
109  }
110  }
111 }
112 
113 template <int NumSubQuantizers, typename LookupT, typename LookupVecT>
114 __global__ void
115 pqScanPrecomputedMultiPass(Tensor<float, 2, true> queries,
116  Tensor<float, 2, true> precompTerm1,
117  Tensor<LookupT, 3, true> precompTerm2,
118  Tensor<LookupT, 3, true> precompTerm3,
119  Tensor<int, 2, true> topQueryToCentroid,
120  void** listCodes,
121  int* listLengths,
122  Tensor<int, 2, true> prefixSumOffsets,
123  Tensor<float, 1, true> distance) {
124  // precomputed term 2 + 3 storage
125  // (sub q)(code id)
126  extern __shared__ char smemTerm23[];
127  LookupT* term23 = (LookupT*) smemTerm23;
128 
129  // Each block handles a single query
130  auto queryId = blockIdx.y;
131  auto probeId = blockIdx.x;
132  auto codesPerSubQuantizer = precompTerm2.getSize(2);
133  auto precompTermSize = precompTerm2.getSize(1) * codesPerSubQuantizer;
134 
135  // This is where we start writing out data
136  // We ensure that before the array (at offset -1), there is a 0 value
137  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
138  float* distanceOut = distance[outBase].data();
139 
140  auto listId = topQueryToCentroid[queryId][probeId];
141  // Safety guard in case NaNs in input cause no list ID to be generated
142  if (listId == -1) {
143  return;
144  }
145 
146  unsigned char* codeList = (unsigned char*) listCodes[listId];
147  int limit = listLengths[listId];
148 
149  constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
150  (NumSubQuantizers / 4);
151  unsigned int code32[kNumCode32];
152  unsigned int nextCode32[kNumCode32];
153 
154  // We double-buffer the code loading, which improves memory utilization
155  if (threadIdx.x < limit) {
156  LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
157  }
158 
159  // Load precomputed terms 1, 2, 3
160  float term1 = precompTerm1[queryId][probeId];
161  loadPrecomputedTerm<LookupT, LookupVecT>(term23,
162  precompTerm2[listId].data(),
163  precompTerm3[queryId].data(),
164  precompTermSize);
165 
166  // Prevent WAR dependencies
167  __syncthreads();
168 
169  // Each thread handles one code element in the list, with a
170  // block-wide stride
171  for (int codeIndex = threadIdx.x;
172  codeIndex < limit;
173  codeIndex += blockDim.x) {
174  // Prefetch next codes
175  if (codeIndex + blockDim.x < limit) {
176  LoadCode32<NumSubQuantizers>::load(
177  nextCode32, codeList, codeIndex + blockDim.x);
178  }
179 
180  float dist = term1;
181 
182 #pragma unroll
183  for (int word = 0; word < kNumCode32; ++word) {
184  constexpr int kBytesPerCode32 =
185  NumSubQuantizers < 4 ? NumSubQuantizers : 4;
186 
187  if (kBytesPerCode32 == 1) {
188  auto code = code32[0];
189  dist = ConvertTo<float>::to(term23[code]);
190 
191  } else {
192 #pragma unroll
193  for (int byte = 0; byte < kBytesPerCode32; ++byte) {
194  auto code = getByte(code32[word], byte * 8, 8);
195 
196  auto offset =
197  codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
198 
199  dist += ConvertTo<float>::to(term23[offset + code]);
200  }
201  }
202  }
203 
204  // Write out intermediate distance result
205  // We do not maintain indices here, in order to reduce global
206  // memory traffic. Those are recovered in the final selection step.
207  distanceOut[codeIndex] = dist;
208 
209  // Rotate buffers
210 #pragma unroll
211  for (int word = 0; word < kNumCode32; ++word) {
212  code32[word] = nextCode32[word];
213  }
214  }
215 }
216 
217 void
218 runMultiPassTile(Tensor<float, 2, true>& queries,
219  Tensor<float, 2, true>& precompTerm1,
220  NoTypeTensor<3, true>& precompTerm2,
221  NoTypeTensor<3, true>& precompTerm3,
222  Tensor<int, 2, true>& topQueryToCentroid,
223  bool useFloat16Lookup,
224  int bytesPerCode,
225  int numSubQuantizers,
226  int numSubQuantizerCodes,
227  thrust::device_vector<void*>& listCodes,
228  thrust::device_vector<void*>& listIndices,
229  IndicesOptions indicesOptions,
230  thrust::device_vector<int>& listLengths,
231  Tensor<char, 1, true>& thrustMem,
232  Tensor<int, 2, true>& prefixSumOffsets,
233  Tensor<float, 1, true>& allDistances,
234  Tensor<float, 3, true>& heapDistances,
235  Tensor<int, 3, true>& heapIndices,
236  int k,
237  Tensor<float, 2, true>& outDistances,
238  Tensor<long, 2, true>& outIndices,
239  cudaStream_t stream) {
240  // Calculate offset lengths, so we know where to write out
241  // intermediate results
242  runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
243  thrustMem, stream);
244 
245  // Convert all codes to a distance, and write out (distance,
246  // index) values for all intermediate results
247  {
248  auto kThreadsPerBlock = 256;
249 
250  auto grid = dim3(topQueryToCentroid.getSize(1),
251  topQueryToCentroid.getSize(0));
252  auto block = dim3(kThreadsPerBlock);
253 
254  // pq precomputed terms (2 + 3)
255  auto smem = sizeof(float);
256 #ifdef FAISS_USE_FLOAT16
257  if (useFloat16Lookup) {
258  smem = sizeof(half);
259  }
260 #endif
261  smem *= numSubQuantizers * numSubQuantizerCodes;
262  FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
263 
264 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
265  do { \
266  auto precompTerm2T = precompTerm2.toTensor<LOOKUP_T>(); \
267  auto precompTerm3T = precompTerm3.toTensor<LOOKUP_T>(); \
268  \
269  pqScanPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
270  <<<grid, block, smem, stream>>>( \
271  queries, \
272  precompTerm1, \
273  precompTerm2T, \
274  precompTerm3T, \
275  topQueryToCentroid, \
276  listCodes.data().get(), \
277  listLengths.data().get(), \
278  prefixSumOffsets, \
279  allDistances); \
280  } while (0)
281 
282 #ifdef FAISS_USE_FLOAT16
283 #define RUN_PQ(NUM_SUB_Q) \
284  do { \
285  if (useFloat16Lookup) { \
286  RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
287  } else { \
288  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
289  } \
290  } while (0)
291 #else
292 #define RUN_PQ(NUM_SUB_Q) \
293  do { \
294  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
295  } while (0)
296 #endif // FAISS_USE_FLOAT16
297 
298  switch (bytesPerCode) {
299  case 1:
300  RUN_PQ(1);
301  break;
302  case 2:
303  RUN_PQ(2);
304  break;
305  case 3:
306  RUN_PQ(3);
307  break;
308  case 4:
309  RUN_PQ(4);
310  break;
311  case 8:
312  RUN_PQ(8);
313  break;
314  case 12:
315  RUN_PQ(12);
316  break;
317  case 16:
318  RUN_PQ(16);
319  break;
320  case 20:
321  RUN_PQ(20);
322  break;
323  case 24:
324  RUN_PQ(24);
325  break;
326  case 28:
327  RUN_PQ(28);
328  break;
329  case 32:
330  RUN_PQ(32);
331  break;
332  case 40:
333  RUN_PQ(40);
334  break;
335  case 48:
336  RUN_PQ(48);
337  break;
338  case 56:
339  RUN_PQ(56);
340  break;
341  case 64:
342  RUN_PQ(64);
343  break;
344  case 96:
345  RUN_PQ(96);
346  break;
347  default:
348  FAISS_ASSERT(false);
349  break;
350  }
351 
352  CUDA_TEST_ERROR();
353 
354 #undef RUN_PQ
355 #undef RUN_PQ_OPT
356  }
357 
358  // k-select the output in chunks, to increase parallelism
359  runPass1SelectLists(prefixSumOffsets,
360  allDistances,
361  topQueryToCentroid.getSize(1),
362  k,
363  false, // L2 distance chooses smallest
364  heapDistances,
365  heapIndices,
366  stream);
367 
368  // k-select final output
369  auto flatHeapDistances = heapDistances.downcastInner<2>();
370  auto flatHeapIndices = heapIndices.downcastInner<2>();
371 
372  runPass2SelectLists(flatHeapDistances,
373  flatHeapIndices,
374  listIndices,
375  indicesOptions,
376  prefixSumOffsets,
377  topQueryToCentroid,
378  k,
379  false, // L2 distance chooses smallest
380  outDistances,
381  outIndices,
382  stream);
383 
384  CUDA_TEST_ERROR();
385 }
386 
387 void runPQScanMultiPassPrecomputed(Tensor<float, 2, true>& queries,
388  Tensor<float, 2, true>& precompTerm1,
389  NoTypeTensor<3, true>& precompTerm2,
390  NoTypeTensor<3, true>& precompTerm3,
391  Tensor<int, 2, true>& topQueryToCentroid,
392  bool useFloat16Lookup,
393  int bytesPerCode,
394  int numSubQuantizers,
395  int numSubQuantizerCodes,
396  thrust::device_vector<void*>& listCodes,
397  thrust::device_vector<void*>& listIndices,
398  IndicesOptions indicesOptions,
399  thrust::device_vector<int>& listLengths,
400  int maxListLength,
401  int k,
402  // output
403  Tensor<float, 2, true>& outDistances,
404  // output
405  Tensor<long, 2, true>& outIndices,
406  GpuResources* res) {
407  constexpr int kMinQueryTileSize = 8;
408  constexpr int kMaxQueryTileSize = 128;
409  constexpr int kThrustMemSize = 16384;
410 
411  int nprobe = topQueryToCentroid.getSize(1);
412 
413  auto& mem = res->getMemoryManagerCurrentDevice();
414  auto stream = res->getDefaultStreamCurrentDevice();
415 
416  // Make a reservation for Thrust to do its dirty work (global memory
417  // cross-block reduction space); hopefully this is large enough.
418  DeviceTensor<char, 1, true> thrustMem1(
419  mem, {kThrustMemSize}, stream);
420  DeviceTensor<char, 1, true> thrustMem2(
421  mem, {kThrustMemSize}, stream);
422  DeviceTensor<char, 1, true>* thrustMem[2] =
423  {&thrustMem1, &thrustMem2};
424 
425  // How much temporary storage is available?
426  // If possible, we'd like to fit within the space available.
427  size_t sizeAvailable = mem.getSizeAvailable();
428 
429  // We run two passes of heap selection
430  // This is the size of the first-level heap passes
431  constexpr int kNProbeSplit = 8;
432  int pass2Chunks = std::min(nprobe, kNProbeSplit);
433 
434  size_t sizeForFirstSelectPass =
435  pass2Chunks * k * (sizeof(float) + sizeof(int));
436 
437  // How much temporary storage we need per each query
438  size_t sizePerQuery =
439  2 * // # streams
440  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
441  nprobe * maxListLength * sizeof(float) + // allDistances
442  sizeForFirstSelectPass);
443 
444  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
445 
446  if (queryTileSize < kMinQueryTileSize) {
447  queryTileSize = kMinQueryTileSize;
448  } else if (queryTileSize > kMaxQueryTileSize) {
449  queryTileSize = kMaxQueryTileSize;
450  }
451 
452  // FIXME: we should adjust queryTileSize to deal with this, since
453  // indexing is in int32
454  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <=
455  std::numeric_limits<int>::max());
456 
457  // Temporary memory buffers
458  // Make sure there is space prior to the start which will be 0, and
459  // will handle the boundary condition without branches
460  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
461  mem, {queryTileSize * nprobe + 1}, stream);
462  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
463  mem, {queryTileSize * nprobe + 1}, stream);
464 
465  DeviceTensor<int, 2, true> prefixSumOffsets1(
466  prefixSumOffsetSpace1[1].data(),
467  {queryTileSize, nprobe});
468  DeviceTensor<int, 2, true> prefixSumOffsets2(
469  prefixSumOffsetSpace2[1].data(),
470  {queryTileSize, nprobe});
471  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
472  {&prefixSumOffsets1, &prefixSumOffsets2};
473 
474  // Make sure the element before prefixSumOffsets is 0, since we
475  // depend upon simple, boundary-less indexing to get proper results
476  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
477  0,
478  sizeof(int),
479  stream));
480  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
481  0,
482  sizeof(int),
483  stream));
484 
485  DeviceTensor<float, 1, true> allDistances1(
486  mem, {queryTileSize * nprobe * maxListLength}, stream);
487  DeviceTensor<float, 1, true> allDistances2(
488  mem, {queryTileSize * nprobe * maxListLength}, stream);
489  DeviceTensor<float, 1, true>* allDistances[2] =
490  {&allDistances1, &allDistances2};
491 
492  DeviceTensor<float, 3, true> heapDistances1(
493  mem, {queryTileSize, pass2Chunks, k}, stream);
494  DeviceTensor<float, 3, true> heapDistances2(
495  mem, {queryTileSize, pass2Chunks, k}, stream);
496  DeviceTensor<float, 3, true>* heapDistances[2] =
497  {&heapDistances1, &heapDistances2};
498 
499  DeviceTensor<int, 3, true> heapIndices1(
500  mem, {queryTileSize, pass2Chunks, k}, stream);
501  DeviceTensor<int, 3, true> heapIndices2(
502  mem, {queryTileSize, pass2Chunks, k}, stream);
503  DeviceTensor<int, 3, true>* heapIndices[2] =
504  {&heapIndices1, &heapIndices2};
505 
506  auto streams = res->getAlternateStreamsCurrentDevice();
507  streamWait(streams, {stream});
508 
509  int curStream = 0;
510 
511  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
512  int numQueriesInTile =
513  std::min(queryTileSize, queries.getSize(0) - query);
514 
515  auto prefixSumOffsetsView =
516  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
517 
518  auto coarseIndicesView =
519  topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
520  auto queryView =
521  queries.narrowOutermost(query, numQueriesInTile);
522  auto term1View =
523  precompTerm1.narrowOutermost(query, numQueriesInTile);
524  auto term3View =
525  precompTerm3.narrowOutermost(query, numQueriesInTile);
526 
527  auto heapDistancesView =
528  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
529  auto heapIndicesView =
530  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
531 
532  auto outDistanceView =
533  outDistances.narrowOutermost(query, numQueriesInTile);
534  auto outIndicesView =
535  outIndices.narrowOutermost(query, numQueriesInTile);
536 
537  runMultiPassTile(queryView,
538  term1View,
539  precompTerm2,
540  term3View,
541  coarseIndicesView,
542  useFloat16Lookup,
543  bytesPerCode,
544  numSubQuantizers,
545  numSubQuantizerCodes,
546  listCodes,
547  listIndices,
548  indicesOptions,
549  listLengths,
550  *thrustMem[curStream],
551  prefixSumOffsetsView,
552  *allDistances[curStream],
553  heapDistancesView,
554  heapIndicesView,
555  k,
556  outDistanceView,
557  outIndicesView,
558  streams[curStream]);
559 
560  curStream = (curStream + 1) % 2;
561  }
562 
563  streamWait({stream}, streams);
564 }
565 
566 } } // namespace