GPU supports arbitrary dimensions per PQ sub-quantizer
Summary: This diff removes a long-standing limitation with GpuIndexIVFPQ, in that only a limited number of dimensions per sub-quantizer were supported when not using precomputed codes. This is part of the general cleanup and extension/optimization that I am performing of the GPU PQ code. Now, we keep the same old specialized distance computations, but if we attempt to use a number of dimensions per sub-Q that are not specialized, we fall back to a general implementation based on batch matrix multiplication for computing PQ distances per code. The batch MM PQ distance computation is enabled automatically if you use an odd number of dimensions per sub-quantizer (say, 7, 11, 53, ...). It can also be manually enabled via the `useMMCodeDistance` option in `GpuIndexIVFPQConfig` for testing purposes, though the result should be within some epsilon of the other implementation. This diff also removes the iterated GEMM wrapper. I don't honestly know why I was using this instead of `cublasGemmStridedBatchedEx`, maybe I couldn't find that or this was originally implemented in a much older version of CUDA. The iterated GEMM call was used in a few other places (e.g., precomputed code computation). Now, this (and the PQ distance computation) use batch MM which is a single CUDA call. This diff also adds stream synchronization to the temporary memory manager, as the fallback PQ distance computation needs to use temporary memory, and there were too many buffers for these to pre-allocate. It also fixes the bug in https://github.com/facebookresearch/faiss/issues/1421. Reviewed By: mdouze Differential Revision: D24130629 fbshipit-source-id: 1c8bc53c86d0523832ad89c8bd4fa4b5fc187caepull/1445/head
parent
5ad630635c
commit
9b007c7418
|
@ -40,7 +40,6 @@ target_sources(faiss PRIVATE
|
|||
utils/BlockSelectHalf.cu
|
||||
utils/DeviceUtils.cu
|
||||
utils/Float16.cu
|
||||
utils/MatrixMult.cu
|
||||
utils/StackDeviceMemory.cpp
|
||||
utils/Timer.cpp
|
||||
utils/WarpSelectFloat.cu
|
||||
|
|
|
@ -100,10 +100,11 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
|
|||
quantizer->getGpuData(),
|
||||
subQuantizers_,
|
||||
bitsPerCode_,
|
||||
ivfpqConfig_.useFloat16LookupTables,
|
||||
ivfpqConfig_.useMMCodeDistance,
|
||||
ivfpqConfig_.alternativeLayout,
|
||||
(float*) index->pq.centroids.data(),
|
||||
ivfpqConfig_.indicesOptions,
|
||||
ivfpqConfig_.useFloat16LookupTables,
|
||||
memorySpace_));
|
||||
// Doesn't make sense to reserve memory here
|
||||
index_->setPrecomputedCodes(ivfpqConfig_.usePrecomputedTables);
|
||||
|
@ -258,10 +259,11 @@ GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) {
|
|||
quantizer->getGpuData(),
|
||||
subQuantizers_,
|
||||
bitsPerCode_,
|
||||
ivfpqConfig_.useFloat16LookupTables,
|
||||
ivfpqConfig_.useMMCodeDistance,
|
||||
ivfpqConfig_.alternativeLayout,
|
||||
pq.centroids.data(),
|
||||
ivfpqConfig_.indicesOptions,
|
||||
ivfpqConfig_.useFloat16LookupTables,
|
||||
memorySpace_));
|
||||
if (reserveMemoryVecs_) {
|
||||
index_->reserveMemory(reserveMemoryVecs_);
|
||||
|
@ -406,20 +408,6 @@ GpuIndexIVFPQ::verifySettings_() const {
|
|||
"reduce parameters",
|
||||
device_, smemPerBlock, bitsPerCode_, subQuantizers_,
|
||||
requiredSmemSize);
|
||||
|
||||
// If precomputed codes are disabled, we have an extra limitation in
|
||||
// terms of the number of dimensions per subquantizer
|
||||
FAISS_THROW_IF_NOT_FMT(ivfpqConfig_.usePrecomputedTables ||
|
||||
IVFPQ::isSupportedNoPrecomputedSubDimSize(
|
||||
this->d / subQuantizers_),
|
||||
"Number of dimensions per sub-quantizer (%d) "
|
||||
"is not currently supported without precomputed codes. "
|
||||
"Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims "
|
||||
"per sub-quantizer are currently supported with no "
|
||||
"precomputed codes. "
|
||||
"Precomputed codes supports any number of dimensions, but "
|
||||
"will involve memory overheads.",
|
||||
this->d / subQuantizers_);
|
||||
}
|
||||
|
||||
} } // namespace
|
||||
|
|
|
@ -23,7 +23,8 @@ struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig {
|
|||
inline GpuIndexIVFPQConfig()
|
||||
: useFloat16LookupTables(false),
|
||||
usePrecomputedTables(false),
|
||||
alternativeLayout(false) {
|
||||
alternativeLayout(false),
|
||||
useMMCodeDistance(false) {
|
||||
}
|
||||
|
||||
/// Whether or not float16 residual distance tables are used in the
|
||||
|
@ -38,6 +39,16 @@ struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig {
|
|||
/// Use the alternative memory layout for the IVF lists
|
||||
/// WARNING: this is a feature under development, do not use!
|
||||
bool alternativeLayout;
|
||||
|
||||
/// Use GEMM-backed computation of PQ code distances for the no precomputed
|
||||
/// table version of IVFPQ.
|
||||
/// This is for debugging purposes, it should not substantially affect the
|
||||
/// results one way for another.
|
||||
///
|
||||
/// Note that MM code distance is enabled automatically if one uses a number
|
||||
/// of dimensions per sub-quantizer that is not natively specialized (an odd
|
||||
/// number like 7 or so).
|
||||
bool useMMCodeDistance;
|
||||
};
|
||||
|
||||
/// IVFPQ index for the GPU
|
||||
|
|
|
@ -214,7 +214,9 @@ __global__ void sumAlongRows(Tensor<T, 1, true> input,
|
|||
for (int i = threadIdx.x; i < output.getSize(1); i += blockDim.x) {
|
||||
T out = output[row][i];
|
||||
out = Math<T>::add(out, val);
|
||||
out = Math<T>::lt(out, Math<T>::zero()) ? Math<T>::zero() : out;
|
||||
if (ZeroClamp) {
|
||||
out = Math<T>::lt(out, Math<T>::zero()) ? Math<T>::zero() : out;
|
||||
}
|
||||
|
||||
output[row][i] = out;
|
||||
}
|
||||
|
|
|
@ -38,10 +38,11 @@ IVFPQ::IVFPQ(GpuResources* resources,
|
|||
FlatIndex* quantizer,
|
||||
int numSubQuantizers,
|
||||
int bitsPerSubQuantizer,
|
||||
bool layoutBy32,
|
||||
bool useFloat16LookupTables,
|
||||
bool useMMCodeDistance,
|
||||
bool alternativeLayout,
|
||||
float* pqCentroidData,
|
||||
IndicesOptions indicesOptions,
|
||||
bool useFloat16LookupTables,
|
||||
MemorySpace space) :
|
||||
IVFBase(resources,
|
||||
metric,
|
||||
|
@ -53,8 +54,9 @@ IVFPQ::IVFPQ(GpuResources* resources,
|
|||
bitsPerSubQuantizer_(bitsPerSubQuantizer),
|
||||
numSubQuantizerCodes_(utils::pow2(bitsPerSubQuantizer_)),
|
||||
dimPerSubQuantizer_(dim_ / numSubQuantizers),
|
||||
layoutBy32_(layoutBy32),
|
||||
useFloat16LookupTables_(useFloat16LookupTables),
|
||||
useMMCodeDistance_(useMMCodeDistance),
|
||||
alternativeLayout_(alternativeLayout),
|
||||
precomputedCodes_(false) {
|
||||
FAISS_ASSERT(pqCentroidData);
|
||||
|
||||
|
@ -93,11 +95,6 @@ IVFPQ::isSupportedPQCodeLength(int size) {
|
|||
}
|
||||
}
|
||||
|
||||
bool
|
||||
IVFPQ::isSupportedNoPrecomputedSubDimSize(int dims) {
|
||||
return faiss::gpu::isSupportedNoPrecomputedSubDimSize(dims);
|
||||
}
|
||||
|
||||
void
|
||||
IVFPQ::setPrecomputedCodes(bool enable) {
|
||||
if (enable && metric_ == MetricType::METRIC_INNER_PRODUCT) {
|
||||
|
@ -212,7 +209,7 @@ IVFPQ::appendVectors_(Tensor<float, 2, true>& vecs,
|
|||
listOffset,
|
||||
encodings,
|
||||
indices,
|
||||
layoutBy32_,
|
||||
alternativeLayout_,
|
||||
deviceListDataPointers_,
|
||||
deviceListIndexPointers_,
|
||||
indicesOptions_,
|
||||
|
@ -221,7 +218,7 @@ IVFPQ::appendVectors_(Tensor<float, 2, true>& vecs,
|
|||
|
||||
size_t
|
||||
IVFPQ::getGpuVectorsEncodingSize_(int numVecs) const {
|
||||
if (layoutBy32_) {
|
||||
if (alternativeLayout_) {
|
||||
return utils::roundUp(
|
||||
(size_t) numVecs, (size_t) 32) * numSubQuantizers_;
|
||||
} else {
|
||||
|
@ -238,7 +235,7 @@ IVFPQ::getCpuVectorsEncodingSize_(int numVecs) const {
|
|||
std::vector<unsigned char>
|
||||
IVFPQ::translateCodesToGpu_(std::vector<unsigned char> codes,
|
||||
size_t numVecs) const {
|
||||
if (!layoutBy32_) {
|
||||
if (!alternativeLayout_) {
|
||||
return codes;
|
||||
}
|
||||
|
||||
|
@ -264,7 +261,7 @@ IVFPQ::translateCodesToGpu_(std::vector<unsigned char> codes,
|
|||
std::vector<unsigned char>
|
||||
IVFPQ::translateCodesFromGpu_(std::vector<unsigned char> codes,
|
||||
size_t numVecs) const {
|
||||
if (!layoutBy32_) {
|
||||
if (!alternativeLayout_) {
|
||||
return codes;
|
||||
}
|
||||
|
||||
|
@ -388,20 +385,20 @@ IVFPQ::precomputeCodesT_() {
|
|||
|
||||
convertTensor(stream, centroidsTransposed, centroidsTransposedF32);
|
||||
|
||||
runIteratedMatrixMult(coarsePQProduct, false,
|
||||
centroidsTransposedF32, false,
|
||||
pqCentroidsMiddleCode_, true,
|
||||
2.0f, 0.0f,
|
||||
resources_->getBlasHandleCurrentDevice(),
|
||||
stream);
|
||||
runBatchMatrixMult(coarsePQProduct, false,
|
||||
centroidsTransposedF32, false,
|
||||
pqCentroidsMiddleCode_, true,
|
||||
2.0f, 0.0f,
|
||||
resources_->getBlasHandleCurrentDevice(),
|
||||
stream);
|
||||
} else {
|
||||
// All in f32
|
||||
runIteratedMatrixMult(coarsePQProduct, false,
|
||||
centroidsTransposed, false,
|
||||
pqCentroidsMiddleCode_, true,
|
||||
2.0f, 0.0f,
|
||||
resources_->getBlasHandleCurrentDevice(),
|
||||
stream);
|
||||
runBatchMatrixMult(coarsePQProduct, false,
|
||||
centroidsTransposed, false,
|
||||
pqCentroidsMiddleCode_, true,
|
||||
2.0f, 0.0f,
|
||||
resources_->getBlasHandleCurrentDevice(),
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -587,12 +584,12 @@ IVFPQ::runPQPrecomputedCodes_(
|
|||
resources_, makeTempAlloc(AllocType::Other, stream),
|
||||
{numSubQuantizers_, queries.getSize(0), numSubQuantizerCodes_});
|
||||
|
||||
runIteratedMatrixMult(term3, false,
|
||||
queriesTransposed, false,
|
||||
pqCentroidsMiddleCode_, true,
|
||||
-2.0f, 0.0f,
|
||||
resources_->getBlasHandleCurrentDevice(),
|
||||
stream);
|
||||
runBatchMatrixMult(term3, false,
|
||||
queriesTransposed, false,
|
||||
pqCentroidsMiddleCode_, true,
|
||||
-2.0f, 0.0f,
|
||||
resources_->getBlasHandleCurrentDevice(),
|
||||
stream);
|
||||
|
||||
runTransposeAny(term3, 0, 1, term3Transposed, stream);
|
||||
}
|
||||
|
@ -619,7 +616,7 @@ IVFPQ::runPQPrecomputedCodes_(
|
|||
term3, // term 3
|
||||
coarseIndices,
|
||||
useFloat16LookupTables_,
|
||||
layoutBy32_,
|
||||
alternativeLayout_,
|
||||
numSubQuantizers_,
|
||||
numSubQuantizerCodes_,
|
||||
deviceListDataPointers_,
|
||||
|
@ -647,9 +644,11 @@ IVFPQ::runPQNoPrecomputedCodesT_(
|
|||
runPQScanMultiPassNoPrecomputed(queries,
|
||||
coarseCentroids,
|
||||
pqCentroidsInnermostCode_,
|
||||
coarseDistances,
|
||||
coarseIndices,
|
||||
useFloat16LookupTables_,
|
||||
layoutBy32_,
|
||||
useMMCodeDistance_,
|
||||
alternativeLayout_,
|
||||
numSubQuantizers_,
|
||||
numSubQuantizerCodes_,
|
||||
deviceListDataPointers_,
|
||||
|
|
|
@ -24,20 +24,16 @@ class IVFPQ : public IVFBase {
|
|||
FlatIndex* quantizer,
|
||||
int numSubQuantizers,
|
||||
int bitsPerSubQuantizer,
|
||||
bool layoutBy32,
|
||||
bool useFloat16LookupTables,
|
||||
bool useMMCodeDistance,
|
||||
bool alternativeLayout,
|
||||
float* pqCentroidData,
|
||||
IndicesOptions indicesOptions,
|
||||
bool useFloat16LookupTables,
|
||||
MemorySpace space);
|
||||
|
||||
/// Returns true if we support PQ in this size
|
||||
static bool isSupportedPQCodeLength(int size);
|
||||
|
||||
/// For no precomputed codes, is this a supported sub-dimension
|
||||
/// size?
|
||||
/// FIXME: get MM implementation working again
|
||||
static bool isSupportedNoPrecomputedSubDimSize(int dims);
|
||||
|
||||
~IVFPQ() override;
|
||||
|
||||
/// Enable or disable pre-computed codes
|
||||
|
@ -134,6 +130,15 @@ class IVFPQ : public IVFBase {
|
|||
/// Number of dimensions per each sub-quantizer
|
||||
const int dimPerSubQuantizer_;
|
||||
|
||||
/// Do we maintain precomputed terms and lookup tables in float16
|
||||
/// form?
|
||||
const bool useFloat16LookupTables_;
|
||||
|
||||
/// For usage without precomputed codes, do we force usage of the
|
||||
/// general-purpose MM code distance computation? This is for testing
|
||||
/// purposes.
|
||||
const bool useMMCodeDistance_;
|
||||
|
||||
/// The default memory layout is [vector][PQ component]:
|
||||
/// (v0 d0) (v0 d1) ... (v0 dD-1) (v1 d0) (v1 d1) ...
|
||||
///
|
||||
|
@ -142,11 +147,7 @@ class IVFPQ : public IVFBase {
|
|||
/// (v0 d0) (v1 d0) ... (v31 d0) (v0 d1) (v1 d1) ... (v31 dD-1) (v32 d0) (v33
|
||||
/// d0) ...
|
||||
/// so the list length is always a multiple of numSubQuantizers * 32
|
||||
const bool layoutBy32_;
|
||||
|
||||
/// Do we maintain precomputed terms and lookup tables in float16
|
||||
/// form?
|
||||
const bool useFloat16LookupTables_;
|
||||
const bool alternativeLayout_;
|
||||
|
||||
/// On the GPU, we prefer different PQ centroid data layouts for
|
||||
/// different purposes.
|
||||
|
|
|
@ -32,7 +32,7 @@ pqCodeDistances(Tensor<float, 2, true> queries,
|
|||
int queriesPerBlock,
|
||||
Tensor<CentroidT, 2, true> coarseCentroids,
|
||||
Tensor<float, 3, true> pqCentroids,
|
||||
Tensor<int, 2, true> topQueryToCentroid,
|
||||
Tensor<int, 2, true> coarseIndices,
|
||||
// (query id)(coarse)(subquantizer)(code) -> dist
|
||||
Tensor<OutCodeT, 4, true> outCodeDistances) {
|
||||
const auto numSubQuantizers = pqCentroids.getSize(0);
|
||||
|
@ -53,9 +53,12 @@ pqCodeDistances(Tensor<float, 2, true> queries,
|
|||
|
||||
// Each thread will load the pq centroid data for the code that it
|
||||
// is processing
|
||||
// The loading threads are out of bounds for the number of codes available
|
||||
if (!isLoadingThread) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DimsPerSubQuantizer; ++i) {
|
||||
subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
|
||||
for (int i = 0; i < DimsPerSubQuantizer; ++i) {
|
||||
subQuantizerData[i] = pqCentroids[subQuantizer][i][code].ldg();
|
||||
}
|
||||
}
|
||||
|
||||
// Where we store our query vector
|
||||
|
@ -91,9 +94,8 @@ pqCodeDistances(Tensor<float, 2, true> queries,
|
|||
}
|
||||
|
||||
// Load list of coarse centroids found
|
||||
for (int i = threadIdx.x;
|
||||
i < topQueryToCentroid.getSize(1); i += blockDim.x) {
|
||||
coarseIds[i] = topQueryToCentroid[queryId][i];
|
||||
for (int i = threadIdx.x; i < coarseIndices.getSize(1); i += blockDim.x) {
|
||||
coarseIds[i] = coarseIndices[queryId][i];
|
||||
}
|
||||
|
||||
// We need coarseIds below
|
||||
|
@ -122,7 +124,7 @@ pqCodeDistances(Tensor<float, 2, true> queries,
|
|||
}
|
||||
|
||||
// The block walks the list for a single query
|
||||
for (int coarse = 0; coarse < topQueryToCentroid.getSize(1); ++coarse) {
|
||||
for (int coarse = 0; coarse < coarseIndices.getSize(1); ++coarse) {
|
||||
// Wait for smemResidual1 to be loaded
|
||||
__syncthreads();
|
||||
|
||||
|
@ -133,7 +135,7 @@ pqCodeDistances(Tensor<float, 2, true> queries,
|
|||
i += blockDim.x - codesPerSubQuantizer) {
|
||||
// FIXME: try always making this centroid id 0 so we can
|
||||
// terminate
|
||||
if (coarse != (topQueryToCentroid.getSize(1) - 1)) {
|
||||
if (coarse != (coarseIndices.getSize(1) - 1)) {
|
||||
auto coarseId = coarseIds[coarse + 1];
|
||||
// In case NaNs were in the original query data
|
||||
coarseId = coarseId == -1 ? 0 : coarseId;
|
||||
|
@ -268,123 +270,131 @@ pqCodeDistances(Tensor<float, 2, true> queries,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename CentroidT>
|
||||
template <typename CentroidT, bool L2Residual>
|
||||
__global__ void
|
||||
residualVector(Tensor<float, 2, true> queries,
|
||||
Tensor<CentroidT, 2, true> coarseCentroids,
|
||||
Tensor<int, 2, true> topQueryToCentroid,
|
||||
int numSubDim,
|
||||
// output is transposed:
|
||||
// (sub q)(query id)(centroid id)(sub dim)
|
||||
Tensor<float, 4, true> residual) {
|
||||
// block x is query id
|
||||
// block y is centroid id
|
||||
// thread x is dim
|
||||
pqResidualVector(Tensor<float, 2, true> queries,
|
||||
Tensor<CentroidT, 2, true> coarseCentroids,
|
||||
Tensor<int, 2, true> coarseIndices,
|
||||
int numSubDim,
|
||||
// output is transposed:
|
||||
// (sub q)(query id)(centroid id)(sub dim)
|
||||
Tensor<float, 4, true> residual) {
|
||||
auto queryId = blockIdx.x;
|
||||
auto centroidId = blockIdx.y;
|
||||
|
||||
int realCentroidId = topQueryToCentroid[queryId][centroidId];
|
||||
int realCentroidId = coarseIndices[queryId][centroidId];
|
||||
|
||||
for (int dim = threadIdx.x; dim < queries.getSize(1); dim += blockDim.x) {
|
||||
float q = queries[queryId][dim];
|
||||
float c = ConvertTo<float>::to(coarseCentroids[realCentroidId][dim]);
|
||||
|
||||
residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] = q - c;
|
||||
float r;
|
||||
|
||||
if (L2Residual) {
|
||||
r = q - c;
|
||||
} else {
|
||||
// IP does not use a residual. Instead, the estimated distance is
|
||||
// (query . (centroid + sub quantizer centroid).
|
||||
//
|
||||
// This kernel is used to calculate (query . sub quantizer centroid),
|
||||
// providing the query value replicated across all of the sub
|
||||
// quantizers. The batch matrix multiplication in runPQCodeDistancesMM
|
||||
// will perform this inner product. The adjustment (query . centroid) is
|
||||
// added later.
|
||||
r = q;
|
||||
}
|
||||
|
||||
residual[dim / numSubDim][queryId][centroidId][dim % numSubDim] = r;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CentroidT>
|
||||
void
|
||||
runResidualVector(Tensor<float, 3, true>& pqCentroids,
|
||||
Tensor<float, 2, true>& queries,
|
||||
Tensor<CentroidT, 2, true>& coarseCentroids,
|
||||
Tensor<int, 2, true>& topQueryToCentroid,
|
||||
Tensor<float, 4, true>& residual,
|
||||
cudaStream_t stream) {
|
||||
auto grid =
|
||||
dim3(topQueryToCentroid.getSize(0), topQueryToCentroid.getSize(1));
|
||||
runPQResidualVector(Tensor<float, 3, true>& pqCentroids,
|
||||
Tensor<float, 2, true>& queries,
|
||||
Tensor<CentroidT, 2, true>& coarseCentroids,
|
||||
Tensor<int, 2, true>& coarseIndices,
|
||||
Tensor<float, 4, true>& residual,
|
||||
bool l2Residual,
|
||||
cudaStream_t stream) {
|
||||
auto grid = dim3(coarseIndices.getSize(0), coarseIndices.getSize(1));
|
||||
auto block = dim3(std::min(queries.getSize(1), getMaxThreadsCurrentDevice()));
|
||||
|
||||
residualVector<<<grid, block, 0, stream>>>(
|
||||
queries, coarseCentroids, topQueryToCentroid, pqCentroids.getSize(1),
|
||||
residual);
|
||||
if (l2Residual) {
|
||||
pqResidualVector<CentroidT, true><<<grid, block, 0, stream>>>(
|
||||
queries, coarseCentroids,
|
||||
coarseIndices, pqCentroids.getSize(1), residual);
|
||||
} else {
|
||||
pqResidualVector<CentroidT, false><<<grid, block, 0, stream>>>(
|
||||
queries, coarseCentroids,
|
||||
coarseIndices, pqCentroids.getSize(1), residual);
|
||||
}
|
||||
|
||||
CUDA_TEST_ERROR();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void
|
||||
pqDistanceIPCorrection(Tensor<T, 3, true> codeDistances,
|
||||
Tensor<T, 2, true> coarseDistances,
|
||||
int numSubQ) {
|
||||
int centroid = blockIdx.x;
|
||||
int query = blockIdx.y;
|
||||
|
||||
// We need to add the (query . centroid) correction factor (coarseDistances)
|
||||
// to all output code distances (q)(c)(sub q)(code).
|
||||
// However, there are numSubQ code distance sums per each approximated
|
||||
// distance, so we need to divide this correction by numSubQ since we will be
|
||||
// adding it numSubQ times.
|
||||
auto d = coarseDistances[query][centroid] / (float) numSubQ;
|
||||
|
||||
auto base = codeDistances[query][centroid].data();
|
||||
|
||||
for (int i = threadIdx.x; i < codeDistances.getSize(2); i += blockDim.x) {
|
||||
base[i] += d;
|
||||
}
|
||||
}
|
||||
|
||||
// We have previously calculated (query . sub quantizer centroid), but we
|
||||
// need to calculate (query . (centroid + sub quantizer centroid). This will add
|
||||
// in the correction factor to each calculated code distance.
|
||||
template <typename T>
|
||||
void
|
||||
runPQDistanceIPCorrection(Tensor<T, 4, true>& codeDistances,
|
||||
Tensor<T, 2, true>& coarseDistances,
|
||||
cudaStream_t stream) {
|
||||
auto grid = dim3(coarseDistances.getSize(1), coarseDistances.getSize(0));
|
||||
auto block = 512;
|
||||
|
||||
auto codeView = codeDistances.template downcastInner<3>();
|
||||
|
||||
pqDistanceIPCorrection<<<grid, block, 0, stream>>>(codeView,
|
||||
coarseDistances,
|
||||
codeDistances.getSize(2));
|
||||
}
|
||||
|
||||
// This is a general purpose implementation that leverages GEMM to calculate
|
||||
// code distances for PQ codes for any number of dimensions per sub-quantizer /
|
||||
// number of sub-quantizers
|
||||
template <typename CentroidT>
|
||||
void
|
||||
runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
|
||||
runPQCodeDistancesMM(GpuResources* res,
|
||||
Tensor<float, 3, true>& pqCentroids,
|
||||
Tensor<float, 2, true>& queries,
|
||||
Tensor<CentroidT, 2, true>& coarseCentroids,
|
||||
Tensor<int, 2, true>& topQueryToCentroid,
|
||||
Tensor<float, 2, true>& coarseDistances,
|
||||
Tensor<int, 2, true>& coarseIndices,
|
||||
// Output is (query)(centroid)(sub q)(code)
|
||||
NoTypeTensor<4, true>& outCodeDistances,
|
||||
bool l2Distance,
|
||||
bool useFloat16Lookup,
|
||||
GpuResources* res,
|
||||
cublasHandle_t handle,
|
||||
cudaStream_t stream) {
|
||||
// Calculate (q - c) residual vector
|
||||
// (sub q)(query id)(centroid id)(sub dim)
|
||||
DeviceTensor<float, 4, true> residual(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0),
|
||||
topQueryToCentroid.getSize(0),
|
||||
topQueryToCentroid.getSize(1),
|
||||
pqCentroids.getSize(1)});
|
||||
|
||||
runResidualVector(pqCentroids, queries,
|
||||
coarseCentroids, topQueryToCentroid,
|
||||
residual, stream);
|
||||
|
||||
// Calculate ||q - c||^2
|
||||
DeviceTensor<float, 1, true> residualNorms(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0) *
|
||||
topQueryToCentroid.getSize(0) *
|
||||
topQueryToCentroid.getSize(1)});
|
||||
|
||||
auto residualView2 = residual.view<2>(
|
||||
{pqCentroids.getSize(0) *
|
||||
topQueryToCentroid.getSize(0) *
|
||||
topQueryToCentroid.getSize(1),
|
||||
pqCentroids.getSize(1)});
|
||||
|
||||
runL2Norm(residualView2, true, residualNorms, true, stream);
|
||||
|
||||
// Perform a batch MM:
|
||||
// (sub q) x {(q * c)(sub dim) x (sub dim)(code)} =>
|
||||
// (sub q) x {(q * c)(code)}
|
||||
auto residualView3 = residual.view<3>(
|
||||
{pqCentroids.getSize(0),
|
||||
topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
|
||||
pqCentroids.getSize(1)});
|
||||
|
||||
DeviceTensor<float, 3, true> residualDistance(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0),
|
||||
topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
|
||||
pqCentroids.getSize(2)});
|
||||
|
||||
runIteratedMatrixMult(residualDistance, false,
|
||||
residualView3, false,
|
||||
pqCentroids, false,
|
||||
-2.0f, 0.0f,
|
||||
handle,
|
||||
stream);
|
||||
|
||||
// Sum ||q - c||^2 along rows
|
||||
auto residualDistanceView2 = residualDistance.view<2>(
|
||||
{pqCentroids.getSize(0) *
|
||||
topQueryToCentroid.getSize(0) *
|
||||
topQueryToCentroid.getSize(1),
|
||||
pqCentroids.getSize(2)});
|
||||
|
||||
runSumAlongRows(residualNorms, residualDistanceView2, false, stream);
|
||||
|
||||
// We construct our float32 output in outCodeDistancesF
|
||||
Tensor<float, 4, true> outCodeDistancesF;
|
||||
DeviceTensor<float, 4, true> outCodeDistancesFloatMem;
|
||||
|
||||
if (useFloat16Lookup) {
|
||||
// outCodeDistances has half memory, we need to allocate a buffer for float
|
||||
outCodeDistancesFloatMem = DeviceTensor<float, 4, true>(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{outCodeDistances.getSize(0),
|
||||
|
@ -394,47 +404,123 @@ runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
|
|||
|
||||
outCodeDistancesF = outCodeDistancesFloatMem;
|
||||
} else {
|
||||
// We can use the memory that we were given
|
||||
outCodeDistancesF = outCodeDistances.toTensor<float>();
|
||||
}
|
||||
|
||||
// Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which
|
||||
// is where we build our output distances)
|
||||
// Calculate (q - c) residual vector if L2. Otherwise, for IP, this kernel
|
||||
// will just replicate q
|
||||
//
|
||||
// (sub q)(query id)(centroid id)(sub dim)
|
||||
DeviceTensor<float, 4, true> residual(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0),
|
||||
coarseIndices.getSize(0),
|
||||
coarseIndices.getSize(1),
|
||||
pqCentroids.getSize(1)});
|
||||
|
||||
runPQResidualVector(pqCentroids, queries,
|
||||
coarseCentroids, coarseIndices,
|
||||
residual, l2Distance, stream);
|
||||
|
||||
// Perform a batch MM:
|
||||
// (sub q) x {(q * c)(sub dim) x (sub dim)(code)} =>
|
||||
// (sub q) x {(q * c)(code)}
|
||||
auto residualView3 = residual.view<3>(
|
||||
{pqCentroids.getSize(0),
|
||||
coarseIndices.getSize(0) * coarseIndices.getSize(1),
|
||||
pqCentroids.getSize(1)});
|
||||
|
||||
DeviceTensor<float, 3, true> residualDistance(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0),
|
||||
coarseIndices.getSize(0) * coarseIndices.getSize(1),
|
||||
pqCentroids.getSize(2)});
|
||||
|
||||
runBatchMatrixMult(residualDistance, false,
|
||||
residualView3, false,
|
||||
pqCentroids, false,
|
||||
l2Distance ? -2.0f : 1.0f, 0.0f,
|
||||
res->getBlasHandleCurrentDevice(),
|
||||
stream);
|
||||
|
||||
if (l2Distance) {
|
||||
// Calculate ||q - c||^2
|
||||
DeviceTensor<float, 1, true> residualNorms(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0) *
|
||||
coarseIndices.getSize(0) *
|
||||
coarseIndices.getSize(1)});
|
||||
|
||||
auto residualView2 = residual.view<2>(
|
||||
{pqCentroids.getSize(0) *
|
||||
coarseIndices.getSize(0) *
|
||||
coarseIndices.getSize(1),
|
||||
pqCentroids.getSize(1)});
|
||||
|
||||
runL2Norm(residualView2, true, residualNorms, true, stream);
|
||||
|
||||
// Sum ||q - c||^2 along rows
|
||||
auto residualDistanceView2 = residualDistance.view<2>(
|
||||
{pqCentroids.getSize(0) *
|
||||
coarseIndices.getSize(0) *
|
||||
coarseIndices.getSize(1),
|
||||
pqCentroids.getSize(2)});
|
||||
|
||||
runSumAlongRows(residualNorms, residualDistanceView2, false, stream);
|
||||
}
|
||||
|
||||
// Transpose (sub q)(q * c)(code) to (q * c)(sub q)(code) (which
|
||||
// is where we build our output distances). L2 version of this has an added -2
|
||||
// multiplicative factor
|
||||
auto outCodeDistancesView = outCodeDistancesF.view<3>(
|
||||
{topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
|
||||
outCodeDistances.getSize(2),
|
||||
outCodeDistances.getSize(3)});
|
||||
{coarseIndices.getSize(0) * coarseIndices.getSize(1),
|
||||
outCodeDistances.getSize(2),
|
||||
outCodeDistances.getSize(3)});
|
||||
|
||||
runTransposeAny(residualDistance, 0, 1, outCodeDistancesView, stream);
|
||||
|
||||
// Calculate code norms per each sub-dim
|
||||
// (sub q)(sub dim)(code) is pqCentroids
|
||||
// transpose to (sub q)(code)(sub dim)
|
||||
DeviceTensor<float, 3, true> pqCentroidsTranspose(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)});
|
||||
if (l2Distance) {
|
||||
// Calculate code norms per each sub-dim
|
||||
// (sub q)(sub dim)(code) is pqCentroids
|
||||
// transpose to (sub q)(code)(sub dim)
|
||||
DeviceTensor<float, 3, true> pqCentroidsTranspose(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0), pqCentroids.getSize(2), pqCentroids.getSize(1)});
|
||||
|
||||
runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
|
||||
runTransposeAny(pqCentroids, 1, 2, pqCentroidsTranspose, stream);
|
||||
|
||||
auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
|
||||
{pqCentroids.getSize(0) * pqCentroids.getSize(2),
|
||||
pqCentroids.getSize(1)});
|
||||
auto pqCentroidsTransposeView = pqCentroidsTranspose.view<2>(
|
||||
{pqCentroids.getSize(0) * pqCentroids.getSize(2),
|
||||
pqCentroids.getSize(1)});
|
||||
|
||||
DeviceTensor<float, 1, true> pqCentroidsNorm(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0) * pqCentroids.getSize(2)});
|
||||
// The norm of each (sub q)(code)
|
||||
DeviceTensor<float, 1, true> pqCentroidsNorm(
|
||||
res, makeTempAlloc(AllocType::Other, stream),
|
||||
{pqCentroids.getSize(0) * pqCentroids.getSize(2)});
|
||||
|
||||
runL2Norm(pqCentroidsTransposeView, true, pqCentroidsNorm, true, stream);
|
||||
runL2Norm(pqCentroidsTransposeView, true, pqCentroidsNorm, true, stream);
|
||||
|
||||
// View output as (q * c)(sub q * code), and add centroid norm to
|
||||
// each row
|
||||
auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
|
||||
{topQueryToCentroid.getSize(0) * topQueryToCentroid.getSize(1),
|
||||
outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
|
||||
// View output as (q * c)(sub q * code), and add centroid norm to
|
||||
// each row
|
||||
auto outDistancesCodeViewCols = outCodeDistancesView.view<2>(
|
||||
{coarseIndices.getSize(0) * coarseIndices.getSize(1),
|
||||
outCodeDistances.getSize(2) * outCodeDistances.getSize(3)});
|
||||
|
||||
runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
|
||||
runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream);
|
||||
} else {
|
||||
// We have previously calculated (query . sub quantizer centroid), but we
|
||||
// need to calculate (query . (centroid + sub quantizer centroid).
|
||||
//
|
||||
// We need to add the (query . centroid) correction factor (coarseDistances)
|
||||
// to all output code distances (q)(c)(sub q)(code).
|
||||
runPQDistanceIPCorrection(outCodeDistancesF, coarseDistances, stream);
|
||||
}
|
||||
|
||||
HostTensor<float, 4, true> debugT(outCodeDistancesF, stream);
|
||||
|
||||
if (useFloat16Lookup) {
|
||||
// Need to convert back
|
||||
// Need to convert back to half in the output memory
|
||||
auto outCodeDistancesH = outCodeDistances.toTensor<half>();
|
||||
convertTensor<float, half, 4>(stream,
|
||||
outCodeDistancesF,
|
||||
|
@ -442,13 +528,39 @@ runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
|
|||
}
|
||||
}
|
||||
|
||||
// Must be kept in sync with runPQDistances
|
||||
inline bool isSpecializedPQCodeDistanceDims(int dims) {
|
||||
switch (dims) {
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
case 5:
|
||||
case 6:
|
||||
case 8:
|
||||
case 10:
|
||||
case 12:
|
||||
case 16:
|
||||
case 20:
|
||||
case 24:
|
||||
case 28:
|
||||
case 32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CentroidT>
|
||||
void
|
||||
runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
|
||||
runPQCodeDistances(GpuResources* res,
|
||||
Tensor<float, 3, true>& pqCentroids,
|
||||
Tensor<float, 2, true>& queries,
|
||||
Tensor<CentroidT, 2, true>& coarseCentroids,
|
||||
Tensor<int, 2, true>& topQueryToCentroid,
|
||||
Tensor<float, 2, true>& coarseDistances,
|
||||
Tensor<int, 2, true>& coarseIndices,
|
||||
NoTypeTensor<4, true>& outCodeDistances,
|
||||
bool useMMImplementation,
|
||||
bool l2Distance,
|
||||
bool useFloat16Lookup,
|
||||
cudaStream_t stream) {
|
||||
|
@ -456,6 +568,26 @@ runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
|
|||
const auto dimsPerSubQuantizer = pqCentroids.getSize(1);
|
||||
const auto codesPerSubQuantizer = pqCentroids.getSize(2);
|
||||
|
||||
// Only a certain number of dimensions per sub quantizer are supported by the
|
||||
// specialized implementation. Every other value falls back to the generalized
|
||||
// MM implementation.
|
||||
if (!isSpecializedPQCodeDistanceDims(dimsPerSubQuantizer) ||
|
||||
useMMImplementation) {
|
||||
// Use the general purpose matrix multiplication implementation which
|
||||
// handles any number of sub-quantizers and dimensions per sub-quantizer
|
||||
runPQCodeDistancesMM<CentroidT>(res,
|
||||
pqCentroids,
|
||||
queries,
|
||||
coarseCentroids,
|
||||
coarseDistances,
|
||||
coarseIndices,
|
||||
outCodeDistances,
|
||||
l2Distance,
|
||||
useFloat16Lookup,
|
||||
stream);
|
||||
return;
|
||||
}
|
||||
|
||||
// FIXME: tune
|
||||
// Reuse of pq centroid data is based on both # of queries * nprobe,
|
||||
// and we should really be tiling in both dimensions
|
||||
|
@ -470,7 +602,7 @@ runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
|
|||
auto block = dim3(codesPerSubQuantizer + loadingThreads);
|
||||
|
||||
auto smem = (3 * dimsPerSubQuantizer) * sizeof(float)
|
||||
+ topQueryToCentroid.getSize(1) * sizeof(int);
|
||||
+ coarseIndices.getSize(1) * sizeof(int);
|
||||
|
||||
#define RUN_CODE(DIMS, L2) \
|
||||
do { \
|
||||
|
@ -480,14 +612,14 @@ runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
|
|||
pqCodeDistances<half, CentroidT, DIMS, L2><<<grid, block, smem, stream>>>( \
|
||||
queries, kQueriesPerBlock, \
|
||||
coarseCentroids, pqCentroids, \
|
||||
topQueryToCentroid, outCodeDistancesT); \
|
||||
coarseIndices, outCodeDistancesT); \
|
||||
} else { \
|
||||
auto outCodeDistancesT = outCodeDistances.toTensor<float>(); \
|
||||
\
|
||||
pqCodeDistances<float, CentroidT, DIMS, L2><<<grid, block, smem, stream>>>( \
|
||||
queries, kQueriesPerBlock, \
|
||||
coarseCentroids, pqCentroids, \
|
||||
topQueryToCentroid, outCodeDistancesT); \
|
||||
coarseIndices, outCodeDistancesT); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
@ -513,6 +645,9 @@ runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
|
|||
case 4:
|
||||
CODE_L2(4);
|
||||
break;
|
||||
case 5:
|
||||
CODE_L2(5);
|
||||
break;
|
||||
case 6:
|
||||
CODE_L2(6);
|
||||
break;
|
||||
|
@ -540,11 +675,11 @@ runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
|
|||
case 32:
|
||||
CODE_L2(32);
|
||||
break;
|
||||
// FIXME: larger sizes require too many registers - we need the
|
||||
// MM implementation working
|
||||
default:
|
||||
FAISS_THROW_MSG("Too many dimensions (>32) per subquantizer "
|
||||
"not currently supported");
|
||||
// This should not be reached, we should fall back to the MM
|
||||
// implementation
|
||||
FAISS_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
|
||||
#undef RUN_CODE
|
||||
|
|
|
@ -19,27 +19,18 @@ class DeviceMemory;
|
|||
/// pqCentroids is of the form (sub q)(sub dim)(code id)
|
||||
/// Calculates the distance from the (query - centroid) residual to
|
||||
/// each sub-code vector, for the given list of query results in
|
||||
/// topQueryToCentroid
|
||||
/// coarseIndices
|
||||
template <typename CentroidT>
|
||||
void runPQCodeDistances(Tensor<float, 3, true>& pqCentroids,
|
||||
void runPQCodeDistances(GpuResources* res,
|
||||
Tensor<float, 3, true>& pqCentroids,
|
||||
Tensor<float, 2, true>& queries,
|
||||
Tensor<CentroidT, 2, true>& coarseCentroids,
|
||||
Tensor<int, 2, true>& topQueryToCentroid,
|
||||
Tensor<float, 2, true>& coarseDistances,
|
||||
Tensor<int, 2, true>& coarseIndices,
|
||||
NoTypeTensor<4, true>& outCodeDistances,
|
||||
bool useMMImplementation,
|
||||
bool l2Distance,
|
||||
bool useFloat16Lookup,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename CentroidT>
|
||||
void runPQCodeDistancesMM(Tensor<float, 3, true>& pqCentroids,
|
||||
Tensor<float, 2, true>& queries,
|
||||
Tensor<CentroidT, 2, true>& coarseCentroids,
|
||||
Tensor<int, 2, true>& topQueryToCentroid,
|
||||
NoTypeTensor<4, true>& outCodeDistances,
|
||||
bool useFloat16Lookup,
|
||||
GpuResources* res,
|
||||
cublasHandle_t handle,
|
||||
cudaStream_t stream);
|
||||
bool useFloat16Lookup);
|
||||
|
||||
} } // namespace
|
||||
|
||||
|
|
|
@ -76,30 +76,6 @@ pqScanInterleaved(Tensor<float, 2, true> queries,
|
|||
}
|
||||
}
|
||||
|
||||
// This must be kept in sync with PQCodeDistances.cu
|
||||
inline bool isSupportedNoPrecomputedSubDimSize(int dims) {
|
||||
switch (dims) {
|
||||
case 1:
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
case 6:
|
||||
case 8:
|
||||
case 10:
|
||||
case 12:
|
||||
case 16:
|
||||
case 20:
|
||||
case 24:
|
||||
case 28:
|
||||
case 32:
|
||||
return true;
|
||||
default:
|
||||
// FIXME: larger sizes require too many registers - we need the
|
||||
// MM implementation working
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LookupT, typename LookupVecT>
|
||||
struct LoadCodeDistances {
|
||||
static inline __device__ void load(LookupT* smem,
|
||||
|
@ -278,8 +254,10 @@ runMultiPassTile(GpuResources* res,
|
|||
Tensor<CentroidT, 2, true>& centroids,
|
||||
Tensor<float, 3, true>& pqCentroidsInnermostCode,
|
||||
NoTypeTensor<4, true>& codeDistances,
|
||||
Tensor<int, 2, true>& topQueryToCentroid,
|
||||
Tensor<float, 2, true>& coarseDistances,
|
||||
Tensor<int, 2, true>& coarseIndices,
|
||||
bool useFloat16Lookup,
|
||||
bool useMMCodeDistance,
|
||||
bool interleavedCodeLayout,
|
||||
int numSubQuantizers,
|
||||
int numSubQuantizerCodes,
|
||||
|
@ -305,16 +283,19 @@ runMultiPassTile(GpuResources* res,
|
|||
|
||||
// Calculate offset lengths, so we know where to write out
|
||||
// intermediate results
|
||||
runCalcListOffsets(res, topQueryToCentroid, listLengths, prefixSumOffsets,
|
||||
runCalcListOffsets(res, coarseIndices, listLengths, prefixSumOffsets,
|
||||
thrustMem, stream);
|
||||
|
||||
// Calculate residual code distances, since this is without
|
||||
// precomputed codes
|
||||
runPQCodeDistances(pqCentroidsInnermostCode,
|
||||
runPQCodeDistances(res,
|
||||
pqCentroidsInnermostCode,
|
||||
queries,
|
||||
centroids,
|
||||
topQueryToCentroid,
|
||||
coarseDistances,
|
||||
coarseIndices,
|
||||
codeDistances,
|
||||
useMMCodeDistance,
|
||||
l2Distance,
|
||||
useFloat16Lookup,
|
||||
stream);
|
||||
|
@ -323,8 +304,8 @@ runMultiPassTile(GpuResources* res,
|
|||
// The vector interleaved layout implementation
|
||||
auto kThreadsPerBlock = 256;
|
||||
|
||||
auto grid = dim3(topQueryToCentroid.getSize(1),
|
||||
topQueryToCentroid.getSize(0));
|
||||
auto grid = dim3(coarseIndices.getSize(1),
|
||||
coarseIndices.getSize(0));
|
||||
auto block = dim3(kThreadsPerBlock);
|
||||
|
||||
if (useFloat16Lookup) {
|
||||
|
@ -333,7 +314,7 @@ runMultiPassTile(GpuResources* res,
|
|||
pqScanInterleaved<half>
|
||||
<<<grid, block, 0, stream>>>(queries,
|
||||
pqCentroidsInnermostCode,
|
||||
topQueryToCentroid,
|
||||
coarseIndices,
|
||||
codeDistancesT,
|
||||
listCodes.data().get(),
|
||||
listLengths.data().get(),
|
||||
|
@ -345,7 +326,7 @@ runMultiPassTile(GpuResources* res,
|
|||
pqScanInterleaved<float>
|
||||
<<<grid, block, 0, stream>>>(queries,
|
||||
pqCentroidsInnermostCode,
|
||||
topQueryToCentroid,
|
||||
coarseIndices,
|
||||
codeDistancesT,
|
||||
listCodes.data().get(),
|
||||
listLengths.data().get(),
|
||||
|
@ -357,8 +338,8 @@ runMultiPassTile(GpuResources* res,
|
|||
// index) values for all intermediate results
|
||||
auto kThreadsPerBlock = 256;
|
||||
|
||||
auto grid = dim3(topQueryToCentroid.getSize(1),
|
||||
topQueryToCentroid.getSize(0));
|
||||
auto grid = dim3(coarseIndices.getSize(1),
|
||||
coarseIndices.getSize(0));
|
||||
auto block = dim3(kThreadsPerBlock);
|
||||
|
||||
// pq centroid distances
|
||||
|
@ -375,7 +356,7 @@ runMultiPassTile(GpuResources* res,
|
|||
<<<grid, block, smem, stream>>>( \
|
||||
queries, \
|
||||
pqCentroidsInnermostCode, \
|
||||
topQueryToCentroid, \
|
||||
coarseIndices, \
|
||||
codeDistancesT, \
|
||||
listCodes.data().get(), \
|
||||
listLengths.data().get(), \
|
||||
|
@ -455,7 +436,7 @@ runMultiPassTile(GpuResources* res,
|
|||
// k-select the output in chunks, to increase parallelism
|
||||
runPass1SelectLists(prefixSumOffsets,
|
||||
allDistances,
|
||||
topQueryToCentroid.getSize(1),
|
||||
coarseIndices.getSize(1),
|
||||
k,
|
||||
!l2Distance, // L2 distance chooses smallest
|
||||
heapDistances,
|
||||
|
@ -471,7 +452,7 @@ runMultiPassTile(GpuResources* res,
|
|||
listIndices,
|
||||
indicesOptions,
|
||||
prefixSumOffsets,
|
||||
topQueryToCentroid,
|
||||
coarseIndices,
|
||||
k,
|
||||
!l2Distance, // L2 distance chooses smallest
|
||||
outDistances,
|
||||
|
@ -484,8 +465,10 @@ void
|
|||
runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
|
||||
Tensor<CentroidT, 2, true>& centroids,
|
||||
Tensor<float, 3, true>& pqCentroidsInnermostCode,
|
||||
Tensor<int, 2, true>& topQueryToCentroid,
|
||||
Tensor<float, 2, true>& coarseDistances,
|
||||
Tensor<int, 2, true>& coarseIndices,
|
||||
bool useFloat16Lookup,
|
||||
bool useMMCodeDistance,
|
||||
bool interleavedCodeLayout,
|
||||
int numSubQuantizers,
|
||||
int numSubQuantizerCodes,
|
||||
|
@ -505,7 +488,7 @@ runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
|
|||
constexpr int kMaxQueryTileSize = 128;
|
||||
constexpr int kThrustMemSize = 16384;
|
||||
|
||||
int nprobe = topQueryToCentroid.getSize(1);
|
||||
int nprobe = coarseIndices.getSize(1);
|
||||
|
||||
auto stream = res->getDefaultStreamCurrentDevice();
|
||||
|
||||
|
@ -650,8 +633,10 @@ runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
|
|||
|
||||
auto codeDistancesView =
|
||||
codeDistances[curStream]->narrowOutermost(0, numQueriesInTile);
|
||||
auto coarseDistancesView =
|
||||
coarseDistances.narrowOutermost(query, numQueriesInTile);
|
||||
auto coarseIndicesView =
|
||||
topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
|
||||
coarseIndices.narrowOutermost(query, numQueriesInTile);
|
||||
auto queryView =
|
||||
queries.narrowOutermost(query, numQueriesInTile);
|
||||
|
||||
|
@ -670,8 +655,10 @@ runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
|
|||
centroids,
|
||||
pqCentroidsInnermostCode,
|
||||
codeDistancesView,
|
||||
coarseDistancesView,
|
||||
coarseIndicesView,
|
||||
useFloat16Lookup,
|
||||
useMMCodeDistance,
|
||||
interleavedCodeLayout,
|
||||
numSubQuantizers,
|
||||
numSubQuantizerCodes,
|
||||
|
|
|
@ -17,16 +17,14 @@ namespace faiss { namespace gpu {
|
|||
|
||||
class GpuResources;
|
||||
|
||||
/// For no precomputed codes, is this a supported number of dimensions
|
||||
/// per subquantizer?
|
||||
inline bool isSupportedNoPrecomputedSubDimSize(int dims);
|
||||
|
||||
template <typename CentroidT>
|
||||
void runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
|
||||
Tensor<CentroidT, 2, true>& centroids,
|
||||
Tensor<float, 3, true>& pqCentroidsInnermostCode,
|
||||
Tensor<int, 2, true>& topQueryToCentroid,
|
||||
Tensor<float, 2, true>& coarseDistances,
|
||||
Tensor<int, 2, true>& coarseIndices,
|
||||
bool useFloat16Lookup,
|
||||
bool useMMCodeDistance,
|
||||
bool interleavedCodeLayout,
|
||||
int numSubQuantizers,
|
||||
int numSubQuantizerCodes,
|
||||
|
|
|
@ -51,7 +51,7 @@ struct Options {
|
|||
// support non-multiple of 8 subcodes for IVFPQ.
|
||||
bitsPerCode = 8;
|
||||
nprobe = std::min(faiss::gpu::randVal(40, 1000), numCentroids);
|
||||
numQuery = faiss::gpu::randVal(1, 8);
|
||||
numQuery = faiss::gpu::randVal(4, 8);
|
||||
|
||||
// Due to the approximate nature of the query and of floating point
|
||||
// differences between GPU and CPU, to stay within our error bounds, only
|
||||
|
@ -91,7 +91,7 @@ struct Options {
|
|||
}
|
||||
|
||||
float getCompareEpsilon() const {
|
||||
return 0.032f;
|
||||
return 0.035f;
|
||||
}
|
||||
|
||||
float getPctMaxDiff1() const {
|
||||
|
@ -136,7 +136,7 @@ TEST(TestGpuIndexIVFPQ, Query_L2) {
|
|||
|
||||
faiss::gpu::GpuIndexIVFPQConfig config;
|
||||
config.device = opt.device;
|
||||
config.usePrecomputedTables = opt.usePrecomputed;
|
||||
config.usePrecomputedTables = (tries % 2 == 0);
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
config.useFloat16LookupTables = opt.useFloat16;
|
||||
|
||||
|
@ -151,6 +151,93 @@ TEST(TestGpuIndexIVFPQ, Query_L2) {
|
|||
}
|
||||
}
|
||||
|
||||
void testMMCodeDistance(faiss::MetricType mt) {
|
||||
// Explicitly test the code distance via batch matrix multiplication route
|
||||
// (even for dimension sizes that would otherwise be handled by the
|
||||
// specialized route (via enabling `useMMCodeDistance`)
|
||||
for (int tries = 0; tries < 2; ++tries) {
|
||||
Options opt;
|
||||
|
||||
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
|
||||
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);
|
||||
|
||||
faiss::IndexFlat coarseQuantizer(opt.dim, mt);
|
||||
faiss::IndexIVFPQ cpuIndex(&coarseQuantizer, opt.dim, opt.numCentroids,
|
||||
opt.codes, opt.bitsPerCode);
|
||||
cpuIndex.nprobe = opt.nprobe;
|
||||
cpuIndex.train(opt.numTrain, trainVecs.data());
|
||||
cpuIndex.add(opt.numAdd, addVecs.data());
|
||||
|
||||
// Use the default temporary memory management to test the memory manager
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
|
||||
faiss::gpu::GpuIndexIVFPQConfig config;
|
||||
config.device = opt.device;
|
||||
config.usePrecomputedTables = false;
|
||||
config.useMMCodeDistance = true;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
|
||||
// Make sure that the float16 version works as well
|
||||
config.useFloat16LookupTables = (tries % 2 == 0);
|
||||
config.flatConfig.useFloat16 = (tries % 2 == 1);
|
||||
|
||||
faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, &cpuIndex, config);
|
||||
gpuIndex.setNumProbes(opt.nprobe);
|
||||
|
||||
faiss::gpu::compareIndices(cpuIndex, gpuIndex,
|
||||
opt.numQuery, opt.dim, opt.k, opt.toString(),
|
||||
opt.getCompareEpsilon(),
|
||||
opt.getPctMaxDiff1(),
|
||||
opt.getPctMaxDiffN());
|
||||
}
|
||||
|
||||
// These sizes are not specialized, they will fall back to the MM version
|
||||
for (int dimPerSubQ : {7, 11}) {
|
||||
Options opt;
|
||||
|
||||
opt.codes = 12;
|
||||
opt.dim = dimPerSubQ * opt.codes;
|
||||
|
||||
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
|
||||
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);
|
||||
|
||||
faiss::IndexFlat coarseQuantizer(opt.dim, mt);
|
||||
faiss::IndexIVFPQ cpuIndex(&coarseQuantizer, opt.dim, opt.numCentroids,
|
||||
opt.codes, opt.bitsPerCode);
|
||||
cpuIndex.nprobe = opt.nprobe;
|
||||
cpuIndex.train(opt.numTrain, trainVecs.data());
|
||||
cpuIndex.add(opt.numAdd, addVecs.data());
|
||||
|
||||
// Use the default temporary memory management to test the memory manager
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
|
||||
faiss::gpu::GpuIndexIVFPQConfig config;
|
||||
config.device = opt.device;
|
||||
config.usePrecomputedTables = false;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
|
||||
// Make sure that the float16 version works as well
|
||||
config.useFloat16LookupTables = (dimPerSubQ == 7);
|
||||
|
||||
faiss::gpu::GpuIndexIVFPQ gpuIndex(&res, &cpuIndex, config);
|
||||
gpuIndex.setNumProbes(opt.nprobe);
|
||||
|
||||
faiss::gpu::compareIndices(cpuIndex, gpuIndex,
|
||||
opt.numQuery, opt.dim, opt.k, opt.toString(),
|
||||
opt.getCompareEpsilon(),
|
||||
opt.getPctMaxDiff1(),
|
||||
opt.getPctMaxDiffN());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFPQ, Query_L2_MMCodeDistance) {
|
||||
testMMCodeDistance(faiss::MetricType::METRIC_L2);
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFPQ, Query_IP_MMCodeDistance) {
|
||||
testMMCodeDistance(faiss::MetricType::METRIC_INNER_PRODUCT);
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFPQ, Query_IP) {
|
||||
for (int tries = 0; tries < 2; ++tries) {
|
||||
Options opt;
|
||||
|
@ -296,54 +383,56 @@ TEST(TestGpuIndexIVFPQ, Add_IP) {
|
|||
}
|
||||
|
||||
TEST(TestGpuIndexIVFPQ, CopyTo) {
|
||||
Options opt;
|
||||
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
|
||||
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);
|
||||
for (int tries = 0; tries < 2; ++tries) {
|
||||
Options opt;
|
||||
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
|
||||
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);
|
||||
|
||||
// Use the default temporary memory management to test the memory manager
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
// Use the default temporary memory management to test the memory manager
|
||||
faiss::gpu::StandardGpuResources res;
|
||||
|
||||
faiss::gpu::GpuIndexIVFPQConfig config;
|
||||
config.device = opt.device;
|
||||
config.usePrecomputedTables = opt.usePrecomputed;
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
config.useFloat16LookupTables = opt.useFloat16;
|
||||
faiss::gpu::GpuIndexIVFPQConfig config;
|
||||
config.device = opt.device;
|
||||
config.usePrecomputedTables = (tries % 2 == 0);
|
||||
config.indicesOptions = opt.indicesOpt;
|
||||
config.useFloat16LookupTables = opt.useFloat16;
|
||||
|
||||
faiss::gpu::GpuIndexIVFPQ gpuIndex(&res,
|
||||
opt.dim,
|
||||
opt.numCentroids,
|
||||
opt.codes,
|
||||
opt.bitsPerCode,
|
||||
faiss::METRIC_L2,
|
||||
config);
|
||||
gpuIndex.setNumProbes(opt.nprobe);
|
||||
gpuIndex.train(opt.numTrain, trainVecs.data());
|
||||
gpuIndex.add(opt.numAdd, addVecs.data());
|
||||
faiss::gpu::GpuIndexIVFPQ gpuIndex(&res,
|
||||
opt.dim,
|
||||
opt.numCentroids,
|
||||
opt.codes,
|
||||
opt.bitsPerCode,
|
||||
faiss::METRIC_L2,
|
||||
config);
|
||||
gpuIndex.setNumProbes(opt.nprobe);
|
||||
gpuIndex.train(opt.numTrain, trainVecs.data());
|
||||
gpuIndex.add(opt.numAdd, addVecs.data());
|
||||
|
||||
// Use garbage values to see if we overwrite them
|
||||
faiss::IndexFlatL2 cpuQuantizer(1);
|
||||
faiss::IndexIVFPQ cpuIndex(&cpuQuantizer, 1, 1, 1, 1);
|
||||
// Use garbage values to see if we overwrite them
|
||||
faiss::IndexFlatL2 cpuQuantizer(1);
|
||||
faiss::IndexIVFPQ cpuIndex(&cpuQuantizer, 1, 1, 1, 1);
|
||||
|
||||
gpuIndex.copyTo(&cpuIndex);
|
||||
gpuIndex.copyTo(&cpuIndex);
|
||||
|
||||
EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal);
|
||||
EXPECT_EQ(gpuIndex.ntotal, opt.numAdd);
|
||||
EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal);
|
||||
EXPECT_EQ(gpuIndex.ntotal, opt.numAdd);
|
||||
|
||||
EXPECT_EQ(cpuIndex.d, gpuIndex.d);
|
||||
EXPECT_EQ(cpuIndex.d, opt.dim);
|
||||
EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists());
|
||||
EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes());
|
||||
EXPECT_EQ(cpuIndex.pq.M, gpuIndex.getNumSubQuantizers());
|
||||
EXPECT_EQ(gpuIndex.getNumSubQuantizers(), opt.codes);
|
||||
EXPECT_EQ(cpuIndex.pq.nbits, gpuIndex.getBitsPerCode());
|
||||
EXPECT_EQ(gpuIndex.getBitsPerCode(), opt.bitsPerCode);
|
||||
EXPECT_EQ(cpuIndex.d, gpuIndex.d);
|
||||
EXPECT_EQ(cpuIndex.d, opt.dim);
|
||||
EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists());
|
||||
EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes());
|
||||
EXPECT_EQ(cpuIndex.pq.M, gpuIndex.getNumSubQuantizers());
|
||||
EXPECT_EQ(gpuIndex.getNumSubQuantizers(), opt.codes);
|
||||
EXPECT_EQ(cpuIndex.pq.nbits, gpuIndex.getBitsPerCode());
|
||||
EXPECT_EQ(gpuIndex.getBitsPerCode(), opt.bitsPerCode);
|
||||
|
||||
// Query both objects; results should be equivalent
|
||||
faiss::gpu::compareIndices(cpuIndex, gpuIndex,
|
||||
opt.numQuery, opt.dim, opt.k, opt.toString(),
|
||||
opt.getCompareEpsilon(),
|
||||
opt.getPctMaxDiff1(),
|
||||
opt.getPctMaxDiffN());
|
||||
// Query both objects; results should be equivalent
|
||||
faiss::gpu::compareIndices(cpuIndex, gpuIndex,
|
||||
opt.numQuery, opt.dim, opt.k, opt.toString(),
|
||||
opt.getCompareEpsilon(),
|
||||
opt.getPctMaxDiff1(),
|
||||
opt.getPctMaxDiffN());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestGpuIndexIVFPQ, CopyFrom) {
|
||||
|
|
|
@ -21,7 +21,7 @@ class EvalIVFPQAccuracy(unittest.TestCase):
|
|||
nq = 2000
|
||||
else:
|
||||
d = 32
|
||||
nb = 1000
|
||||
nb = 10000
|
||||
nt = 1000
|
||||
nq = 200
|
||||
np.random.seed(123)
|
||||
|
@ -66,7 +66,7 @@ class EvalIVFPQAccuracy(unittest.TestCase):
|
|||
ts.append(time.time())
|
||||
|
||||
index.nprobe = 4
|
||||
D, Iref = index.search(xq, 10)
|
||||
Dref, Iref = index.search(xq, 10)
|
||||
ts.append(time.time())
|
||||
|
||||
res = faiss.StandardGpuResources()
|
||||
|
@ -83,7 +83,7 @@ class EvalIVFPQAccuracy(unittest.TestCase):
|
|||
|
||||
gpu_index.setNumProbes(4)
|
||||
|
||||
D, Inew = gpu_index.search(xq, 10)
|
||||
Dnew, Inew = gpu_index.search(xq, 10)
|
||||
ts.append(time.time())
|
||||
print('times:', [t - ts[0] for t in ts])
|
||||
|
||||
|
@ -105,7 +105,7 @@ class EvalIVFPQAccuracy(unittest.TestCase):
|
|||
faiss.GpuParameterSpace().set_index_parameter(
|
||||
gpu_index, 'nprobe', 4)
|
||||
|
||||
D, Inew = gpu_index.search(xq, 10)
|
||||
Dnew, Inew = gpu_index.search(xq, 10)
|
||||
|
||||
# 0.99: allow some tolerance in results otherwise test
|
||||
# fails occasionally (not reproducible)
|
||||
|
|
|
@ -115,7 +115,7 @@ template <typename T, int Dim, bool InnerContig,
|
|||
typename IndexT, template <typename U> class PtrTraits>
|
||||
__host__
|
||||
HostTensor<T, Dim, InnerContig, IndexT, PtrTraits>::HostTensor(
|
||||
Tensor<T, Dim, InnerContig, IndexT, PtrTraits>& t,
|
||||
const Tensor<T, Dim, InnerContig, IndexT, PtrTraits>& t,
|
||||
cudaStream_t stream) :
|
||||
Tensor<T, Dim, InnerContig, IndexT, PtrTraits>(nullptr, t.sizes(), t.strides()),
|
||||
state_(AllocState::Owner) {
|
||||
|
|
|
@ -56,7 +56,7 @@ class HostTensor : public Tensor<T, Dim, InnerContig, IndexT, PtrTraits> {
|
|||
/// Copies a tensor into ourselves, allocating memory for it
|
||||
/// locally. If the tensor is on the GPU, then we will copy it to
|
||||
/// ourselves wrt the given stream.
|
||||
__host__ HostTensor(Tensor<T, Dim, InnerContig, IndexT, PtrTraits>& t,
|
||||
__host__ HostTensor(const Tensor<T, Dim, InnerContig, IndexT, PtrTraits>& t,
|
||||
cudaStream_t stream);
|
||||
|
||||
/// Call to zero out memory
|
||||
|
|
|
@ -38,9 +38,9 @@ rawGemm(cublasHandle_t handle,
|
|||
int n,
|
||||
int k,
|
||||
const float fAlpha,
|
||||
const AT *A,
|
||||
const void *A,
|
||||
int lda,
|
||||
const BT *B,
|
||||
const void *B,
|
||||
int ldb,
|
||||
const float fBeta,
|
||||
float *C,
|
||||
|
@ -56,6 +56,37 @@ rawGemm(cublasHandle_t handle,
|
|||
C, CUDA_R_32F, ldc);
|
||||
}
|
||||
|
||||
template <typename AT, typename BT>
|
||||
cublasStatus_t
|
||||
rawBatchGemm(cublasHandle_t handle,
|
||||
cublasOperation_t transa,
|
||||
cublasOperation_t transb,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const float fAlpha,
|
||||
const void* A,
|
||||
int lda,
|
||||
long long int strideA,
|
||||
const void* B,
|
||||
int ldb,
|
||||
long long int strideB,
|
||||
const float fBeta,
|
||||
float* C,
|
||||
int ldc,
|
||||
long long int strideC,
|
||||
int batchCount) {
|
||||
auto cAT = GetCudaType<AT>::Type;
|
||||
auto cBT = GetCudaType<BT>::Type;
|
||||
|
||||
// Always accumulate in f32
|
||||
return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k,
|
||||
&fAlpha, A, cAT, lda, strideA,
|
||||
B, cBT, ldb, strideB, &fBeta,
|
||||
C, CUDA_R_32F, ldc, strideC, batchCount,
|
||||
CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
|
||||
}
|
||||
|
||||
template <typename AT, typename BT>
|
||||
void
|
||||
runMatrixMult(Tensor<float, 2, true>& c, bool transC,
|
||||
|
@ -109,17 +140,17 @@ runMatrixMult(Tensor<float, 2, true>& c, bool transC,
|
|||
cublasStatus_t err;
|
||||
|
||||
if (transC) {
|
||||
err = rawGemm(handle,
|
||||
gemmTrA, gemmTrB,
|
||||
m, n, k, alpha,
|
||||
a.data(), lda, b.data(), ldb, beta,
|
||||
pC, ldc);
|
||||
err = rawGemm<AT, BT>(handle,
|
||||
gemmTrA, gemmTrB,
|
||||
m, n, k, alpha,
|
||||
a.data(), lda, b.data(), ldb, beta,
|
||||
pC, ldc);
|
||||
} else {
|
||||
err = rawGemm(handle,
|
||||
gemmTrA, gemmTrB,
|
||||
m, n, k, alpha,
|
||||
b.data(), lda, a.data(), ldb, beta,
|
||||
pC, ldc);
|
||||
err = rawGemm<AT, BT>(handle,
|
||||
gemmTrA, gemmTrB,
|
||||
m, n, k, alpha,
|
||||
b.data(), lda, a.data(), ldb, beta,
|
||||
pC, ldc);
|
||||
}
|
||||
|
||||
FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
|
||||
|
@ -133,26 +164,76 @@ runMatrixMult(Tensor<float, 2, true>& c, bool transC,
|
|||
}
|
||||
|
||||
template <typename AT, typename BT>
|
||||
void runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
|
||||
Tensor<AT, 3, true>& a, bool transA,
|
||||
Tensor<BT, 3, true>& b, bool transB,
|
||||
float alpha,
|
||||
float beta,
|
||||
cublasHandle_t handle,
|
||||
cudaStream_t stream) {
|
||||
void
|
||||
runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
|
||||
Tensor<AT, 3, true>& a, bool transA,
|
||||
Tensor<BT, 3, true>& b, bool transB,
|
||||
float alpha,
|
||||
float beta,
|
||||
cublasHandle_t handle,
|
||||
cudaStream_t stream) {
|
||||
FAISS_ASSERT(c.getSize(0) == a.getSize(0));
|
||||
FAISS_ASSERT(a.getSize(0) == b.getSize(0));
|
||||
|
||||
for (int i = 0; i < a.getSize(0); ++i) {
|
||||
auto cView = c[i].view();
|
||||
auto aView = a[i].view();
|
||||
auto bView = b[i].view();
|
||||
// This uses the strided batch MM, which assumes a uniform stride
|
||||
FAISS_ASSERT(a.getStride(0) == a.getSize(1) * a.getSize(2));
|
||||
FAISS_ASSERT(b.getStride(0) == b.getSize(1) * b.getSize(2));
|
||||
FAISS_ASSERT(c.getStride(0) == c.getSize(1) * c.getSize(2));
|
||||
|
||||
runMatrixMult(cView, transC,
|
||||
aView, transA,
|
||||
bView, transB,
|
||||
alpha, beta, handle, stream);
|
||||
cublasSetStream(handle, stream);
|
||||
|
||||
// Check that we have (m x k) * (k x n) = (m x n)
|
||||
// using the input row-major layout
|
||||
int aM = transA ? a.getSize(2) : a.getSize(1);
|
||||
int aK = transA ? a.getSize(1) : a.getSize(2);
|
||||
|
||||
int bK = transB ? b.getSize(2) : b.getSize(1);
|
||||
int bN = transB ? b.getSize(1) : b.getSize(2);
|
||||
|
||||
int cM = transC ? c.getSize(2) : c.getSize(1);
|
||||
int cN = transC ? c.getSize(1) : c.getSize(2);
|
||||
|
||||
FAISS_ASSERT(aM == cM);
|
||||
FAISS_ASSERT(aK == bK);
|
||||
FAISS_ASSERT(bN == cN);
|
||||
|
||||
// Now, we have to represent the matrix multiplication in
|
||||
// column-major layout
|
||||
void* pA = transC ? (void*) a.data() : (void*) b.data();
|
||||
void* pB = transC ? (void*) b.data() : (void*) a.data();
|
||||
float* pC = c.data();
|
||||
|
||||
int m = c.getSize(2); // stride 1 size
|
||||
int n = c.getSize(1); // other size
|
||||
int k = transA ? a.getSize(1) : a.getSize(2);
|
||||
|
||||
int lda = transC ? a.getStride(1) : b.getStride(1);
|
||||
int ldb = transC ? b.getStride(1) : a.getStride(1);
|
||||
int ldc = c.getStride(1);
|
||||
|
||||
auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
|
||||
if (transC) {
|
||||
gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
}
|
||||
|
||||
long long int gemmStrideA = transC ? a.getStride(0) : b.getStride(0);
|
||||
long long int gemmStrideB = transC ? b.getStride(0) : a.getStride(0);
|
||||
long long int gemmStrideC = c.getStride(0);
|
||||
|
||||
auto err =
|
||||
rawBatchGemm<AT, BT>(handle,
|
||||
gemmTrA, gemmTrB,
|
||||
m, n, k, alpha,
|
||||
pA, lda, gemmStrideA,
|
||||
pB, ldb, gemmStrideB, beta,
|
||||
pC, ldc, gemmStrideC, a.getSize(0));
|
||||
|
||||
FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
|
||||
"cublasGemmStridedBatchedEx failed (%d)", (int) err);
|
||||
CUDA_TEST_ERROR();
|
||||
}
|
||||
|
||||
} } // namespace
|
||||
|
|
|
@ -1,97 +0,0 @@
|
|||
/**
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
|
||||
#include <faiss/gpu/utils/MatrixMult.cuh>
|
||||
#include <faiss/gpu/GpuResources.h>
|
||||
|
||||
namespace faiss { namespace gpu {
|
||||
|
||||
void
|
||||
runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
|
||||
Tensor<float, 3, true>& a, bool transA,
|
||||
Tensor<float, 3, true>& b, bool transB,
|
||||
float alpha,
|
||||
float beta,
|
||||
GpuResources* res,
|
||||
cublasHandle_t handle,
|
||||
cudaStream_t stream) {
|
||||
FAISS_ASSERT(c.getSize(0) == a.getSize(0));
|
||||
FAISS_ASSERT(a.getSize(0) == b.getSize(0));
|
||||
cublasSetStream(handle, stream);
|
||||
|
||||
// Check that we have (m x k) * (k x n) = (m x n)
|
||||
// using the input row-major layout
|
||||
int aM = transA ? a.getSize(2) : a.getSize(1);
|
||||
int aK = transA ? a.getSize(1) : a.getSize(2);
|
||||
|
||||
int bK = transB ? b.getSize(2) : b.getSize(1);
|
||||
int bN = transB ? b.getSize(1) : b.getSize(2);
|
||||
|
||||
int cM = transC ? c.getSize(2) : c.getSize(1);
|
||||
int cN = transC ? c.getSize(1) : c.getSize(2);
|
||||
|
||||
FAISS_ASSERT(aM == cM);
|
||||
FAISS_ASSERT(aK == bK);
|
||||
FAISS_ASSERT(bN == cN);
|
||||
|
||||
// Now, we have to represent the matrix multiplication in
|
||||
// column-major layout
|
||||
float* pA = transC ? a.data() : b.data();
|
||||
float* pB = transC ? b.data() : a.data();
|
||||
float* pC = c.data();
|
||||
|
||||
int m = c.getSize(2); // stride 1 size
|
||||
int n = c.getSize(1); // other size
|
||||
int k = transA ? a.getSize(1) : a.getSize(2);
|
||||
|
||||
int lda = transC ? a.getStride(1) : b.getStride(1);
|
||||
int ldb = transC ? b.getStride(1) : a.getStride(1);
|
||||
int ldc = c.getStride(1);
|
||||
|
||||
auto gemmTrA = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
auto gemmTrB = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
|
||||
if (transC) {
|
||||
gemmTrA = transA ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
gemmTrB = transB ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
}
|
||||
|
||||
HostTensor<float*, 1, true> hostA({a.getSize(0)});
|
||||
HostTensor<float*, 1, true> hostB({b.getSize(0)});
|
||||
HostTensor<float*, 1, true> hostC({c.getSize(0)});
|
||||
|
||||
size_t aOffset = a.getStride(0);
|
||||
size_t bOffset = b.getStride(0);
|
||||
size_t cOffset = c.getStride(0);
|
||||
|
||||
for (int i = 0; i < a.getSize(0); ++i) {
|
||||
hostA[i] = transC ? a.data() + i * aOffset : b.data() + i * bOffset;
|
||||
hostB[i] = transC ? b.data() + i * bOffset : a.data() + i * aOffset;
|
||||
hostC[i] = c.data() + i * cOffset;
|
||||
}
|
||||
|
||||
DeviceTensor<float*, 1, true>
|
||||
deviceA(res, makeTempAlloc(AllocType::Other, stream), hostA);
|
||||
DeviceTensor<float*, 1, true>
|
||||
deviceB(res, makeTempAlloc(AllocType::Other, stream), hostB);
|
||||
DeviceTensor<float*, 1, true>
|
||||
deviceC(res, makeTempAlloc(AllocType::Other, stream), hostC);
|
||||
|
||||
auto err =
|
||||
cublasSgemmBatched(handle,
|
||||
gemmTrA, gemmTrB,
|
||||
m, n, k, &alpha,
|
||||
(const float**) deviceA.data(), lda,
|
||||
(const float**) deviceB.data(), ldb, &beta,
|
||||
deviceC.data(), ldc, a.getSize(0));
|
||||
FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
|
||||
"cublasSgemmBatched failed (%d)", (int) err);
|
||||
CUDA_TEST_ERROR();
|
||||
}
|
||||
|
||||
} } // namespace
|
|
@ -42,18 +42,6 @@ void runIteratedMatrixMult(Tensor<float, 3, true>& c, bool transC,
|
|||
cublasHandle_t handle,
|
||||
cudaStream_t stream);
|
||||
|
||||
/// C_i = alpha * A_i * B_i + beta * C_i
|
||||
/// where `i` is the outermost dimension, via batched gemm
|
||||
/// Expects row major layout, not fortran/blas column major!
|
||||
void runBatchMatrixMult(Tensor<float, 3, true>& c, bool transC,
|
||||
Tensor<float, 3, true>& a, bool transA,
|
||||
Tensor<float, 3, true>& b, bool transB,
|
||||
float alpha,
|
||||
float beta,
|
||||
GpuResources* res,
|
||||
cublasHandle_t handle,
|
||||
cudaStream_t stream);
|
||||
|
||||
} } // namespace
|
||||
|
||||
#include <faiss/gpu/utils/MatrixMult-inl.cuh>
|
||||
|
|
|
@ -97,8 +97,7 @@ StackDeviceMemory::Stack::getAlloc(size_t size,
|
|||
|
||||
if (stream != prevUser.stream_) {
|
||||
// Synchronization required
|
||||
// FIXME
|
||||
FAISS_ASSERT(false);
|
||||
streamWait({stream}, {prevUser.stream_});
|
||||
}
|
||||
|
||||
if (endAlloc < prevUser.end_) {
|
||||
|
|
|
@ -128,7 +128,7 @@ template <typename T, int Dim, bool InnerContig,
|
|||
typename IndexT, template <typename U> class PtrTraits>
|
||||
__host__ void
|
||||
Tensor<T, Dim, InnerContig, IndexT, PtrTraits>::copyFrom(
|
||||
Tensor<T, Dim, InnerContig, IndexT, PtrTraits>& t,
|
||||
const Tensor<T, Dim, InnerContig, IndexT, PtrTraits>& t,
|
||||
cudaStream_t stream) {
|
||||
// The tensor must be fully contiguous
|
||||
GPU_FAISS_ASSERT(this->isContiguous());
|
||||
|
|
|
@ -112,7 +112,7 @@ class Tensor {
|
|||
const IndexT strides[Dim]);
|
||||
|
||||
/// Copies a tensor into ourselves; sizes must match
|
||||
__host__ void copyFrom(Tensor<T, Dim, InnerContig, IndexT, PtrTraits>& t,
|
||||
__host__ void copyFrom(const Tensor<T, Dim, InnerContig, IndexT, PtrTraits>& t,
|
||||
cudaStream_t stream);
|
||||
|
||||
/// Copies ourselves into a tensor; sizes must match
|
||||
|
|
Loading…
Reference in New Issue