Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQScanMultiPassNoPrecomputed.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "PQScanMultiPassNoPrecomputed.cuh"
12 #include "../GpuResources.h"
13 #include "PQCodeDistances.cuh"
14 #include "PQCodeLoad.cuh"
15 #include "IVFUtils.cuh"
16 #include "../utils/ConversionOperators.cuh"
17 #include "../utils/DeviceTensor.cuh"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Float16.cuh"
20 #include "../utils/LoadStoreOperators.cuh"
21 #include "../utils/NoTypeTensor.cuh"
22 #include "../utils/StaticUtils.h"
23 
24 #include "../utils/HostTensor.cuh"
25 
26 namespace faiss { namespace gpu {
27 
28 // This must be kept in sync with PQCodeDistances.cu
29 bool isSupportedNoPrecomputedSubDimSize(int dims) {
30  switch (dims) {
31  case 1:
32  case 2:
33  case 3:
34  case 4:
35  case 6:
36  case 8:
37  case 10:
38  case 12:
39  case 16:
40  case 20:
41  case 24:
42  case 28:
43  case 32:
44  return true;
45  default:
46  // FIXME: larger sizes require too many registers - we need the
47  // MM implementation working
48  return false;
49  }
50 }
51 
52 template <typename LookupT, typename LookupVecT>
54  static inline __device__ void load(LookupT* smem,
55  LookupT* codes,
56  int numCodes) {
57  constexpr int kWordSize = sizeof(LookupVecT) / sizeof(LookupT);
58 
59  // We can only use the vector type if the data is guaranteed to be
60  // aligned. The codes are innermost, so if it is evenly divisible,
61  // then any slice will be aligned.
62  if (numCodes % kWordSize == 0) {
63  // Load the data by float4 for efficiency, and then handle any remainder
64  // limitVec is the number of whole vec words we can load, in terms
65  // of whole blocks performing the load
66  constexpr int kUnroll = 2;
67  int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
68  limitVec *= kUnroll * blockDim.x;
69 
70  LookupVecT* smemV = (LookupVecT*) smem;
71  LookupVecT* codesV = (LookupVecT*) codes;
72 
73  for (int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
74  LookupVecT vals[kUnroll];
75 
76 #pragma unroll
77  for (int j = 0; j < kUnroll; ++j) {
78  vals[j] =
79  LoadStore<LookupVecT>::load(&codesV[i + j * blockDim.x]);
80  }
81 
82 #pragma unroll
83  for (int j = 0; j < kUnroll; ++j) {
84  LoadStore<LookupVecT>::store(&smemV[i + j * blockDim.x], vals[j]);
85  }
86  }
87 
88  // This is where we start loading the remainder that does not evenly
89  // fit into kUnroll x blockDim.x
90  int remainder = limitVec * kWordSize;
91 
92  for (int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
93  smem[i] = codes[i];
94  }
95  } else {
96  // Potential unaligned load
97  constexpr int kUnroll = 4;
98 
99  int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
100 
101  int i = threadIdx.x;
102  for (; i < limit; i += kUnroll * blockDim.x) {
103  LookupT vals[kUnroll];
104 
105 #pragma unroll
106  for (int j = 0; j < kUnroll; ++j) {
107  vals[j] = codes[i + j * blockDim.x];
108  }
109 
110 #pragma unroll
111  for (int j = 0; j < kUnroll; ++j) {
112  smem[i + j * blockDim.x] = vals[j];
113  }
114  }
115 
116  for (; i < numCodes; i += blockDim.x) {
117  smem[i] = codes[i];
118  }
119  }
120  }
121 };
122 
123 template <int NumSubQuantizers, typename LookupT, typename LookupVecT>
124 __global__ void
125 pqScanNoPrecomputedMultiPass(Tensor<float, 2, true> queries,
126  Tensor<float, 3, true> pqCentroids,
127  Tensor<int, 2, true> topQueryToCentroid,
128  Tensor<LookupT, 4, true> codeDistances,
129  void** listCodes,
130  int* listLengths,
131  Tensor<int, 2, true> prefixSumOffsets,
132  Tensor<float, 1, true> distance) {
133  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
134 
135  // Where the pq code -> residual distance is stored
136  extern __shared__ char smemCodeDistances[];
137  LookupT* codeDist = (LookupT*) smemCodeDistances;
138 
139  // Each block handles a single query
140  auto queryId = blockIdx.y;
141  auto probeId = blockIdx.x;
142 
143  // This is where we start writing out data
144  // We ensure that before the array (at offset -1), there is a 0 value
145  int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
146  float* distanceOut = distance[outBase].data();
147 
148  auto listId = topQueryToCentroid[queryId][probeId];
149  // Safety guard in case NaNs in input cause no list ID to be generated
150  if (listId == -1) {
151  return;
152  }
153 
154  unsigned char* codeList = (unsigned char*) listCodes[listId];
155  int limit = listLengths[listId];
156 
157  constexpr int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
158  (NumSubQuantizers / 4);
159  unsigned int code32[kNumCode32];
160  unsigned int nextCode32[kNumCode32];
161 
162  // We double-buffer the code loading, which improves memory utilization
163  if (threadIdx.x < limit) {
164  LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
165  }
166 
167  LoadCodeDistances<LookupT, LookupVecT>::load(
168  codeDist,
169  codeDistances[queryId][probeId].data(),
170  codeDistances.getSize(2) * codeDistances.getSize(3));
171 
172  // Prevent WAR dependencies
173  __syncthreads();
174 
175  // Each thread handles one code element in the list, with a
176  // block-wide stride
177  for (int codeIndex = threadIdx.x;
178  codeIndex < limit;
179  codeIndex += blockDim.x) {
180  // Prefetch next codes
181  if (codeIndex + blockDim.x < limit) {
182  LoadCode32<NumSubQuantizers>::load(
183  nextCode32, codeList, codeIndex + blockDim.x);
184  }
185 
186  float dist = 0.0f;
187 
188 #pragma unroll
189  for (int word = 0; word < kNumCode32; ++word) {
190  constexpr int kBytesPerCode32 =
191  NumSubQuantizers < 4 ? NumSubQuantizers : 4;
192 
193  if (kBytesPerCode32 == 1) {
194  auto code = code32[0];
195  dist = ConvertTo<float>::to(codeDist[code]);
196 
197  } else {
198 #pragma unroll
199  for (int byte = 0; byte < kBytesPerCode32; ++byte) {
200  auto code = getByte(code32[word], byte * 8, 8);
201 
202  auto offset =
203  codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
204 
205  dist += ConvertTo<float>::to(codeDist[offset + code]);
206  }
207  }
208  }
209 
210  // Write out intermediate distance result
211  // We do not maintain indices here, in order to reduce global
212  // memory traffic. Those are recovered in the final selection step.
213  distanceOut[codeIndex] = dist;
214 
215  // Rotate buffers
216 #pragma unroll
217  for (int word = 0; word < kNumCode32; ++word) {
218  code32[word] = nextCode32[word];
219  }
220  }
221 }
222 
223 void
224 runMultiPassTile(Tensor<float, 2, true>& queries,
225  Tensor<float, 2, true>& centroids,
226  Tensor<float, 3, true>& pqCentroidsInnermostCode,
227  NoTypeTensor<4, true>& codeDistances,
228  Tensor<int, 2, true>& topQueryToCentroid,
229  bool useFloat16Lookup,
230  int bytesPerCode,
231  int numSubQuantizers,
232  int numSubQuantizerCodes,
233  thrust::device_vector<void*>& listCodes,
234  thrust::device_vector<void*>& listIndices,
235  IndicesOptions indicesOptions,
236  thrust::device_vector<int>& listLengths,
237  Tensor<char, 1, true>& thrustMem,
238  Tensor<int, 2, true>& prefixSumOffsets,
239  Tensor<float, 1, true>& allDistances,
240  Tensor<float, 3, true>& heapDistances,
241  Tensor<int, 3, true>& heapIndices,
242  int k,
243  Tensor<float, 2, true>& outDistances,
244  Tensor<long, 2, true>& outIndices,
245  cudaStream_t stream) {
246 #ifndef FAISS_USE_FLOAT16
247  FAISS_ASSERT(!useFloat16Lookup);
248 #endif
249 
250  // Calculate offset lengths, so we know where to write out
251  // intermediate results
252  runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
253  thrustMem, stream);
254 
255  // Calculate residual code distances, since this is without
256  // precomputed codes
257  runPQCodeDistances(pqCentroidsInnermostCode,
258  queries,
259  centroids,
260  topQueryToCentroid,
261  codeDistances,
262  useFloat16Lookup,
263  stream);
264 
265  // Convert all codes to a distance, and write out (distance,
266  // index) values for all intermediate results
267  {
268  auto kThreadsPerBlock = 256;
269 
270  auto grid = dim3(topQueryToCentroid.getSize(1),
271  topQueryToCentroid.getSize(0));
272  auto block = dim3(kThreadsPerBlock);
273 
274  // pq centroid distances
275  auto smem = sizeof(float);
276 #ifdef FAISS_USE_FLOAT16
277  if (useFloat16Lookup) {
278  smem = sizeof(half);
279  }
280 #endif
281  smem *= numSubQuantizers * numSubQuantizerCodes;
282  FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
283 
284 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T) \
285  do { \
286  auto codeDistancesT = codeDistances.toTensor<LOOKUP_T>(); \
287  \
288  pqScanNoPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T> \
289  <<<grid, block, smem, stream>>>( \
290  queries, \
291  pqCentroidsInnermostCode, \
292  topQueryToCentroid, \
293  codeDistancesT, \
294  listCodes.data().get(), \
295  listLengths.data().get(), \
296  prefixSumOffsets, \
297  allDistances); \
298  } while (0)
299 
300 #ifdef FAISS_USE_FLOAT16
301 #define RUN_PQ(NUM_SUB_Q) \
302  do { \
303  if (useFloat16Lookup) { \
304  RUN_PQ_OPT(NUM_SUB_Q, half, Half8); \
305  } else { \
306  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
307  } \
308  } while (0)
309 #else
310 #define RUN_PQ(NUM_SUB_Q) \
311  do { \
312  RUN_PQ_OPT(NUM_SUB_Q, float, float4); \
313  } while (0)
314 #endif // FAISS_USE_FLOAT16
315 
316  switch (bytesPerCode) {
317  case 1:
318  RUN_PQ(1);
319  break;
320  case 2:
321  RUN_PQ(2);
322  break;
323  case 3:
324  RUN_PQ(3);
325  break;
326  case 4:
327  RUN_PQ(4);
328  break;
329  case 8:
330  RUN_PQ(8);
331  break;
332  case 12:
333  RUN_PQ(12);
334  break;
335  case 16:
336  RUN_PQ(16);
337  break;
338  case 20:
339  RUN_PQ(20);
340  break;
341  case 24:
342  RUN_PQ(24);
343  break;
344  case 28:
345  RUN_PQ(28);
346  break;
347  case 32:
348  RUN_PQ(32);
349  break;
350  case 40:
351  RUN_PQ(40);
352  break;
353  case 48:
354  RUN_PQ(48);
355  break;
356  case 56:
357  RUN_PQ(56);
358  break;
359  case 64:
360  RUN_PQ(64);
361  break;
362  case 96:
363  RUN_PQ(96);
364  break;
365  default:
366  FAISS_ASSERT(false);
367  break;
368  }
369 
370 #undef RUN_PQ
371 #undef RUN_PQ_OPT
372  }
373 
374  CUDA_TEST_ERROR();
375 
376  // k-select the output in chunks, to increase parallelism
377  runPass1SelectLists(prefixSumOffsets,
378  allDistances,
379  topQueryToCentroid.getSize(1),
380  k,
381  false, // L2 distance chooses smallest
382  heapDistances,
383  heapIndices,
384  stream);
385 
386  // k-select final output
387  auto flatHeapDistances = heapDistances.downcastInner<2>();
388  auto flatHeapIndices = heapIndices.downcastInner<2>();
389 
390  runPass2SelectLists(flatHeapDistances,
391  flatHeapIndices,
392  listIndices,
393  indicesOptions,
394  prefixSumOffsets,
395  topQueryToCentroid,
396  k,
397  false, // L2 distance chooses smallest
398  outDistances,
399  outIndices,
400  stream);
401 }
402 
403 void runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
404  Tensor<float, 2, true>& centroids,
405  Tensor<float, 3, true>& pqCentroidsInnermostCode,
406  Tensor<int, 2, true>& topQueryToCentroid,
407  bool useFloat16Lookup,
408  int bytesPerCode,
409  int numSubQuantizers,
410  int numSubQuantizerCodes,
411  thrust::device_vector<void*>& listCodes,
412  thrust::device_vector<void*>& listIndices,
413  IndicesOptions indicesOptions,
414  thrust::device_vector<int>& listLengths,
415  int maxListLength,
416  int k,
417  // output
418  Tensor<float, 2, true>& outDistances,
419  // output
420  Tensor<long, 2, true>& outIndices,
421  GpuResources* res) {
422  constexpr int kMinQueryTileSize = 8;
423  constexpr int kMaxQueryTileSize = 128;
424  constexpr int kThrustMemSize = 16384;
425 
426  int nprobe = topQueryToCentroid.getSize(1);
427 
428  auto& mem = res->getMemoryManagerCurrentDevice();
429  auto stream = res->getDefaultStreamCurrentDevice();
430 
431  // Make a reservation for Thrust to do its dirty work (global memory
432  // cross-block reduction space); hopefully this is large enough.
433  DeviceTensor<char, 1, true> thrustMem1(
434  mem, {kThrustMemSize}, stream);
435  DeviceTensor<char, 1, true> thrustMem2(
436  mem, {kThrustMemSize}, stream);
437  DeviceTensor<char, 1, true>* thrustMem[2] =
438  {&thrustMem1, &thrustMem2};
439 
440  // How much temporary storage is available?
441  // If possible, we'd like to fit within the space available.
442  size_t sizeAvailable = mem.getSizeAvailable();
443 
444  // We run two passes of heap selection
445  // This is the size of the first-level heap passes
446  constexpr int kNProbeSplit = 8;
447  int pass2Chunks = std::min(nprobe, kNProbeSplit);
448 
449  size_t sizeForFirstSelectPass =
450  pass2Chunks * k * (sizeof(float) + sizeof(int));
451 
452  // How much temporary storage we need per each query
453  size_t sizePerQuery =
454  2 * // streams
455  ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
456  nprobe * maxListLength * sizeof(float) + // allDistances
457  // residual distances
458  nprobe * numSubQuantizers * numSubQuantizerCodes * sizeof(float) +
459  sizeForFirstSelectPass);
460 
461  int queryTileSize = (int) (sizeAvailable / sizePerQuery);
462 
463  if (queryTileSize < kMinQueryTileSize) {
464  queryTileSize = kMinQueryTileSize;
465  } else if (queryTileSize > kMaxQueryTileSize) {
466  queryTileSize = kMaxQueryTileSize;
467  }
468 
469  // FIXME: we should adjust queryTileSize to deal with this, since
470  // indexing is in int32
471  FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
472  std::numeric_limits<int>::max());
473 
474  // Temporary memory buffers
475  // Make sure there is space prior to the start which will be 0, and
476  // will handle the boundary condition without branches
477  DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
478  mem, {queryTileSize * nprobe + 1}, stream);
479  DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
480  mem, {queryTileSize * nprobe + 1}, stream);
481 
482  DeviceTensor<int, 2, true> prefixSumOffsets1(
483  prefixSumOffsetSpace1[1].data(),
484  {queryTileSize, nprobe});
485  DeviceTensor<int, 2, true> prefixSumOffsets2(
486  prefixSumOffsetSpace2[1].data(),
487  {queryTileSize, nprobe});
488  DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
489  {&prefixSumOffsets1, &prefixSumOffsets2};
490 
491  // Make sure the element before prefixSumOffsets is 0, since we
492  // depend upon simple, boundary-less indexing to get proper results
493  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
494  0,
495  sizeof(int),
496  stream));
497  CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
498  0,
499  sizeof(int),
500  stream));
501 
502  int codeDistanceTypeSize = sizeof(float);
503 #ifdef FAISS_USE_FLOAT16
504  if (useFloat16Lookup) {
505  codeDistanceTypeSize = sizeof(half);
506  }
507 #else
508  FAISS_ASSERT(!useFloat16Lookup);
509  int codeSize = sizeof(float);
510 #endif
511 
512  int totalCodeDistancesSize =
513  queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes *
514  codeDistanceTypeSize;
515 
516  DeviceTensor<char, 1, true> codeDistances1Mem(
517  mem, {totalCodeDistancesSize}, stream);
518  NoTypeTensor<4, true> codeDistances1(
519  codeDistances1Mem.data(),
520  codeDistanceTypeSize,
521  {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
522 
523  DeviceTensor<char, 1, true> codeDistances2Mem(
524  mem, {totalCodeDistancesSize}, stream);
525  NoTypeTensor<4, true> codeDistances2(
526  codeDistances2Mem.data(),
527  codeDistanceTypeSize,
528  {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
529 
530  NoTypeTensor<4, true>* codeDistances[2] =
531  {&codeDistances1, &codeDistances2};
532 
533  DeviceTensor<float, 1, true> allDistances1(
534  mem, {queryTileSize * nprobe * maxListLength}, stream);
535  DeviceTensor<float, 1, true> allDistances2(
536  mem, {queryTileSize * nprobe * maxListLength}, stream);
537  DeviceTensor<float, 1, true>* allDistances[2] =
538  {&allDistances1, &allDistances2};
539 
540  DeviceTensor<float, 3, true> heapDistances1(
541  mem, {queryTileSize, pass2Chunks, k}, stream);
542  DeviceTensor<float, 3, true> heapDistances2(
543  mem, {queryTileSize, pass2Chunks, k}, stream);
544  DeviceTensor<float, 3, true>* heapDistances[2] =
545  {&heapDistances1, &heapDistances2};
546 
547  DeviceTensor<int, 3, true> heapIndices1(
548  mem, {queryTileSize, pass2Chunks, k}, stream);
549  DeviceTensor<int, 3, true> heapIndices2(
550  mem, {queryTileSize, pass2Chunks, k}, stream);
551  DeviceTensor<int, 3, true>* heapIndices[2] =
552  {&heapIndices1, &heapIndices2};
553 
554  auto streams = res->getAlternateStreamsCurrentDevice();
555  streamWait(streams, {stream});
556 
557  int curStream = 0;
558 
559  for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
560  int numQueriesInTile =
561  std::min(queryTileSize, queries.getSize(0) - query);
562 
563  auto prefixSumOffsetsView =
564  prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
565 
566  auto codeDistancesView =
567  codeDistances[curStream]->narrowOutermost(0, numQueriesInTile);
568  auto coarseIndicesView =
569  topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
570  auto queryView =
571  queries.narrowOutermost(query, numQueriesInTile);
572 
573  auto heapDistancesView =
574  heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
575  auto heapIndicesView =
576  heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
577 
578  auto outDistanceView =
579  outDistances.narrowOutermost(query, numQueriesInTile);
580  auto outIndicesView =
581  outIndices.narrowOutermost(query, numQueriesInTile);
582 
583  runMultiPassTile(queryView,
584  centroids,
585  pqCentroidsInnermostCode,
586  codeDistancesView,
587  coarseIndicesView,
588  useFloat16Lookup,
589  bytesPerCode,
590  numSubQuantizers,
591  numSubQuantizerCodes,
592  listCodes,
593  listIndices,
594  indicesOptions,
595  listLengths,
596  *thrustMem[curStream],
597  prefixSumOffsetsView,
598  *allDistances[curStream],
599  heapDistancesView,
600  heapIndicesView,
601  k,
602  outDistanceView,
603  outIndicesView,
604  streams[curStream]);
605 
606  curStream = (curStream + 1) % 2;
607  }
608 
609  streamWait({stream}, streams);
610 }
611 
612 } } // namespace
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:173
Our tensor type.
Definition: Tensor.cuh:30
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:221