Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQCodeDistances.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "PQCodeDistances.cuh"
11 
12 #include "BroadcastSum.cuh"
13 #include "Distance.cuh"
14 #include "L2Norm.cuh"
15 #include "../utils/DeviceDefs.cuh"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Float16.cuh"
18 #include "../utils/MatrixMult.cuh"
19 #include "../utils/PtxUtils.cuh"
20 #include "../utils/StaticUtils.h"
21 #include "../utils/Transpose.cuh"
22 
23 namespace faiss { namespace gpu {
24 
25 template <typename T>
26 struct Converter {
27 };
28 
29 #ifdef FAISS_USE_FLOAT16
30 template <>
31 struct Converter<half> {
32  inline static __device__ half to(float v) { return __float2half(v); }
33 };
34 #endif
35 
36 template <>
37 struct Converter<float> {
38  inline static __device__ float to(float v) { return v; }
39 };
40 
41 // Kernel responsible for calculating distance from residual vector to
42 // each product quantizer code centroid
43 template <typename OutCodeT, int DimsPerSubQuantizer>
44 __global__ void
45 __launch_bounds__(288, 4)
46 pqCodeDistances(Tensor<float, 2, true> queries,
47  int queriesPerBlock,
48  Tensor<float, 2, true> coarseCentroids,
49  Tensor<float, 3, true> pqCentroids,
50  Tensor<int, 2, true> topQueryToCentroid,
51  // (query id)(coarse)(subquantizer)(code) -> dist
52  Tensor<OutCodeT, 4, true> outCodeDistances) {
53  const auto numSubQuantizers = pqCentroids.getSize(0);
54  const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
55  assert(DimsPerSubQuantizer == dimsPerSubQuantizer);
56  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
57 
58  bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer;
59  int loadingThreadId = threadIdx.x - codesPerSubQuantizer;
60 
61  extern __shared__ float smem[];
62 
63  // Each thread calculates a single code
64  float subQuantizerData[DimsPerSubQuantizer];
65 
66  auto code = threadIdx.x;
67  auto subQuantizer = blockIdx.y;
68 
69  // Each thread will load the pq centroid data for the code that it
70  // is processing
71 #pragma unroll
72  for (int i = 0; i < DimsPerSubQuantizer; ++i) {
73  subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
74  }
75 
76  // Where we store our query vector
77  float* smemQuery = smem;
78 
79  // Where we store our residual vector; this is double buffered so we
80  // can be loading the next one while processing the current one
81  float* smemResidual1 = &smemQuery[DimsPerSubQuantizer];
82  float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer];
83 
84  // Where we pre-load the coarse centroid IDs
85  int* coarseIds = (int*) &smemResidual2[DimsPerSubQuantizer];
86 
87  // Each thread is calculating the distance for a single code,
88  // performing the reductions locally
89 
90  // Handle multiple queries per block
91  auto startQueryId = blockIdx.x * queriesPerBlock;
92  auto numQueries = queries.getSize(0) - startQueryId;
93  if (numQueries > queriesPerBlock) {
94  numQueries = queriesPerBlock;
95  }
96 
97  for (int query = 0; query < numQueries; ++query) {
98  auto queryId = startQueryId + query;
99 
100  auto querySubQuantizer =
101  queries[queryId][subQuantizer * DimsPerSubQuantizer].data();
102 
103  // Load current query vector
104  for (int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) {
105  smemQuery[i] = querySubQuantizer[i];
106  }
107 
108  // Load list of coarse centroids found
109  for (int i = threadIdx.x;
110  i < topQueryToCentroid.getSize(1); i += blockDim.x) {
111  coarseIds[i] = topQueryToCentroid[queryId][i];
112  }
113 
114  // We need coarseIds below
115  // FIXME: investigate loading separately, so we don't need this
116  __syncthreads();
117 
118  // Preload first buffer of residual data
119  if (isLoadingThread) {
120  for (int i = loadingThreadId;
121  i < DimsPerSubQuantizer;
122  i += blockDim.x - codesPerSubQuantizer) {
123  auto coarseId = coarseIds[0];
124  // In case NaNs were in the original query data
125  coarseId = coarseId == -1 ? 0 : coarseId;
126  auto coarseCentroidSubQuantizer =
127  coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
128 
129  smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
130  }
131  }
132 
133  // The block walks the list for a single query
134  for (int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
135  // Wait for smemResidual1 to be loaded
136  __syncthreads();
137 
138  if (isLoadingThread) {
139  // Preload second buffer of residual data
140  for (int i = loadingThreadId;
141  i < DimsPerSubQuantizer;
142  i += blockDim.x - codesPerSubQuantizer) {
143  // FIXME: try always making this centroid id 0 so we can
144  // terminate
145  if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
146  auto coarseId = coarseIds[coarse + 1];
147  // In case NaNs were in the original query data
148  coarseId = coarseId == -1 ? 0 : coarseId;
149 
150  auto coarseCentroidSubQuantizer =
151  coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
152 
153  smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
154  }
155  }
156  } else {
157  // These are the processing threads
158  float dist = 0.0f;
159 
160  constexpr int kUnroll = 4;
161  constexpr int kRemainder = DimsPerSubQuantizer % kUnroll;
162  constexpr int kRemainderBase = DimsPerSubQuantizer - kRemainder;
163  float vals[kUnroll];
164 
165  // Calculate residual - pqCentroid for each dim that we're
166  // processing
167 
168  // Unrolled loop
169 #pragma unroll
170  for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
171 
172 #pragma unroll
173  for (int j = 0; j < kUnroll; ++j) {
174  vals[j] = smemResidual1[i * kUnroll + j];
175  }
176 
177 #pragma unroll
178  for (int j = 0; j < kUnroll; ++j) {
179  vals[j] -= subQuantizerData[i * kUnroll + j];
180  }
181 
182 #pragma unroll
183  for (int j = 0; j < kUnroll; ++j) {
184  vals[j] *= vals[j];
185  }
186 
187 #pragma unroll
188  for (int j = 0; j < kUnroll; ++j) {
189  dist += vals[j];
190  }
191  }
192 
193  // Remainder loop
194 #pragma unroll
195  for (int j = 0; j < kRemainder; ++j) {
196  vals[j] = smemResidual1[kRemainderBase + j];
197  }
198 
199 #pragma unroll
200  for (int j = 0; j < kRemainder; ++j) {
201  vals[j] -= subQuantizerData[kRemainderBase + j];
202  }
203 
204 #pragma unroll
205  for (int j = 0; j < kRemainder; ++j) {
206  vals[j] *= vals[j];
207  }
208 
209 #pragma unroll
210  for (int j = 0; j < kRemainder; ++j) {
211  dist += vals[j];
212  }
213 
214  // We have the distance for our code; write it out
215  outCodeDistances[queryId][coarse][subQuantizer][code] =
216  Converter<OutCodeT>::to(dist);
217  } // !isLoadingThread
218 
219  // Swap residual buffers
220  float* tmp = smemResidual1;
221  smemResidual1 = smemResidual2;
222  smemResidual2 = tmp;
223  }
224  }
225 }
226 
227 __global__ void
228 residualVector(Tensor<float, 2, true> queries,
229  Tensor<float, 2, true> coarseCentroids,
230  Tensor<int, 2, true> topQueryToCentroid,
231  int numSubDim,
232  // output is transposed:
233  // (sub q)(query id)(centroid id)(sub dim)
234  Tensor<float, 4, true> residual) {
235  // block x is query id
236  // block y is centroid id
237  // thread x is dim
238  auto queryId = blockIdx.x;
239  auto centroidId = blockIdx.y;
240 
241  int realCentroidId = topQueryToCentroid[queryId][centroidId];
242 
243  for (int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
244  float q = queries[queryId][dim];
245  float c = coarseCentroids[realCentroidId][dim];
246 
247  residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] =
248  q - c;
249  }
250 }
251 
252 void
253 runResidualVector(Tensor<float, 3, true>& pqCentroids,
254  Tensor<float, 2, true>& queries,
255  Tensor<float, 2, true>& coarseCentroids,
256  Tensor<int, 2, true>& topQueryToCentroid,
257  Tensor<float, 4, true>& residual,
258  cudaStream_t stream) {
259  auto grid =
260  dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
261  auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
262 
263  residualVector<<<grid, block, 0, stream>>>(
264  queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
265  residual);
266 
267  CUDA_TEST_ERROR();
268 }
269 
270 void
271 runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
272  Tensor<float, 2, true>& queries,
273  Tensor<float, 2, true>& coarseCentroids,
274  Tensor<int, 2, true>& topQueryToCentroid,
275  NoTypeTensor<4, true>& outCodeDistances,
276  bool useFloat16Lookup,
277  DeviceMemory& mem,
278  cublasHandle_t handle,
279  cudaStream_t stream) {
280  // Calculate (q - c) residual vector
281  // (sub q)(query id)(centroid id)(sub dim)
282  DeviceTensor<float, 4, true> residual(
283  mem,
284  {pqCentroids.getSize(0),
285  topQueryToCentroid.getSize(0),
286  topQueryToCentroid.getSize(1),
287  pqCentroids.getSize(1)},
288  stream);
289 
290  runResidualVector(pqCentroids, queries,
291  coarseCentroids, topQueryToCentroid,
292  residual, stream);
293 
294  // Calculate ||q - c||^2
295  DeviceTensor<float, 1, true> residualNorms(
296  mem,
297  {pqCentroids.getSize(0) *
298  topQueryToCentroid.getSize(0) *
299  topQueryToCentroid.getSize(1)},
300  stream);
301 
302  auto residualView2 = residual.view<2>(
303  {pqCentroids.getSize(0) *
304  topQueryToCentroid.getSize(0) *
305  topQueryToCentroid.getSize(1),
306  pqCentroids.getSize(1)});
307 
308  runL2Norm(residualView2, residualNorms, true, stream);
309 
310  // Perform a batch MM:
311  // (sub q) x {(q * c)(sub dim) x (sub dim)(code)} =>
312  // (sub q) x {(q * c)(code)}
313  auto residualView3 = residual.view<3>(
314  {pqCentroids.getSize(0),
315  topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
316  pqCentroids.getSize(1)});
317 
318  DeviceTensor<float, 3, true> residualDistance(
319  mem,
320  {pqCentroids.getSize(0),
321  topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
322  pqCentroids.getSize(2)},
323  stream);
324 
325  runIteratedMatrixMult(residualDistance, false,
326  residualView3, false,
327  pqCentroids, false,
328  -2.0f, 0.0f,
329  handle,
330  stream);
331 
332  // Sum ||q - c||^2 along rows
333  auto residualDistanceView2 = residualDistance.view<2>(
334  {pqCentroids.getSize(0) *
335  topQueryToCentroid.getSize(0) *
336  topQueryToCentroid.getSize(1),
337  pqCentroids.getSize(2)});
338 
339  runSumAlongRows(residualNorms, residualDistanceView2, false, stream);
340 
341  Tensor<float, 4, true> outCodeDistancesF;
342  DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
343 
344 #ifdef FAISS_USE_FLOAT16
345  if (useFloat16Lookup) {
346  outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
347  mem, {outCodeDistances.getSize(0),
348  outCodeDistances.getSize(1),
349  outCodeDistances.getSize(2),
350  outCodeDistances.getSize(3)},
351  stream);
352 
353  outCodeDistancesF = outCodeDistancesFloatMem;
354  }
355 #endif
356 
357  if (!useFloat16Lookup) {
358  outCodeDistancesF = outCodeDistances.toTensor<float>();
359  }
360 
361  // Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which
362  // is where we build our output distances)
363  auto outCodeDistancesView = outCodeDistancesF.view<3>(
364  {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
365  outCodeDistances.getSize(2),
366  outCodeDistances.getSize(3)});
367 
368  runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
369 
370  // Calculate code norms per each sub-dim
371  // (sub q)(sub dim)(code) is pqCentroids
372  // transpose to (sub q)(code)(sub dim)
373  DeviceTensor<float, 3, true> pqCentroidsTranspose(
374  mem,
375  {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)},
376  stream);
377 
378  runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
379 
380  auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
381  {pqCentroids.getSize(0) * pqCentroids.getSize(2),
382  pqCentroids.getSize(1)});
383 
384  DeviceTensor<float, 1, true> pqCentroidsNorm(
385  mem,
386  {pqCentroids.getSize(0) * pqCentroids.getSize(2)},
387  stream);
388 
389  runL2Norm(pqCentroidsTransposeView, pqCentroidsNorm, true, stream);
390 
391  // View output as (q * c)(sub q * code), and add centroid norm to
392  // each row
393  auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
394  {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
395  outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
396 
397  runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
398 
399 #ifdef FAISS_USE_FLOAT16
400  if (useFloat16Lookup) {
401  // Need to convert back
402  auto outCodeDistancesH = outCodeDistances.toTensor<half>();
403  toHalf(stream, outCodeDistancesF, outCodeDistancesH);
404  }
405 #endif
406 }
407 
408 void
409 runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
410  Tensor<float, 2, true>& queries,
411  Tensor<float, 2, true>& coarseCentroids,
412  Tensor<int, 2, true>& topQueryToCentroid,
413  NoTypeTensor<4, true>& outCodeDistances,
414  bool useFloat16Lookup,
415  cudaStream_t stream) {
416  const auto numSubQuantizers = pqCentroids.getSize(0);
417  const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
418  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
419 
420  // FIXME: tune
421  // Reuse of pq centroid data is based on both # of queries * nprobe,
422  // and we should really be tiling in both dimensions
423  constexpr int kQueriesPerBlock = 8;
424 
425  auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock),
426  numSubQuantizers);
427 
428  // Reserve one block of threads for double buffering
429  // FIXME: probably impractical for large # of dims?
430  auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize);
431  auto block = dim3(codesPerSubQuantizer + loadingThreads);
432 
433  auto smem = (3 * dimsPerSubQuantizer) * sizeof(float)
434  + topQueryToCentroid.getSize(1) * sizeof(int);
435 
436 #ifdef FAISS_USE_FLOAT16
437 #define CODE_DISTANCE(DIMS) \
438  do { \
439  if (useFloat16Lookup) { \
440  auto outCodeDistancesT = outCodeDistances.toTensor<half>(); \
441  \
442  pqCodeDistances<half, DIMS><<<grid, block, smem, stream>>>( \
443  queries, kQueriesPerBlock, \
444  coarseCentroids, pqCentroids, \
445  topQueryToCentroid, outCodeDistancesT); \
446  } else { \
447  auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
448  \
449  pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
450  queries, kQueriesPerBlock, \
451  coarseCentroids, pqCentroids, \
452  topQueryToCentroid, outCodeDistancesT); \
453  } \
454  } while (0)
455 #else
456 #define CODE_DISTANCE(DIMS) \
457  do { \
458  if (!useFloat16Lookup) { \
459  auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
460  \
461  pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
462  queries, kQueriesPerBlock, \
463  coarseCentroids, pqCentroids, \
464  topQueryToCentroid, outCodeDistancesT); \
465  } \
466  } while (0)
467 #endif
468 
469  switch (dimsPerSubQuantizer) {
470  case 1:
471  CODE_DISTANCE(1);
472  break;
473  case 2:
474  CODE_DISTANCE(2);
475  break;
476  case 3:
477  CODE_DISTANCE(3);
478  break;
479  case 4:
480  CODE_DISTANCE(4);
481  break;
482  case 6:
483  CODE_DISTANCE(6);
484  break;
485  case 8:
486  CODE_DISTANCE(8);
487  break;
488  case 10:
489  CODE_DISTANCE(10);
490  break;
491  case 12:
492  CODE_DISTANCE(12);
493  break;
494  case 16:
495  CODE_DISTANCE(16);
496  break;
497  case 20:
498  CODE_DISTANCE(20);
499  break;
500  case 24:
501  CODE_DISTANCE(24);
502  break;
503  case 28:
504  CODE_DISTANCE(28);
505  break;
506  case 32:
507  CODE_DISTANCE(32);
508  break;
509  // FIXME: larger sizes require too many registers - we need the
510  // MM implementation working
511  default:
512  FAISS_ASSERT(false);
513  break;
514  }
515 
516 #undef CODE_DISTANCE
517 
518  CUDA_TEST_ERROR();
519 }
520 
521 } } // namespace
Our tensor type.
Definition: Tensor.cuh:29