12 #include "PQCodeDistances.cuh"
14 #include "BroadcastSum.cuh"
15 #include "Distance.cuh"
17 #include "../utils/DeviceDefs.cuh"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Float16.cuh"
20 #include "../utils/MatrixMult.cuh"
21 #include "../utils/PtxUtils.cuh"
22 #include "../utils/StaticUtils.h"
23 #include "../utils/Transpose.cuh"
25 namespace faiss {
namespace gpu {
31 #ifdef FAISS_USE_FLOAT16
34 inline static __device__ half to(
float v) {
return __float2half(v); }
40 inline static __device__
float to(
float v) {
return v; }
45 template <
typename OutCodeT,
int DimsPerSubQuantizer>
47 __launch_bounds__(288, 4)
48 pqCodeDistances(
Tensor<
float, 2, true> queries,
50 Tensor<
float, 2, true> coarseCentroids,
51 Tensor<
float, 3, true> pqCentroids,
52 Tensor<
int, 2, true> topQueryToCentroid,
54 Tensor<OutCodeT, 4, true> outCodeDistances) {
55 const auto numSubQuantizers = pqCentroids.getSize(0);
56 const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
57 assert(DimsPerSubQuantizer == dimsPerSubQuantizer);
58 const auto codesPerSubQuantizer = pqCentroids.getSize(2);
60 bool isLoadingThread = threadIdx.x >= codesPerSubQuantizer;
61 int loadingThreadId = threadIdx.x - codesPerSubQuantizer;
63 extern __shared__
float smem[];
66 float subQuantizerData[DimsPerSubQuantizer];
68 auto code = threadIdx.x;
69 auto subQuantizer = blockIdx.y;
74 for (
int i = 0; i < DimsPerSubQuantizer; ++i) {
75 subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
79 float* smemQuery = smem;
83 float* smemResidual1 = &smemQuery[DimsPerSubQuantizer];
84 float* smemResidual2 = &smemResidual1[DimsPerSubQuantizer];
87 int* coarseIds = (
int*) &smemResidual2[DimsPerSubQuantizer];
93 auto startQueryId = blockIdx.x * queriesPerBlock;
94 auto numQueries = queries.getSize(0) - startQueryId;
95 if (numQueries > queriesPerBlock) {
96 numQueries = queriesPerBlock;
99 for (
int query = 0; query < numQueries; ++query) {
100 auto queryId = startQueryId + query;
102 auto querySubQuantizer =
103 queries[queryId][subQuantizer * DimsPerSubQuantizer].data();
106 for (
int i = threadIdx.x; i < DimsPerSubQuantizer; i += blockDim.x) {
107 smemQuery[i] = querySubQuantizer[i];
111 for (
int i = threadIdx.x;
112 i < topQueryToCentroid.getSize(1); i += blockDim.x) {
113 coarseIds[i] = topQueryToCentroid[queryId][i];
121 if (isLoadingThread) {
122 for (
int i = loadingThreadId;
123 i < DimsPerSubQuantizer;
124 i += blockDim.x - codesPerSubQuantizer) {
125 auto coarseId = coarseIds[0];
126 auto coarseCentroidSubQuantizer =
127 coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
129 smemResidual1[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
134 for (
int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
138 if (isLoadingThread) {
140 for (
int i = loadingThreadId;
141 i < DimsPerSubQuantizer;
142 i += blockDim.x - codesPerSubQuantizer) {
145 if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
146 auto coarseId = coarseIds[coarse + 1];
147 auto coarseCentroidSubQuantizer =
148 coarseCentroids[coarseId][subQuantizer * dimsPerSubQuantizer].data();
150 smemResidual2[i] = smemQuery[i] - coarseCentroidSubQuantizer[i];
157 constexpr
int kUnroll = 4;
158 constexpr
int kRemainder = DimsPerSubQuantizer % kUnroll;
159 constexpr
int kRemainderBase = DimsPerSubQuantizer - kRemainder;
167 for (
int i = 0; i < DimsPerSubQuantizer / kUnroll; ++i) {
170 for (
int j = 0; j < kUnroll; ++j) {
171 vals[j] = smemResidual1[i * kUnroll + j];
175 for (
int j = 0; j < kUnroll; ++j) {
176 vals[j] -= subQuantizerData[i * kUnroll + j];
180 for (
int j = 0; j < kUnroll; ++j) {
185 for (
int j = 0; j < kUnroll; ++j) {
192 for (
int j = 0; j < kRemainder; ++j) {
193 vals[j] = smemResidual1[kRemainderBase + j];
197 for (
int j = 0; j < kRemainder; ++j) {
198 vals[j] -= subQuantizerData[kRemainderBase + j];
202 for (
int j = 0; j < kRemainder; ++j) {
207 for (
int j = 0; j < kRemainder; ++j) {
212 outCodeDistances[queryId][coarse][subQuantizer][code] =
213 Converter<OutCodeT>::to(dist);
217 float* tmp = smemResidual1;
218 smemResidual1 = smemResidual2;
225 residualVector(Tensor<float, 2, true> queries,
226 Tensor<float, 2, true> coarseCentroids,
227 Tensor<int, 2, true> topQueryToCentroid,
231 Tensor<float, 4, true> residual) {
235 auto queryId = blockIdx.x;
236 auto centroidId = blockIdx.y;
238 int realCentroidId = topQueryToCentroid[queryId][centroidId];
240 for (
int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
241 float q = queries[queryId][dim];
242 float c = coarseCentroids[realCentroidId][dim];
244 residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] =
250 runResidualVector(Tensor<float, 3, true>& pqCentroids,
251 Tensor<float, 2, true>& queries,
252 Tensor<float, 2, true>& coarseCentroids,
253 Tensor<int, 2, true>& topQueryToCentroid,
254 Tensor<float, 4, true>& residual,
255 cudaStream_t stream) {
257 dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
258 auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
260 residualVector<<<grid, block, 0, stream>>>(
261 queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
263 CUDA_VERIFY(cudaGetLastError());
267 runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
268 Tensor<float, 2, true>& queries,
269 Tensor<float, 2, true>& coarseCentroids,
270 Tensor<int, 2, true>& topQueryToCentroid,
271 NoTypeTensor<4, true>& outCodeDistances,
272 bool useFloat16Lookup,
274 cublasHandle_t handle,
275 cudaStream_t stream) {
278 DeviceTensor<float, 4, true> residual(
280 {pqCentroids.getSize(0),
281 topQueryToCentroid.getSize(0),
282 topQueryToCentroid.getSize(1),
283 pqCentroids.getSize(1)},
286 runResidualVector(pqCentroids, queries,
287 coarseCentroids, topQueryToCentroid,
291 DeviceTensor<float, 1, true> residualNorms(
293 {pqCentroids.getSize(0) *
294 topQueryToCentroid.getSize(0) *
295 topQueryToCentroid.getSize(1)},
298 auto residualView2 = residual.view<2>(
299 {pqCentroids.getSize(0) *
300 topQueryToCentroid.getSize(0) *
301 topQueryToCentroid.getSize(1),
302 pqCentroids.getSize(1)});
304 runL2Norm(residualView2, residualNorms,
true, stream);
309 auto residualView3 = residual.view<3>(
310 {pqCentroids.getSize(0),
311 topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
312 pqCentroids.getSize(1)});
314 DeviceTensor<float, 3, true> residualDistance(
316 {pqCentroids.getSize(0),
317 topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
318 pqCentroids.getSize(2)},
321 runIteratedMatrixMult(residualDistance,
false,
322 residualView3,
false,
329 auto residualDistanceView2 = residualDistance.view<2>(
330 {pqCentroids.getSize(0) *
331 topQueryToCentroid.getSize(0) *
332 topQueryToCentroid.getSize(1),
333 pqCentroids.getSize(2)});
335 runSumAlongRows(residualNorms, residualDistanceView2, stream);
337 Tensor<float, 4, true> outCodeDistancesF;
338 DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
340 #ifdef FAISS_USE_FLOAT16
341 if (useFloat16Lookup) {
342 outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
343 mem, {outCodeDistances.getSize(0),
344 outCodeDistances.getSize(1),
345 outCodeDistances.getSize(2),
346 outCodeDistances.getSize(3)},
349 outCodeDistancesF = outCodeDistancesFloatMem;
353 if (!useFloat16Lookup) {
354 outCodeDistancesF = outCodeDistances.toTensor<
float>();
359 auto outCodeDistancesView = outCodeDistancesF.view<3>(
360 {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
361 outCodeDistances.getSize(2),
362 outCodeDistances.getSize(3)});
364 runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
369 DeviceTensor<float, 3, true> pqCentroidsTranspose(
371 {pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)},
374 runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
376 auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
377 {pqCentroids.getSize(0) * pqCentroids.getSize(2),
378 pqCentroids.getSize(1)});
380 DeviceTensor<float, 1, true> pqCentroidsNorm(
382 {pqCentroids.getSize(0) * pqCentroids.getSize(2)},
385 runL2Norm(pqCentroidsTransposeView, pqCentroidsNorm,
true, stream);
389 auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
390 {topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
391 outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
393 runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
395 #ifdef FAISS_USE_FLOAT16
396 if (useFloat16Lookup) {
398 auto outCodeDistancesH = outCodeDistances.toTensor<half>();
399 toHalf(stream, outCodeDistancesF, outCodeDistancesH);
405 runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
406 Tensor<float, 2, true>& queries,
407 Tensor<float, 2, true>& coarseCentroids,
408 Tensor<int, 2, true>& topQueryToCentroid,
409 NoTypeTensor<4, true>& outCodeDistances,
410 bool useFloat16Lookup,
411 cudaStream_t stream) {
412 const auto numSubQuantizers = pqCentroids.getSize(0);
413 const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
414 const auto codesPerSubQuantizer = pqCentroids.getSize(2);
419 constexpr
int kQueriesPerBlock = 8;
421 auto grid = dim3(utils::divUp(queries.getSize(0), kQueriesPerBlock),
426 auto loadingThreads = utils::roundUp(dimsPerSubQuantizer, kWarpSize);
427 auto block = dim3(codesPerSubQuantizer + loadingThreads);
429 auto smem = (3 * dimsPerSubQuantizer) *
sizeof(
float)
430 + topQueryToCentroid.getSize(1) *
sizeof(int);
432 #ifdef FAISS_USE_FLOAT16
433 #define CODE_DISTANCE(DIMS) \
435 if (useFloat16Lookup) { \
436 auto outCodeDistancesT = outCodeDistances.toTensor<half>(); \
438 pqCodeDistances<half, DIMS><<<grid, block, smem, stream>>>( \
439 queries, kQueriesPerBlock, \
440 coarseCentroids, pqCentroids, \
441 topQueryToCentroid, outCodeDistancesT); \
443 auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
445 pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
446 queries, kQueriesPerBlock, \
447 coarseCentroids, pqCentroids, \
448 topQueryToCentroid, outCodeDistancesT); \
452 #define CODE_DISTANCE(DIMS) \
454 if (!useFloat16Lookup) { \
455 auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
457 pqCodeDistances<float, DIMS><<<grid, block, smem, stream>>>( \
458 queries, kQueriesPerBlock, \
459 coarseCentroids, pqCentroids, \
460 topQueryToCentroid, outCodeDistancesT); \
465 switch (dimsPerSubQuantizer) {
505 CUDA_VERIFY(cudaGetLastError());