317 lines
11 KiB
Plaintext
317 lines
11 KiB
Plaintext
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#include "../utils/DeviceTensor.cuh"
|
|
#include "../utils/DeviceDefs.cuh"
|
|
#include "../utils/DeviceUtils.h"
|
|
#include "../utils/Select.cuh"
|
|
|
|
namespace faiss { namespace gpu {
|
|
|
|
// Number of warps that the kernel is instantiated with
|
|
constexpr int kWarps = 8;
|
|
constexpr int kLanes = kWarpSize;
|
|
|
|
constexpr int kMaxDistance = std::numeric_limits<int>::max();
|
|
|
|
// Performs a binary matrix multiplication, returning the lowest k results in
|
|
// `vecs` for each `query` in terms of Hamming distance (a fused kernel)
|
|
// Each warp calculates distance for a single query
|
|
template <int NumWarpQ,
|
|
int NumThreadQ,
|
|
typename BinaryType>
|
|
__launch_bounds__(kWarps * kLanes)
|
|
__global__ void binaryDistanceAnySize(const Tensor<BinaryType, 2, true> vecs,
|
|
const Tensor<BinaryType, 2, true> query,
|
|
Tensor<int, 2, true> outK,
|
|
Tensor<int, 2, true> outV,
|
|
int k) {
|
|
// A matrix tile (query, k)
|
|
__shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
|
|
|
|
// B matrix tile (vec, k)
|
|
__shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
|
|
|
|
WarpSelect<int, int, false, Comparator<int>,
|
|
NumWarpQ, NumThreadQ, kWarps * kLanes>
|
|
heap(kMaxDistance, -1, k);
|
|
|
|
int warpId = threadIdx.y;
|
|
int laneId = threadIdx.x;
|
|
|
|
// Each warp handles a single query
|
|
int warpQuery = blockIdx.x * kWarps + warpId;
|
|
bool queryInBounds = warpQuery < query.getSize(0);
|
|
|
|
// Each warp loops through the entire chunk of vectors
|
|
for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
|
|
int threadDistance = 0;
|
|
|
|
// Reduction dimension
|
|
for (int blockK = 0; blockK < vecs.getSize(1); blockK += kLanes) {
|
|
int laneK = blockK + laneId;
|
|
bool kInBounds = laneK < vecs.getSize(1);
|
|
|
|
queryTile[warpId][laneId] = queryInBounds && kInBounds ?
|
|
query[warpQuery][laneK] : 0;
|
|
|
|
// kWarps warps are responsible for loading 32 vecs
|
|
#pragma unroll
|
|
for (int i = 0; i < kLanes / kWarps; ++i) {
|
|
int warpVec = i * kWarps + warpId;
|
|
int vec = blockVec + warpVec;
|
|
bool vecInBounds = vec < vecs.getSize(0);
|
|
|
|
vecTile[warpVec][laneId] = vecInBounds && kInBounds ?
|
|
vecs[vec][laneK] : 0;
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// Compare distances
|
|
#pragma unroll
|
|
for (int i = 0; i < kLanes; ++i) {
|
|
threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
|
|
}
|
|
|
|
__syncthreads();
|
|
}
|
|
|
|
// Lanes within a warp are different vec results against the same query
|
|
// Only submit distances which represent real (query, vec) pairs
|
|
bool valInBounds = queryInBounds && (blockVec + laneId < vecs.getSize(0));
|
|
threadDistance = valInBounds ? threadDistance : kMaxDistance;
|
|
int id = valInBounds ? blockVec + laneId : -1;
|
|
|
|
heap.add(threadDistance, id);
|
|
}
|
|
|
|
heap.reduce();
|
|
|
|
if (warpQuery < query.getSize(0)) {
|
|
heap.writeOut(outK[warpQuery].data(),
|
|
outV[warpQuery].data(),
|
|
k);
|
|
}
|
|
}
|
|
|
|
// Version of the kernel that avoids a loop over the reduction dimension, and
|
|
// thus avoids reloading the query vectors
|
|
template <int NumWarpQ,
|
|
int NumThreadQ,
|
|
typename BinaryType,
|
|
int ReductionLimit = kLanes>
|
|
__global__ void
|
|
__launch_bounds__(kWarps * kLanes)
|
|
binaryDistanceLimitSize(const Tensor<BinaryType, 2, true> vecs,
|
|
const Tensor<BinaryType, 2, true> query,
|
|
Tensor<int, 2, true> outK,
|
|
Tensor<int, 2, true> outV,
|
|
int k) {
|
|
// A matrix tile (query, k)
|
|
__shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
|
|
|
|
// B matrix tile (vec, k)
|
|
__shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
|
|
|
|
WarpSelect<int, int, false, Comparator<int>,
|
|
NumWarpQ, NumThreadQ, kWarps * kLanes>
|
|
heap(kMaxDistance, -1, k);
|
|
|
|
int warpId = threadIdx.y;
|
|
int laneId = threadIdx.x;
|
|
|
|
// Each warp handles a single query
|
|
int laneK = laneId;
|
|
int warpQuery = blockIdx.x * kWarps + warpId;
|
|
bool kInBounds = laneK < vecs.getSize(1);
|
|
bool queryInBounds = warpQuery < query.getSize(0);
|
|
|
|
|
|
queryTile[warpId][laneId] = queryInBounds && kInBounds ?
|
|
query[warpQuery][laneK] : 0;
|
|
|
|
// Each warp loops through the entire chunk of vectors
|
|
for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
|
|
int threadDistance = 0;
|
|
|
|
// kWarps warps are responsible for loading 32 vecs
|
|
#pragma unroll
|
|
for (int i = 0; i < kLanes / kWarps; ++i) {
|
|
int warpVec = i * kWarps + warpId;
|
|
int vec = blockVec + warpVec;
|
|
bool vecInBounds = vec < vecs.getSize(0);
|
|
|
|
vecTile[warpVec][laneId] = vecInBounds && kInBounds ?
|
|
vecs[vec][laneK] : 0;
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// Compare distances
|
|
#pragma unroll
|
|
for (int i = 0; i < ReductionLimit; ++i) {
|
|
threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// Lanes within a warp are different vec results against the same query
|
|
// Only submit distances which represent real (query, vec) pairs
|
|
bool valInBounds = queryInBounds && (blockVec + laneId < vecs.getSize(0));
|
|
threadDistance = valInBounds ? threadDistance : kMaxDistance;
|
|
int id = valInBounds ? blockVec + laneId : -1;
|
|
|
|
heap.add(threadDistance, id);
|
|
}
|
|
|
|
heap.reduce();
|
|
|
|
if (warpQuery < query.getSize(0)) {
|
|
heap.writeOut(outK[warpQuery].data(),
|
|
outV[warpQuery].data(),
|
|
k);
|
|
}
|
|
}
|
|
|
|
template <typename BinaryType>
|
|
void runBinaryDistanceAnySize(Tensor<BinaryType, 2, true>& vecs,
|
|
Tensor<BinaryType, 2, true>& query,
|
|
Tensor<int, 2, true>& outK,
|
|
Tensor<int, 2, true>& outV,
|
|
int k, cudaStream_t stream) {
|
|
dim3 grid(utils::divUp(query.getSize(0), kWarps));
|
|
dim3 block(kLanes, kWarps);
|
|
|
|
if (k == 1) {
|
|
binaryDistanceAnySize<1, 1, BinaryType>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 32) {
|
|
binaryDistanceAnySize<32, 2, BinaryType>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 64) {
|
|
binaryDistanceAnySize<64, 3, BinaryType>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 128) {
|
|
binaryDistanceAnySize<128, 3, BinaryType>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 256) {
|
|
binaryDistanceAnySize<256, 4, BinaryType>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 512) {
|
|
binaryDistanceAnySize<512, 8, BinaryType>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 1024) {
|
|
binaryDistanceAnySize<1024, 8, BinaryType>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
}
|
|
#if GPU_MAX_SELECTION_K >= 2048
|
|
else if (k <= 2048) {
|
|
binaryDistanceAnySize<2048, 8, BinaryType>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template <typename BinaryType, int ReductionLimit>
|
|
void runBinaryDistanceLimitSize(Tensor<BinaryType, 2, true>& vecs,
|
|
Tensor<BinaryType, 2, true>& query,
|
|
Tensor<int, 2, true>& outK,
|
|
Tensor<int, 2, true>& outV,
|
|
int k, cudaStream_t stream) {
|
|
dim3 grid(utils::divUp(query.getSize(0), kWarps));
|
|
dim3 block(kLanes, kWarps);
|
|
|
|
if (k == 1) {
|
|
binaryDistanceLimitSize<1, 1, BinaryType, ReductionLimit>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 32) {
|
|
binaryDistanceLimitSize<32, 2, BinaryType, ReductionLimit>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 64) {
|
|
binaryDistanceLimitSize<64, 3, BinaryType, ReductionLimit>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 128) {
|
|
binaryDistanceLimitSize<128, 3, BinaryType, ReductionLimit>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 256) {
|
|
binaryDistanceLimitSize<256, 4, BinaryType, ReductionLimit>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 512) {
|
|
binaryDistanceLimitSize<512, 8, BinaryType, ReductionLimit>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
} else if (k <= 1024) {
|
|
binaryDistanceLimitSize<1024, 8, BinaryType, ReductionLimit>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
}
|
|
#if GPU_MAX_SELECTION_K >= 2048
|
|
else if (k <= 2048) {
|
|
binaryDistanceLimitSize<2048, 8, BinaryType, ReductionLimit>
|
|
<<<grid, block, 0, stream>>>(
|
|
vecs, query, outK, outV, k);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void runBinaryDistance(Tensor<unsigned char, 2, true>& vecs,
|
|
Tensor<unsigned char, 2, true>& query,
|
|
Tensor<int, 2, true>& outK,
|
|
Tensor<int, 2, true>& outV,
|
|
int k, cudaStream_t stream) {
|
|
FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
|
|
FAISS_ASSERT(vecs.getSize(1) == query.getSize(1));
|
|
|
|
FAISS_ASSERT(outK.getSize(1) == k);
|
|
FAISS_ASSERT(outV.getSize(1) == k);
|
|
|
|
// For the optimized uint32 kernel, we handle 32 * 8 = 256 max dims
|
|
constexpr int kReductionLimit32 = 8;
|
|
|
|
// For the optimized uint8 kernel, we handle 8 * 16 = 128 max dims
|
|
constexpr int kReductionLimit8 = 16;
|
|
|
|
// All other cases (large or small) go through the general kernel
|
|
|
|
if (vecs.getSize(1) % sizeof(unsigned int) == 0 &&
|
|
(vecs.getSize(1) / sizeof(unsigned int)) <= kReductionLimit32) {
|
|
auto vecs32 = vecs.castResize<unsigned int>();
|
|
auto query32 = query.castResize<unsigned int>();
|
|
|
|
// Optimize for vectors with dimensions a multiple of 32 that are less than
|
|
// 32 * kReductionLimit (256) dimensions in size
|
|
runBinaryDistanceLimitSize<unsigned int, kReductionLimit32>(
|
|
vecs32, query32, outK, outV, k, stream);
|
|
|
|
} else if (vecs.getSize(1) <= kReductionLimit8) {
|
|
// Optimize for vectors with dimensions a multiple of 32 that are less than
|
|
// 32 * kReductionLimit (256) dimensions in size
|
|
runBinaryDistanceLimitSize<unsigned char, kReductionLimit8>(
|
|
vecs, query, outK, outV, k, stream);
|
|
} else {
|
|
// Arbitrary size kernel
|
|
runBinaryDistanceAnySize<unsigned char>(
|
|
vecs, query, outK, outV, k, stream);
|
|
}
|
|
}
|
|
|
|
} } // namespace
|