14 #include "../../FaissAssert.h"
16 #include "DeviceUtils.h"
21 namespace faiss {
namespace gpu {
25 static constexpr
int kMaxDims = 8;
29 int strides[kMaxDims];
33 template <
typename T,
int Dim>
35 __device__
inline static unsigned int get(
const TensorInfo<T>& info,
36 unsigned int linearId) {
37 unsigned int offset = 0;
40 for (
int i = Dim - 1; i >= 0; --i) {
41 unsigned int curDimIndex = linearId % info.sizes[i];
42 unsigned int curDimOffset = curDimIndex * info.strides[i];
44 offset += curDimOffset;
47 linearId /= info.sizes[i];
57 __device__
inline static unsigned int get(
const TensorInfo<T>& info,
58 unsigned int linearId) {
63 template <
typename T,
int Dim>
67 for (
int i = 0; i < Dim; ++i) {
78 template <
typename T,
int DimInput,
int DimOutput>
79 __global__
void transposeAny(TensorInfo<T> input,
81 unsigned int totalSize) {
82 auto linearThreadId = blockIdx.x * blockDim.x + threadIdx.x;
84 if (linearThreadId >= totalSize) {
89 TensorInfoOffset<T, DimInput>::get(input, linearThreadId);
91 TensorInfoOffset<T, DimOutput>::get(output, linearThreadId);
93 output.data[outputOffset] = __ldg(&input.data[inputOffset]);
105 template <
typename T,
int Dim>
106 void runTransposeAny(Tensor<T, Dim, true>& in,
108 Tensor<T, Dim, true>& out,
109 cudaStream_t stream) {
110 static_assert(Dim <= TensorInfo<T>::kMaxDims,
"too many dimensions");
112 FAISS_ASSERT(dim1 != dim2);
113 FAISS_ASSERT(dim1 < Dim && dim2 < Dim);
117 for (
int i = 0; i < Dim; ++i) {
118 outSize[i] = in.getSize(i);
121 std::swap(outSize[dim1], outSize[dim2]);
123 for (
int i = 0; i < Dim; ++i) {
124 FAISS_ASSERT(out.getSize(i) == outSize[i]);
127 auto inInfo = getTensorInfo<T, Dim>(in);
128 auto outInfo = getTensorInfo<T, Dim>(out);
130 std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
131 std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
133 int totalSize = in.numElements();
135 int numThreads = std::min(getMaxThreadsCurrentDevice(), totalSize);
136 auto grid = dim3(utils::divUp(totalSize, numThreads));
137 auto block = dim3(numThreads);
139 transposeAny<T, Dim, -1><<<grid, block, 0, stream>>>(inInfo, outInfo, totalSize);
140 CUDA_VERIFY(cudaGetLastError());
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
__host__ __device__ IndexT getStride(int i) const
__host__ __device__ IndexT getSize(int i) const