Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Transpose.cuh
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #pragma once
11 
12 #include "../../FaissAssert.h"
13 #include "Tensor.cuh"
14 #include "DeviceUtils.h"
15 #include <cuda.h>
16 
17 #include <stdio.h>
18 
19 namespace faiss { namespace gpu {
20 
21 template <typename T, typename IndexT>
22 struct TensorInfo {
23  static constexpr int kMaxDims = 8;
24 
25  T* data;
26  IndexT sizes[kMaxDims];
27  IndexT strides[kMaxDims];
28  int dims;
29 };
30 
31 template <typename T, typename IndexT, int Dim>
33  __device__ inline static unsigned int get(const TensorInfo<T, IndexT>& info,
34  IndexT linearId) {
35  IndexT offset = 0;
36 
37 #pragma unroll
38  for (int i = Dim - 1; i >= 0; --i) {
39  IndexT curDimIndex = linearId % info.sizes[i];
40  IndexT curDimOffset = curDimIndex * info.strides[i];
41 
42  offset += curDimOffset;
43 
44  if (i > 0) {
45  linearId /= info.sizes[i];
46  }
47  }
48 
49  return offset;
50  }
51 };
52 
53 template <typename T, typename IndexT>
54 struct TensorInfoOffset<T, IndexT, -1> {
55  __device__ inline static unsigned int get(const TensorInfo<T, IndexT>& info,
56  IndexT linearId) {
57  return linearId;
58  }
59 };
60 
61 template <typename T, typename IndexT, int Dim>
62 TensorInfo<T, IndexT> getTensorInfo(const Tensor<T, Dim, true>& t) {
64 
65  for (int i = 0; i < Dim; ++i) {
66  info.sizes[i] = (IndexT) t.getSize(i);
67  info.strides[i] = (IndexT) t.getStride(i);
68  }
69 
70  info.data = t.data();
71  info.dims = Dim;
72 
73  return info;
74 }
75 
76 template <typename T, typename IndexT, int DimInput, int DimOutput>
77 __global__ void transposeAny(TensorInfo<T, IndexT> input,
78  TensorInfo<T, IndexT> output,
79  IndexT totalSize) {
80  for (IndexT i = blockIdx.x * blockDim.x + threadIdx.x;
81  i < totalSize;
82  i += gridDim.x + blockDim.x) {
83  auto inputOffset = TensorInfoOffset<T, IndexT, DimInput>::get(input, i);
84  auto outputOffset = TensorInfoOffset<T, IndexT, DimOutput>::get(output, i);
85 
86 #if __CUDA_ARCH__ >= 350
87  output.data[outputOffset] = __ldg(&input.data[inputOffset]);
88 #else
89  output.data[outputOffset] = input.data[inputOffset];
90 #endif
91  }
92 }
93 
94 /// Performs an out-of-place transposition between any two dimensions.
95 /// Best performance is if the transposed dimensions are not
96 /// innermost, since the reads and writes will be coalesced.
97 /// Could include a shared memory transposition if the dimensions
98 /// being transposed are innermost, but would require support for
99 /// arbitrary rectangular matrices.
100 /// This linearized implementation seems to perform well enough,
101 /// especially for cases that we care about (outer dimension
102 /// transpositions).
103 template <typename T, int Dim>
104 void runTransposeAny(Tensor<T, Dim, true>& in,
105  int dim1, int dim2,
106  Tensor<T, Dim, true>& out,
107  cudaStream_t stream) {
108  static_assert(Dim <= TensorInfo<T, unsigned int>::kMaxDims,
109  "too many dimensions");
110 
111  FAISS_ASSERT(dim1 != dim2);
112  FAISS_ASSERT(dim1 < Dim && dim2 < Dim);
113 
114  int outSize[Dim];
115 
116  for (int i = 0; i < Dim; ++i) {
117  outSize[i] = in.getSize(i);
118  }
119 
120  std::swap(outSize[dim1], outSize[dim2]);
121 
122  for (int i = 0; i < Dim; ++i) {
123  FAISS_ASSERT(out.getSize(i) == outSize[i]);
124  }
125 
126  size_t totalSize = in.numElements();
127  size_t block = std::min((size_t) getMaxThreadsCurrentDevice(), totalSize);
128 
129  if (totalSize <= (size_t) std::numeric_limits<int>::max()) {
130  // div/mod seems faster with unsigned types
131  auto inInfo = getTensorInfo<T, unsigned int, Dim>(in);
132  auto outInfo = getTensorInfo<T, unsigned int, Dim>(out);
133 
134  std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
135  std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
136 
137  auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096);
138 
139  transposeAny<T, unsigned int, Dim, -1>
140  <<<grid, block, 0, stream>>>(inInfo, outInfo, totalSize);
141  } else {
142  auto inInfo = getTensorInfo<T, unsigned long, Dim>(in);
143  auto outInfo = getTensorInfo<T, unsigned long, Dim>(out);
144 
145  std::swap(inInfo.sizes[dim1], inInfo.sizes[dim2]);
146  std::swap(inInfo.strides[dim1], inInfo.strides[dim2]);
147 
148  auto grid = std::min(utils::divUp(totalSize, block), (size_t) 4096);
149 
150  transposeAny<T, unsigned long, Dim, -1>
151  <<<grid, block, 0, stream>>>(inInfo, outInfo, totalSize);
152  }
153  CUDA_TEST_ERROR();
154 }
155 
156 } } // namespace
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:223
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:175
Our tensor type.
Definition: Tensor.cuh:29
__host__ __device__ IndexT getStride(int i) const
Definition: Tensor.cuh:229