Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
BinaryDistance.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include "../utils/DeviceTensor.cuh"
10 #include "../utils/DeviceDefs.cuh"
11 #include "../utils/DeviceUtils.h"
12 #include "../utils/Select.cuh"
13 
14 namespace faiss { namespace gpu {
15 
16 // Number of warps that the kernel is instantiated with
17 constexpr int kWarps = 8;
18 constexpr int kLanes = kWarpSize;
19 
20 constexpr int kMaxDistance = std::numeric_limits<int>::max();
21 
22 // Performs a binary matrix multiplication, returning the lowest k results in
23 // `vecs` for each `query` in terms of Hamming distance (a fused kernel)
24 // Each warp calculates distance for a single query
25 template <int NumWarpQ,
26  int NumThreadQ,
27  typename BinaryType>
28 __launch_bounds__(kWarps * kLanes)
29 __global__ void binaryDistanceAnySize(const Tensor<BinaryType, 2, true> vecs,
30  const Tensor<BinaryType, 2, true> query,
31  Tensor<int, 2, true> outK,
32  Tensor<int, 2, true> outV,
33  int k) {
34  // A matrix tile (query, k)
35  __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
36 
37  // B matrix tile (vec, k)
38  __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
39 
40  WarpSelect<int, int, false, Comparator<int>,
41  NumWarpQ, NumThreadQ, kWarps * kLanes>
42  heap(kMaxDistance, -1, k);
43 
44  int warpId = threadIdx.y;
45  int laneId = threadIdx.x;
46 
47  // Each warp handles a single query
48  int warpQuery = blockIdx.x * kWarps + warpId;
49  bool queryInBounds = warpQuery < query.getSize(0);
50 
51  // Each warp loops through the entire chunk of vectors
52  for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
53  int threadDistance = 0;
54 
55  // Reduction dimension
56  for (int blockK = 0; blockK < vecs.getSize(1); blockK += kLanes) {
57  int laneK = blockK + laneId;
58  bool kInBounds = laneK < vecs.getSize(1);
59 
60  queryTile[warpId][laneId] = queryInBounds && kInBounds ?
61  query[warpQuery][laneK] : 0;
62 
63  // kWarps warps are responsible for loading 32 vecs
64 #pragma unroll
65  for (int i = 0; i < kLanes / kWarps; ++i) {
66  int warpVec = i * kWarps + warpId;
67  int vec = blockVec + warpVec;
68  bool vecInBounds = vec < vecs.getSize(0);
69 
70  vecTile[warpVec][laneId] = vecInBounds && kInBounds ?
71  vecs[vec][laneK] : 0;
72  }
73 
74  __syncthreads();
75 
76  // Compare distances
77 #pragma unroll
78  for (int i = 0; i < kLanes; ++i) {
79  threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
80  }
81 
82  __syncthreads();
83  }
84 
85  // Lanes within a warp are different vec results against the same query
86  // Only submit distances which represent real (query, vec) pairs
87  bool valInBounds = queryInBounds && (blockVec + laneId < vecs.getSize(0));
88  threadDistance = valInBounds ? threadDistance : kMaxDistance;
89  int id = valInBounds ? blockVec + laneId : -1;
90 
91  heap.add(threadDistance, id);
92  }
93 
94  heap.reduce();
95 
96  if (warpQuery < query.getSize(0)) {
97  heap.writeOut(outK[warpQuery].data(),
98  outV[warpQuery].data(),
99  k);
100  }
101 }
102 
103 // Version of the kernel that avoids a loop over the reduction dimension, and
104 // thus avoids reloading the query vectors
105 template <int NumWarpQ,
106  int NumThreadQ,
107  typename BinaryType,
108  int ReductionLimit = kLanes>
109 __global__ void
110 __launch_bounds__(kWarps * kLanes)
111 binaryDistanceLimitSize(const Tensor<BinaryType, 2, true> vecs,
112  const Tensor<BinaryType, 2, true> query,
113  Tensor<int, 2, true> outK,
114  Tensor<int, 2, true> outV,
115  int k) {
116  // A matrix tile (query, k)
117  __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
118 
119  // B matrix tile (vec, k)
120  __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
121 
122  WarpSelect<int, int, false, Comparator<int>,
123  NumWarpQ, NumThreadQ, kWarps * kLanes>
124  heap(kMaxDistance, -1, k);
125 
126  int warpId = threadIdx.y;
127  int laneId = threadIdx.x;
128 
129  // Each warp handles a single query
130  int laneK = laneId;
131  int warpQuery = blockIdx.x * kWarps + warpId;
132  bool kInBounds = laneK < vecs.getSize(1);
133  bool queryInBounds = warpQuery < query.getSize(0);
134 
135 
136  queryTile[warpId][laneId] = queryInBounds && kInBounds ?
137  query[warpQuery][laneK] : 0;
138 
139  // Each warp loops through the entire chunk of vectors
140  for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
141  int threadDistance = 0;
142 
143  // kWarps warps are responsible for loading 32 vecs
144 #pragma unroll
145  for (int i = 0; i < kLanes / kWarps; ++i) {
146  int warpVec = i * kWarps + warpId;
147  int vec = blockVec + warpVec;
148  bool vecInBounds = vec < vecs.getSize(0);
149 
150  vecTile[warpVec][laneId] = vecInBounds && kInBounds ?
151  vecs[vec][laneK] : 0;
152  }
153 
154  __syncthreads();
155 
156  // Compare distances
157 #pragma unroll
158  for (int i = 0; i < ReductionLimit; ++i) {
159  threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
160  }
161 
162  __syncthreads();
163 
164  // Lanes within a warp are different vec results against the same query
165  // Only submit distances which represent real (query, vec) pairs
166  bool valInBounds = queryInBounds && (blockVec + laneId < vecs.getSize(0));
167  threadDistance = valInBounds ? threadDistance : kMaxDistance;
168  int id = valInBounds ? blockVec + laneId : -1;
169 
170  heap.add(threadDistance, id);
171  }
172 
173  heap.reduce();
174 
175  if (warpQuery < query.getSize(0)) {
176  heap.writeOut(outK[warpQuery].data(),
177  outV[warpQuery].data(),
178  k);
179  }
180 }
181 
182 template <typename BinaryType>
183 void runBinaryDistanceAnySize(Tensor<BinaryType, 2, true>& vecs,
184  Tensor<BinaryType, 2, true>& query,
185  Tensor<int, 2, true>& outK,
186  Tensor<int, 2, true>& outV,
187  int k, cudaStream_t stream) {
188  dim3 grid(utils::divUp(query.getSize(0), kWarps));
189  dim3 block(kLanes, kWarps);
190 
191  if (k == 1) {
192  binaryDistanceAnySize<1, 1, BinaryType>
193  <<<grid, block, 0, stream>>>(
194  vecs, query, outK, outV, k);
195  } else if (k <= 32) {
196  binaryDistanceAnySize<32, 2, BinaryType>
197  <<<grid, block, 0, stream>>>(
198  vecs, query, outK, outV, k);
199  } else if (k <= 64) {
200  binaryDistanceAnySize<64, 3, BinaryType>
201  <<<grid, block, 0, stream>>>(
202  vecs, query, outK, outV, k);
203  } else if (k <= 128) {
204  binaryDistanceAnySize<128, 3, BinaryType>
205  <<<grid, block, 0, stream>>>(
206  vecs, query, outK, outV, k);
207  } else if (k <= 256) {
208  binaryDistanceAnySize<256, 4, BinaryType>
209  <<<grid, block, 0, stream>>>(
210  vecs, query, outK, outV, k);
211  } else if (k <= 512) {
212  binaryDistanceAnySize<512, 8, BinaryType>
213  <<<grid, block, 0, stream>>>(
214  vecs, query, outK, outV, k);
215  } else if (k <= 1024) {
216  binaryDistanceAnySize<1024, 8, BinaryType>
217  <<<grid, block, 0, stream>>>(
218  vecs, query, outK, outV, k);
219  }
220 }
221 
222 template <typename BinaryType, int ReductionLimit>
223 void runBinaryDistanceLimitSize(Tensor<BinaryType, 2, true>& vecs,
224  Tensor<BinaryType, 2, true>& query,
225  Tensor<int, 2, true>& outK,
226  Tensor<int, 2, true>& outV,
227  int k, cudaStream_t stream) {
228  dim3 grid(utils::divUp(query.getSize(0), kWarps));
229  dim3 block(kLanes, kWarps);
230 
231  if (k == 1) {
232  binaryDistanceLimitSize<1, 1, BinaryType, ReductionLimit>
233  <<<grid, block, 0, stream>>>(
234  vecs, query, outK, outV, k);
235  } else if (k <= 32) {
236  binaryDistanceLimitSize<32, 2, BinaryType, ReductionLimit>
237  <<<grid, block, 0, stream>>>(
238  vecs, query, outK, outV, k);
239  } else if (k <= 64) {
240  binaryDistanceLimitSize<64, 3, BinaryType, ReductionLimit>
241  <<<grid, block, 0, stream>>>(
242  vecs, query, outK, outV, k);
243  } else if (k <= 128) {
244  binaryDistanceLimitSize<128, 3, BinaryType, ReductionLimit>
245  <<<grid, block, 0, stream>>>(
246  vecs, query, outK, outV, k);
247  } else if (k <= 256) {
248  binaryDistanceLimitSize<256, 4, BinaryType, ReductionLimit>
249  <<<grid, block, 0, stream>>>(
250  vecs, query, outK, outV, k);
251  } else if (k <= 512) {
252  binaryDistanceLimitSize<512, 8, BinaryType, ReductionLimit>
253  <<<grid, block, 0, stream>>>(
254  vecs, query, outK, outV, k);
255  } else if (k <= 1024) {
256  binaryDistanceLimitSize<1024, 8, BinaryType, ReductionLimit>
257  <<<grid, block, 0, stream>>>(
258  vecs, query, outK, outV, k);
259  }
260 }
261 
262 void runBinaryDistance(Tensor<unsigned char, 2, true>& vecs,
263  Tensor<unsigned char, 2, true>& query,
264  Tensor<int, 2, true>& outK,
265  Tensor<int, 2, true>& outV,
266  int k, cudaStream_t stream) {
267  FAISS_ASSERT(k <= 1024);
268  FAISS_ASSERT(vecs.getSize(1) == query.getSize(1));
269 
270  FAISS_ASSERT(outK.getSize(1) == k);
271  FAISS_ASSERT(outV.getSize(1) == k);
272 
273  // For the optimized uint32 kernel, we handle 32 * 8 = 256 max dims
274  constexpr int kReductionLimit32 = 8;
275 
276  // For the optimized uint8 kernel, we handle 8 * 16 = 128 max dims
277  constexpr int kReductionLimit8 = 16;
278 
279  // All other cases (large or small) go through the general kernel
280 
281  if (vecs.getSize(1) % sizeof(unsigned int) == 0 &&
282  (vecs.getSize(1) / sizeof(unsigned int)) <= kReductionLimit32) {
283  auto vecs32 = vecs.castResize<unsigned int>();
284  auto query32 = query.castResize<unsigned int>();
285 
286  // Optimize for vectors with dimensions a multiple of 32 that are less than
287  // 32 * kReductionLimit (256) dimensions in size
288  runBinaryDistanceLimitSize<unsigned int, kReductionLimit32>(
289  vecs32, query32, outK, outV, k, stream);
290 
291  } else if (vecs.getSize(1) <= kReductionLimit8) {
292  // Optimize for vectors with dimensions a multiple of 32 that are less than
293  // 32 * kReductionLimit (256) dimensions in size
294  runBinaryDistanceLimitSize<unsigned char, kReductionLimit8>(
295  vecs, query, outK, outV, k, stream);
296  } else {
297  // Arbitrary size kernel
298  runBinaryDistanceAnySize<unsigned char>(
299  vecs, query, outK, outV, k, stream);
300  }
301 }
302 
303 } } // namespace