Improve transpose performance

Summary:
The general array transposition kernel for the GPU in Faiss had two issues.

One, there was a typo (`+` instead of `*`) which did not cause a correctness bug but was a severe performance issue, since the general transposition kernel was written in 2016/2017. This was causing large slowdowns with precomputed code usage that I noticed while profiling over IVFPQ issues.

Two, the general transposition code was written for the most generic case. The transposition that we care about/use the most in Faiss is a transposition of outermost dimensions, say transposing an array [s1 s2 s3] -> [s2 s1 s3], where there are one or more innermost dimensions which are still contiguous in the new layout. A separate kernel has been written to cover this transposition case.

Also updates the code to avoid `unsigned int` and `unsigned long` in lieu of `uint32_t` and `uint64_t`.

D25703821 (removing serialize tags for GPU tests) is reverted in this as well, as that change prevents all GPU tests from being run locally on devservers; RE might have implicit serialization, but local execution doesn't.

Reviewed By: beauby

Differential Revision: D25929892

fbshipit-source-id: 66ddfc56189305f698a85c44abdeb64eb95ffe6b
This commit is contained in:
Jeff Johnson 2021-01-19 13:20:40 -08:00 committed by Facebook GitHub Bot
parent 010b05712c
commit d8b64b5122

View File

@ -13,6 +13,7 @@
#include <faiss/gpu/utils/DeviceUtils.h>
#include <faiss/gpu/utils/StaticUtils.h>
#include <cuda.h>
#include <stdint.h>
namespace faiss { namespace gpu {
@ -77,7 +78,7 @@ __global__ void transposeAny(TensorInfo<T, IndexT> input,
IndexT totalSize) {
for (IndexT i = blockIdx.x * blockDim.x + threadIdx.x;
i < totalSize;
i += gridDim.x + blockDim.x) {
i += gridDim.x * blockDim.x) {
auto inputOffset = TensorInfoOffset<T, IndexT, DimInput>::get(input, i);
auto outputOffset = TensorInfoOffset<T, IndexT, DimOutput>::get(output, i);
@ -89,6 +90,22 @@ __global__ void transposeAny(TensorInfo<T, IndexT> input,
}
}
// Transpose contiguous t1 t2 i1 -> t2 t1 i1
template <typename T, typename IndexT>
__global__ void transposeOuter(const T* in,
T* out,
IndexT t1, IndexT t2, IndexT i1) {
IndexT gt1 = blockIdx.y;
IndexT gt2 = blockIdx.x;
in += i1 * (gt1 * t2 + gt2);
out += i1 * (gt2 * t1 + gt1);
for (IndexT i = threadIdx.x; i < i1; i += blockDim.x) {
out[i] = in[i];
}
}
/// Performs an out-of-place transposition between any two dimensions.
/// Best performance is if the transposed dimensions are not
/// innermost, since the reads and writes will be coalesced.
@ -109,6 +126,12 @@ void runTransposeAny(Tensor<T, Dim, true>& in,
FAISS_ASSERT(dim1 != dim2);
FAISS_ASSERT(dim1 < Dim && dim2 < Dim);
// Rearrange dim1 and dim2 in increasing order in order to see if this is an
// outer dimension transposition (below)
if (dim1 > dim2) {
std::swap(dim1, dim2);
}
int outSize[Dim];
for (int i = 0; i < Dim; ++i) {
@ -121,33 +144,66 @@ void runTransposeAny(Tensor<T, Dim, true>& in,
FAISS_ASSERT(out.getSize(i) == outSize[i]);
}
size_t totalSize = in.numElements();
size_t block = std::min((size_t) getMaxThreadsCurrentDevice(), totalSize);
auto maxThreads = getMaxThreadsCurrentDevice();
auto totalSize = in.numElements();
if (totalSize <= (size_t) std::numeric_limits<int>::max()) {
// div/mod seems faster with unsigned types
auto inInfo = getTensorInfo<T, unsigned int, Dim>(in);
auto outInfo = getTensorInfo<T, unsigned int, Dim>(out);
// Is this a transposition of the two outer dimensions?
bool isTransposeOuter = (Dim >= 3) && (dim1 == 0) && (dim2 == 1);
if (isTransposeOuter) {
// Outer dimension transposition only (there is a contiguous inner
// dimension)
size_t innerSize = 1;
for (int i = 2; i < Dim; ++i) {
innerSize *= in.getSize(i);
}
std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
auto grid = dim3(in.getSize(1), in.getSize(0));
int block = (innerSize < maxThreads) ? innerSize : maxThreads;
auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096);
transposeAny<T, unsigned int, Dim, -1>
<<<grid, block, 0, stream>>>(inInfo, outInfo, totalSize);
if (totalSize <= (size_t) std::numeric_limits<int>::max()) {
transposeOuter<T, int32_t><<<grid, block, 0, stream>>>(in.data(),
out.data(),
in.getSize(0),
in.getSize(1),
innerSize);
} else {
transposeOuter<T, int64_t><<<grid, block, 0, stream>>>(in.data(),
out.data(),
in.getSize(0),
in.getSize(1),
innerSize);
}
} else {
auto inInfo = getTensorInfo<T, unsigned long, Dim>(in);
auto outInfo = getTensorInfo<T, unsigned long, Dim>(out);
int block = (totalSize < maxThreads) ? totalSize : maxThreads;
std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
// Non-outer transposition
if (totalSize <= (size_t) std::numeric_limits<int>::max()) {
// General transposition
// div/mod seems faster with unsigned types
auto inInfo = getTensorInfo<T, uint32_t, Dim>(in);
auto outInfo = getTensorInfo<T, uint32_t, Dim>(out);
auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096);
std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
transposeAny<T, unsigned long, Dim, -1>
<<<grid, block, 0, stream>>>(inInfo, outInfo, totalSize);
auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096);
transposeAny<T, uint32_t, Dim, -1>
<<<grid, block, 0, stream>>>(inInfo, outInfo, totalSize);
} else {
auto inInfo = getTensorInfo<T, uint64_t, Dim>(in);
auto outInfo = getTensorInfo<T, uint64_t, Dim>(out);
std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096);
transposeAny<T, uint64_t, Dim, -1>
<<<grid, block, 0, stream>>>(inInfo, outInfo, totalSize);
}
}
CUDA_TEST_ERROR();
}