Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQScanMultiPassNoPrecomputed.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "PQScanMultiPassNoPrecomputed.cuh"
12 #include "../GpuResources.h"
13 #include "PQCodeDistances.cuh"
14 #include "PQCodeLoad.cuh"
15 #include "IVFUtils.cuh"
16 #include "../utils/ConversionOperators.cuh"
17 #include "../utils/DeviceTensor.cuh"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Float16.cuh"
20 #include "../utils/LoadStoreOperators.cuh"
21 #include "../utils/NoTypeTensor.cuh"
22 #include "../utils/StaticUtils.h"
23 
24 #include "../utils/HostTensor.cuh"
25 
26 namespace faiss { namespace gpu {
27 
28 bool isSupportedNoPrecomputedSubDimSize(int dims) {
29  switch (dims) {
30  case 1:
31  case 2:
32  case 3:
33  case 4:
34  case 6:
35  case 8:
36  case 10:
37  case 12:
38  case 16:
39  case 32:
40  return true;
41  default:
42  // FIXME: larger sizes require too many registers - we need the
43  // MM implementation working
44  return false;
45  }
46 }
47 
48 template <typename LookupT, typename LookupVecT>
50  static inline __device__ void load(LookupT* smem,
51  LookupT* codes,
52  int numCodes) {
53  constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT);
54 
55  // We can only use the vector type if the data is guaranteed to be
56  // aligned. The codes are innermost, so if it is evenly divisible,
57  // then any slice will be aligned.
58  if (numCodes % kWordSize == 0) {
59  // Load the data by float4 for efficiency, and then handle any remainder
60  // limitVec is the number of whole vec words we can load, in terms
61  // of whole blocks performing the load
62  constexpr int kUnroll = 2;
63  int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
64  limitVec *= kUnroll * blockDim.x;
65 
66  LookupVecT* smemV = (LookupVecT*) smem;
67  LookupVecT* codesV = (LookupVecT*) codes;
68 
69  for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
70  LookupVecT vals[kUnroll];
71 
72 #pragma unroll
73  for (int j = 0; j < kUnroll; ++j) {
74  vals[j] =
75  LoadStore<LookupVecT>::load(&codesV[i + j * blockDim.x]);
76  }
77 
78 #pragma unroll
79  for (int j = 0; j < kUnroll; ++j) {
80  LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
81  }
82  }
83 
84  // This is where we start loading the remainder that does not evenly
85  // fit into kUnroll x blockDim.x
86  int remainder = limitVec * kWordSize;
87 
88  for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
89  smem[i] = codes[i];
90  }
91  } else {
92  // Potential unaligned load
93  constexpr int kUnroll = 4;
94 
95  int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
96 
97  int i = threadIdx.x;
98  for (; i < limit; i += kUnroll * blockDim.x) {
99  LookupT vals[kUnroll];
100 
101 #pragma unroll
102  for (int j = 0; j < kUnroll; ++j) {
103  vals[j] = codes[i + j * blockDim.x];
104  }
105 
106 #pragma unroll
107  for (int j = 0; j < kUnroll; ++j) {
108  smem[i + j * blockDim.x] = vals[j];
109  }
110  }
111 
112  for (; i < numCodes; i += blockDim.x) {
113  smem[i] = codes[i];
114  }
115  }
116  }
117 };
118 
119 template <int NumSubQuantizers, typename LookupT, typename LookupVecT>
120 __global__ void
121 pqScanNoPrecomputedMultiPass(Tensor<float, 2, true> queries,
122  Tensor<float, 3, true> pqCentroids,
123  Tensor<int, 2, true> topQueryToCentroid,
124  Tensor<LookupT, 4, true> codeDistances,
125  void** listCodes,
126  int* listLengths,
127  Tensor<int, 2, true> prefixSumOffsets,
128  Tensor<float, 1, true> distance) {
129  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
130 
131  // Where the pq code -> residual distance is stored
132  extern __shared__ char smemCodeDistances[];
133  LookupT* codeDist = (LookupT*) smemCodeDistances;
134 
135  // Each block handles a single query
136  auto queryId = blockIdx.y;
137  auto probeId = blockIdx.x;
138 
139  // This is where we start writing out data
140  // We ensure that before the array (at offset -1), there is a 0 value
141  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
142  float* distanceOut = distance[outBase].data();
143 
144  auto listId = topQueryToCentroid[queryId][probeId];
145  // Safety guard in case NaNs in input cause no list ID to be generated
146  if (listId == -1) {
147  return;
148  }
149 
150  unsigned char* codeList = (unsigned char*) listCodes[listId];
151  int limit = listLengths[listId];
152 
153  constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
154  (NumSubQuantizers / 4);
155  unsigned int code32[kNumCode32];
156  unsigned int nextCode32[kNumCode32];
157 
158  // We double-buffer the code loading, which improves memory utilization
159  if (threadIdx.x < limit) {
160  LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
161  }
162 
163  LoadCodeDistances<LookupT, LookupVecT>::load(
164  codeDist,
165  codeDistances[queryId][probeId].data(),
166  codeDistances.getSize(2) * codeDistances.getSize(3));
167 
168  // Prevent WAR dependencies
169  __syncthreads();
170 
171  // Each thread handles one code element in the list, with a
172  // block-wide stride
173  for (int codeIndex = threadIdx.x;
174  codeIndex < limit;
175  codeIndex += blockDim.x) {
176  // Prefetch next codes
177  if (codeIndex + blockDim.x < limit) {
178  LoadCode32<NumSubQuantizers>::load(
179  nextCode32, codeList, codeIndex + blockDim.x);
180  }
181 
182  float dist = 0.0f;
183 
184 #pragma unroll
185  for (int word = 0; word < kNumCode32; ++word) {
186  constexpr int kBytesPerCode32 =
187  NumSubQuantizers < 4 ? NumSubQuantizers : 4;
188 
189  if (kBytesPerCode32 == 1) {
190  auto code = code32[0];
191  dist = ConvertTo<float>::to(codeDist[code]);
192 
193  } else {
194 #pragma unroll
195  for (int byte = 0; byte < kBytesPerCode32; ++byte) {
196  auto code = getByte(code32[word], byte * 8, 8);
197 
198  auto offset =
199  codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
200 
201  dist += ConvertTo<float>::to(codeDist[offset + code]);
202  }
203  }
204  }
205 
206  // Write out intermediate distance result
207  // We do not maintain indices here, in order to reduce global
208  // memory traffic. Those are recovered in the final selection step.
209  distanceOut[codeIndex] = dist;
210 
211  // Rotate buffers
212 #pragma unroll
213  for (int word = 0; word < kNumCode32; ++word) {
214  code32[word] = nextCode32[word];
215  }
216  }
217 }
218 
219 void
220 runMultiPassTile(Tensor<float, 2, true>& queries,
221  Tensor<float, 2, true>& centroids,
222  Tensor<float, 3, true>& pqCentroidsInnermostCode,
223  NoTypeTensor<4, true>& codeDistances,
224  Tensor<int, 2, true>& topQueryToCentroid,
225  bool useFloat16Lookup,
226  int bytesPerCode,
227  int numSubQuantizers,
228  int numSubQuantizerCodes,
229  thrust::device_vector<void*>& listCodes,
230  thrust::device_vector<void*>& listIndices,
231  IndicesOptions indicesOptions,
232  thrust::device_vector<int>& listLengths,
233  Tensor<char, 1, true>& thrustMem,
234  Tensor<int, 2, true>& prefixSumOffsets,
235  Tensor<float, 1, true>& allDistances,
236  Tensor<float, 3, true>& heapDistances,
237  Tensor<int, 3, true>& heapIndices,
238  int k,
239  Tensor<float, 2, true>& outDistances,
240  Tensor<long, 2, true>& outIndices,
241  cudaStream_t stream) {
242 #ifndef FAISS_USE_FLOAT16
243  FAISS_ASSERT(!useFloat16Lookup);
244 #endif
245 
246  // Calculate offset lengths, so we know where to write out
247  // intermediate results
248  runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
249  thrustMem, stream);
250 
251  // Calculate residual code distances, since this is without
252  // precomputed codes
253  runPQCodeDistances(pqCentroidsInnermostCode,
254  queries,
255  centroids,
256  topQueryToCentroid,
257  codeDistances,
258  useFloat16Lookup,
259  stream);
260 
261  // Convert all codes to a distance, and write out (distance,
262  // index) values for all intermediate results
263  {
264  auto kThreadsPerBlock = 256;
265 
266  auto grid = dim3(topQueryToCentroid.getSize(1),
267  topQueryToCentroid.getSize(0));
268  auto block = dim3(kThreadsPerBlock);
269 
270  // pq centroid distances
271  auto smem = sizeof(float);
272 #ifdef FAISS_USE_FLOAT16
273  if (useFloat16Lookup) {
274  smem = sizeof(half);
275  }
276 #endif
277  smem *= numSubQuantizers * numSubQuantizerCodes;
278  FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
279 
280 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
281  do { \
282  auto codeDistancesT = codeDistances.toTensor<LOOKUP_T>(); \
283  \
284  pqScanNoPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
285  <<<grid, block, smem, stream>>>( \
286  queries, \
287  pqCentroidsInnermostCode, \
288  topQueryToCentroid, \
289  codeDistancesT, \
290  listCodes.data().get(), \
291  listLengths.data().get(), \
292  prefixSumOffsets, \
293  allDistances); \
294  } while (0)
295 
296 #ifdef FAISS_USE_FLOAT16
297 #define RUN_PQ(NUM_SUB_Q) \
298  do { \
299  if (useFloat16Lookup) { \
300  RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
301  } else { \
302  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
303  } \
304  } while (0)
305 #else
306 #define RUN_PQ(NUM_SUB_Q) \
307  do { \
308  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
309  } while (0)
310 #endif // FAISS_USE_FLOAT16
311 
312  switch (bytesPerCode) {
313  case 1:
314  RUN_PQ(1);
315  break;
316  case 2:
317  RUN_PQ(2);
318  break;
319  case 3:
320  RUN_PQ(3);
321  break;
322  case 4:
323  RUN_PQ(4);
324  break;
325  case 8:
326  RUN_PQ(8);
327  break;
328  case 12:
329  RUN_PQ(12);
330  break;
331  case 16:
332  RUN_PQ(16);
333  break;
334  case 20:
335  RUN_PQ(20);
336  break;
337  case 24:
338  RUN_PQ(24);
339  break;
340  case 28:
341  RUN_PQ(28);
342  break;
343  case 32:
344  RUN_PQ(32);
345  break;
346  case 40:
347  RUN_PQ(40);
348  break;
349  case 48:
350  RUN_PQ(48);
351  break;
352  case 56:
353  RUN_PQ(56);
354  break;
355  case 64:
356  RUN_PQ(64);
357  break;
358  case 96:
359  RUN_PQ(96);
360  break;
361  default:
362  FAISS_ASSERT(false);
363  break;
364  }
365 
366 #undef RUN_PQ
367 #undef RUN_PQ_OPT
368  }
369 
370  CUDA_TEST_ERROR();
371 
372  // k-select the output in chunks, to increase parallelism
373  runPass1SelectLists(prefixSumOffsets,
374  allDistances,
375  topQueryToCentroid.getSize(1),
376  k,
377  false, // L2 distance chooses smallest
378  heapDistances,
379  heapIndices,
380  stream);
381 
382  // k-select final output
383  auto flatHeapDistances = heapDistances.downcastInner<2>();
384  auto flatHeapIndices = heapIndices.downcastInner<2>();
385 
386  runPass2SelectLists(flatHeapDistances,
387  flatHeapIndices,
388  listIndices,
389  indicesOptions,
390  prefixSumOffsets,
391  topQueryToCentroid,
392  k,
393  false, // L2 distance chooses smallest
394  outDistances,
395  outIndices,
396  stream);
397 }
398 
399 void runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
400  Tensor<float, 2, true>& centroids,
401  Tensor<float, 3, true>& pqCentroidsInnermostCode,
402  Tensor<int, 2, true>& topQueryToCentroid,
403  bool useFloat16Lookup,
404  int bytesPerCode,
405  int numSubQuantizers,
406  int numSubQuantizerCodes,
407  thrust::device_vector<void*>& listCodes,
408  thrust::device_vector<void*>& listIndices,
409  IndicesOptions indicesOptions,
410  thrust::device_vector<int>& listLengths,
411  int maxListLength,
412  int k,
413  // output
414  Tensor<float, 2, true>& outDistances,
415  // output
416  Tensor<long, 2, true>& outIndices,
417  GpuResources* res) {
418  constexpr int kMinQueryTileSize = 8;
419  constexpr int kMaxQueryTileSize = 128;
420  constexpr int kThrustMemSize = 16384;
421 
422  int nprobe = topQueryToCentroid.getSize(1);
423 
424  auto& mem = res->getMemoryManagerCurrentDevice();
425  auto stream = res->getDefaultStreamCurrentDevice();
426 
427  // Make a reservation for Thrust to do its dirty work (global memory
428  // cross-block reduction space); hopefully this is large enough.
429  DeviceTensor<char, 1, true> thrustMem1(
430  mem, {kThrustMemSize}, stream);
431  DeviceTensor<char, 1, true> thrustMem2(
432  mem, {kThrustMemSize}, stream);
433  DeviceTensor<char, 1, true>* thrustMem[2] =
434  {&thrustMem1, &thrustMem2};
435 
436  // How much temporary storage is available?
437  // If possible, we'd like to fit within the space available.
438  size_t sizeAvailable = mem.getSizeAvailable();
439 
440  // We run two passes of heap selection
441  // This is the size of the first-level heap passes
442  constexpr int kNProbeSplit = 8;
443  int pass2Chunks = std::min(nprobe, kNProbeSplit);
444 
445  size_t sizeForFirstSelectPass =
446  pass2Chunks * k * (sizeof(float) + sizeof(int));
447 
448  // How much temporary storage we need per each query
449  size_t sizePerQuery =
450  2 * // streams
451  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
452  nprobe * maxListLength * sizeof(float) + // allDistances
453  // residual distances
454  nprobe * numSubQuantizers * numSubQuantizerCodes * sizeof(float) +
455  sizeForFirstSelectPass);
456 
457  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
458 
459  if (queryTileSize < kMinQueryTileSize) {
460  queryTileSize = kMinQueryTileSize;
461  } else if (queryTileSize > kMaxQueryTileSize) {
462  queryTileSize = kMaxQueryTileSize;
463  }
464 
465  // FIXME: we should adjust queryTileSize to deal with this, since
466  // indexing is in int32
467  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
468  std::numeric_limits<int>::max());
469 
470  // Temporary memory buffers
471  // Make sure there is space prior to the start which will be 0, and
472  // will handle the boundary condition without branches
473  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
474  mem, {queryTileSize * nprobe + 1}, stream);
475  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
476  mem, {queryTileSize * nprobe + 1}, stream);
477 
478  DeviceTensor<int, 2, true> prefixSumOffsets1(
479  prefixSumOffsetSpace1[1].data(),
480  {queryTileSize, nprobe});
481  DeviceTensor<int, 2, true> prefixSumOffsets2(
482  prefixSumOffsetSpace2[1].data(),
483  {queryTileSize, nprobe});
484  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
485  {&prefixSumOffsets1, &prefixSumOffsets2};
486 
487  // Make sure the element before prefixSumOffsets is 0, since we
488  // depend upon simple, boundary-less indexing to get proper results
489  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
490  0,
491  sizeof(int),
492  stream));
493  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
494  0,
495  sizeof(int),
496  stream));
497 
498  int codeDistanceTypeSize = sizeof(float);
499 #ifdef FAISS_USE_FLOAT16
500  if (useFloat16Lookup) {
501  codeDistanceTypeSize = sizeof(half);
502  }
503 #else
504  FAISS_ASSERT(!useFloat16Lookup);
505  int codeSize = sizeof(float);
506 #endif
507 
508  int totalCodeDistancesSize =
509  queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes *
510  codeDistanceTypeSize;
511 
512  DeviceTensor<char, 1, true> codeDistances1Mem(
513  mem, {totalCodeDistancesSize}, stream);
514  NoTypeTensor<4, true> codeDistances1(
515  codeDistances1Mem.data(),
516  codeDistanceTypeSize,
517  {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
518 
519  DeviceTensor<char, 1, true> codeDistances2Mem(
520  mem, {totalCodeDistancesSize}, stream);
521  NoTypeTensor<4, true> codeDistances2(
522  codeDistances2Mem.data(),
523  codeDistanceTypeSize,
524  {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
525 
526  NoTypeTensor<4, true>* codeDistances[2] =
527  {&codeDistances1, &codeDistances2};
528 
529  DeviceTensor<float, 1, true> allDistances1(
530  mem, {queryTileSize * nprobe * maxListLength}, stream);
531  DeviceTensor<float, 1, true> allDistances2(
532  mem, {queryTileSize * nprobe * maxListLength}, stream);
533  DeviceTensor<float, 1, true>* allDistances[2] =
534  {&allDistances1, &allDistances2};
535 
536  DeviceTensor<float, 3, true> heapDistances1(
537  mem, {queryTileSize, pass2Chunks, k}, stream);
538  DeviceTensor<float, 3, true> heapDistances2(
539  mem, {queryTileSize, pass2Chunks, k}, stream);
540  DeviceTensor<float, 3, true>* heapDistances[2] =
541  {&heapDistances1, &heapDistances2};
542 
543  DeviceTensor<int, 3, true> heapIndices1(
544  mem, {queryTileSize, pass2Chunks, k}, stream);
545  DeviceTensor<int, 3, true> heapIndices2(
546  mem, {queryTileSize, pass2Chunks, k}, stream);
547  DeviceTensor<int, 3, true>* heapIndices[2] =
548  {&heapIndices1, &heapIndices2};
549 
550  auto streams = res->getAlternateStreamsCurrentDevice();
551  streamWait(streams, {stream});
552 
553  int curStream = 0;
554 
555  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
556  int numQueriesInTile =
557  std::min(queryTileSize, queries.getSize(0) - query);
558 
559  auto prefixSumOffsetsView =
560  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
561 
562  auto codeDistancesView =
563  codeDistances[curStream]->narrowOutermost(0, numQueriesInTile);
564  auto coarseIndicesView =
565  topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
566  auto queryView =
567  queries.narrowOutermost(query, numQueriesInTile);
568 
569  auto heapDistancesView =
570  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
571  auto heapIndicesView =
572  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
573 
574  auto outDistanceView =
575  outDistances.narrowOutermost(query, numQueriesInTile);
576  auto outIndicesView =
577  outIndices.narrowOutermost(query, numQueriesInTile);
578 
579  runMultiPassTile(queryView,
580  centroids,
581  pqCentroidsInnermostCode,
582  codeDistancesView,
583  coarseIndicesView,
584  useFloat16Lookup,
585  bytesPerCode,
586  numSubQuantizers,
587  numSubQuantizerCodes,
588  listCodes,
589  listIndices,
590  indicesOptions,
591  listLengths,
592  *thrustMem[curStream],
593  prefixSumOffsetsView,
594  *allDistances[curStream],
595  heapDistancesView,
596  heapIndicesView,
597  k,
598  outDistanceView,
599  outIndicesView,
600  streams[curStream]);
601 
602  curStream = (curStream + 1) % 2;
603  }
604 
605  streamWait({stream}, streams);
606 }
607 
608 } } // namespace
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:173
Our tensor type.
Definition: Tensor.cuh:30
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:221