11 #include "Distance.cuh"
12 #include "BroadcastSum.cuh"
14 #include "L2Select.cuh"
15 #include "../../FaissAssert.h"
16 #include "../GpuResources.h"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/Limits.cuh"
19 #include "../utils/MatrixMult.cuh"
20 #include "../utils/BlockSelectKernel.cuh"
23 #include <thrust/fill.h>
24 #include <thrust/for_each.h>
25 #include <thrust/device_ptr.h>
26 #include <thrust/execution_policy.h>
28 namespace faiss {
namespace gpu {
33 Tensor<T, 2, true> sliceCentroids(Tensor<T, 2, true>& centroids,
34 Tensor<T, 2, true>* centroidsTransposed,
37 if (startCentroid == 0 && num == centroids.getSize(0)) {
38 if (centroidsTransposed) {
39 return *centroidsTransposed;
45 if (centroidsTransposed) {
47 return centroidsTransposed->narrow(1, startCentroid, num);
49 return centroids.narrow(0, startCentroid, num);
55 __global__
void incrementIndex(Tensor<T, 2, true> indices,
58 for (
int i = threadIdx.x; i < k; i += blockDim.x) {
59 indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment;
66 void runIncrementIndex(Tensor<T, 2, true>& indices,
69 cudaStream_t stream) {
70 dim3 grid(indices.getSize(1) / k, indices.getSize(0));
71 int block = std::min(k, 512);
74 FAISS_ASSERT(grid.x * k == indices.getSize(1));
76 incrementIndex<<<grid, block, 0, stream>>>(indices, k, increment);
78 cudaDeviceSynchronize();
84 void chooseTileSize(
int numQueries,
88 size_t tempMemAvailable,
99 auto totalMem = getCurrentDeviceProperties().totalGlobalMem;
103 if (totalMem <= ((
size_t) 4) * 1024 * 1024 * 1024) {
104 targetUsage = 512 * 1024 * 1024;
105 }
else if (totalMem <= ((
size_t) 8) * 1024 * 1024 * 1024) {
106 targetUsage = 768 * 1024 * 1024;
108 targetUsage = 1024 * 1024 * 1024;
111 targetUsage /= 2 * elementSize;
117 int preferredTileRows = 512;
119 preferredTileRows = 1024;
122 tileRows = std::min(preferredTileRows, numQueries);
125 tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
130 template <
typename T>
131 void runDistance(
bool computeL2,
132 GpuResources* resources,
133 Tensor<T, 2, true>& centroids,
134 Tensor<T, 2, true>* centroidsTransposed,
135 Tensor<T, 1, true>* centroidNorms,
136 Tensor<T, 2, true>& queries,
138 Tensor<T, 2, true>& outDistances,
139 Tensor<int, 2, true>& outIndices,
141 bool ignoreOutDistances) {
142 FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
143 FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
144 FAISS_ASSERT(outDistances.getSize(1) == k);
145 FAISS_ASSERT(outIndices.getSize(1) == k);
147 auto& mem = resources->getMemoryManagerCurrentDevice();
148 auto defaultStream = resources->getDefaultStreamCurrentDevice();
151 if (centroids.numElements() == 0) {
152 thrust::fill(thrust::cuda::par.on(defaultStream),
153 outDistances.data(), outDistances.end(),
154 Limits<T>::getMax());
156 thrust::fill(thrust::cuda::par.on(defaultStream),
157 outIndices.data(), outIndices.end(),
164 DeviceTensor<T, 1, true> cNorms;
165 if (computeL2 && !centroidNorms) {
166 cNorms = std::move(DeviceTensor<T, 1, true>(
168 {centroids.getSize(0)}, defaultStream));
169 runL2Norm(centroids, cNorms,
true, defaultStream);
170 centroidNorms = &cNorms;
176 int qNormSize[1] = {queries.getSize(0)};
177 DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
181 runL2Norm(queries, queryNorms,
true, defaultStream);
188 chooseTileSize(queries.getSize(0),
189 centroids.getSize(0),
192 mem.getSizeAvailable(),
196 int numColTiles = utils::divUp(centroids.getSize(0), tileCols);
198 FAISS_ASSERT(k <= centroids.getSize(0));
199 FAISS_ASSERT(k <= 1024);
202 DeviceTensor<T, 2, true> distanceBuf1(
203 mem, {tileRows, tileCols}, defaultStream);
204 DeviceTensor<T, 2, true> distanceBuf2(
205 mem, {tileRows, tileCols}, defaultStream);
206 DeviceTensor<T, 2, true>* distanceBufs[2] =
207 {&distanceBuf1, &distanceBuf2};
209 DeviceTensor<T, 2, true> outDistanceBuf1(
210 mem, {tileRows, numColTiles * k}, defaultStream);
211 DeviceTensor<T, 2, true> outDistanceBuf2(
212 mem, {tileRows, numColTiles * k}, defaultStream);
213 DeviceTensor<T, 2, true>* outDistanceBufs[2] =
214 {&outDistanceBuf1, &outDistanceBuf2};
216 DeviceTensor<int, 2, true> outIndexBuf1(
217 mem, {tileRows, numColTiles * k}, defaultStream);
218 DeviceTensor<int, 2, true> outIndexBuf2(
219 mem, {tileRows, numColTiles * k}, defaultStream);
220 DeviceTensor<int, 2, true>* outIndexBufs[2] =
221 {&outIndexBuf1, &outIndexBuf2};
223 auto streams = resources->getAlternateStreamsCurrentDevice();
224 streamWait(streams, {defaultStream});
229 for (
int i = 0; i < queries.getSize(0); i += tileRows) {
230 int curQuerySize = std::min(tileRows, queries.getSize(0) - i);
232 auto outDistanceView =
233 outDistances.narrow(0, i, curQuerySize);
235 outIndices.narrow(0, i, curQuerySize);
238 queries.narrow(0, i, curQuerySize);
240 queryNorms.narrow(0, i, curQuerySize);
242 auto outDistanceBufRowView =
243 outDistanceBufs[curStream]->narrow(0, 0, curQuerySize);
244 auto outIndexBufRowView =
245 outIndexBufs[curStream]->narrow(0, 0, curQuerySize);
248 for (
int j = 0; j < centroids.getSize(0); j += tileCols) {
249 int curCentroidSize = std::min(tileCols, centroids.getSize(0) - j);
251 int curColTile = j / tileCols;
254 sliceCentroids(centroids, centroidsTransposed, j, curCentroidSize);
256 auto distanceBufView = distanceBufs[curStream]->
257 narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize);
259 auto outDistanceBufColView =
260 outDistanceBufRowView.narrow(1, k * curColTile, k);
261 auto outIndexBufColView =
262 outIndexBufRowView.narrow(1, k * curColTile, k);
267 runMatrixMult(distanceBufView,
false,
270 centroidsTransposed ?
false :
true,
271 computeL2 ? -2.0f : 1.0f, 0.0f, useHgemm,
272 resources->getBlasHandleCurrentDevice(),
283 if (tileCols == centroids.getSize(0)) {
285 runL2SelectMin(distanceBufView,
292 if (!ignoreOutDistances) {
295 runSumAlongRows(queryNormNiew, outDistanceView, streams[curStream]);
298 auto centroidNormsView =
299 centroidNorms->narrow(0, j, curCentroidSize);
302 runL2SelectMin(distanceBufView,
304 outDistanceBufColView,
309 if (!ignoreOutDistances) {
312 runSumAlongRows(queryNormNiew,
313 outDistanceBufColView,
319 if (tileCols == centroids.getSize(0)) {
321 runBlockSelect(distanceBufView,
324 true, k, streams[curStream]);
327 runBlockSelect(distanceBufView,
328 outDistanceBufColView,
330 true, k, streams[curStream]);
337 if (tileCols != centroids.getSize(0)) {
340 runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]);
342 runBlockSelectPair(outDistanceBufRowView,
346 computeL2 ?
false :
true, k, streams[curStream]);
349 curStream = (curStream + 1) % 2;
353 streamWait({defaultStream}, streams);
356 template <
typename T>
357 void runL2Distance(GpuResources* resources,
358 Tensor<T, 2, true>& centroids,
359 Tensor<T, 2, true>* centroidsTransposed,
360 Tensor<T, 1, true>* centroidNorms,
361 Tensor<T, 2, true>& queries,
363 Tensor<T, 2, true>& outDistances,
364 Tensor<int, 2, true>& outIndices,
366 bool ignoreOutDistances =
false) {
380 template <
typename T>
381 void runIPDistance(GpuResources* resources,
382 Tensor<T, 2, true>& centroids,
383 Tensor<T, 2, true>* centroidsTransposed,
384 Tensor<T, 2, true>& queries,
386 Tensor<T, 2, true>& outDistances,
387 Tensor<int, 2, true>& outIndices,
389 runDistance<T>(
false,
407 runIPDistance(GpuResources* resources,
408 Tensor<float, 2, true>& vectors,
409 Tensor<float, 2, true>* vectorsTransposed,
410 Tensor<float, 2, true>& queries,
412 Tensor<float, 2, true>& outDistances,
413 Tensor<int, 2, true>& outIndices) {
414 runIPDistance<float>(resources,
424 #ifdef FAISS_USE_FLOAT16
426 runIPDistance(GpuResources* resources,
427 Tensor<half, 2, true>& vectors,
428 Tensor<half, 2, true>* vectorsTransposed,
429 Tensor<half, 2, true>& queries,
431 Tensor<half, 2, true>& outDistances,
432 Tensor<int, 2, true>& outIndices,
434 runIPDistance<half>(resources,
446 runL2Distance(GpuResources* resources,
447 Tensor<float, 2, true>& vectors,
448 Tensor<float, 2, true>* vectorsTransposed,
449 Tensor<float, 1, true>* vectorNorms,
450 Tensor<float, 2, true>& queries,
452 Tensor<float, 2, true>& outDistances,
453 Tensor<int, 2, true>& outIndices,
454 bool ignoreOutDistances) {
455 runL2Distance<float>(resources,
467 #ifdef FAISS_USE_FLOAT16
469 runL2Distance(GpuResources* resources,
470 Tensor<half, 2, true>& vectors,
471 Tensor<half, 2, true>* vectorsTransposed,
472 Tensor<half, 1, true>* vectorNorms,
473 Tensor<half, 2, true>& queries,
475 Tensor<half, 2, true>& outDistances,
476 Tensor<int, 2, true>& outIndices,
478 bool ignoreOutDistances) {
479 runL2Distance<half>(resources,