Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQCodeDistances.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "PQCodeDistances.cuh"
12 
13 #include "BroadcastSum.cuh"
14 #include "Distance.cuh"
15 #include "L2Norm.cuh"
16 #include "../utils/DeviceDefs.cuh"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/Float16.cuh"
19 #include "../utils/MatrixMult.cuh"
20 #include "../utils/PtxUtils.cuh"
21 #include "../utils/StaticUtils.h"
22 #include "../utils/Transpose.cuh"
23 
24 namespace faiss { namespace gpu {
25 
26 template <typename T>
27 struct Converter {
28 };
29 
30 #ifdef FAISS_USE_FLOAT16
31 template <>
32 struct Converter<half> {
33  inline static __device__ half to(float v) { return __float2half(v); }
34 };
35 #endif
36 
37 template <>
38 struct Converter<float> {
39  inline static __device__ float to(float v) { return v; }
40 };
41 
42 // Kernel responsible for calculating distance from residual vector to
43 // each product quantizer code centroid
44 template <typename OutCodeT, int DimsPerSubQuantizer>
45 __global__ void
46 __launch_bounds__(288, 4)
47 pqCodeDistances(Tensor<float, 2, true> queries,
48  int queriesPerBlock,
49  Tensor<float, 2, true> coarseCentroids,
50  Tensor<float, 3, true> pqCentroids,
51  Tensor<int, 2, true> topQueryToCentroid,
52  // (query id)(coarse)(subquantizer)(code) -> dist
53  Tensor<OutCodeT, 4, true> outCodeDistances) {
54  const auto numSubQuantizers = pqCentroids.getSize(0);
55  const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
56  assert(DimsPerSubQuantizer == dimsPerSubQuantizer);
57  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
58 
59  bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer;
60  int loadingThreadId = threadIdx.x - codesPerSubQuantizer;
61 
62  extern __shared__ float smem[];
63 
64  // Each thread calculates a single code
65  float subQuantizerData[DimsPerSubQuantizer];
66 
67  auto code = threadIdx.x;
68  auto subQuantizer = blockIdx.y;
69 
70  // Each thread will load the pq centroid data for the code that it
71  // is processing
72 #pragma unroll
73  for (int i = 0; i < DimsPerSubQuantizer; ++i) {
74  subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
75  }
76 
77  // Where we store our query vector
78  float* smemQuery = smem;
79 
80  // Where we store our residual vector; this is double buffered so we
81  // can be loading the next one while processing the current one
82  float* smemResidual1 = &smemQuery[DimsPerSubQuantizer];
83  float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer];
84 
85  // Where we pre-load the coarse centroid IDs
86  int* coarseIds = (int*) &smemResidual2[DimsPerSubQuantizer];
87 
88  // Each thread is calculating the distance for a single code,
89  // performing the reductions locally
90 
91  // Handle multiple queries per block
92  auto startQueryId = blockIdx.x * queriesPerBlock;
93  auto numQueries = queries.getSize(0) - startQueryId;
94  if (numQueries > queriesPerBlock) {
95  numQueries = queriesPerBlock;
96  }
97 
98  for (int query = 0; query < numQueries; ++query) {
99  auto queryId = startQueryId + query;
100 
101  auto querySubQuantizer =
102  queries[queryId][subQuantizer * DimsPerSubQuantizer].data();
103 
104  // Load current query vector
105  for (int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) {
106  smemQuery[i] = querySubQuantizer[i];
107  }
108 
109  // Load list of coarse centroids found
110  for (int i = threadIdx.x;
111  i < topQueryToCentroid.getSize(1); i += blockDim.x) {
112  coarseIds[i] = topQueryToCentroid[queryId][i];
113  }
114 
115  // We need coarseIds below
116  // FIXME: investigate loading separately, so we don't need this
117  __syncthreads();
118 
119  // Preload first buffer of residual data
120  if (isLoadingThread) {
121  for (int i = loadingThreadId;
122  i < DimsPerSubQuantizer;
123  i += blockDim.x - codesPerSubQuantizer) {
124  auto coarseId = coarseIds[0];
125  // In case NaNs were in the original query data
126  coarseId = coarseId == -1 ? 0 : coarseId;
127  auto coarseCentroidSubQuantizer =
128  coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
129 
130  smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
131  }
132  }
133 
134  // The block walks the list for a single query
135  for (int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
136  // Wait for smemResidual1 to be loaded
137  __syncthreads();
138 
139  if (isLoadingThread) {
140  // Preload second buffer of residual data
141  for (int i = loadingThreadId;
142  i < DimsPerSubQuantizer;
143  i += blockDim.x - codesPerSubQuantizer) {
144  // FIXME: try always making this centroid id 0 so we can
145  // terminate
146  if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
147  auto coarseId = coarseIds[coarse + 1];
148  // In case NaNs were in the original query data
149  coarseId = coarseId == -1 ? 0 : coarseId;
150 
151  auto coarseCentroidSubQuantizer =
152  coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
153 
154  smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
155  }
156  }
157  } else {
158  // These are the processing threads
159  float dist = 0.0f;
160 
161  constexpr int kUnroll = 4;
162  constexpr int kRemainder = DimsPerSubQuantizer % kUnroll;
163  constexpr int kRemainderBase = DimsPerSubQuantizer - kRemainder;
164  float vals[kUnroll];
165 
166  // Calculate residual - pqCentroid for each dim that we're
167  // processing
168 
169  // Unrolled loop
170 #pragma unroll
171  for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
172 
173 #pragma unroll
174  for (int j = 0; j < kUnroll; ++j) {
175  vals[j] = smemResidual1[i * kUnroll + j];
176  }
177 
178 #pragma unroll
179  for (int j = 0; j < kUnroll; ++j) {
180  vals[j] -= subQuantizerData[i * kUnroll + j];
181  }
182 
183 #pragma unroll
184  for (int j = 0; j < kUnroll; ++j) {
185  vals[j] *= vals[j];
186  }
187 
188 #pragma unroll
189  for (int j = 0; j < kUnroll; ++j) {
190  dist += vals[j];
191  }
192  }
193 
194  // Remainder loop
195 #pragma unroll
196  for (int j = 0; j < kRemainder; ++j) {
197  vals[j] = smemResidual1[kRemainderBase + j];
198  }
199 
200 #pragma unroll
201  for (int j = 0; j < kRemainder; ++j) {
202  vals[j] -= subQuantizerData[kRemainderBase + j];
203  }
204 
205 #pragma unroll
206  for (int j = 0; j < kRemainder; ++j) {
207  vals[j] *= vals[j];
208  }
209 
210 #pragma unroll
211  for (int j = 0; j < kRemainder; ++j) {
212  dist += vals[j];
213  }
214 
215  // We have the distance for our code; write it out
216  outCodeDistances[queryId][coarse][subQuantizer][code] =
217  Converter<OutCodeT>::to(dist);
218  } // !isLoadingThread
219 
220  // Swap residual buffers
221  float* tmp = smemResidual1;
222  smemResidual1 = smemResidual2;
223  smemResidual2 = tmp;
224  }
225  }
226 }
227 
228 __global__ void
229 residualVector(Tensor<float, 2, true> queries,
230  Tensor<float, 2, true> coarseCentroids,
231  Tensor<int, 2, true> topQueryToCentroid,
232  int numSubDim,
233  // output is transposed:
234  // (sub q)(query id)(centroid id)(sub dim)
235  Tensor<float, 4, true> residual) {
236  // block x is query id
237  // block y is centroid id
238  // thread x is dim
239  auto queryId = blockIdx.x;
240  auto centroidId = blockIdx.y;
241 
242  int realCentroidId = topQueryToCentroid[queryId][centroidId];
243 
244  for (int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
245  float q = queries[queryId][dim];
246  float c = coarseCentroids[realCentroidId][dim];
247 
248  residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] =
249  q - c;
250  }
251 }
252 
253 void
254 runResidualVector(Tensor<float, 3, true>& pqCentroids,
255  Tensor<float, 2, true>& queries,
256  Tensor<float, 2, true>& coarseCentroids,
257  Tensor<int, 2, true>& topQueryToCentroid,
258  Tensor<float, 4, true>& residual,
259  cudaStream_t stream) {
260  auto grid =
261  dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
262  auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
263 
264  residualVector<<<grid, block, 0, stream>>>(
265  queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
266  residual);
267 
268  CUDA_TEST_ERROR();
269 }
270 
271 void
272 runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
273  Tensor<float, 2, true>& queries,
274  Tensor<float, 2, true>& coarseCentroids,
275  Tensor<int, 2, true>& topQueryToCentroid,
276  NoTypeTensor<4, true>& outCodeDistances,
277  bool useFloat16Lookup,
278  DeviceMemory& mem,
279  cublasHandle_t handle,
280  cudaStream_t stream) {
281  // Calculate (q - c) residual vector
282  // (sub q)(query id)(centroid id)(sub dim)
283  DeviceTensor<float, 4, true> residual(
284  mem,
285  {pqCentroids.getSize(0),
286  topQueryToCentroid.getSize(0),
287  topQueryToCentroid.getSize(1),
288  pqCentroids.getSize(1)},
289  stream);
290 
291  runResidualVector(pqCentroids, queries,
292  coarseCentroids, topQueryToCentroid,
293  residual, stream);
294 
295  // Calculate ||q - c||^2
296  DeviceTensor<float, 1, true> residualNorms(
297  mem,
298  {pqCentroids.getSize(0) *
299  topQueryToCentroid.getSize(0) *
300  topQueryToCentroid.getSize(1)},
301  stream);
302 
303  auto residualView2 = residual.view<2>(
304  {pqCentroids.getSize(0) *
305  topQueryToCentroid.getSize(0) *
306  topQueryToCentroid.getSize(1),
307  pqCentroids.getSize(1)});
308 
309  runL2Norm(residualView2, residualNorms, true, stream);
310 
311  // Perform a batch MM:
312  // (sub q) x {(q * c)(sub dim) x (sub dim)(code)} =>
313  // (sub q) x {(q * c)(code)}
314  auto residualView3 = residual.view<3>(
315  {pqCentroids.getSize(0),
316  topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
317  pqCentroids.getSize(1)});
318 
319  DeviceTensor<float, 3, true> residualDistance(
320  mem,
321  {pqCentroids.getSize(0),
322  topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
323  pqCentroids.getSize(2)},
324  stream);
325 
326  runIteratedMatrixMult(residualDistance, false,
327  residualView3, false,
328  pqCentroids, false,
329  -2.0f, 0.0f,
330  handle,
331  stream);
332 
333  // Sum ||q - c||^2 along rows
334  auto residualDistanceView2 = residualDistance.view<2>(
335  {pqCentroids.getSize(0) *
336  topQueryToCentroid.getSize(0) *
337  topQueryToCentroid.getSize(1),
338  pqCentroids.getSize(2)});
339 
340  runSumAlongRows(residualNorms, residualDistanceView2, stream);
341 
342  Tensor<float, 4, true> outCodeDistancesF;
343  DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
344 
345 #ifdef FAISS_USE_FLOAT16
346  if (useFloat16Lookup) {
347  outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
348  mem, {outCodeDistances.getSize(0),
349  outCodeDistances.getSize(1),
350  outCodeDistances.getSize(2),
351  outCodeDistances.getSize(3)},
352  stream);
353 
354  outCodeDistancesF = outCodeDistancesFloatMem;
355  }
356 #endif
357 
358  if (!useFloat16Lookup) {
359  outCodeDistancesF = outCodeDistances.toTensor<float>();
360  }
361 
362  // Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which
363  // is where we build our output distances)
364  auto outCodeDistancesView = outCodeDistancesF.view<3>(
365  {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
366  outCodeDistances.getSize(2),
367  outCodeDistances.getSize(3)});
368 
369  runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
370 
371  // Calculate code norms per each sub-dim
372  // (sub q)(sub dim)(code) is pqCentroids
373  // transpose to (sub q)(code)(sub dim)
374  DeviceTensor<float, 3, true> pqCentroidsTranspose(
375  mem,
376  {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)},
377  stream);
378 
379  runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
380 
381  auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
382  {pqCentroids.getSize(0) * pqCentroids.getSize(2),
383  pqCentroids.getSize(1)});
384 
385  DeviceTensor<float, 1, true> pqCentroidsNorm(
386  mem,
387  {pqCentroids.getSize(0) * pqCentroids.getSize(2)},
388  stream);
389 
390  runL2Norm(pqCentroidsTransposeView, pqCentroidsNorm, true, stream);
391 
392  // View output as (q * c)(sub q * code), and add centroid norm to
393  // each row
394  auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
395  {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
396  outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
397 
398  runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
399 
400 #ifdef FAISS_USE_FLOAT16
401  if (useFloat16Lookup) {
402  // Need to convert back
403  auto outCodeDistancesH = outCodeDistances.toTensor<half>();
404  toHalf(stream, outCodeDistancesF, outCodeDistancesH);
405  }
406 #endif
407 }
408 
409 void
410 runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
411  Tensor<float, 2, true>& queries,
412  Tensor<float, 2, true>& coarseCentroids,
413  Tensor<int, 2, true>& topQueryToCentroid,
414  NoTypeTensor<4, true>& outCodeDistances,
415  bool useFloat16Lookup,
416  cudaStream_t stream) {
417  const auto numSubQuantizers = pqCentroids.getSize(0);
418  const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
419  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
420 
421  // FIXME: tune
422  // Reuse of pq centroid data is based on both # of queries * nprobe,
423  // and we should really be tiling in both dimensions
424  constexpr int kQueriesPerBlock = 8;
425 
426  auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock),
427  numSubQuantizers);
428 
429  // Reserve one block of threads for double buffering
430  // FIXME: probably impractical for large # of dims?
431  auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize);
432  auto block = dim3(codesPerSubQuantizer + loadingThreads);
433 
434  auto smem = (3 * dimsPerSubQuantizer) * sizeof(float)
435  + topQueryToCentroid.getSize(1) * sizeof(int);
436 
437 #ifdef FAISS_USE_FLOAT16
438 #define CODE_DISTANCE(DIMS) \
439  do { \
440  if (useFloat16Lookup) { \
441  auto outCodeDistancesT = outCodeDistances.toTensor<half>(); \
442  \
443  pqCodeDistances<half, DIMS><<<grid, block, smem, stream>>>( \
444  queries, kQueriesPerBlock, \
445  coarseCentroids, pqCentroids, \
446  topQueryToCentroid, outCodeDistancesT); \
447  } else { \
448  auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
449  \
450  pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
451  queries, kQueriesPerBlock, \
452  coarseCentroids, pqCentroids, \
453  topQueryToCentroid, outCodeDistancesT); \
454  } \
455  } while (0)
456 #else
457 #define CODE_DISTANCE(DIMS) \
458  do { \
459  if (!useFloat16Lookup) { \
460  auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
461  \
462  pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
463  queries, kQueriesPerBlock, \
464  coarseCentroids, pqCentroids, \
465  topQueryToCentroid, outCodeDistancesT); \
466  } \
467  } while (0)
468 #endif
469 
470  switch (dimsPerSubQuantizer) {
471  case 1:
472  CODE_DISTANCE(1);
473  break;
474  case 2:
475  CODE_DISTANCE(2);
476  break;
477  case 3:
478  CODE_DISTANCE(3);
479  break;
480  case 4:
481  CODE_DISTANCE(4);
482  break;
483  case 6:
484  CODE_DISTANCE(6);
485  break;
486  case 8:
487  CODE_DISTANCE(8);
488  break;
489  case 10:
490  CODE_DISTANCE(10);
491  break;
492  case 12:
493  CODE_DISTANCE(12);
494  break;
495  case 16:
496  CODE_DISTANCE(16);
497  break;
498  case 20:
499  CODE_DISTANCE(20);
500  break;
501  case 24:
502  CODE_DISTANCE(24);
503  break;
504  case 28:
505  CODE_DISTANCE(28);
506  break;
507  case 32:
508  CODE_DISTANCE(32);
509  break;
510  // FIXME: larger sizes require too many registers - we need the
511  // MM implementation working
512  default:
513  FAISS_ASSERT(false);
514  break;
515  }
516 
517 #undef CODE_DISTANCE
518 
519  CUDA_TEST_ERROR();
520 }
521 
522 } } // namespace
Our tensor type.
Definition: Tensor.cuh:30