Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Distance.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "Distance.cuh"
11 #include "BroadcastSum.cuh"
12 #include "L2Norm.cuh"
13 #include "L2Select.cuh"
14 #include "../../FaissAssert.h"
15 #include "../GpuResources.h"
16 #include "../utils/DeviceUtils.h"
17 #include "../utils/Limits.cuh"
18 #include "../utils/MatrixMult.cuh"
19 #include "../utils/BlockSelectKernel.cuh"
20 
21 #include <memory>
22 #include <thrust/fill.h>
23 #include <thrust/for_each.h>
24 #include <thrust/device_ptr.h>
25 #include <thrust/execution_policy.h>
26 
27 namespace faiss { namespace gpu {
28 
29 namespace {
30 
31 template <typename T>
32 Tensor<T, 2, true> sliceCentroids(Tensor<T, 2, true>& centroids,
33  Tensor<T, 2, true>* centroidsTransposed,
34  int startCentroid,
35  int num) {
36  if (startCentroid == 0 && num == centroids.getSize(0)) {
37  if (centroidsTransposed) {
38  return *centroidsTransposed;
39  } else {
40  return centroids;
41  }
42  }
43 
44  if (centroidsTransposed) {
45  // (dim, num)
46  return centroidsTransposed->narrow(1, startCentroid, num);
47  } else {
48  return centroids.narrow(0, startCentroid, num);
49  }
50 }
51 
52 // For each chunk of k indices, increment the index by chunk * increment
53 template <typename T>
54 __global__ void incrementIndex(Tensor<T, 2, true> indices,
55  int k,
56  int increment) {
57  for (int i = threadIdx.x; i < k; i += blockDim.x) {
58  indices[blockIdx.y][blockIdx.x * k + i] += blockIdx.x * increment;
59  }
60 }
61 
62 // Used to update result indices in distance computation where the number of
63 // centroids is high, and is tiled
64 template <typename T>
65 void runIncrementIndex(Tensor<T, 2, true>& indices,
66  int k,
67  int increment,
68  cudaStream_t stream) {
69  dim3 grid(indices.getSize(1) / k, indices.getSize(0));
70  int block = std::min(k, 512);
71 
72  // should be exact
73  FAISS_ASSERT(grid.x * k == indices.getSize(1));
74 
75  incrementIndex<<<grid, block, 0, stream>>>(indices, k, increment);
76 
77  cudaDeviceSynchronize();
78 }
79 
80 // If the inner size (dim) of the vectors is small, we want a larger query tile
81 // size, like 1024
82 
83 void chooseTileSize(int numQueries,
84  int numCentroids,
85  int dim,
86  int elementSize,
87  size_t tempMemAvailable,
88  int& tileRows,
89  int& tileCols) {
90  // The matrix multiplication should be large enough to be efficient, but if it
91  // is too large, we seem to lose efficiency as opposed to double-streaming.
92  // Each tile size here defines 1/2 of the memory use due to double streaming.
93  // We ignore available temporary memory, as that is adjusted independently by
94  // the user and can thus meet these requirements (or not).
95  // For <= 4 GB GPUs, prefer 512 MB of usage.
96  // For <= 8 GB GPUs, prefer 768 MB of usage.
97  // Otherwise, prefer 1 GB of usage.
98  auto totalMem = getCurrentDeviceProperties().totalGlobalMem;
99 
100  int targetUsage = 0;
101 
102  if (totalMem <= ((size_t) 4) * 1024 * 1024 * 1024) {
103  targetUsage = 512 * 1024 * 1024;
104  } else if (totalMem <= ((size_t) 8) * 1024 * 1024 * 1024) {
105  targetUsage = 768 * 1024 * 1024;
106  } else {
107  targetUsage = 1024 * 1024 * 1024;
108  }
109 
110  targetUsage /= 2 * elementSize;
111 
112  // 512 seems to be a batch size sweetspot for float32.
113  // If we are on float16, increase to 512.
114  // If the k size (vec dim) of the matrix multiplication is small (<= 32),
115  // increase to 1024.
116  int preferredTileRows = 512;
117  if (dim <= 32) {
118  preferredTileRows = 1024;
119  }
120 
121  tileRows = std::min(preferredTileRows, numQueries);
122 
123  // tileCols is the remainder size
124  tileCols = std::min(targetUsage / preferredTileRows, numCentroids);
125 }
126 
127 }
128 
129 template <typename T>
130 void runDistance(bool computeL2,
131  GpuResources* resources,
132  Tensor<T, 2, true>& centroids,
133  Tensor<T, 2, true>* centroidsTransposed,
134  Tensor<T, 1, true>* centroidNorms,
135  Tensor<T, 2, true>& queries,
136  int k,
137  Tensor<T, 2, true>& outDistances,
138  Tensor<int, 2, true>& outIndices,
139  bool useHgemm,
140  bool ignoreOutDistances) {
141  FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
142  FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
143  FAISS_ASSERT(outDistances.getSize(1) == k);
144  FAISS_ASSERT(outIndices.getSize(1) == k);
145 
146  auto& mem = resources->getMemoryManagerCurrentDevice();
147  auto defaultStream = resources->getDefaultStreamCurrentDevice();
148 
149  // If we're quering against a 0 sized set, just return empty results
150  if (centroids.numElements() == 0) {
151  thrust::fill(thrust::cuda::par.on(defaultStream),
152  outDistances.data(), outDistances.end(),
153  Limits<T>::getMax());
154 
155  thrust::fill(thrust::cuda::par.on(defaultStream),
156  outIndices.data(), outIndices.end(),
157  -1);
158 
159  return;
160  }
161 
162  // L2: If ||c||^2 is not pre-computed, calculate it
163  DeviceTensor<T, 1, true> cNorms;
164  if (computeL2 && !centroidNorms) {
165  cNorms = std::move(DeviceTensor<T, 1, true>(
166  mem,
167  {centroids.getSize(0)}, defaultStream));
168  runL2Norm(centroids, cNorms, true, defaultStream);
169  centroidNorms = &cNorms;
170  }
171 
172  //
173  // Prepare norm vector ||q||^2; ||c||^2 is already pre-computed
174  //
175  int qNormSize[1] = {queries.getSize(0)};
176  DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
177 
178  // ||q||^2
179  if (computeL2) {
180  runL2Norm(queries, queryNorms, true, defaultStream);
181  }
182 
183  // By default, aim to use up to 512 MB of memory for the processing, with both
184  // number of queries and number of centroids being at least 512.
185  int tileRows = 0;
186  int tileCols = 0;
187  chooseTileSize(queries.getSize(0),
188  centroids.getSize(0),
189  queries.getSize(1),
190  sizeof(T),
191  mem.getSizeAvailable(),
192  tileRows,
193  tileCols);
194 
195  int numColTiles = utils::divUp(centroids.getSize(0), tileCols);
196 
197  FAISS_ASSERT(k <= centroids.getSize(0));
198  FAISS_ASSERT(k <= 1024); // select limitation
199 
200  // Temporary output memory space we'll use
201  DeviceTensor<T, 2, true> distanceBuf1(
202  mem, {tileRows, tileCols}, defaultStream);
203  DeviceTensor<T, 2, true> distanceBuf2(
204  mem, {tileRows, tileCols}, defaultStream);
205  DeviceTensor<T, 2, true>* distanceBufs[2] =
206  {&distanceBuf1, &distanceBuf2};
207 
208  DeviceTensor<T, 2, true> outDistanceBuf1(
209  mem, {tileRows, numColTiles * k}, defaultStream);
210  DeviceTensor<T, 2, true> outDistanceBuf2(
211  mem, {tileRows, numColTiles * k}, defaultStream);
212  DeviceTensor<T, 2, true>* outDistanceBufs[2] =
213  {&outDistanceBuf1, &outDistanceBuf2};
214 
215  DeviceTensor<int, 2, true> outIndexBuf1(
216  mem, {tileRows, numColTiles * k}, defaultStream);
217  DeviceTensor<int, 2, true> outIndexBuf2(
218  mem, {tileRows, numColTiles * k}, defaultStream);
219  DeviceTensor<int, 2, true>* outIndexBufs[2] =
220  {&outIndexBuf1, &outIndexBuf2};
221 
222  auto streams = resources->getAlternateStreamsCurrentDevice();
223  streamWait(streams, {defaultStream});
224 
225  int curStream = 0;
226 
227  // Tile over the input queries
228  for (int i = 0; i < queries.getSize(0); i += tileRows) {
229  int curQuerySize = std::min(tileRows, queries.getSize(0) - i);
230 
231  auto outDistanceView =
232  outDistances.narrow(0, i, curQuerySize);
233  auto outIndexView =
234  outIndices.narrow(0, i, curQuerySize);
235 
236  auto queryView =
237  queries.narrow(0, i, curQuerySize);
238  auto queryNormNiew =
239  queryNorms.narrow(0, i, curQuerySize);
240 
241  auto outDistanceBufRowView =
242  outDistanceBufs[curStream]->narrow(0, 0, curQuerySize);
243  auto outIndexBufRowView =
244  outIndexBufs[curStream]->narrow(0, 0, curQuerySize);
245 
246  // Tile over the centroids
247  for (int j = 0; j < centroids.getSize(0); j += tileCols) {
248  int curCentroidSize = std::min(tileCols, centroids.getSize(0) - j);
249 
250  int curColTile = j / tileCols;
251 
252  auto centroidsView =
253  sliceCentroids(centroids, centroidsTransposed, j, curCentroidSize);
254 
255  auto distanceBufView = distanceBufs[curStream]->
256  narrow(0, 0, curQuerySize).narrow(1, 0, curCentroidSize);
257 
258  auto outDistanceBufColView =
259  outDistanceBufRowView.narrow(1, k * curColTile, k);
260  auto outIndexBufColView =
261  outIndexBufRowView.narrow(1, k * curColTile, k);
262 
263  // L2: distance is ||c||^2 - 2qc + ||q||^2, we compute -2qc
264  // IP: just compute qc
265  // (query id x dim) x (centroid id, dim)' = (query id, centroid id)
266  runMatrixMult(distanceBufView, false,
267  queryView, false,
268  centroidsView,
269  centroidsTransposed ? false : true,
270  computeL2 ? -2.0f : 1.0f, 0.0f, useHgemm,
271  resources->getBlasHandleCurrentDevice(),
272  streams[curStream]);
273 
274  if (computeL2) {
275  // For L2 distance, we use this fused kernel that performs both
276  // adding ||c||^2 to -2qc and k-selection, so we only need two
277  // passes (one write by the gemm, one read here) over the huge
278  // region of output memory
279  //
280  // If we aren't tiling along the number of centroids, we can perform the
281  // output work directly
282  if (tileCols == centroids.getSize(0)) {
283  // Write into the final output
284  runL2SelectMin(distanceBufView,
285  *centroidNorms,
286  outDistanceView,
287  outIndexView,
288  k,
289  streams[curStream]);
290 
291  if (!ignoreOutDistances) {
292  // expand (query id) to (query id, k) by duplicating along rows
293  // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
294  runSumAlongRows(queryNormNiew,
295  outDistanceView,
296  true, // L2 distances should not go below zero due
297  // to roundoff error
298  streams[curStream]);
299  }
300  } else {
301  auto centroidNormsView =
302  centroidNorms->narrow(0, j, curCentroidSize);
303 
304  // Write into our intermediate output
305  runL2SelectMin(distanceBufView,
306  centroidNormsView,
307  outDistanceBufColView,
308  outIndexBufColView,
309  k,
310  streams[curStream]);
311 
312  if (!ignoreOutDistances) {
313  // expand (query id) to (query id, k) by duplicating along rows
314  // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
315  runSumAlongRows(queryNormNiew,
316  outDistanceBufColView,
317  true, // L2 distances should not go below zero due
318  // to roundoff error
319  streams[curStream]);
320  }
321  }
322  } else {
323  // For IP, just k-select the output for this tile
324  if (tileCols == centroids.getSize(0)) {
325  // Write into the final output
326  runBlockSelect(distanceBufView,
327  outDistanceView,
328  outIndexView,
329  true, k, streams[curStream]);
330  } else {
331  // Write into the intermediate output
332  runBlockSelect(distanceBufView,
333  outDistanceBufColView,
334  outIndexBufColView,
335  true, k, streams[curStream]);
336  }
337  }
338  }
339 
340  // As we're finished with processing a full set of centroids, perform the
341  // final k-selection
342  if (tileCols != centroids.getSize(0)) {
343  // The indices are tile-relative; for each tile of k, we need to add
344  // tileCols to the index
345  runIncrementIndex(outIndexBufRowView, k, tileCols, streams[curStream]);
346 
347  runBlockSelectPair(outDistanceBufRowView,
348  outIndexBufRowView,
349  outDistanceView,
350  outIndexView,
351  computeL2 ? false : true, k, streams[curStream]);
352  }
353 
354  curStream = (curStream + 1) % 2;
355  }
356 
357  // Have the desired ordering stream wait on the multi-stream
358  streamWait({defaultStream}, streams);
359 }
360 
361 template <typename T>
362 void runL2Distance(GpuResources* resources,
363  Tensor<T, 2, true>& centroids,
364  Tensor<T, 2, true>* centroidsTransposed,
365  Tensor<T, 1, true>* centroidNorms,
366  Tensor<T, 2, true>& queries,
367  int k,
368  Tensor<T, 2, true>& outDistances,
369  Tensor<int, 2, true>& outIndices,
370  bool useHgemm,
371  bool ignoreOutDistances = false) {
372  runDistance<T>(true, // L2
373  resources,
374  centroids,
375  centroidsTransposed,
376  centroidNorms,
377  queries,
378  k,
379  outDistances,
380  outIndices,
381  useHgemm,
382  ignoreOutDistances);
383 }
384 
385 template <typename T>
386 void runIPDistance(GpuResources* resources,
387  Tensor<T, 2, true>& centroids,
388  Tensor<T, 2, true>* centroidsTransposed,
389  Tensor<T, 2, true>& queries,
390  int k,
391  Tensor<T, 2, true>& outDistances,
392  Tensor<int, 2, true>& outIndices,
393  bool useHgemm) {
394  runDistance<T>(false, // IP
395  resources,
396  centroids,
397  centroidsTransposed,
398  nullptr,
399  queries,
400  k,
401  outDistances,
402  outIndices,
403  useHgemm,
404  false);
405 }
406 
407 //
408 // Instantiations of the distance templates
409 //
410 
411 void
412 runIPDistance(GpuResources* resources,
413  Tensor<float, 2, true>& vectors,
414  Tensor<float, 2, true>* vectorsTransposed,
415  Tensor<float, 2, true>& queries,
416  int k,
417  Tensor<float, 2, true>& outDistances,
418  Tensor<int, 2, true>& outIndices) {
419  runIPDistance<float>(resources,
420  vectors,
421  vectorsTransposed,
422  queries,
423  k,
424  outDistances,
425  outIndices,
426  false);
427 }
428 
429 #ifdef FAISS_USE_FLOAT16
430 void
431 runIPDistance(GpuResources* resources,
432  Tensor<half, 2, true>& vectors,
433  Tensor<half, 2, true>* vectorsTransposed,
434  Tensor<half, 2, true>& queries,
435  int k,
436  Tensor<half, 2, true>& outDistances,
437  Tensor<int, 2, true>& outIndices,
438  bool useHgemm) {
439  runIPDistance<half>(resources,
440  vectors,
441  vectorsTransposed,
442  queries,
443  k,
444  outDistances,
445  outIndices,
446  useHgemm);
447 }
448 #endif
449 
450 void
451 runL2Distance(GpuResources* resources,
452  Tensor<float, 2, true>& vectors,
453  Tensor<float, 2, true>* vectorsTransposed,
454  Tensor<float, 1, true>* vectorNorms,
455  Tensor<float, 2, true>& queries,
456  int k,
457  Tensor<float, 2, true>& outDistances,
458  Tensor<int, 2, true>& outIndices,
459  bool ignoreOutDistances) {
460  runL2Distance<float>(resources,
461  vectors,
462  vectorsTransposed,
463  vectorNorms,
464  queries,
465  k,
466  outDistances,
467  outIndices,
468  false,
469  ignoreOutDistances);
470 }
471 
472 #ifdef FAISS_USE_FLOAT16
473 void
474 runL2Distance(GpuResources* resources,
475  Tensor<half, 2, true>& vectors,
476  Tensor<half, 2, true>* vectorsTransposed,
477  Tensor<half, 1, true>* vectorNorms,
478  Tensor<half, 2, true>& queries,
479  int k,
480  Tensor<half, 2, true>& outDistances,
481  Tensor<int, 2, true>& outIndices,
482  bool useHgemm,
483  bool ignoreOutDistances) {
484  runL2Distance<half>(resources,
485  vectors,
486  vectorsTransposed,
487  vectorNorms,
488  queries,
489  k,
490  outDistances,
491  outIndices,
492  useHgemm,
493  ignoreOutDistances);
494 }
495 #endif
496 
497 } } // namespace