Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PQCodeDistances.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "PQCodeDistances.cuh"
13 
14 #include "BroadcastSum.cuh"
15 #include "Distance.cuh"
16 #include "L2Norm.cuh"
17 #include "../utils/DeviceDefs.cuh"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Float16.cuh"
20 #include "../utils/MatrixMult.cuh"
21 #include "../utils/PtxUtils.cuh"
22 #include "../utils/StaticUtils.h"
23 #include "../utils/Transpose.cuh"
24 
25 namespace faiss { namespace gpu {
26 
27 template <typename T>
28 struct Converter {
29 };
30 
31 #ifdef FAISS_USE_FLOAT16
32 template <>
33 struct Converter<half> {
34  inline static __device__ half to(float v) { return __float2half(v); }
35 };
36 #endif
37 
38 template <>
39 struct Converter<float> {
40  inline static __device__ float to(float v) { return v; }
41 };
42 
43 // Kernel responsible for calculating distance from residual vector to
44 // each product quantizer code centroid
45 template <typename OutCodeT, int DimsPerSubQuantizer>
46 __global__ void
47 __launch_bounds__(288, 4)
48 pqCodeDistances(Tensor<float, 2, true> queries,
49  int queriesPerBlock,
50  Tensor<float, 2, true> coarseCentroids,
51  Tensor<float, 3, true> pqCentroids,
52  Tensor<int, 2, true> topQueryToCentroid,
53  // (query id)(coarse)(subquantizer)(code) -> dist
54  Tensor<OutCodeT, 4, true> outCodeDistances) {
55  const auto numSubQuantizers = pqCentroids.getSize(0);
56  const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
57  assert(DimsPerSubQuantizer == dimsPerSubQuantizer);
58  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
59 
60  bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer;
61  int loadingThreadId = threadIdx.x - codesPerSubQuantizer;
62 
63  extern __shared__ float smem[];
64 
65  // Each thread calculates a single code
66  float subQuantizerData[DimsPerSubQuantizer];
67 
68  auto code = threadIdx.x;
69  auto subQuantizer = blockIdx.y;
70 
71  // Each thread will load the pq centroid data for the code that it
72  // is processing
73 #pragma unroll
74  for (int i = 0; i < DimsPerSubQuantizer; ++i) {
75  subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
76  }
77 
78  // Where we store our query vector
79  float* smemQuery = smem;
80 
81  // Where we store our residual vector; this is double buffered so we
82  // can be loading the next one while processing the current one
83  float* smemResidual1 = &smemQuery[DimsPerSubQuantizer];
84  float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer];
85 
86  // Where we pre-load the coarse centroid IDs
87  int* coarseIds = (int*) &smemResidual2[DimsPerSubQuantizer];
88 
89  // Each thread is calculating the distance for a single code,
90  // performing the reductions locally
91 
92  // Handle multiple queries per block
93  auto startQueryId = blockIdx.x * queriesPerBlock;
94  auto numQueries = queries.getSize(0) - startQueryId;
95  if (numQueries > queriesPerBlock) {
96  numQueries = queriesPerBlock;
97  }
98 
99  for (int query = 0; query < numQueries; ++query) {
100  auto queryId = startQueryId + query;
101 
102  auto querySubQuantizer =
103  queries[queryId][subQuantizer * DimsPerSubQuantizer].data();
104 
105  // Load current query vector
106  for (int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) {
107  smemQuery[i] = querySubQuantizer[i];
108  }
109 
110  // Load list of coarse centroids found
111  for (int i = threadIdx.x;
112  i < topQueryToCentroid.getSize(1); i += blockDim.x) {
113  coarseIds[i] = topQueryToCentroid[queryId][i];
114  }
115 
116  // We need coarseIds below
117  // FIXME: investigate loading separately, so we don't need this
118  __syncthreads();
119 
120  // Preload first buffer of residual data
121  if (isLoadingThread) {
122  for (int i = loadingThreadId;
123  i < DimsPerSubQuantizer;
124  i += blockDim.x - codesPerSubQuantizer) {
125  auto coarseId = coarseIds[0];
126  auto coarseCentroidSubQuantizer =
127  coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
128 
129  smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
130  }
131  }
132 
133  // The block walks the list for a single query
134  for (int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
135  // Wait for smemResidual1 to be loaded
136  __syncthreads();
137 
138  if (isLoadingThread) {
139  // Preload second buffer of residual data
140  for (int i = loadingThreadId;
141  i < DimsPerSubQuantizer;
142  i += blockDim.x - codesPerSubQuantizer) {
143  // FIXME: try always making this centroid id 0 so we can
144  // terminate
145  if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
146  auto coarseId = coarseIds[coarse + 1];
147  auto coarseCentroidSubQuantizer =
148  coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
149 
150  smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
151  }
152  }
153  } else {
154  // These are the processing threads
155  float dist = 0.0f;
156 
157  constexpr int kUnroll = 4;
158  constexpr int kRemainder = DimsPerSubQuantizer % kUnroll;
159  constexpr int kRemainderBase = DimsPerSubQuantizer - kRemainder;
160  float vals[kUnroll];
161 
162  // Calculate residual - pqCentroid for each dim that we're
163  // processing
164 
165  // Unrolled loop
166 #pragma unroll
167  for (int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
168 
169 #pragma unroll
170  for (int j = 0; j < kUnroll; ++j) {
171  vals[j] = smemResidual1[i * kUnroll + j];
172  }
173 
174 #pragma unroll
175  for (int j = 0; j < kUnroll; ++j) {
176  vals[j] -= subQuantizerData[i * kUnroll + j];
177  }
178 
179 #pragma unroll
180  for (int j = 0; j < kUnroll; ++j) {
181  vals[j] *= vals[j];
182  }
183 
184 #pragma unroll
185  for (int j = 0; j < kUnroll; ++j) {
186  dist += vals[j];
187  }
188  }
189 
190  // Remainder loop
191 #pragma unroll
192  for (int j = 0; j < kRemainder; ++j) {
193  vals[j] = smemResidual1[kRemainderBase + j];
194  }
195 
196 #pragma unroll
197  for (int j = 0; j < kRemainder; ++j) {
198  vals[j] -= subQuantizerData[kRemainderBase + j];
199  }
200 
201 #pragma unroll
202  for (int j = 0; j < kRemainder; ++j) {
203  vals[j] *= vals[j];
204  }
205 
206 #pragma unroll
207  for (int j = 0; j < kRemainder; ++j) {
208  dist += vals[j];
209  }
210 
211  // We have the distance for our code; write it out
212  outCodeDistances[queryId][coarse][subQuantizer][code] =
213  Converter<OutCodeT>::to(dist);
214  } // !isLoadingThread
215 
216  // Swap residual buffers
217  float* tmp = smemResidual1;
218  smemResidual1 = smemResidual2;
219  smemResidual2 = tmp;
220  }
221  }
222 }
223 
224 __global__ void
225 residualVector(Tensor<float, 2, true> queries,
226  Tensor<float, 2, true> coarseCentroids,
227  Tensor<int, 2, true> topQueryToCentroid,
228  int numSubDim,
229  // output is transposed:
230  // (sub q)(query id)(centroid id)(sub dim)
231  Tensor<float, 4, true> residual) {
232  // block x is query id
233  // block y is centroid id
234  // thread x is dim
235  auto queryId = blockIdx.x;
236  auto centroidId = blockIdx.y;
237 
238  int realCentroidId = topQueryToCentroid[queryId][centroidId];
239 
240  for (int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
241  float q = queries[queryId][dim];
242  float c = coarseCentroids[realCentroidId][dim];
243 
244  residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] =
245  q - c;
246  }
247 }
248 
249 void
250 runResidualVector(Tensor<float, 3, true>& pqCentroids,
251  Tensor<float, 2, true>& queries,
252  Tensor<float, 2, true>& coarseCentroids,
253  Tensor<int, 2, true>& topQueryToCentroid,
254  Tensor<float, 4, true>& residual,
255  cudaStream_t stream) {
256  auto grid =
257  dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
258  auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
259 
260  residualVector<<<grid, block, 0, stream>>>(
261  queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
262  residual);
263  CUDA_VERIFY(cudaGetLastError());
264 }
265 
266 void
267 runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
268  Tensor<float, 2, true>& queries,
269  Tensor<float, 2, true>& coarseCentroids,
270  Tensor<int, 2, true>& topQueryToCentroid,
271  NoTypeTensor<4, true>& outCodeDistances,
272  bool useFloat16Lookup,
273  DeviceMemory& mem,
274  cublasHandle_t handle,
275  cudaStream_t stream) {
276  // Calculate (q - c) residual vector
277  // (sub q)(query id)(centroid id)(sub dim)
278  DeviceTensor<float, 4, true> residual(
279  mem,
280  {pqCentroids.getSize(0),
281  topQueryToCentroid.getSize(0),
282  topQueryToCentroid.getSize(1),
283  pqCentroids.getSize(1)},
284  stream);
285 
286  runResidualVector(pqCentroids, queries,
287  coarseCentroids, topQueryToCentroid,
288  residual, stream);
289 
290  // Calculate ||q - c||^2
291  DeviceTensor<float, 1, true> residualNorms(
292  mem,
293  {pqCentroids.getSize(0) *
294  topQueryToCentroid.getSize(0) *
295  topQueryToCentroid.getSize(1)},
296  stream);
297 
298  auto residualView2 = residual.view<2>(
299  {pqCentroids.getSize(0) *
300  topQueryToCentroid.getSize(0) *
301  topQueryToCentroid.getSize(1),
302  pqCentroids.getSize(1)});
303 
304  runL2Norm(residualView2, residualNorms, true, stream);
305 
306  // Perform a batch MM:
307  // (sub q) x {(q * c)(sub dim) x (sub dim)(code)} =>
308  // (sub q) x {(q * c)(code)}
309  auto residualView3 = residual.view<3>(
310  {pqCentroids.getSize(0),
311  topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
312  pqCentroids.getSize(1)});
313 
314  DeviceTensor<float, 3, true> residualDistance(
315  mem,
316  {pqCentroids.getSize(0),
317  topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
318  pqCentroids.getSize(2)},
319  stream);
320 
321  runIteratedMatrixMult(residualDistance, false,
322  residualView3, false,
323  pqCentroids, false,
324  -2.0f, 0.0f,
325  handle,
326  stream);
327 
328  // Sum ||q - c||^2 along rows
329  auto residualDistanceView2 = residualDistance.view<2>(
330  {pqCentroids.getSize(0) *
331  topQueryToCentroid.getSize(0) *
332  topQueryToCentroid.getSize(1),
333  pqCentroids.getSize(2)});
334 
335  runSumAlongRows(residualNorms, residualDistanceView2, stream);
336 
337  Tensor<float, 4, true> outCodeDistancesF;
338  DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
339 
340 #ifdef FAISS_USE_FLOAT16
341  if (useFloat16Lookup) {
342  outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
343  mem, {outCodeDistances.getSize(0),
344  outCodeDistances.getSize(1),
345  outCodeDistances.getSize(2),
346  outCodeDistances.getSize(3)},
347  stream);
348 
349  outCodeDistancesF = outCodeDistancesFloatMem;
350  }
351 #endif
352 
353  if (!useFloat16Lookup) {
354  outCodeDistancesF = outCodeDistances.toTensor<float>();
355  }
356 
357  // Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which
358  // is where we build our output distances)
359  auto outCodeDistancesView = outCodeDistancesF.view<3>(
360  {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
361  outCodeDistances.getSize(2),
362  outCodeDistances.getSize(3)});
363 
364  runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
365 
366  // Calculate code norms per each sub-dim
367  // (sub q)(sub dim)(code) is pqCentroids
368  // transpose to (sub q)(code)(sub dim)
369  DeviceTensor<float, 3, true> pqCentroidsTranspose(
370  mem,
371  {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)},
372  stream);
373 
374  runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
375 
376  auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
377  {pqCentroids.getSize(0) * pqCentroids.getSize(2),
378  pqCentroids.getSize(1)});
379 
380  DeviceTensor<float, 1, true> pqCentroidsNorm(
381  mem,
382  {pqCentroids.getSize(0) * pqCentroids.getSize(2)},
383  stream);
384 
385  runL2Norm(pqCentroidsTransposeView, pqCentroidsNorm, true, stream);
386 
387  // View output as (q * c)(sub q * code), and add centroid norm to
388  // each row
389  auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
390  {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
391  outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
392 
393  runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
394 
395 #ifdef FAISS_USE_FLOAT16
396  if (useFloat16Lookup) {
397  // Need to convert back
398  auto outCodeDistancesH = outCodeDistances.toTensor<half>();
399  toHalf(stream, outCodeDistancesF, outCodeDistancesH);
400  }
401 #endif
402 }
403 
404 void
405 runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
406  Tensor<float, 2, true>& queries,
407  Tensor<float, 2, true>& coarseCentroids,
408  Tensor<int, 2, true>& topQueryToCentroid,
409  NoTypeTensor<4, true>& outCodeDistances,
410  bool useFloat16Lookup,
411  cudaStream_t stream) {
412  const auto numSubQuantizers = pqCentroids.getSize(0);
413  const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
414  const auto codesPerSubQuantizer = pqCentroids.getSize(2);
415 
416  // FIXME: tune
417  // Reuse of pq centroid data is based on both # of queries * nprobe,
418  // and we should really be tiling in both dimensions
419  constexpr int kQueriesPerBlock = 8;
420 
421  auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock),
422  numSubQuantizers);
423 
424  // Reserve one block of threads for double buffering
425  // FIXME: probably impractical for large # of dims?
426  auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize);
427  auto block = dim3(codesPerSubQuantizer + loadingThreads);
428 
429  auto smem = (3 * dimsPerSubQuantizer) * sizeof(float)
430  + topQueryToCentroid.getSize(1) * sizeof(int);
431 
432 #ifdef FAISS_USE_FLOAT16
433 #define CODE_DISTANCE(DIMS) \
434  do { \
435  if (useFloat16Lookup) { \
436  auto outCodeDistancesT = outCodeDistances.toTensor<half>(); \
437  \
438  pqCodeDistances<half, DIMS><<<grid, block, smem, stream>>>( \
439  queries, kQueriesPerBlock, \
440  coarseCentroids, pqCentroids, \
441  topQueryToCentroid, outCodeDistancesT); \
442  } else { \
443  auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
444  \
445  pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
446  queries, kQueriesPerBlock, \
447  coarseCentroids, pqCentroids, \
448  topQueryToCentroid, outCodeDistancesT); \
449  } \
450  } while (0)
451 #else
452 #define CODE_DISTANCE(DIMS) \
453  do { \
454  if (!useFloat16Lookup) { \
455  auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
456  \
457  pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
458  queries, kQueriesPerBlock, \
459  coarseCentroids, pqCentroids, \
460  topQueryToCentroid, outCodeDistancesT); \
461  } \
462  } while (0)
463 #endif
464 
465  switch (dimsPerSubQuantizer) {
466  case 1:
467  CODE_DISTANCE(1);
468  break;
469  case 2:
470  CODE_DISTANCE(2);
471  break;
472  case 3:
473  CODE_DISTANCE(3);
474  break;
475  case 4:
476  CODE_DISTANCE(4);
477  break;
478  case 6:
479  CODE_DISTANCE(6);
480  break;
481  case 8:
482  CODE_DISTANCE(8);
483  break;
484  case 10:
485  CODE_DISTANCE(10);
486  break;
487  case 12:
488  CODE_DISTANCE(12);
489  break;
490  case 16:
491  CODE_DISTANCE(16);
492  break;
493  case 32:
494  CODE_DISTANCE(32);
495  break;
496  // FIXME: larger sizes require too many registers - we need the
497  // MM implementation working
498  default:
499  FAISS_ASSERT(false);
500  break;
501  }
502 
503 #undef CODE_DISTANCE
504 
505  CUDA_VERIFY(cudaGetLastError());
506 }
507 
508 } } // namespace
Our tensor type.
Definition: Tensor.cuh:31