Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQCodeDistances.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "PQCodeDistances.cuh"
10 
11 #include "BroadcastSum.cuh"
12 #include "Distance.cuh"
13 #include "L2Norm.cuh"
14 #include "../utils/DeviceDefs.cuh"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/Float16.cuh"
17 #include "../utils/MatrixMult.cuh"
18 #include "../utils/PtxUtils.cuh"
19 #include "../utils/StaticUtils.h"
20 #include "../utils/Transpose.cuh"
21 
22 namespace faiss { namespace gpu {
23 
24 template <typename T>
25 struct Converter {
26 };
27 
28 #ifdef FAISS_USE_FLOAT16
29 template <>
30 struct Converter<half> {
31  inline static __device__ half to(float v) { return __float2half(v); }
32 };
33 #endif
34 
35 template <>
36 struct Converter<float> {
37  inline static __device__ float to(float v) { return v; }
38 };
39 
40 // Kernel responsible for calculating distance from residual vector to
41 // each product quantizer code centroid
42 template <typename OutCodeT, int DimsPerSubQuantizer>
43 __global__ void
44 __launch_bounds__(288, 4)
45 pqCodeDistances(Tensor<float, 2, true> queries,
46  int queriesPerBlock,
47  Tensor<float, 2, true> coarseCentroids,
48  Tensor<float, 3, true> pqCentroids,
49  Tensor<int, 2, true> topQueryToCentroid,
50  // (query id)(coarse)(subquantizer)(code) -> dist
51  Tensor<OutCodeT, 4, true> outCodeDistances) {
52  const auto numSubQuantizers = pqCentroids.getSize(0);
53  const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
54  assert(DimsPerSubQuantizer == dimsPerSubQuantizer);
55  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
56 
57  bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer;
58  int loadingThreadId = threadIdx.x - codesPerSubQuantizer;
59 
60  extern __shared__ float smem[];
61 
62  // Each thread calculates a single code
63  float subQuantizerData[DimsPerSubQuantizer];
64 
65  auto code = threadIdx.x;
66  auto subQuantizer = blockIdx.y;
67 
68  // Each thread will load the pq centroid data for the code that it
69  // is processing
70 #pragma unroll
71  for (int i = 0; i < DimsPerSubQuantizer; ++i) {
72  subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
73  }
74 
75  // Where we store our query vector
76  float* smemQuery = smem;
77 
78  // Where we store our residual vector; this is double buffered so we
79  // can be loading the next one while processing the current one
80  float* smemResidual1 = &smemQuery[DimsPerSubQuantizer];
81  float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer];
82 
83  // Where we pre-load the coarse centroid IDs
84  int* coarseIds = (int*) &smemResidual2[DimsPerSubQuantizer];
85 
86  // Each thread is calculating the distance for a single code,
87  // performing the reductions locally
88 
89  // Handle multiple queries per block
90  auto startQueryId = blockIdx.x * queriesPerBlock;
91  auto numQueries = queries.getSize(0) - startQueryId;
92  if (numQueries > queriesPerBlock) {
93  numQueries = queriesPerBlock;
94  }
95 
96  for (int query = 0; query < numQueries; ++query) {
97  auto queryId = startQueryId + query;
98 
99  auto querySubQuantizer =
100  queries[queryId][subQuantizer * DimsPerSubQuantizer].data();
101 
102  // Load current query vector
103  for (int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) {
104  smemQuery[i] = querySubQuantizer[i];
105  }
106 
107  // Load list of coarse centroids found
108  for (int i = threadIdx.x;
109  i < topQueryToCentroid.getSize(1); i += blockDim.x) {
110  coarseIds[i] = topQueryToCentroid[queryId][i];
111  }
112 
113  // We need coarseIds below
114  // FIXME: investigate loading separately, so we don't need this
115  __syncthreads();
116 
117  // Preload first buffer of residual data
118  if (isLoadingThread) {
119  for (int i = loadingThreadId;
120  i < DimsPerSubQuantizer;
121  i += blockDim.x - codesPerSubQuantizer) {
122  auto coarseId = coarseIds[0];
123  // In case NaNs were in the original query data
124  coarseId = coarseId == -1 ? 0 : coarseId;
125  auto coarseCentroidSubQuantizer =
126  coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
127 
128  smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
129  }
130  }
131 
132  // The block walks the list for a single query
133  for (int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
134  // Wait for smemResidual1 to be loaded
135  __syncthreads();
136 
137  if (isLoadingThread) {
138  // Preload second buffer of residual data
139  for (int i = loadingThreadId;
140  i < DimsPerSubQuantizer;
141  i += blockDim.x - codesPerSubQuantizer) {
142  // FIXME: try always making this centroid id 0 so we can
143  // terminate
144  if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
145  auto coarseId = coarseIds[coarse + 1];
146  // In case NaNs were in the original query data
147  coarseId = coarseId == -1 ? 0 : coarseId;
148 
149  auto coarseCentroidSubQuantizer =
150  coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
151 
152  smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
153  }
154  }
155  } else {
156  // These are the processing threads
157  float dist = 0.0f;
158 
159  constexpr int kUnroll = 4;
160  constexpr int kRemainder = DimsPerSubQuantizer % kUnroll;
161  constexpr int kRemainderBase = DimsPerSubQuantizer - kRemainder;
162  float vals[kUnroll];
163 
164  // Calculate residual - pqCentroid for each dim that we're
165  // processing
166 
167  // Unrolled loop
168 #pragma unroll
169  for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
170 
171 #pragma unroll
172  for (int j = 0; j < kUnroll; ++j) {
173  vals[j] = smemResidual1[i * kUnroll + j];
174  }
175 
176 #pragma unroll
177  for (int j = 0; j < kUnroll; ++j) {
178  vals[j] -= subQuantizerData[i * kUnroll + j];
179  }
180 
181 #pragma unroll
182  for (int j = 0; j < kUnroll; ++j) {
183  vals[j] *= vals[j];
184  }
185 
186 #pragma unroll
187  for (int j = 0; j < kUnroll; ++j) {
188  dist += vals[j];
189  }
190  }
191 
192  // Remainder loop
193 #pragma unroll
194  for (int j = 0; j < kRemainder; ++j) {
195  vals[j] = smemResidual1[kRemainderBase + j];
196  }
197 
198 #pragma unroll
199  for (int j = 0; j < kRemainder; ++j) {
200  vals[j] -= subQuantizerData[kRemainderBase + j];
201  }
202 
203 #pragma unroll
204  for (int j = 0; j < kRemainder; ++j) {
205  vals[j] *= vals[j];
206  }
207 
208 #pragma unroll
209  for (int j = 0; j < kRemainder; ++j) {
210  dist += vals[j];
211  }
212 
213  // We have the distance for our code; write it out
214  outCodeDistances[queryId][coarse][subQuantizer][code] =
215  Converter<OutCodeT>::to(dist);
216  } // !isLoadingThread
217 
218  // Swap residual buffers
219  float* tmp = smemResidual1;
220  smemResidual1 = smemResidual2;
221  smemResidual2 = tmp;
222  }
223  }
224 }
225 
226 __global__ void
227 residualVector(Tensor<float, 2, true> queries,
228  Tensor<float, 2, true> coarseCentroids,
229  Tensor<int, 2, true> topQueryToCentroid,
230  int numSubDim,
231  // output is transposed:
232  // (sub q)(query id)(centroid id)(sub dim)
233  Tensor<float, 4, true> residual) {
234  // block x is query id
235  // block y is centroid id
236  // thread x is dim
237  auto queryId = blockIdx.x;
238  auto centroidId = blockIdx.y;
239 
240  int realCentroidId = topQueryToCentroid[queryId][centroidId];
241 
242  for (int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
243  float q = queries[queryId][dim];
244  float c = coarseCentroids[realCentroidId][dim];
245 
246  residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] =
247  q - c;
248  }
249 }
250 
251 void
252 runResidualVector(Tensor<float, 3, true>& pqCentroids,
253  Tensor<float, 2, true>& queries,
254  Tensor<float, 2, true>& coarseCentroids,
255  Tensor<int, 2, true>& topQueryToCentroid,
256  Tensor<float, 4, true>& residual,
257  cudaStream_t stream) {
258  auto grid =
259  dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
260  auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
261 
262  residualVector<<<grid, block, 0, stream>>>(
263  queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
264  residual);
265 
266  CUDA_TEST_ERROR();
267 }
268 
269 void
270 runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
271  Tensor<float, 2, true>& queries,
272  Tensor<float, 2, true>& coarseCentroids,
273  Tensor<int, 2, true>& topQueryToCentroid,
274  NoTypeTensor<4, true>& outCodeDistances,
275  bool useFloat16Lookup,
276  DeviceMemory& mem,
277  cublasHandle_t handle,
278  cudaStream_t stream) {
279  // Calculate (q - c) residual vector
280  // (sub q)(query id)(centroid id)(sub dim)
281  DeviceTensor<float, 4, true> residual(
282  mem,
283  {pqCentroids.getSize(0),
284  topQueryToCentroid.getSize(0),
285  topQueryToCentroid.getSize(1),
286  pqCentroids.getSize(1)},
287  stream);
288 
289  runResidualVector(pqCentroids, queries,
290  coarseCentroids, topQueryToCentroid,
291  residual, stream);
292 
293  // Calculate ||q - c||^2
294  DeviceTensor<float, 1, true> residualNorms(
295  mem,
296  {pqCentroids.getSize(0) *
297  topQueryToCentroid.getSize(0) *
298  topQueryToCentroid.getSize(1)},
299  stream);
300 
301  auto residualView2 = residual.view<2>(
302  {pqCentroids.getSize(0) *
303  topQueryToCentroid.getSize(0) *
304  topQueryToCentroid.getSize(1),
305  pqCentroids.getSize(1)});
306 
307  runL2Norm(residualView2, true, residualNorms, true, stream);
308 
309  // Perform a batch MM:
310  // (sub q) x {(q * c)(sub dim) x (sub dim)(code)} =>
311  // (sub q) x {(q * c)(code)}
312  auto residualView3 = residual.view<3>(
313  {pqCentroids.getSize(0),
314  topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
315  pqCentroids.getSize(1)});
316 
317  DeviceTensor<float, 3, true> residualDistance(
318  mem,
319  {pqCentroids.getSize(0),
320  topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
321  pqCentroids.getSize(2)},
322  stream);
323 
324  runIteratedMatrixMult(residualDistance, false,
325  residualView3, false,
326  pqCentroids, false,
327  -2.0f, 0.0f,
328  handle,
329  stream);
330 
331  // Sum ||q - c||^2 along rows
332  auto residualDistanceView2 = residualDistance.view<2>(
333  {pqCentroids.getSize(0) *
334  topQueryToCentroid.getSize(0) *
335  topQueryToCentroid.getSize(1),
336  pqCentroids.getSize(2)});
337 
338  runSumAlongRows(residualNorms, residualDistanceView2, false, stream);
339 
340  Tensor<float, 4, true> outCodeDistancesF;
341  DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
342 
343 #ifdef FAISS_USE_FLOAT16
344  if (useFloat16Lookup) {
345  outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
346  mem, {outCodeDistances.getSize(0),
347  outCodeDistances.getSize(1),
348  outCodeDistances.getSize(2),
349  outCodeDistances.getSize(3)},
350  stream);
351 
352  outCodeDistancesF = outCodeDistancesFloatMem;
353  }
354 #endif
355 
356  if (!useFloat16Lookup) {
357  outCodeDistancesF = outCodeDistances.toTensor<float>();
358  }
359 
360  // Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which
361  // is where we build our output distances)
362  auto outCodeDistancesView = outCodeDistancesF.view<3>(
363  {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
364  outCodeDistances.getSize(2),
365  outCodeDistances.getSize(3)});
366 
367  runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
368 
369  // Calculate code norms per each sub-dim
370  // (sub q)(sub dim)(code) is pqCentroids
371  // transpose to (sub q)(code)(sub dim)
372  DeviceTensor<float, 3, true> pqCentroidsTranspose(
373  mem,
374  {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)},
375  stream);
376 
377  runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
378 
379  auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
380  {pqCentroids.getSize(0) * pqCentroids.getSize(2),
381  pqCentroids.getSize(1)});
382 
383  DeviceTensor<float, 1, true> pqCentroidsNorm(
384  mem,
385  {pqCentroids.getSize(0) * pqCentroids.getSize(2)},
386  stream);
387 
388  runL2Norm(pqCentroidsTransposeView, true, pqCentroidsNorm, true, stream);
389 
390  // View output as (q * c)(sub q * code), and add centroid norm to
391  // each row
392  auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
393  {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
394  outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
395 
396  runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
397 
398 #ifdef FAISS_USE_FLOAT16
399  if (useFloat16Lookup) {
400  // Need to convert back
401  auto outCodeDistancesH = outCodeDistances.toTensor<half>();
402  toHalf(stream, outCodeDistancesF, outCodeDistancesH);
403  }
404 #endif
405 }
406 
407 void
408 runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
409  Tensor<float, 2, true>& queries,
410  Tensor<float, 2, true>& coarseCentroids,
411  Tensor<int, 2, true>& topQueryToCentroid,
412  NoTypeTensor<4, true>& outCodeDistances,
413  bool useFloat16Lookup,
414  cudaStream_t stream) {
415  const auto numSubQuantizers = pqCentroids.getSize(0);
416  const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
417  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
418 
419  // FIXME: tune
420  // Reuse of pq centroid data is based on both # of queries * nprobe,
421  // and we should really be tiling in both dimensions
422  constexpr int kQueriesPerBlock = 8;
423 
424  auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock),
425  numSubQuantizers);
426 
427  // Reserve one block of threads for double buffering
428  // FIXME: probably impractical for large # of dims?
429  auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize);
430  auto block = dim3(codesPerSubQuantizer + loadingThreads);
431 
432  auto smem = (3 * dimsPerSubQuantizer) * sizeof(float)
433  + topQueryToCentroid.getSize(1) * sizeof(int);
434 
435 #ifdef FAISS_USE_FLOAT16
436 #define CODE_DISTANCE(DIMS) \
437  do { \
438  if (useFloat16Lookup) { \
439  auto outCodeDistancesT = outCodeDistances.toTensor<half>(); \
440  \
441  pqCodeDistances<half, DIMS><<<grid, block, smem, stream>>>( \
442  queries, kQueriesPerBlock, \
443  coarseCentroids, pqCentroids, \
444  topQueryToCentroid, outCodeDistancesT); \
445  } else { \
446  auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
447  \
448  pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
449  queries, kQueriesPerBlock, \
450  coarseCentroids, pqCentroids, \
451  topQueryToCentroid, outCodeDistancesT); \
452  } \
453  } while (0)
454 #else
455 #define CODE_DISTANCE(DIMS) \
456  do { \
457  if (!useFloat16Lookup) { \
458  auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
459  \
460  pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
461  queries, kQueriesPerBlock, \
462  coarseCentroids, pqCentroids, \
463  topQueryToCentroid, outCodeDistancesT); \
464  } \
465  } while (0)
466 #endif
467 
468  switch (dimsPerSubQuantizer) {
469  case 1:
470  CODE_DISTANCE(1);
471  break;
472  case 2:
473  CODE_DISTANCE(2);
474  break;
475  case 3:
476  CODE_DISTANCE(3);
477  break;
478  case 4:
479  CODE_DISTANCE(4);
480  break;
481  case 6:
482  CODE_DISTANCE(6);
483  break;
484  case 8:
485  CODE_DISTANCE(8);
486  break;
487  case 10:
488  CODE_DISTANCE(10);
489  break;
490  case 12:
491  CODE_DISTANCE(12);
492  break;
493  case 16:
494  CODE_DISTANCE(16);
495  break;
496  case 20:
497  CODE_DISTANCE(20);
498  break;
499  case 24:
500  CODE_DISTANCE(24);
501  break;
502  case 28:
503  CODE_DISTANCE(28);
504  break;
505  case 32:
506  CODE_DISTANCE(32);
507  break;
508  // FIXME: larger sizes require too many registers - we need the
509  // MM implementation working
510  default:
511  FAISS_ASSERT(false);
512  break;
513  }
514 
515 #undef CODE_DISTANCE
516 
517  CUDA_TEST_ERROR();
518 }
519 
520 } } // namespace
Our tensor type.
Definition: Tensor.cuh:28