9 #include "Distance.cuh"
10 #include "BroadcastSum.cuh"
12 #include "L2Select.cuh"
13 #include "../../FaissAssert.h"
14 #include "../../AuxIndexStructures.h"
15 #include "../GpuResources.h"
16 #include "../utils/DeviceDefs.cuh"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/Limits.cuh"
19 #include "../utils/MatrixMult.cuh"
20 #include "../utils/BlockSelectKernel.cuh"
24 #include <thrust/fill.h>
25 #include <thrust/for_each.h>
26 #include <thrust/device_ptr.h>
27 #include <thrust/execution_policy.h>
29 namespace faiss {
namespace gpu {
34 Tensor<T, 2, true> sliceCentroids(Tensor<T, 2, true>& centroids,
35 bool centroidsRowMajor,
40 if (startCentroid == 0 &&
41 num == centroids.getSize(centroidsRowMajor ? 0 : 1)) {
45 return centroids.narrow(centroidsRowMajor ? 0 : 1, startCentroid, num);
50 __global__
void incrementIndex(Tensor<T, 2, true> indices,
53 for (
int i = threadIdx.x; i < k; i += blockDim.x) {
54 indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment;
61 void runIncrementIndex(Tensor<T, 2, true>& indices,
64 cudaStream_t stream) {
65 dim3 grid(indices.getSize(1) / k, indices.getSize(0));
66 int block = std::min(k, 512);
69 FAISS_ASSERT(grid.x * k == indices.getSize(1));
71 incrementIndex<<<grid, block, 0, stream>>>(indices, k, increment);
73 cudaDeviceSynchronize();
79 void chooseTileSize(
int numQueries,
83 size_t tempMemAvailable,
94 auto totalMem = getCurrentDeviceProperties().totalGlobalMem;
98 if (totalMem <= ((
size_t) 4) * 1024 * 1024 * 1024) {
99 targetUsage = 512 * 1024 * 1024;
100 }
else if (totalMem <= ((
size_t) 8) * 1024 * 1024 * 1024) {
101 targetUsage = 768 * 1024 * 1024;
103 targetUsage = 1024 * 1024 * 1024;
106 targetUsage /= 2 * elementSize;
112 int preferredTileRows = 512;
114 preferredTileRows = 1024;
117 tileRows = std::min(preferredTileRows, numQueries);
120 tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
125 template <
typename T>
126 void runDistance(
bool computeL2,
127 GpuResources* resources,
128 Tensor<T, 2, true>& centroids,
129 bool centroidsRowMajor,
130 Tensor<T, 1, true>* centroidNorms,
131 Tensor<T, 2, true>& queries,
132 bool queriesRowMajor,
134 Tensor<T, 2, true>& outDistances,
135 Tensor<int, 2, true>& outIndices,
137 bool ignoreOutDistances) {
139 auto numCentroids = centroids.getSize(centroidsRowMajor ? 0 : 1);
142 auto numQueries = queries.getSize(queriesRowMajor ? 0 : 1);
145 auto dim = queries.getSize(queriesRowMajor ? 1 : 0);
146 FAISS_ASSERT((numQueries == 0 || numCentroids == 0) ||
147 dim == centroids.getSize(centroidsRowMajor ? 1 : 0));
149 FAISS_ASSERT(outDistances.getSize(0) == numQueries);
150 FAISS_ASSERT(outIndices.getSize(0) == numQueries);
151 FAISS_ASSERT(outDistances.getSize(1) == k);
152 FAISS_ASSERT(outIndices.getSize(1) == k);
154 auto& mem = resources->getMemoryManagerCurrentDevice();
155 auto defaultStream = resources->getDefaultStreamCurrentDevice();
158 if (centroids.numElements() == 0) {
159 thrust::fill(thrust::cuda::par.on(defaultStream),
160 outDistances.data(), outDistances.end(),
161 Limits<T>::getMax());
163 thrust::fill(thrust::cuda::par.on(defaultStream),
164 outIndices.data(), outIndices.end(),
171 DeviceTensor<T, 1, true> cNorms;
172 if (computeL2 && !centroidNorms) {
174 std::move(DeviceTensor<T, 1, true>(mem,
175 {numCentroids}, defaultStream));
176 runL2Norm(centroids, centroidsRowMajor, cNorms,
true, defaultStream);
177 centroidNorms = &cNorms;
183 int qNormSize[1] = {numQueries};
184 DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
188 runL2Norm(queries, queriesRowMajor, queryNorms,
true, defaultStream);
195 chooseTileSize(numQueries,
199 mem.getSizeAvailable(),
203 int numColTiles = utils::divUp(numCentroids, tileCols);
207 FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
210 DeviceTensor<T, 2, true> distanceBuf1(
211 mem, {tileRows, tileCols}, defaultStream);
212 DeviceTensor<T, 2, true> distanceBuf2(
213 mem, {tileRows, tileCols}, defaultStream);
214 DeviceTensor<T, 2, true>* distanceBufs[2] =
215 {&distanceBuf1, &distanceBuf2};
217 DeviceTensor<T, 2, true> outDistanceBuf1(
218 mem, {tileRows, numColTiles * k}, defaultStream);
219 DeviceTensor<T, 2, true> outDistanceBuf2(
220 mem, {tileRows, numColTiles * k}, defaultStream);
221 DeviceTensor<T, 2, true>* outDistanceBufs[2] =
222 {&outDistanceBuf1, &outDistanceBuf2};
224 DeviceTensor<int, 2, true> outIndexBuf1(
225 mem, {tileRows, numColTiles * k}, defaultStream);
226 DeviceTensor<int, 2, true> outIndexBuf2(
227 mem, {tileRows, numColTiles * k}, defaultStream);
228 DeviceTensor<int, 2, true>* outIndexBufs[2] =
229 {&outIndexBuf1, &outIndexBuf2};
231 auto streams = resources->getAlternateStreamsCurrentDevice();
232 streamWait(streams, {defaultStream});
235 bool interrupt =
false;
238 for (
int i = 0; i < numQueries; i += tileRows) {
244 int curQuerySize = std::min(tileRows, numQueries - i);
246 auto outDistanceView =
247 outDistances.narrow(0, i, curQuerySize);
249 outIndices.narrow(0, i, curQuerySize);
252 queries.narrow(queriesRowMajor ? 0 : 1, i, curQuerySize);
254 queryNorms.narrow(0, i, curQuerySize);
256 auto outDistanceBufRowView =
257 outDistanceBufs[curStream]->narrow(0, 0, curQuerySize);
258 auto outIndexBufRowView =
259 outIndexBufs[curStream]->narrow(0, 0, curQuerySize);
262 for (
int j = 0; j < numCentroids; j += tileCols) {
268 int curCentroidSize = std::min(tileCols, numCentroids - j);
269 int curColTile = j / tileCols;
272 sliceCentroids(centroids, centroidsRowMajor, j, curCentroidSize);
274 auto distanceBufView = distanceBufs[curStream]->
275 narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize);
277 auto outDistanceBufColView =
278 outDistanceBufRowView.narrow(1, k * curColTile, k);
279 auto outIndexBufColView =
280 outIndexBufRowView.narrow(1, k * curColTile, k);
285 runMatrixMult(distanceBufView,
291 computeL2 ? -2.0f : 1.0f,
294 resources->getBlasHandleCurrentDevice(),
305 if (tileCols == numCentroids) {
307 runL2SelectMin(distanceBufView,
314 if (!ignoreOutDistances) {
317 runSumAlongRows(queryNormNiew,
324 auto centroidNormsView = centroidNorms->narrow(0, j, curCentroidSize);
327 runL2SelectMin(distanceBufView,
329 outDistanceBufColView,
334 if (!ignoreOutDistances) {
337 runSumAlongRows(queryNormNiew,
338 outDistanceBufColView,
346 if (tileCols == numCentroids) {
348 runBlockSelect(distanceBufView,
351 true, k, streams[curStream]);
354 runBlockSelect(distanceBufView,
355 outDistanceBufColView,
357 true, k, streams[curStream]);
364 if (tileCols != numCentroids) {
367 runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]);
369 runBlockSelectPair(outDistanceBufRowView,
373 computeL2 ?
false :
true, k, streams[curStream]);
376 curStream = (curStream + 1) % 2;
380 streamWait({defaultStream}, streams);
383 FAISS_THROW_MSG(
"interrupted");
387 template <
typename T>
388 void runL2Distance(GpuResources* resources,
389 Tensor<T, 2, true>& centroids,
390 bool centroidsRowMajor,
391 Tensor<T, 1, true>* centroidNorms,
392 Tensor<T, 2, true>& queries,
393 bool queriesRowMajor,
395 Tensor<T, 2, true>& outDistances,
396 Tensor<int, 2, true>& outIndices,
398 bool ignoreOutDistances =
false) {
413 template <
typename T>
414 void runIPDistance(GpuResources* resources,
415 Tensor<T, 2, true>& centroids,
416 bool centroidsRowMajor,
417 Tensor<T, 2, true>& queries,
418 bool queriesRowMajor,
420 Tensor<T, 2, true>& outDistances,
421 Tensor<int, 2, true>& outIndices,
423 runDistance<T>(
false,
442 runIPDistance(GpuResources* resources,
443 Tensor<float, 2, true>& vectors,
444 bool vectorsRowMajor,
445 Tensor<float, 2, true>& queries,
446 bool queriesRowMajor,
448 Tensor<float, 2, true>& outDistances,
449 Tensor<int, 2, true>& outIndices) {
450 runIPDistance<float>(resources,
461 #ifdef FAISS_USE_FLOAT16
463 runIPDistance(GpuResources* resources,
464 Tensor<half, 2, true>& vectors,
465 bool vectorsRowMajor,
466 Tensor<half, 2, true>& queries,
467 bool queriesRowMajor,
469 Tensor<half, 2, true>& outDistances,
470 Tensor<int, 2, true>& outIndices,
472 runIPDistance<half>(resources,
485 runL2Distance(GpuResources* resources,
486 Tensor<float, 2, true>& vectors,
487 bool vectorsRowMajor,
488 Tensor<float, 1, true>* vectorNorms,
489 Tensor<float, 2, true>& queries,
490 bool queriesRowMajor,
492 Tensor<float, 2, true>& outDistances,
493 Tensor<int, 2, true>& outIndices,
494 bool ignoreOutDistances) {
495 runL2Distance<float>(resources,
508 #ifdef FAISS_USE_FLOAT16
510 runL2Distance(GpuResources* resources,
511 Tensor<half, 2, true>& vectors,
512 bool vectorsRowMajor,
513 Tensor<half, 1, true>* vectorNorms,
514 Tensor<half, 2, true>& queries,
515 bool queriesRowMajor,
517 Tensor<half, 2, true>& outDistances,
518 Tensor<int, 2, true>& outIndices,
520 bool ignoreOutDistances) {
521 runL2Distance<half>(resources,
static bool is_interrupted()