13 #include "../../FaissAssert.h"
15 #include "DeviceUtils.h"
20 namespace faiss {
namespace gpu {
24 static constexpr
int kMaxDims = 8;
28 int strides[kMaxDims];
32 template <
typename T,
int Dim>
34 __device__
inline static unsigned int get(
const TensorInfo<T>& info,
35 unsigned int linearId) {
36 unsigned int offset = 0;
39 for (
int i = Dim - 1; i >= 0; --i) {
40 unsigned int curDimIndex = linearId % info.sizes[i];
41 unsigned int curDimOffset = curDimIndex * info.strides[i];
43 offset += curDimOffset;
46 linearId /= info.sizes[i];
56 __device__
inline static unsigned int get(
const TensorInfo<T>& info,
57 unsigned int linearId) {
62 template <
typename T,
int Dim>
66 for (
int i = 0; i < Dim; ++i) {
77 template <
typename T,
int DimInput,
int DimOutput>
78 __global__
void transposeAny(TensorInfo<T> input,
80 unsigned int totalSize) {
81 auto linearThreadId = blockIdx.x * blockDim.x + threadIdx.x;
83 if (linearThreadId >= totalSize) {
88 TensorInfoOffset<T, DimInput>::get(input, linearThreadId);
90 TensorInfoOffset<T, DimOutput>::get(output, linearThreadId);
92 #if __CUDA_ARCH__ >= 350
93 output.data[outputOffset] = __ldg(&input.data[inputOffset]);
95 output.data[outputOffset] = input.data[inputOffset];
108 template <
typename T,
int Dim>
109 void runTransposeAny(Tensor<T, Dim, true>& in,
111 Tensor<T, Dim, true>& out,
112 cudaStream_t stream) {
113 static_assert(Dim <= TensorInfo<T>::kMaxDims,
"too many dimensions");
115 FAISS_ASSERT(dim1 != dim2);
116 FAISS_ASSERT(dim1 < Dim && dim2 < Dim);
120 for (
int i = 0; i < Dim; ++i) {
121 outSize[i] = in.getSize(i);
124 std::swap(outSize[dim1], outSize[dim2]);
126 for (
int i = 0; i < Dim; ++i) {
127 FAISS_ASSERT(out.getSize(i) == outSize[i]);
130 auto inInfo = getTensorInfo<T, Dim>(in);
131 auto outInfo = getTensorInfo<T, Dim>(out);
133 std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
134 std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
136 int totalSize = in.numElements();
138 int numThreads = std::min(getMaxThreadsCurrentDevice(), totalSize);
139 auto grid = dim3(utils::divUp(totalSize, numThreads));
140 auto block = dim3(numThreads);
142 transposeAny<T, Dim, -1><<<grid, block, 0, stream>>>(
143 inInfo, outInfo, totalSize);
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
__host__ __device__ IndexT getStride(int i) const
__host__ __device__ IndexT getSize(int i) const