Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Distance.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the CC-by-NC license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "Distance.cuh"
12 #include "BroadcastSum.cuh"
13 #include "L2Norm.cuh"
14 #include "L2Select.cuh"
15 #include "../../FaissAssert.h"
16 #include "../GpuResources.h"
17 #include "../utils/DeviceUtils.h"
18 #include "../utils/Limits.cuh"
19 #include "../utils/MatrixMult.cuh"
20 #include "../utils/BlockSelectKernel.cuh"
21 
22 #include <memory>
23 #include <thrust/fill.h>
24 #include <thrust/for_each.h>
25 #include <thrust/device_ptr.h>
26 #include <thrust/execution_policy.h>
27 
28 namespace faiss { namespace gpu {
29 
30 namespace {
31 
32 constexpr int kDefaultTileSize = 256;
33 
34 template <typename T>
35 int chooseTileSize(int tileSizeOverride,
36  size_t numCentroids,
37  size_t tempMemAvailable) {
38  if (tileSizeOverride > 0) {
39  return tileSizeOverride;
40  }
41 
42  size_t tileSize =
43  sizeof(T) < 4 ? kDefaultTileSize * 2 : kDefaultTileSize;
44 
45  while (tileSize > 64) {
46  size_t memRequirement = 2 * tileSize * numCentroids * sizeof(T);
47 
48  if (memRequirement <= tempMemAvailable) {
49  // This fits entirely into our temporary memory
50  return tileSize;
51  }
52 
53  // Otherwise, halve the tile size
54  tileSize /= 2;
55  }
56 
57  // We use 64 as the minimum acceptable tile size
58  FAISS_ASSERT(tileSize >= 64);
59 
60  // FIXME: if we're running with no available temp memory, do we try
61  // and go larger based on free memory available on the device?
62 
63  return tileSize;
64 }
65 
66 }
67 
68 template <typename T>
69 void runL2Distance(GpuResources* resources,
70  Tensor<T, 2, true>& centroids,
71  Tensor<T, 2, true>* centroidsTransposed,
72  Tensor<T, 1, true>* centroidNorms,
73  Tensor<T, 2, true>& queries,
74  int k,
75  Tensor<T, 2, true>& outDistances,
76  Tensor<int, 2, true>& outIndices,
77  bool useHgemm,
78  bool ignoreOutDistances = false,
79  int tileSizeOverride = -1) {
80  FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
81  FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
82  FAISS_ASSERT(outDistances.getSize(1) == k);
83  FAISS_ASSERT(outIndices.getSize(1) == k);
84 
85  auto& mem = resources->getMemoryManagerCurrentDevice();
86  auto defaultStream = resources->getDefaultStreamCurrentDevice();
87 
88  // If we're quering against a 0 sized set, just return empty results
89  if (centroids.numElements() == 0) {
90  thrust::fill(thrust::cuda::par.on(defaultStream),
91  outDistances.data(), outDistances.end(),
92  Limits<T>::getMax());
93 
94  thrust::fill(thrust::cuda::par.on(defaultStream),
95  outIndices.data(), outIndices.end(),
96  -1);
97 
98  return;
99  }
100 
101  // If ||c||^2 is not pre-computed, calculate it
102  DeviceTensor<T, 1, true> cNorms;
103  if (!centroidNorms) {
104  cNorms = std::move(DeviceTensor<T, 1, true>(
105  mem,
106  {centroids.getSize(0)}, defaultStream));
107  runL2Norm(centroids, cNorms, true, defaultStream);
108  centroidNorms = &cNorms;
109  }
110 
111  //
112  // Prepare norm vector ||q||^2; ||c||^2 is already pre-computed
113  //
114  int qNormSize[1] = {queries.getSize(0)};
115  DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
116 
117  // ||q||^2
118  runL2Norm(queries, queryNorms, true, defaultStream);
119 
120  //
121  // Handle the problem in row tiles, to avoid excessive temporary
122  // memory requests
123  //
124 
125  FAISS_ASSERT(k <= centroids.getSize(0));
126  FAISS_ASSERT(k <= 1024); // select limitation
127 
128  int tileSize =
129  chooseTileSize<T>(
130  tileSizeOverride,
131  centroids.getSize(0),
132  resources->getMemoryManagerCurrentDevice().getSizeAvailable());
133 
134  int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
135 
136  // Temporary output memory space we'll use
137  DeviceTensor<T, 2, true> distanceBuf1(
138  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
139  DeviceTensor<T, 2, true> distanceBuf2(
140  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
141  DeviceTensor<T, 2, true>* distanceBufs[2] =
142  {&distanceBuf1, &distanceBuf2};
143 
144  auto streams = resources->getAlternateStreamsCurrentDevice();
145  streamWait(streams, {defaultStream});
146 
147  int curStream = 0;
148 
149  for (int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
150  int numQueriesForIteration = std::min(maxQueriesPerIteration,
151  queries.getSize(0) - i);
152 
153  auto distanceBufView =
154  distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
155  auto queryView =
156  queries.narrowOutermost(i, numQueriesForIteration);
157  auto outDistanceView =
158  outDistances.narrowOutermost(i, numQueriesForIteration);
159  auto outIndexView =
160  outIndices.narrowOutermost(i, numQueriesForIteration);
161  auto queryNormNiew =
162  queryNorms.narrowOutermost(i, numQueriesForIteration);
163 
164  // L2 distance is ||c||^2 - 2qc + ||q||^2
165 
166  // -2qc
167  // (query id x dim) x (centroid id, dim)' = (query id, centroid id)
168  runMatrixMult(distanceBufView, false,
169  queryView, false,
170  centroidsTransposed ? *centroidsTransposed : centroids,
171  centroidsTransposed ? false : true,
172  -2.0f, 0.0f, useHgemm,
173  resources->getBlasHandleCurrentDevice(),
174  streams[curStream]);
175 
176  // For L2 distance, we use this fused kernel that performs both
177  // adding ||c||^2 to -2qc and k-selection, so we only need two
178  // passes (one write by the gemm, one read here) over the huge
179  // region of output memory
180  runL2SelectMin(distanceBufView,
181  *centroidNorms,
182  outDistanceView,
183  outIndexView,
184  k,
185  streams[curStream]);
186 
187  if (!ignoreOutDistances) {
188  // expand (query id) to (query id, k) by duplicating along rows
189  // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
190  runSumAlongRows(queryNormNiew, outDistanceView, streams[curStream]);
191  }
192 
193  curStream = (curStream + 1) % 2;
194  }
195 
196  // Have the desired ordering stream wait on the multi-stream
197  streamWait({defaultStream}, streams);
198 }
199 
200 template <typename T>
201 void runIPDistance(GpuResources* resources,
202  Tensor<T, 2, true>& centroids,
203  Tensor<T, 2, true>* centroidsTransposed,
204  Tensor<T, 2, true>& queries,
205  int k,
206  Tensor<T, 2, true>& outDistances,
207  Tensor<int, 2, true>& outIndices,
208  bool useHgemm,
209  int tileSizeOverride = -1) {
210  FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
211  FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
212  FAISS_ASSERT(outDistances.getSize(1) == k);
213  FAISS_ASSERT(outIndices.getSize(1) == k);
214 
215  auto& mem = resources->getMemoryManagerCurrentDevice();
216  auto defaultStream = resources->getDefaultStreamCurrentDevice();
217 
218  // If we're quering against a 0 sized set, just return empty results
219  if (centroids.numElements() == 0) {
220  thrust::fill(thrust::cuda::par.on(defaultStream),
221  outDistances.data(), outDistances.end(),
222  Limits<T>::getMax());
223 
224  thrust::fill(thrust::cuda::par.on(defaultStream),
225  outIndices.data(), outIndices.end(),
226  -1);
227 
228  return;
229  }
230 
231  //
232  // Handle the problem in row tiles, to avoid excessive temporary
233  // memory requests
234  //
235 
236  FAISS_ASSERT(k <= centroids.getSize(0));
237  FAISS_ASSERT(k <= 1024); // select limitation
238 
239  int tileSize =
240  chooseTileSize<T>(
241  tileSizeOverride,
242  centroids.getSize(0),
243  resources->getMemoryManagerCurrentDevice().getSizeAvailable());
244 
245  int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
246 
247  // Temporary output memory space we'll use
248  DeviceTensor<T, 2, true> distanceBuf1(
249  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
250  DeviceTensor<T, 2, true> distanceBuf2(
251  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
252  DeviceTensor<T, 2, true>* distanceBufs[2] =
253  {&distanceBuf1, &distanceBuf2};
254 
255  auto streams = resources->getAlternateStreamsCurrentDevice();
256  streamWait(streams, {defaultStream});
257 
258  int curStream = 0;
259 
260  for (int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
261  int numQueriesForIteration = std::min(maxQueriesPerIteration,
262  queries.getSize(0) - i);
263 
264  auto distanceBufView =
265  distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
266  auto queryView =
267  queries.narrowOutermost(i, numQueriesForIteration);
268  auto outDistanceView =
269  outDistances.narrowOutermost(i, numQueriesForIteration);
270  auto outIndexView =
271  outIndices.narrowOutermost(i, numQueriesForIteration);
272 
273  // (query id x dim) x (centroid id, dim)' = (query id, centroid id)
274  runMatrixMult(distanceBufView, false,
275  queryView, false,
276  centroidsTransposed ? *centroidsTransposed : centroids,
277  centroidsTransposed ? false : true,
278  1.0f, 0.0f, useHgemm,
279  resources->getBlasHandleCurrentDevice(),
280  streams[curStream]);
281 
282  // top-k of dot products
283  // (query id, top k centroids)
284  runBlockSelect(distanceBufView,
285  outDistanceView,
286  outIndexView,
287  true, k, streams[curStream]);
288 
289  curStream = (curStream + 1) % 2;
290  }
291 
292  streamWait({defaultStream}, streams);
293 }
294 
295 //
296 // Instantiations of the distance templates
297 //
298 
299 void
300 runIPDistance(GpuResources* resources,
301  Tensor<float, 2, true>& vectors,
302  Tensor<float, 2, true>* vectorsTransposed,
303  Tensor<float, 2, true>& queries,
304  int k,
305  Tensor<float, 2, true>& outDistances,
306  Tensor<int, 2, true>& outIndices,
307  int tileSizeOverride) {
308  runIPDistance<float>(resources,
309  vectors,
310  vectorsTransposed,
311  queries,
312  k,
313  outDistances,
314  outIndices,
315  false,
316  tileSizeOverride);
317 }
318 
319 #ifdef FAISS_USE_FLOAT16
320 void
321 runIPDistance(GpuResources* resources,
322  Tensor<half, 2, true>& vectors,
323  Tensor<half, 2, true>* vectorsTransposed,
324  Tensor<half, 2, true>& queries,
325  int k,
326  Tensor<half, 2, true>& outDistances,
327  Tensor<int, 2, true>& outIndices,
328  bool useHgemm,
329  int tileSizeOverride) {
330  runIPDistance<half>(resources,
331  vectors,
332  vectorsTransposed,
333  queries,
334  k,
335  outDistances,
336  outIndices,
337  useHgemm,
338  tileSizeOverride);
339 }
340 #endif
341 
342 void
343 runL2Distance(GpuResources* resources,
344  Tensor<float, 2, true>& vectors,
345  Tensor<float, 2, true>* vectorsTransposed,
346  Tensor<float, 1, true>* vectorNorms,
347  Tensor<float, 2, true>& queries,
348  int k,
349  Tensor<float, 2, true>& outDistances,
350  Tensor<int, 2, true>& outIndices,
351  bool ignoreOutDistances,
352  int tileSizeOverride) {
353  runL2Distance<float>(resources,
354  vectors,
355  vectorsTransposed,
356  vectorNorms,
357  queries,
358  k,
359  outDistances,
360  outIndices,
361  false,
362  ignoreOutDistances,
363  tileSizeOverride);
364 }
365 
366 #ifdef FAISS_USE_FLOAT16
367 void
368 runL2Distance(GpuResources* resources,
369  Tensor<half, 2, true>& vectors,
370  Tensor<half, 2, true>* vectorsTransposed,
371  Tensor<half, 1, true>* vectorNorms,
372  Tensor<half, 2, true>& queries,
373  int k,
374  Tensor<half, 2, true>& outDistances,
375  Tensor<int, 2, true>& outIndices,
376  bool useHgemm,
377  bool ignoreOutDistances,
378  int tileSizeOverride) {
379  runL2Distance<half>(resources,
380  vectors,
381  vectorsTransposed,
382  vectorNorms,
383  queries,
384  k,
385  outDistances,
386  outIndices,
387  useHgemm,
388  ignoreOutDistances,
389  tileSizeOverride);
390 }
391 #endif
392 
393 } } // namespace