Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Distance.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "Distance.cuh"
13 #include "BroadcastSum.cuh"
14 #include "L2Norm.cuh"
15 #include "L2Select.cuh"
16 #include "../../FaissAssert.h"
17 #include "../GpuResources.h"
18 #include "../utils/DeviceUtils.h"
19 #include "../utils/Limits.cuh"
20 #include "../utils/MatrixMult.cuh"
21 #include "../utils/BlockSelectKernel.cuh"
22 
23 #include <memory>
24 #include <thrust/fill.h>
25 #include <thrust/for_each.h>
26 #include <thrust/device_ptr.h>
27 #include <thrust/execution_policy.h>
28 
29 namespace faiss { namespace gpu {
30 
31 constexpr int kDefaultTileSize = 256;
32 
33 template <typename T>
34 void runL2Distance(GpuResources* resources,
35  Tensor<T, 2, true>& centroids,
36  Tensor<T, 1, true>* centroidNorms,
37  Tensor<T, 2, true>& queries,
38  int k,
39  Tensor<T, 2, true>& outDistances,
40  Tensor<int, 2, true>& outIndices,
41  bool ignoreOutDistances = false,
42  int tileSize = -1) {
43  FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
44  FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
45  FAISS_ASSERT(outDistances.getSize(1) == k);
46  FAISS_ASSERT(outIndices.getSize(1) == k);
47 
48  auto& mem = resources->getMemoryManagerCurrentDevice();
49  auto defaultStream = resources->getDefaultStreamCurrentDevice();
50 
51  // If we're quering against a 0 sized set, just return empty results
52  if (centroids.numElements() == 0) {
53  thrust::fill(thrust::cuda::par.on(defaultStream),
54  outDistances.data(), outDistances.end(),
55  Limits<T>::getMax());
56 
57  thrust::fill(thrust::cuda::par.on(defaultStream),
58  outIndices.data(), outIndices.end(),
59  -1);
60 
61  return;
62  }
63 
64  // If ||c||^2 is not pre-computed, calculate it
65  DeviceTensor<T, 1, true> cNorms;
66  if (!centroidNorms) {
67  cNorms = std::move(DeviceTensor<T, 1, true>(
68  mem,
69  {centroids.getSize(0)}, defaultStream));
70  runL2Norm(centroids, cNorms, true, defaultStream);
71  centroidNorms = &cNorms;
72  }
73 
74  //
75  // Prepare norm vector ||q||^2; ||c||^2 is already pre-computed
76  //
77  int qNormSize[1] = {queries.getSize(0)};
78  DeviceTensor<T, 1, true> queryNorms(mem, qNormSize, defaultStream);
79 
80  // ||q||^2
81  runL2Norm(queries, queryNorms, true, defaultStream);
82 
83  //
84  // Handle the problem in row tiles, to avoid excessive temporary
85  // memory requests
86  //
87 
88  FAISS_ASSERT(k <= centroids.getSize(0));
89  FAISS_ASSERT(k <= 1024); // select limitation
90 
91  // To allocate all of (#queries, #centroids) is potentially too much
92  // memory. Limit our total size requested
93  size_t distanceRowSize = centroids.getSize(0) * sizeof(T);
94 
95  // FIXME: parameterize based on # of centroids and DeviceMemory
96  int defaultTileSize = sizeof(T) < 4 ? kDefaultTileSize * 2 : kDefaultTileSize;
97  tileSize = tileSize <= 0 ? defaultTileSize : tileSize;
98 
99  int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
100 
101  // Temporary output memory space we'll use
102  DeviceTensor<T, 2, true> distanceBuf1(
103  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
104  DeviceTensor<T, 2, true> distanceBuf2(
105  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
106  DeviceTensor<T, 2, true>* distanceBufs[2] =
107  {&distanceBuf1, &distanceBuf2};
108 
109  auto streams = resources->getAlternateStreamsCurrentDevice();
110  streamWait(streams, {defaultStream});
111 
112  int curStream = 0;
113 
114  for (int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
115  int numQueriesForIteration = std::min(maxQueriesPerIteration,
116  queries.getSize(0) - i);
117 
118  auto distanceBufView =
119  distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
120  auto queryView =
121  queries.narrowOutermost(i, numQueriesForIteration);
122  auto outDistanceView =
123  outDistances.narrowOutermost(i, numQueriesForIteration);
124  auto outIndexView =
125  outIndices.narrowOutermost(i, numQueriesForIteration);
126  auto queryNormNiew =
127  queryNorms.narrowOutermost(i, numQueriesForIteration);
128 
129  // L2 distance is ||c||^2 - 2qc + ||q||^2
130 
131  // -2qc
132  // (query id x dim) x (centroid id, dim)' = (query id, centroid id)
133  runMatrixMult(distanceBufView, false,
134  queryView, false,
135  centroids, true,
136  -2.0f, 0.0f,
137  resources->getBlasHandleCurrentDevice(),
138  streams[curStream]);
139 
140  // For L2 distance, we use this fused kernel that performs both
141  // adding ||c||^2 to -2qc and k-selection, so we only need two
142  // passes (one write by the gemm, one read here) over the huge
143  // region of output memory
144  runL2SelectMin(distanceBufView,
145  *centroidNorms,
146  outDistanceView,
147  outIndexView,
148  k,
149  streams[curStream]);
150 
151  if (!ignoreOutDistances) {
152  // expand (query id) to (query id, k) by duplicating along rows
153  // top-k ||c||^2 - 2qc + ||q||^2 in the form (query id, k)
154  runSumAlongRows(queryNormNiew, outDistanceView, streams[curStream]);
155  }
156 
157  curStream = (curStream + 1) % 2;
158  }
159 
160  // Have the desired ordering stream wait on the multi-stream
161  streamWait({defaultStream}, streams);
162 }
163 
164 template <typename T>
165 void runIPDistance(GpuResources* resources,
166  Tensor<T, 2, true>& centroids,
167  Tensor<T, 2, true>& queries,
168  int k,
169  Tensor<T, 2, true>& outDistances,
170  Tensor<int, 2, true>& outIndices,
171  int tileSize = -1) {
172  FAISS_ASSERT(outDistances.getSize(0) == queries.getSize(0));
173  FAISS_ASSERT(outIndices.getSize(0) == queries.getSize(0));
174  FAISS_ASSERT(outDistances.getSize(1) == k);
175  FAISS_ASSERT(outIndices.getSize(1) == k);
176 
177  auto& mem = resources->getMemoryManagerCurrentDevice();
178  auto defaultStream = resources->getDefaultStreamCurrentDevice();
179 
180  // If we're quering against a 0 sized set, just return empty results
181  if (centroids.numElements() == 0) {
182  thrust::fill(thrust::cuda::par.on(defaultStream),
183  outDistances.data(), outDistances.end(),
184  Limits<T>::getMax());
185 
186  thrust::fill(thrust::cuda::par.on(defaultStream),
187  outIndices.data(), outIndices.end(),
188  -1);
189 
190  return;
191  }
192 
193  //
194  // Handle the problem in row tiles, to avoid excessive temporary
195  // memory requests
196  //
197 
198  FAISS_ASSERT(k <= centroids.getSize(0));
199  FAISS_ASSERT(k <= 1024); // select limitation
200 
201  // To allocate all of (#queries, #centroids) is potentially too much
202  // memory. Limit our total size requested
203  size_t distanceRowSize = centroids.getSize(0) * sizeof(T);
204 
205  // FIXME: parameterize based on # of centroids and DeviceMemory
206  int defaultTileSize = sizeof(T) < 4 ? kDefaultTileSize * 2 : kDefaultTileSize;
207  tileSize = tileSize <= 0 ? defaultTileSize : tileSize;
208 
209  int maxQueriesPerIteration = std::min(tileSize, queries.getSize(0));
210 
211  // Temporary output memory space we'll use
212  DeviceTensor<T, 2, true> distanceBuf1(
213  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
214  DeviceTensor<T, 2, true> distanceBuf2(
215  mem, {maxQueriesPerIteration, centroids.getSize(0)}, defaultStream);
216  DeviceTensor<T, 2, true>* distanceBufs[2] =
217  {&distanceBuf1, &distanceBuf2};
218 
219  auto streams = resources->getAlternateStreamsCurrentDevice();
220  streamWait(streams, {defaultStream});
221 
222  int curStream = 0;
223 
224  for (int i = 0; i < queries.getSize(0); i += maxQueriesPerIteration) {
225  int numQueriesForIteration = std::min(maxQueriesPerIteration,
226  queries.getSize(0) - i);
227 
228  auto distanceBufView =
229  distanceBufs[curStream]->narrowOutermost(0, numQueriesForIteration);
230  auto queryView =
231  queries.narrowOutermost(i, numQueriesForIteration);
232  auto outDistanceView =
233  outDistances.narrowOutermost(i, numQueriesForIteration);
234  auto outIndexView =
235  outIndices.narrowOutermost(i, numQueriesForIteration);
236 
237  // (query id x dim) x (centroid id, dim)' = (query id, centroid id)
238  runMatrixMult(distanceBufView, false,
239  queryView, false, centroids, true,
240  1.0f, 0.0f,
241  resources->getBlasHandleCurrentDevice(),
242  streams[curStream]);
243 
244  // top-k of dot products
245  // (query id, top k centroids)
246  runBlockSelect(distanceBufView,
247  outDistanceView,
248  outIndexView,
249  true, k, streams[curStream]);
250 
251  curStream = (curStream + 1) % 2;
252  }
253 
254  streamWait({defaultStream}, streams);
255 }
256 
257 //
258 // Instantiations of the distance templates
259 //
260 
261 void
262 runIPDistance(GpuResources* resources,
263  Tensor<float, 2, true>& vectors,
264  Tensor<float, 2, true>& queries,
265  int k,
266  Tensor<float, 2, true>& outDistances,
267  Tensor<int, 2, true>& outIndices,
268  int tileSize) {
269  runIPDistance<float>(resources,
270  vectors,
271  queries,
272  k,
273  outDistances,
274  outIndices,
275  tileSize);
276 }
277 
278 #ifdef FAISS_USE_FLOAT16
279 void
280 runIPDistance(GpuResources* resources,
281  Tensor<half, 2, true>& vectors,
282  Tensor<half, 2, true>& queries,
283  int k,
284  Tensor<half, 2, true>& outDistances,
285  Tensor<int, 2, true>& outIndices,
286  int tileSize) {
287  runIPDistance<half>(resources,
288  vectors,
289  queries,
290  k,
291  outDistances,
292  outIndices,
293  tileSize);
294 }
295 #endif
296 
297 void
298 runL2Distance(GpuResources* resources,
299  Tensor<float, 2, true>& vectors,
300  Tensor<float, 1, true>* vectorNorms,
301  Tensor<float, 2, true>& queries,
302  int k,
303  Tensor<float, 2, true>& outDistances,
304  Tensor<int, 2, true>& outIndices,
305  bool ignoreOutDistances,
306  int tileSize) {
307  runL2Distance<float>(resources,
308  vectors,
309  vectorNorms,
310  queries,
311  k,
312  outDistances,
313  outIndices,
314  ignoreOutDistances,
315  tileSize);
316 }
317 
318 #ifdef FAISS_USE_FLOAT16
319 void
320 runL2Distance(GpuResources* resources,
321  Tensor<half, 2, true>& vectors,
322  Tensor<half, 1, true>* vectorNorms,
323  Tensor<half, 2, true>& queries,
324  int k,
325  Tensor<half, 2, true>& outDistances,
326  Tensor<int, 2, true>& outIndices,
327  bool ignoreOutDistances,
328  int tileSize) {
329  runL2Distance<half>(resources,
330  vectors,
331  vectorNorms,
332  queries,
333  k,
334  outDistances,
335  outIndices,
336  ignoreOutDistances,
337  tileSize);
338 }
339 #endif
340 
341 } } // namespace