12 #include "L2Select.cuh"
13 #include "../../FaissAssert.h"
15 #include "../utils/DeviceUtils.h"
16 #include "../utils/MathOperators.cuh"
17 #include "../utils/Pair.cuh"
18 #include "../utils/Reductions.cuh"
19 #include "../utils/Select.cuh"
20 #include "../utils/Tensor.cuh"
21 #include "../utils/StaticUtils.h"
23 namespace faiss {
namespace gpu {
26 template <
typename T,
int kRowsPerBlock,
int kBlockSize>
27 __global__
void l2SelectMin1(Tensor<T, 2, true> productDistances,
28 Tensor<T, 1, true> centroidDistances,
29 Tensor<T, 2, true> outDistances,
30 Tensor<int, 2, true> outIndices) {
32 Pair<T, int> threadMin[kRowsPerBlock];
33 __shared__ Pair<T, int> blockMin[kRowsPerBlock * (kBlockSize / kWarpSize)];
35 T distance[kRowsPerBlock];
38 for (
int i = 0; i < kRowsPerBlock; ++i) {
39 threadMin[i].k = Limits<T>::getMax();
44 int rowStart = blockIdx.x * kRowsPerBlock;
47 bool endRow = (blockIdx.x == gridDim.x - 1);
50 if (productDistances.getSize(0) % kRowsPerBlock == 0) {
56 for (
int row = rowStart; row < productDistances.getSize(0); ++row) {
57 for (
int col = threadIdx.x; col < productDistances.getSize(1);
59 distance[0] = Math<T>::add(centroidDistances[col],
60 productDistances[row][col]);
62 if (Math<T>::lt(distance[0], threadMin[0].k)) {
63 threadMin[0].k = distance[0];
70 blockReduceAll<Pair<T, int>, Min<Pair<T, int> >,
false,
false>(
71 threadMin[0], Min<Pair<T, int> >(), blockMin);
73 if (threadIdx.x == 0) {
74 outDistances[row][0] = threadMin[0].k;
75 outIndices[row][0] = threadMin[0].v;
81 threadMin[0].k = Limits<T>::getMax();
85 for (
int col = threadIdx.x; col < productDistances.getSize(1);
87 T centroidDistance = centroidDistances[col];
90 for (
int row = 0; row < kRowsPerBlock; ++row) {
91 distance[row] = productDistances[rowStart + row][col];
95 for (
int row = 0; row < kRowsPerBlock; ++row) {
96 distance[row] = Math<T>::add(distance[row], centroidDistance);
100 for (
int row = 0; row < kRowsPerBlock; ++row) {
101 if (Math<T>::lt(distance[row], threadMin[row].k)) {
102 threadMin[row].k = distance[row];
103 threadMin[row].v = col;
109 blockReduceAll<kRowsPerBlock, Pair<T, int>, Min<Pair<T, int> >,
false,
false>(
110 threadMin, Min<Pair<T, int> >(), blockMin);
112 if (threadIdx.x == 0) {
114 for (
int row = 0; row < kRowsPerBlock; ++row) {
115 outDistances[rowStart + row][0] = threadMin[row].k;
116 outIndices[rowStart + row][0] = threadMin[row].v;
123 template <
typename T,
int NumWarpQ,
int NumThreadQ,
int ThreadsPerBlock>
124 __global__
void l2SelectMinK(Tensor<T, 2, true> productDistances,
125 Tensor<T, 1, true> centroidDistances,
126 Tensor<T, 2, true> outDistances,
127 Tensor<int, 2, true> outIndices,
130 constexpr
int kNumWarps = ThreadsPerBlock / kWarpSize;
132 __shared__ T smemK[kNumWarps * NumWarpQ];
133 __shared__
int smemV[kNumWarps * NumWarpQ];
135 BlockSelect<T, int, false, Comparator<T>,
136 NumWarpQ, NumThreadQ, ThreadsPerBlock>
137 heap(initK, -1, smemK, smemV, k);
139 int row = blockIdx.x;
142 int limit = utils::roundDown(productDistances.getSize(1), kWarpSize);
145 for (; i < limit; i += blockDim.x) {
146 T v = Math<T>::add(centroidDistances[i],
147 productDistances[row][i]);
151 if (i < productDistances.getSize(1)) {
152 T v = Math<T>::add(centroidDistances[i],
153 productDistances[row][i]);
154 heap.addThreadQ(v, i);
158 for (
int i = threadIdx.x; i < k; i += blockDim.x) {
159 outDistances[row][i] = smemK[i];
160 outIndices[row][i] = smemV[i];
165 template <
typename T>
166 void runL2SelectMin(Tensor<T, 2, true>& productDistances,
167 Tensor<T, 1, true>& centroidDistances,
168 Tensor<T, 2, true>& outDistances,
169 Tensor<int, 2, true>& outIndices,
171 cudaStream_t stream) {
172 FAISS_ASSERT(productDistances.getSize(0) == outDistances.getSize(0));
173 FAISS_ASSERT(productDistances.getSize(0) == outIndices.getSize(0));
174 FAISS_ASSERT(centroidDistances.getSize(0) == productDistances.getSize(1));
175 FAISS_ASSERT(outDistances.getSize(1) == k);
176 FAISS_ASSERT(outIndices.getSize(1) == k);
177 FAISS_ASSERT(k <= 1024);
180 constexpr
int kThreadsPerBlock = 256;
181 constexpr
int kRowsPerBlock = 8;
183 auto block = dim3(kThreadsPerBlock);
184 auto grid = dim3(utils::divUp(outDistances.getSize(0), kRowsPerBlock));
186 l2SelectMin1<T, kRowsPerBlock, kThreadsPerBlock>
187 <<<grid, block, 0, stream>>>(productDistances, centroidDistances,
188 outDistances, outIndices);
190 constexpr
int kThreadsPerBlock = 128;
192 auto block = dim3(kThreadsPerBlock);
193 auto grid = dim3(outDistances.getSize(0));
195 #define RUN_L2_SELECT(NUM_WARP_Q, NUM_THREAD_Q) \
197 l2SelectMinK<T, NUM_WARP_Q, NUM_THREAD_Q, kThreadsPerBlock> \
198 <<<grid, block, 0, stream>>>(productDistances, centroidDistances, \
199 outDistances, outIndices, \
200 k, Limits<T>::getMax()); \
204 RUN_L2_SELECT(32, 2);
205 }
else if (k <= 64) {
206 RUN_L2_SELECT(64, 3);
207 }
else if (k <= 128) {
208 RUN_L2_SELECT(128, 3);
209 }
else if (k <= 256) {
210 RUN_L2_SELECT(256, 4);
211 }
else if (k <= 512) {
212 RUN_L2_SELECT(512, 8);
213 }
else if (k <= 1024) {
214 RUN_L2_SELECT(1024, 8);
220 CUDA_VERIFY(cudaGetLastError());
223 void runL2SelectMin(Tensor<float, 2, true>& productDistances,
224 Tensor<float, 1, true>& centroidDistances,
225 Tensor<float, 2, true>& outDistances,
226 Tensor<int, 2, true>& outIndices,
228 cudaStream_t stream) {
229 runL2SelectMin<float>(productDistances,
237 #ifdef FAISS_USE_FLOAT16
238 void runL2SelectMin(Tensor<half, 2, true>& productDistances,
239 Tensor<half, 1, true>& centroidDistances,
240 Tensor<half, 2, true>& outDistances,
241 Tensor<int, 2, true>& outIndices,
243 cudaStream_t stream) {
244 runL2SelectMin<half>(productDistances,