Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQScanMultiPassNoPrecomputed.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "PQScanMultiPassNoPrecomputed.cuh"
10 #include "../GpuResources.h"
11 #include "PQCodeDistances.cuh"
12 #include "PQCodeLoad.cuh"
13 #include "IVFUtils.cuh"
14 #include "../utils/ConversionOperators.cuh"
15 #include "../utils/DeviceTensor.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Float16.cuh"
18 #include "../utils/LoadStoreOperators.cuh"
19 #include "../utils/NoTypeTensor.cuh"
20 #include "../utils/StaticUtils.h"
21 
22 #include "../utils/HostTensor.cuh"
23 
24 namespace faiss { namespace gpu {
25 
26 // This must be kept in sync with PQCodeDistances.cu
27 bool isSupportedNoPrecomputedSubDimSize(int dims) {
28  switch (dims) {
29  case 1:
30  case 2:
31  case 3:
32  case 4:
33  case 6:
34  case 8:
35  case 10:
36  case 12:
37  case 16:
38  case 20:
39  case 24:
40  case 28:
41  case 32:
42  return true;
43  default:
44  // FIXME: larger sizes require too many registers - we need the
45  // MM implementation working
46  return false;
47  }
48 }
49 
50 template <typename LookupT, typename LookupVecT>
52  static inline __device__ void load(LookupT* smem,
53  LookupT* codes,
54  int numCodes) {
55  constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT);
56 
57  // We can only use the vector type if the data is guaranteed to be
58  // aligned. The codes are innermost, so if it is evenly divisible,
59  // then any slice will be aligned.
60  if (numCodes % kWordSize == 0) {
61  // Load the data by float4 for efficiency, and then handle any remainder
62  // limitVec is the number of whole vec words we can load, in terms
63  // of whole blocks performing the load
64  constexpr int kUnroll = 2;
65  int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
66  limitVec *= kUnroll * blockDim.x;
67 
68  LookupVecT* smemV = (LookupVecT*) smem;
69  LookupVecT* codesV = (LookupVecT*) codes;
70 
71  for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
72  LookupVecT vals[kUnroll];
73 
74 #pragma unroll
75  for (int j = 0; j < kUnroll; ++j) {
76  vals[j] =
77  LoadStore<LookupVecT>::load(&codesV[i + j * blockDim.x]);
78  }
79 
80 #pragma unroll
81  for (int j = 0; j < kUnroll; ++j) {
82  LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
83  }
84  }
85 
86  // This is where we start loading the remainder that does not evenly
87  // fit into kUnroll x blockDim.x
88  int remainder = limitVec * kWordSize;
89 
90  for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
91  smem[i] = codes[i];
92  }
93  } else {
94  // Potential unaligned load
95  constexpr int kUnroll = 4;
96 
97  int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
98 
99  int i = threadIdx.x;
100  for (; i < limit; i += kUnroll * blockDim.x) {
101  LookupT vals[kUnroll];
102 
103 #pragma unroll
104  for (int j = 0; j < kUnroll; ++j) {
105  vals[j] = codes[i + j * blockDim.x];
106  }
107 
108 #pragma unroll
109  for (int j = 0; j < kUnroll; ++j) {
110  smem[i + j * blockDim.x] = vals[j];
111  }
112  }
113 
114  for (; i < numCodes; i += blockDim.x) {
115  smem[i] = codes[i];
116  }
117  }
118  }
119 };
120 
121 template <int NumSubQuantizers, typename LookupT, typename LookupVecT>
122 __global__ void
123 pqScanNoPrecomputedMultiPass(Tensor<float, 2, true> queries,
124  Tensor<float, 3, true> pqCentroids,
125  Tensor<int, 2, true> topQueryToCentroid,
126  Tensor<LookupT, 4, true> codeDistances,
127  void** listCodes,
128  int* listLengths,
129  Tensor<int, 2, true> prefixSumOffsets,
130  Tensor<float, 1, true> distance) {
131  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
132 
133  // Where the pq code -> residual distance is stored
134  extern __shared__ char smemCodeDistances[];
135  LookupT* codeDist = (LookupT*) smemCodeDistances;
136 
137  // Each block handles a single query
138  auto queryId = blockIdx.y;
139  auto probeId = blockIdx.x;
140 
141  // This is where we start writing out data
142  // We ensure that before the array (at offset -1), there is a 0 value
143  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
144  float* distanceOut = distance[outBase].data();
145 
146  auto listId = topQueryToCentroid[queryId][probeId];
147  // Safety guard in case NaNs in input cause no list ID to be generated
148  if (listId == -1) {
149  return;
150  }
151 
152  unsigned char* codeList = (unsigned char*) listCodes[listId];
153  int limit = listLengths[listId];
154 
155  constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
156  (NumSubQuantizers / 4);
157  unsigned int code32[kNumCode32];
158  unsigned int nextCode32[kNumCode32];
159 
160  // We double-buffer the code loading, which improves memory utilization
161  if (threadIdx.x < limit) {
162  LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
163  }
164 
165  LoadCodeDistances<LookupT, LookupVecT>::load(
166  codeDist,
167  codeDistances[queryId][probeId].data(),
168  codeDistances.getSize(2) * codeDistances.getSize(3));
169 
170  // Prevent WAR dependencies
171  __syncthreads();
172 
173  // Each thread handles one code element in the list, with a
174  // block-wide stride
175  for (int codeIndex = threadIdx.x;
176  codeIndex < limit;
177  codeIndex += blockDim.x) {
178  // Prefetch next codes
179  if (codeIndex + blockDim.x < limit) {
180  LoadCode32<NumSubQuantizers>::load(
181  nextCode32, codeList, codeIndex + blockDim.x);
182  }
183 
184  float dist = 0.0f;
185 
186 #pragma unroll
187  for (int word = 0; word < kNumCode32; ++word) {
188  constexpr int kBytesPerCode32 =
189  NumSubQuantizers < 4 ? NumSubQuantizers : 4;
190 
191  if (kBytesPerCode32 == 1) {
192  auto code = code32[0];
193  dist = ConvertTo<float>::to(codeDist[code]);
194 
195  } else {
196 #pragma unroll
197  for (int byte = 0; byte < kBytesPerCode32; ++byte) {
198  auto code = getByte(code32[word], byte * 8, 8);
199 
200  auto offset =
201  codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
202 
203  dist += ConvertTo<float>::to(codeDist[offset + code]);
204  }
205  }
206  }
207 
208  // Write out intermediate distance result
209  // We do not maintain indices here, in order to reduce global
210  // memory traffic. Those are recovered in the final selection step.
211  distanceOut[codeIndex] = dist;
212 
213  // Rotate buffers
214 #pragma unroll
215  for (int word = 0; word < kNumCode32; ++word) {
216  code32[word] = nextCode32[word];
217  }
218  }
219 }
220 
221 void
222 runMultiPassTile(Tensor<float, 2, true>& queries,
223  Tensor<float, 2, true>& centroids,
224  Tensor<float, 3, true>& pqCentroidsInnermostCode,
225  NoTypeTensor<4, true>& codeDistances,
226  Tensor<int, 2, true>& topQueryToCentroid,
227  bool useFloat16Lookup,
228  int bytesPerCode,
229  int numSubQuantizers,
230  int numSubQuantizerCodes,
231  thrust::device_vector<void*>& listCodes,
232  thrust::device_vector<void*>& listIndices,
233  IndicesOptions indicesOptions,
234  thrust::device_vector<int>& listLengths,
235  Tensor<char, 1, true>& thrustMem,
236  Tensor<int, 2, true>& prefixSumOffsets,
237  Tensor<float, 1, true>& allDistances,
238  Tensor<float, 3, true>& heapDistances,
239  Tensor<int, 3, true>& heapIndices,
240  int k,
241  Tensor<float, 2, true>& outDistances,
242  Tensor<long, 2, true>& outIndices,
243  cudaStream_t stream) {
244 #ifndef FAISS_USE_FLOAT16
245  FAISS_ASSERT(!useFloat16Lookup);
246 #endif
247 
248  // Calculate offset lengths, so we know where to write out
249  // intermediate results
250  runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
251  thrustMem, stream);
252 
253  // Calculate residual code distances, since this is without
254  // precomputed codes
255  runPQCodeDistances(pqCentroidsInnermostCode,
256  queries,
257  centroids,
258  topQueryToCentroid,
259  codeDistances,
260  useFloat16Lookup,
261  stream);
262 
263  // Convert all codes to a distance, and write out (distance,
264  // index) values for all intermediate results
265  {
266  auto kThreadsPerBlock = 256;
267 
268  auto grid = dim3(topQueryToCentroid.getSize(1),
269  topQueryToCentroid.getSize(0));
270  auto block = dim3(kThreadsPerBlock);
271 
272  // pq centroid distances
273  auto smem = sizeof(float);
274 #ifdef FAISS_USE_FLOAT16
275  if (useFloat16Lookup) {
276  smem = sizeof(half);
277  }
278 #endif
279  smem *= numSubQuantizers * numSubQuantizerCodes;
280  FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
281 
282 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
283  do { \
284  auto codeDistancesT = codeDistances.toTensor<LOOKUP_T>(); \
285  \
286  pqScanNoPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
287  <<<grid, block, smem, stream>>>( \
288  queries, \
289  pqCentroidsInnermostCode, \
290  topQueryToCentroid, \
291  codeDistancesT, \
292  listCodes.data().get(), \
293  listLengths.data().get(), \
294  prefixSumOffsets, \
295  allDistances); \
296  } while (0)
297 
298 #ifdef FAISS_USE_FLOAT16
299 #define RUN_PQ(NUM_SUB_Q) \
300  do { \
301  if (useFloat16Lookup) { \
302  RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
303  } else { \
304  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
305  } \
306  } while (0)
307 #else
308 #define RUN_PQ(NUM_SUB_Q) \
309  do { \
310  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
311  } while (0)
312 #endif // FAISS_USE_FLOAT16
313 
314  switch (bytesPerCode) {
315  case 1:
316  RUN_PQ(1);
317  break;
318  case 2:
319  RUN_PQ(2);
320  break;
321  case 3:
322  RUN_PQ(3);
323  break;
324  case 4:
325  RUN_PQ(4);
326  break;
327  case 8:
328  RUN_PQ(8);
329  break;
330  case 12:
331  RUN_PQ(12);
332  break;
333  case 16:
334  RUN_PQ(16);
335  break;
336  case 20:
337  RUN_PQ(20);
338  break;
339  case 24:
340  RUN_PQ(24);
341  break;
342  case 28:
343  RUN_PQ(28);
344  break;
345  case 32:
346  RUN_PQ(32);
347  break;
348  case 40:
349  RUN_PQ(40);
350  break;
351  case 48:
352  RUN_PQ(48);
353  break;
354  case 56:
355  RUN_PQ(56);
356  break;
357  case 64:
358  RUN_PQ(64);
359  break;
360  case 96:
361  RUN_PQ(96);
362  break;
363  default:
364  FAISS_ASSERT(false);
365  break;
366  }
367 
368 #undef RUN_PQ
369 #undef RUN_PQ_OPT
370  }
371 
372  CUDA_TEST_ERROR();
373 
374  // k-select the output in chunks, to increase parallelism
375  runPass1SelectLists(prefixSumOffsets,
376  allDistances,
377  topQueryToCentroid.getSize(1),
378  k,
379  false, // L2 distance chooses smallest
380  heapDistances,
381  heapIndices,
382  stream);
383 
384  // k-select final output
385  auto flatHeapDistances = heapDistances.downcastInner<2>();
386  auto flatHeapIndices = heapIndices.downcastInner<2>();
387 
388  runPass2SelectLists(flatHeapDistances,
389  flatHeapIndices,
390  listIndices,
391  indicesOptions,
392  prefixSumOffsets,
393  topQueryToCentroid,
394  k,
395  false, // L2 distance chooses smallest
396  outDistances,
397  outIndices,
398  stream);
399 }
400 
401 void runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
402  Tensor<float, 2, true>& centroids,
403  Tensor<float, 3, true>& pqCentroidsInnermostCode,
404  Tensor<int, 2, true>& topQueryToCentroid,
405  bool useFloat16Lookup,
406  int bytesPerCode,
407  int numSubQuantizers,
408  int numSubQuantizerCodes,
409  thrust::device_vector<void*>& listCodes,
410  thrust::device_vector<void*>& listIndices,
411  IndicesOptions indicesOptions,
412  thrust::device_vector<int>& listLengths,
413  int maxListLength,
414  int k,
415  // output
416  Tensor<float, 2, true>& outDistances,
417  // output
418  Tensor<long, 2, true>& outIndices,
419  GpuResources* res) {
420  constexpr int kMinQueryTileSize = 8;
421  constexpr int kMaxQueryTileSize = 128;
422  constexpr int kThrustMemSize = 16384;
423 
424  int nprobe = topQueryToCentroid.getSize(1);
425 
426  auto& mem = res->getMemoryManagerCurrentDevice();
427  auto stream = res->getDefaultStreamCurrentDevice();
428 
429  // Make a reservation for Thrust to do its dirty work (global memory
430  // cross-block reduction space); hopefully this is large enough.
431  DeviceTensor<char, 1, true> thrustMem1(
432  mem, {kThrustMemSize}, stream);
433  DeviceTensor<char, 1, true> thrustMem2(
434  mem, {kThrustMemSize}, stream);
435  DeviceTensor<char, 1, true>* thrustMem[2] =
436  {&thrustMem1, &thrustMem2};
437 
438  // How much temporary storage is available?
439  // If possible, we'd like to fit within the space available.
440  size_t sizeAvailable = mem.getSizeAvailable();
441 
442  // We run two passes of heap selection
443  // This is the size of the first-level heap passes
444  constexpr int kNProbeSplit = 8;
445  int pass2Chunks = std::min(nprobe, kNProbeSplit);
446 
447  size_t sizeForFirstSelectPass =
448  pass2Chunks * k * (sizeof(float) + sizeof(int));
449 
450  // How much temporary storage we need per each query
451  size_t sizePerQuery =
452  2 * // streams
453  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
454  nprobe * maxListLength * sizeof(float) + // allDistances
455  // residual distances
456  nprobe * numSubQuantizers * numSubQuantizerCodes * sizeof(float) +
457  sizeForFirstSelectPass);
458 
459  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
460 
461  if (queryTileSize < kMinQueryTileSize) {
462  queryTileSize = kMinQueryTileSize;
463  } else if (queryTileSize > kMaxQueryTileSize) {
464  queryTileSize = kMaxQueryTileSize;
465  }
466 
467  // FIXME: we should adjust queryTileSize to deal with this, since
468  // indexing is in int32
469  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
470  std::numeric_limits<int>::max());
471 
472  // Temporary memory buffers
473  // Make sure there is space prior to the start which will be 0, and
474  // will handle the boundary condition without branches
475  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
476  mem, {queryTileSize * nprobe + 1}, stream);
477  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
478  mem, {queryTileSize * nprobe + 1}, stream);
479 
480  DeviceTensor<int, 2, true> prefixSumOffsets1(
481  prefixSumOffsetSpace1[1].data(),
482  {queryTileSize, nprobe});
483  DeviceTensor<int, 2, true> prefixSumOffsets2(
484  prefixSumOffsetSpace2[1].data(),
485  {queryTileSize, nprobe});
486  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
487  {&prefixSumOffsets1, &prefixSumOffsets2};
488 
489  // Make sure the element before prefixSumOffsets is 0, since we
490  // depend upon simple, boundary-less indexing to get proper results
491  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
492  0,
493  sizeof(int),
494  stream));
495  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
496  0,
497  sizeof(int),
498  stream));
499 
500  int codeDistanceTypeSize = sizeof(float);
501 #ifdef FAISS_USE_FLOAT16
502  if (useFloat16Lookup) {
503  codeDistanceTypeSize = sizeof(half);
504  }
505 #else
506  FAISS_ASSERT(!useFloat16Lookup);
507 #endif
508 
509  int totalCodeDistancesSize =
510  queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes *
511  codeDistanceTypeSize;
512 
513  DeviceTensor<char, 1, true> codeDistances1Mem(
514  mem, {totalCodeDistancesSize}, stream);
515  NoTypeTensor<4, true> codeDistances1(
516  codeDistances1Mem.data(),
517  codeDistanceTypeSize,
518  {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
519 
520  DeviceTensor<char, 1, true> codeDistances2Mem(
521  mem, {totalCodeDistancesSize}, stream);
522  NoTypeTensor<4, true> codeDistances2(
523  codeDistances2Mem.data(),
524  codeDistanceTypeSize,
525  {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
526 
527  NoTypeTensor<4, true>* codeDistances[2] =
528  {&codeDistances1, &codeDistances2};
529 
530  DeviceTensor<float, 1, true> allDistances1(
531  mem, {queryTileSize * nprobe * maxListLength}, stream);
532  DeviceTensor<float, 1, true> allDistances2(
533  mem, {queryTileSize * nprobe * maxListLength}, stream);
534  DeviceTensor<float, 1, true>* allDistances[2] =
535  {&allDistances1, &allDistances2};
536 
537  DeviceTensor<float, 3, true> heapDistances1(
538  mem, {queryTileSize, pass2Chunks, k}, stream);
539  DeviceTensor<float, 3, true> heapDistances2(
540  mem, {queryTileSize, pass2Chunks, k}, stream);
541  DeviceTensor<float, 3, true>* heapDistances[2] =
542  {&heapDistances1, &heapDistances2};
543 
544  DeviceTensor<int, 3, true> heapIndices1(
545  mem, {queryTileSize, pass2Chunks, k}, stream);
546  DeviceTensor<int, 3, true> heapIndices2(
547  mem, {queryTileSize, pass2Chunks, k}, stream);
548  DeviceTensor<int, 3, true>* heapIndices[2] =
549  {&heapIndices1, &heapIndices2};
550 
551  auto streams = res->getAlternateStreamsCurrentDevice();
552  streamWait(streams, {stream});
553 
554  int curStream = 0;
555 
556  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
557  int numQueriesInTile =
558  std::min(queryTileSize, queries.getSize(0) - query);
559 
560  auto prefixSumOffsetsView =
561  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
562 
563  auto codeDistancesView =
564  codeDistances[curStream]->narrowOutermost(0, numQueriesInTile);
565  auto coarseIndicesView =
566  topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
567  auto queryView =
568  queries.narrowOutermost(query, numQueriesInTile);
569 
570  auto heapDistancesView =
571  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
572  auto heapIndicesView =
573  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
574 
575  auto outDistanceView =
576  outDistances.narrowOutermost(query, numQueriesInTile);
577  auto outIndicesView =
578  outIndices.narrowOutermost(query, numQueriesInTile);
579 
580  runMultiPassTile(queryView,
581  centroids,
582  pqCentroidsInnermostCode,
583  codeDistancesView,
584  coarseIndicesView,
585  useFloat16Lookup,
586  bytesPerCode,
587  numSubQuantizers,
588  numSubQuantizerCodes,
589  listCodes,
590  listIndices,
591  indicesOptions,
592  listLengths,
593  *thrustMem[curStream],
594  prefixSumOffsetsView,
595  *allDistances[curStream],
596  heapDistancesView,
597  heapIndicesView,
598  k,
599  outDistanceView,
600  outIndicesView,
601  streams[curStream]);
602 
603  curStream = (curStream + 1) % 2;
604  }
605 
606  streamWait({stream}, streams);
607 }
608 
609 } } // namespace
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:222
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:174
Our tensor type.
Definition: Tensor.cuh:28