11 #include "Distance.cuh"
12 #include "BroadcastSum.cuh"
14 #include "L2Select.cuh"
15 #include "../../FaissAssert.h"
16 #include "../GpuResources.h"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/Limits.cuh"
19 #include "../utils/MatrixMult.cuh"
20 #include "../utils/BlockSelectKernel.cuh"
23 #include <thrust/fill.h>
24 #include <thrust/for_each.h>
25 #include <thrust/device_ptr.h>
26 #include <thrust/execution_policy.h>
28 namespace faiss {
namespace gpu {
32 constexpr
int kDefaultTileSize = 256;
35 int chooseTileSize(
int tileSizeOverride,
37 size_t tempMemAvailable) {
38 if (tileSizeOverride > 0) {
39 return tileSizeOverride;
43 sizeof(T) < 4 ? kDefaultTileSize * 2 : kDefaultTileSize;
45 while (tileSize > 64) {
46 size_t memRequirement = 2 * tileSize * numCentroids *
sizeof(T);
48 if (memRequirement <= tempMemAvailable) {
58 FAISS_ASSERT(tileSize >= 64);
69 void runL2Distance(GpuResources* resources,
70 Tensor<T, 2, true>& centroids,
71 Tensor<T, 2, true>* centroidsTransposed,
72 Tensor<T, 1, true>* centroidNorms,
73 Tensor<T, 2, true>& queries,
75 Tensor<T, 2, true>& outDistances,
76 Tensor<int, 2, true>& outIndices,
78 bool ignoreOutDistances =
false,
79 int tileSizeOverride = -1) {
80 FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
81 FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
82 FAISS_ASSERT(outDistances.getSize(1) == k);
83 FAISS_ASSERT(outIndices.getSize(1) == k);
85 auto& mem = resources->getMemoryManagerCurrentDevice();
86 auto defaultStream = resources->getDefaultStreamCurrentDevice();
89 if (centroids.numElements() == 0) {
90 thrust::fill(thrust::cuda::par.on(defaultStream),
91 outDistances.data(), outDistances.end(),
94 thrust::fill(thrust::cuda::par.on(defaultStream),
95 outIndices.data(), outIndices.end(),
102 DeviceTensor<T, 1, true> cNorms;
103 if (!centroidNorms) {
104 cNorms = std::move(DeviceTensor<T, 1, true>(
106 {centroids.getSize(0)}, defaultStream));
107 runL2Norm(centroids, cNorms,
true, defaultStream);
108 centroidNorms = &cNorms;
114 int qNormSize[1] = {queries.getSize(0)};
115 DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
118 runL2Norm(queries, queryNorms,
true, defaultStream);
125 FAISS_ASSERT(k <= centroids.getSize(0));
126 FAISS_ASSERT(k <= 1024);
131 centroids.getSize(0),
132 resources->getMemoryManagerCurrentDevice().getSizeAvailable());
134 int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
137 DeviceTensor<T, 2, true> distanceBuf1(
138 mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
139 DeviceTensor<T, 2, true> distanceBuf2(
140 mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
141 DeviceTensor<T, 2, true>* distanceBufs[2] =
142 {&distanceBuf1, &distanceBuf2};
144 auto streams = resources->getAlternateStreamsCurrentDevice();
145 streamWait(streams, {defaultStream});
149 for (
int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
150 int numQueriesForIteration = std::min(maxQueriesPerIteration,
151 queries.getSize(0) - i);
153 auto distanceBufView =
154 distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
156 queries.narrowOutermost(i, numQueriesForIteration);
157 auto outDistanceView =
158 outDistances.narrowOutermost(i, numQueriesForIteration);
160 outIndices.narrowOutermost(i, numQueriesForIteration);
162 queryNorms.narrowOutermost(i, numQueriesForIteration);
168 runMatrixMult(distanceBufView,
false,
170 centroidsTransposed ? *centroidsTransposed : centroids,
171 centroidsTransposed ?
false :
true,
172 -2.0f, 0.0f, useHgemm,
173 resources->getBlasHandleCurrentDevice(),
180 runL2SelectMin(distanceBufView,
187 if (!ignoreOutDistances) {
190 runSumAlongRows(queryNormNiew, outDistanceView, streams[curStream]);
193 curStream = (curStream + 1) % 2;
197 streamWait({defaultStream}, streams);
200 template <
typename T>
201 void runIPDistance(GpuResources* resources,
202 Tensor<T, 2, true>& centroids,
203 Tensor<T, 2, true>* centroidsTransposed,
204 Tensor<T, 2, true>& queries,
206 Tensor<T, 2, true>& outDistances,
207 Tensor<int, 2, true>& outIndices,
209 int tileSizeOverride = -1) {
210 FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
211 FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
212 FAISS_ASSERT(outDistances.getSize(1) == k);
213 FAISS_ASSERT(outIndices.getSize(1) == k);
215 auto& mem = resources->getMemoryManagerCurrentDevice();
216 auto defaultStream = resources->getDefaultStreamCurrentDevice();
219 if (centroids.numElements() == 0) {
220 thrust::fill(thrust::cuda::par.on(defaultStream),
221 outDistances.data(), outDistances.end(),
222 Limits<T>::getMax());
224 thrust::fill(thrust::cuda::par.on(defaultStream),
225 outIndices.data(), outIndices.end(),
236 FAISS_ASSERT(k <= centroids.getSize(0));
237 FAISS_ASSERT(k <= 1024);
242 centroids.getSize(0),
243 resources->getMemoryManagerCurrentDevice().getSizeAvailable());
245 int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
248 DeviceTensor<T, 2, true> distanceBuf1(
249 mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
250 DeviceTensor<T, 2, true> distanceBuf2(
251 mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
252 DeviceTensor<T, 2, true>* distanceBufs[2] =
253 {&distanceBuf1, &distanceBuf2};
255 auto streams = resources->getAlternateStreamsCurrentDevice();
256 streamWait(streams, {defaultStream});
260 for (
int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
261 int numQueriesForIteration = std::min(maxQueriesPerIteration,
262 queries.getSize(0) - i);
264 auto distanceBufView =
265 distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
267 queries.narrowOutermost(i, numQueriesForIteration);
268 auto outDistanceView =
269 outDistances.narrowOutermost(i, numQueriesForIteration);
271 outIndices.narrowOutermost(i, numQueriesForIteration);
274 runMatrixMult(distanceBufView,
false,
276 centroidsTransposed ? *centroidsTransposed : centroids,
277 centroidsTransposed ?
false :
true,
278 1.0f, 0.0f, useHgemm,
279 resources->getBlasHandleCurrentDevice(),
284 runBlockSelect(distanceBufView,
287 true, k, streams[curStream]);
289 curStream = (curStream + 1) % 2;
292 streamWait({defaultStream}, streams);
300 runIPDistance(GpuResources* resources,
301 Tensor<float, 2, true>& vectors,
302 Tensor<float, 2, true>* vectorsTransposed,
303 Tensor<float, 2, true>& queries,
305 Tensor<float, 2, true>& outDistances,
306 Tensor<int, 2, true>& outIndices,
307 int tileSizeOverride) {
308 runIPDistance<float>(resources,
319 #ifdef FAISS_USE_FLOAT16
321 runIPDistance(GpuResources* resources,
322 Tensor<half, 2, true>& vectors,
323 Tensor<half, 2, true>* vectorsTransposed,
324 Tensor<half, 2, true>& queries,
326 Tensor<half, 2, true>& outDistances,
327 Tensor<int, 2, true>& outIndices,
329 int tileSizeOverride) {
330 runIPDistance<half>(resources,
343 runL2Distance(GpuResources* resources,
344 Tensor<float, 2, true>& vectors,
345 Tensor<float, 2, true>* vectorsTransposed,
346 Tensor<float, 1, true>* vectorNorms,
347 Tensor<float, 2, true>& queries,
349 Tensor<float, 2, true>& outDistances,
350 Tensor<int, 2, true>& outIndices,
351 bool ignoreOutDistances,
352 int tileSizeOverride) {
353 runL2Distance<float>(resources,
366 #ifdef FAISS_USE_FLOAT16
368 runL2Distance(GpuResources* resources,
369 Tensor<half, 2, true>& vectors,
370 Tensor<half, 2, true>* vectorsTransposed,
371 Tensor<half, 1, true>* vectorNorms,
372 Tensor<half, 2, true>& queries,
374 Tensor<half, 2, true>& outDistances,
375 Tensor<int, 2, true>& outIndices,
377 bool ignoreOutDistances,
378 int tileSizeOverride) {
379 runL2Distance<half>(resources,