Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Tensor.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #pragma once
13 
14 #include <assert.h>
15 #include <cuda.h>
16 #include <cuda_runtime.h>
17 #include <initializer_list>
18 
19 /// Multi-dimensional array class for CUDA device and host usage.
20 /// Originally from Facebook's fbcunn, since added to the Torch GPU
21 /// library cutorch as well.
22 
23 namespace faiss { namespace gpu {
24 
25 /// Our tensor type
26 template <typename T,
27  int Dim,
28  bool Contig,
29  typename IndexT,
30  template <typename U> class PtrTraits>
31 class Tensor;
32 
33 /// Type of a subspace of a tensor
34 namespace detail {
35 template <typename TensorType,
36  int SubDim,
37  template <typename U> class PtrTraits>
38 class SubTensor;
39 }
40 
41 namespace traits {
42 
43 template <typename T>
45  typedef T* __restrict__ PtrType;
46 };
47 
48 template <typename T>
50  typedef T* PtrType;
51 };
52 
53 }
54 
55 /**
56  Templated multi-dimensional array that supports strided access of
57  elements. Main access is through `operator[]`; e.g.,
58  `tensor[x][y][z]`.
59 
60  - `T` is the contained type (e.g., `float`)
61  - `Dim` is the tensor rank
62  - If `Contig` is true, then the tensor is assumed to be
63  - contiguous, and only operations that make sense on contiguous
64  - arrays are allowed (e.g., no transpose). Strides are still
65  - calculated, but innermost stride is assumed to be 1.
66  - `IndexT` is the integer type used for size/stride arrays, and for
67  - all indexing math. Default is `int`, but for large tensors, `long`
68  - can be used instead.
69  - `PtrTraits` are traits applied to our data pointer (T*). By default,
70  - this is just T*, but RestrictPtrTraits can be used to apply T*
71  - __restrict__ for alias-free analysis.
72 */
73 template <typename T,
74  int Dim,
75  bool Contig = false,
76  typename IndexT = int,
77  template <typename U> class PtrTraits = traits::DefaultPtrTraits>
78 class Tensor {
79  public:
80  enum { NumDim = Dim };
81  typedef T DataType;
82  typedef IndexT IndexType;
83  enum { IsContig = Contig };
84  typedef typename PtrTraits<T>::PtrType DataPtrType;
85  typedef Tensor<T, Dim, Contig, IndexT, PtrTraits> TensorType;
86 
87  /// Default constructor
88  __host__ __device__ Tensor();
89 
90  /// Copy constructor
91  __host__ __device__ Tensor(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t)
92  = default;
93 
94  /// Move constructor
95  __host__ __device__ Tensor(Tensor<T, Dim, Contig, IndexT, PtrTraits>&& t)
96  = default;
97 
98  /// Assignment
99  __host__ __device__ Tensor<T, Dim, Contig, IndexT, PtrTraits>&
100  operator=(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t) = default;
101 
102  /// Move assignment
103  __host__ __device__ Tensor<T, Dim, Contig, IndexT, PtrTraits>&
104  operator=(Tensor<T, Dim, Contig, IndexT, PtrTraits>&& t);
105 
106  /// Constructor that calculates strides with no padding
107  __host__ __device__ Tensor(DataPtrType data,
108  const IndexT sizes[Dim]);
109  __host__ __device__ Tensor(DataPtrType data,
110  std::initializer_list<IndexT> sizes);
111 
112  /// Constructor that takes arbitrary size/stride arrays.
113  /// Errors if you attempt to pass non-contiguous strides to a
114  /// contiguous tensor.
115  __host__ __device__ Tensor(DataPtrType data,
116  const IndexT sizes[Dim],
117  const IndexT strides[Dim]);
118 
119  /// Copies a tensor into ourselves; sizes must match
120  __host__ void copyFrom(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t,
121  cudaStream_t stream);
122 
123  /// Copies ourselves into a tensor; sizes must match
124  __host__ void copyTo(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t,
125  cudaStream_t stream);
126 
127  /// Returns true if the two tensors are of the same dimensionality,
128  /// size and stride.
129  template <int OtherDim>
130  __host__ __device__ bool
131  isSame(const Tensor<T, OtherDim, Contig, IndexT, PtrTraits>& rhs) const;
132 
133  /// Cast to a tensor of a different type of the same size and
134  /// stride. U and our type T must be of the same size
135  template <typename U>
136  __host__ __device__ Tensor<U, Dim, Contig, IndexT, PtrTraits> cast();
137 
138  /// Const version of `cast`
139  template <typename U>
140  __host__ __device__
141  const Tensor<U, Dim, Contig, IndexT, PtrTraits> cast() const;
142 
143  /// Cast to a tensor of a different type which is potentially a
144  /// different size than our type T. Tensor must be aligned and the
145  /// innermost dimension must be a size that is a multiple of
146  /// sizeof(U) / sizeof(T), and the stride of the innermost dimension
147  /// must be contiguous. The stride of all outer dimensions must be a
148  /// multiple of sizeof(U) / sizeof(T) as well.
149  template <typename U>
150  __host__ __device__ Tensor<U, Dim, Contig, IndexT, PtrTraits> castResize();
151 
152  /// Const version of `castResize`
153  template <typename U>
154  __host__ __device__ const Tensor<U, Dim, Contig, IndexT, PtrTraits>
155  castResize() const;
156 
157  /// Returns true if we can castResize() this tensor to the new type
158  template <typename U>
159  __host__ __device__ bool canCastResize() const;
160 
161  /// Returns a raw pointer to the start of our data.
162  __host__ __device__ inline DataPtrType data() {
163  return data_;
164  }
165 
166  /// Returns a raw pointer to the end of our data, assuming
167  /// continuity
168  __host__ __device__ inline DataPtrType end() {
169  return data() + numElements();
170  }
171 
172  /// Returns a raw pointer to the start of our data (const).
173  __host__ __device__ inline
174  const DataPtrType data() const {
175  return data_;
176  }
177 
178  /// Returns a raw pointer to the end of our data, assuming
179  /// continuity (const)
180  __host__ __device__ inline DataPtrType end() const {
181  return data() + numElements();
182  }
183 
184  /// Cast to a different datatype
185  template <typename U>
186  __host__ __device__ inline
187  typename PtrTraits<U>::PtrType dataAs() {
188  return reinterpret_cast<typename PtrTraits<U>::PtrType>(data_);
189  }
190 
191  /// Cast to a different datatype
192  template <typename U>
193  __host__ __device__ inline
194  const typename PtrTraits<const U>::PtrType dataAs() const {
195  return reinterpret_cast<typename PtrTraits<const U>::PtrType>(data_);
196  }
197 
198  /// Returns a read/write view of a portion of our tensor.
199  __host__ __device__ inline
200  detail::SubTensor<TensorType, Dim - 1, PtrTraits>
201  operator[](IndexT);
202 
203  /// Returns a read/write view of a portion of our tensor (const).
204  __host__ __device__ inline
205  const detail::SubTensor<TensorType, Dim - 1, PtrTraits>
206  operator[](IndexT) const;
207 
208  /// Returns the size of a given dimension, `[0, Dim - 1]`. No bounds
209  /// checking.
210  __host__ __device__ inline IndexT getSize(int i) const {
211  return size_[i];
212  }
213 
214  /// Returns the stride of a given dimension, `[0, Dim - 1]`. No bounds
215  /// checking.
216  __host__ __device__ inline IndexT getStride(int i) const {
217  return stride_[i];
218  }
219 
220  /// Returns the total number of elements contained within our data
221  /// (product of `getSize(i)`)
222  __host__ __device__ IndexT numElements() const;
223 
224  /// If we are contiguous, returns the total size in bytes of our
225  /// data
226  __host__ __device__ size_t getSizeInBytes() const {
227  return (size_t) numElements() * sizeof(T);
228  }
229 
230  /// Returns the size array.
231  __host__ __device__ inline const IndexT* sizes() const {
232  return size_;
233  }
234 
235  /// Returns the stride array.
236  __host__ __device__ inline const IndexT* strides() const {
237  return stride_;
238  }
239 
240  /// Returns true if there is no padding within the tensor and no
241  /// re-ordering of the dimensions.
242  /// ~~~
243  /// (stride(i) == size(i + 1) * stride(i + 1)) && stride(dim - 1) == 0
244  /// ~~~
245  __host__ __device__ bool isContiguous() const;
246 
247  /// Returns whether a given dimension has only increasing stride
248  /// from the previous dimension. A tensor that was permuted by
249  /// exchanging size and stride only will fail this check.
250  /// If `i == 0` just check `size > 0`. Returns `false` if `stride` is `<= 0`.
251  __host__ __device__ bool isConsistentlySized(int i) const;
252 
253  // Returns whether at each dimension `stride <= size`.
254  // If this is not the case then iterating once over the size space will
255  // touch the same memory locations multiple times.
256  __host__ __device__ bool isConsistentlySized() const;
257 
258  /// Returns true if the given dimension index has no padding
259  __host__ __device__ bool isContiguousDim(int i) const;
260 
261  /// Returns a tensor of the same dimension after transposing the two
262  /// dimensions given. Does not actually move elements; transposition
263  /// is made by permuting the size/stride arrays.
264  /// If the dimensions are not valid, asserts.
266  transpose(int dim1, int dim2) const;
267 
268  /// Upcast a tensor of dimension `D` to some tensor of dimension
269  /// D' > D by padding the leading dimensions by 1
270  /// e.g., upcasting a 2-d tensor `[2][3]` to a 4-d tensor `[1][1][2][3]`
271  template <int NewDim>
273  upcastOuter();
274 
275  /// Upcast a tensor of dimension `D` to some tensor of dimension
276  /// D' > D by padding the lowest/most varying dimensions by 1
277  /// e.g., upcasting a 2-d tensor `[2][3]` to a 4-d tensor `[2][3][1][1]`
278  template <int NewDim>
280  upcastInner();
281 
282  /// Downcast a tensor of dimension `D` to some tensor of dimension
283  /// D' < D by collapsing the leading dimensions. asserts if there is
284  /// padding on the leading dimensions.
285  template <int NewDim>
286  __host__ __device__
288 
289  /// Downcast a tensor of dimension `D` to some tensor of dimension
290  /// D' < D by collapsing the leading dimensions. asserts if there is
291  /// padding on the leading dimensions.
292  template <int NewDim>
293  __host__ __device__
295 
296  /// Returns a tensor that is a view of the `SubDim`-dimensional slice
297  /// of this tensor, starting at `at`.
298  template <int SubDim>
300  view(DataPtrType at);
301 
302  /// Returns a tensor that is a view of the `SubDim`-dimensional slice
303  /// of this tensor, starting where our data begins
304  template <int SubDim>
306  view();
307 
308  /// Returns a tensor of the same dimension that is a view of the
309  /// original tensor with the specified dimension restricted to the
310  /// elements in the range [start, start + size)
312  narrowOutermost(IndexT start, IndexT size);
313 
314  /// Returns a tensor of the same dimension that is a view of the
315  /// original tensor with the specified dimension restricted to the
316  /// elements in the range [start, start + size).
317  /// Can occur in an arbitrary dimension, and is possibly
318  /// non-contiguous
319  __host__ __device__ Tensor<T, Dim, false, IndexT, PtrTraits>
320  narrow(int dim, IndexT start, IndexT size);
321 
322  /// Returns a view of the given tensor expressed as a tensor of a
323  /// different number of dimensions.
324  /// Only works if we are contiguous.
325  template <int NewDim>
327  view(std::initializer_list<IndexT> sizes);
328 
329  protected:
330  /// Raw pointer to where the tensor data begins
331  DataPtrType data_;
332 
333  /// Array of strides (in sizeof(T) terms) per each dimension
334  IndexT stride_[Dim];
335 
336  /// Size per each dimension
337  IndexT size_[Dim];
338 };
339 
340 namespace detail {
341 
342 /// Specialization for a view of a single value (0-dimensional)
343 template <typename TensorType, template <typename U> class PtrTraits>
344 class SubTensor<TensorType, 0, PtrTraits> {
345  public:
346  __host__ __device__ SubTensor<TensorType, 0, PtrTraits>
347  operator=(typename TensorType::DataType val) {
348  *data_ = val;
349  return *this;
350  }
351 
352  // operator T&
353  __host__ __device__ operator typename TensorType::DataType&() {
354  return *data_;
355  }
356 
357  // const operator T& returning const T&
358  __host__ __device__ operator const typename TensorType::DataType&() const {
359  return *data_;
360  }
361 
362  // operator& returning T*
363  __host__ __device__ typename TensorType::DataType* operator&() {
364  return data_;
365  }
366 
367  // const operator& returning const T*
368  __host__ __device__ const typename TensorType::DataType* operator&() const {
369  return data_;
370  }
371 
372  /// Returns a raw accessor to our slice.
373  __host__ __device__ inline typename TensorType::DataPtrType data() {
374  return data_;
375  }
376 
377  /// Returns a raw accessor to our slice (const).
378  __host__ __device__ inline
379  const typename TensorType::DataPtrType data() const {
380  return data_;
381  }
382 
383  /// Cast to a different datatype.
384  template <typename T>
385  __host__ __device__ T& as() {
386  return *dataAs<T>();
387  }
388 
389  /// Cast to a different datatype (const).
390  template <typename T>
391  __host__ __device__ const T& as() const {
392  return *dataAs<T>();
393  }
394 
395  /// Cast to a different datatype
396  template <typename T>
397  __host__ __device__ inline
398  typename PtrTraits<T>::PtrType dataAs() {
399  return reinterpret_cast<typename PtrTraits<T>::PtrType>(data_);
400  }
401 
402  /// Cast to a different datatype (const)
403  template <typename T>
404  __host__ __device__ inline
405  typename PtrTraits<const T>::PtrType dataAs() const {
406  return reinterpret_cast<typename PtrTraits<const T>::PtrType>(data_);
407  }
408 
409  /// Use the texture cache for reads
410  __device__ inline typename TensorType::DataType ldg() const {
411 #if __CUDA_ARCH__ >= 350
412  return __ldg(data_);
413 #else
414  return *data_;
415 #endif
416  }
417 
418  /// Use the texture cache for reads; cast as a particular type
419  template <typename T>
420  __device__ inline T ldgAs() const {
421 #if __CUDA_ARCH__ >= 350
422  return __ldg(dataAs<T>());
423 #else
424  return as<T>();
425 #endif
426  }
427 
428  protected:
429  /// One dimension greater can create us
430  friend class SubTensor<TensorType, 1, PtrTraits>;
431 
432  /// Our parent tensor can create us
433  friend class Tensor<typename TensorType::DataType,
434  1,
435  TensorType::IsContig,
436  typename TensorType::IndexType,
437  PtrTraits>;
438 
439  __host__ __device__ inline SubTensor(
440  TensorType& t,
441  typename TensorType::DataPtrType data)
442  : tensor_(t),
443  data_(data) {
444  }
445 
446  /// The tensor we're referencing
447  TensorType& tensor_;
448 
449  /// Where our value is located
450  typename TensorType::DataPtrType const data_;
451 };
452 
453 /// A `SubDim`-rank slice of a parent Tensor
454 template <typename TensorType,
455  int SubDim,
456  template <typename U> class PtrTraits>
457 class SubTensor {
458  public:
459  /// Returns a view of the data located at our offset (the dimension
460  /// `SubDim` - 1 tensor).
461  __host__ __device__ inline
462  SubTensor<TensorType, SubDim - 1, PtrTraits>
463  operator[](typename TensorType::IndexType index) {
464  if (TensorType::IsContig && SubDim == 1) {
465  // Innermost dimension is stride 1 for contiguous arrays
466  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
467  tensor_, data_ + index);
468  } else {
469  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
470  tensor_,
471  data_ + index * tensor_.getStride(TensorType::NumDim - SubDim));
472  }
473  }
474 
475  /// Returns a view of the data located at our offset (the dimension
476  /// `SubDim` - 1 tensor) (const).
477  __host__ __device__ inline
478  const SubTensor<TensorType, SubDim - 1, PtrTraits>
479  operator[](typename TensorType::IndexType index) const {
480  if (TensorType::IsContig && SubDim == 1) {
481  // Innermost dimension is stride 1 for contiguous arrays
482  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
483  tensor_, data_ + index);
484  } else {
485  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
486  tensor_,
487  data_ + index * tensor_.getStride(TensorType::NumDim - SubDim));
488  }
489  }
490 
491  // operator& returning T*
492  __host__ __device__ typename TensorType::DataType* operator&() {
493  return data_;
494  }
495 
496  // const operator& returning const T*
497  __host__ __device__ const typename TensorType::DataType* operator&() const {
498  return data_;
499  }
500 
501  /// Returns a raw accessor to our slice.
502  __host__ __device__ inline typename TensorType::DataPtrType data() {
503  return data_;
504  }
505 
506  /// Returns a raw accessor to our slice (const).
507  __host__ __device__ inline
508  const typename TensorType::DataPtrType data() const {
509  return data_;
510  }
511 
512  /// Cast to a different datatype.
513  template <typename T>
514  __host__ __device__ T& as() {
515  return *dataAs<T>();
516  }
517 
518  /// Cast to a different datatype (const).
519  template <typename T>
520  __host__ __device__ const T& as() const {
521  return *dataAs<T>();
522  }
523 
524  /// Cast to a different datatype
525  template <typename T>
526  __host__ __device__ inline
527  typename PtrTraits<T>::PtrType dataAs() {
528  return reinterpret_cast<typename PtrTraits<T>::PtrType>(data_);
529  }
530 
531  /// Cast to a different datatype (const)
532  template <typename T>
533  __host__ __device__ inline
534  typename PtrTraits<const T>::PtrType dataAs() const {
535  return reinterpret_cast<typename PtrTraits<const T>::PtrType>(data_);
536  }
537 
538  /// Use the texture cache for reads
539  __device__ inline typename TensorType::DataType ldg() const {
540 #if __CUDA_ARCH__ >= 350
541  return __ldg(data_);
542 #else
543  return *data_;
544 #endif
545  }
546 
547  /// Use the texture cache for reads; cast as a particular type
548  template <typename T>
549  __device__ inline T ldgAs() const {
550 #if __CUDA_ARCH__ >= 350
551  return __ldg(dataAs<T>());
552 #else
553  return as<T>();
554 #endif
555  }
556 
557  /// Returns a tensor that is a view of the SubDim-dimensional slice
558  /// of this tensor, starting where our data begins
559  Tensor<typename TensorType::DataType,
560  SubDim,
561  TensorType::IsContig,
562  typename TensorType::IndexType,
563  PtrTraits> view() {
564  return tensor_.template view<SubDim>(data_);
565  }
566 
567  protected:
568  /// One dimension greater can create us
569  friend class SubTensor<TensorType, SubDim + 1, PtrTraits>;
570 
571  /// Our parent tensor can create us
572  friend class
573  Tensor<typename TensorType::DataType,
574  TensorType::NumDim,
575  TensorType::IsContig,
576  typename TensorType::IndexType,
577  PtrTraits>;
578 
579  __host__ __device__ inline SubTensor(
580  TensorType& t,
581  typename TensorType::DataPtrType data)
582  : tensor_(t),
583  data_(data) {
584  }
585 
586  /// The tensor we're referencing
587  TensorType& tensor_;
588 
589  /// The start of our sub-region
590  typename TensorType::DataPtrType const data_;
591 };
592 
593 } // namespace detail
594 
595 template <typename T, int Dim, bool Contig,
596  typename IndexT, template <typename U> class PtrTraits>
597 __host__ __device__ inline
599  Dim - 1, PtrTraits>
601  return detail::SubTensor<TensorType, Dim - 1, PtrTraits>(
603  *this, data_)[index]);
604 }
605 
606 template <typename T, int Dim, bool Contig,
607  typename IndexT, template <typename U> class PtrTraits>
608 __host__ __device__ inline
610  Dim - 1, PtrTraits>
612  return detail::SubTensor<TensorType, Dim - 1, PtrTraits>(
614  const_cast<TensorType&>(*this), data_)[index]);
615 }
616 
617 } } // namespace
618 
619 #include "Tensor-inl.cuh"
__host__ __device__ Tensor()
Default constructor.
Definition: Tensor-inl.cuh:20
__host__ __device__ const PtrTraits< const U >::PtrType dataAs() const
Cast to a different datatype.
Definition: Tensor.cuh:194
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastInner()
Definition: Tensor-inl.cuh:483
__host__ __device__ PtrTraits< const T >::PtrType dataAs() const
Cast to a different datatype (const)
Definition: Tensor.cuh:534
__host__ __device__ bool isContiguousDim(int i) const
Returns true if the given dimension index has no padding.
Definition: Tensor-inl.cuh:344
__device__ T ldgAs() const
Use the texture cache for reads; cast as a particular type.
Definition: Tensor.cuh:549
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > transpose(int dim1, int dim2) const
Definition: Tensor-inl.cuh:353
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:162
__host__ __device__ PtrTraits< T >::PtrType dataAs()
Cast to a different datatype.
Definition: Tensor.cuh:527
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > narrowOutermost(IndexT start, IndexT size)
Definition: Tensor-inl.cuh:553
__host__ __device__ PtrTraits< const T >::PtrType dataAs() const
Cast to a different datatype (const)
Definition: Tensor.cuh:405
DataPtrType data_
Raw pointer to where the tensor data begins.
Definition: Tensor.cuh:331
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > & operator=(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t)=default
Assignment.
__host__ __device__ const TensorType::DataPtrType data() const
Returns a raw accessor to our slice (const).
Definition: Tensor.cuh:508
__device__ TensorType::DataType ldg() const
Use the texture cache for reads.
Definition: Tensor.cuh:410
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastOuter()
Definition: Tensor-inl.cuh:382
__host__ __device__ T & as()
Cast to a different datatype.
Definition: Tensor.cuh:385
__host__ __device__ T & as()
Cast to a different datatype.
Definition: Tensor.cuh:514
TensorType & tensor_
The tensor we&#39;re referencing.
Definition: Tensor.cuh:447
__host__ __device__ Tensor< T, Dim, false, IndexT, PtrTraits > narrow(int dim, IndexT start, IndexT size)
Definition: Tensor-inl.cuh:577
__device__ TensorType::DataType ldg() const
Use the texture cache for reads.
Definition: Tensor.cuh:539
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > cast()
Definition: Tensor-inl.cuh:203
__host__ void copyTo(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies ourselves into a tensor; sizes must match.
Definition: Tensor-inl.cuh:139
TensorType::DataPtrType const data_
The start of our sub-region.
Definition: Tensor.cuh:589
__host__ __device__ bool isSame(const Tensor< T, OtherDim, Contig, IndexT, PtrTraits > &rhs) const
Definition: Tensor-inl.cuh:178
__host__ __device__ IndexT numElements() const
Definition: Tensor-inl.cuh:285
__device__ T ldgAs() const
Use the texture cache for reads; cast as a particular type.
Definition: Tensor.cuh:420
__host__ __device__ const SubTensor< TensorType, SubDim-1, PtrTraits > operator[](typename TensorType::IndexType index) const
Definition: Tensor.cuh:479
Tensor< typename TensorType::DataType, SubDim, TensorType::IsContig, typename TensorType::IndexType, PtrTraits > view()
Definition: Tensor.cuh:563
TensorType::DataPtrType const data_
Where our value is located.
Definition: Tensor.cuh:450
__host__ __device__ const IndexT * strides() const
Returns the stride array.
Definition: Tensor.cuh:236
__host__ __device__ IndexT getStride(int i) const
Definition: Tensor.cuh:216
Our tensor type.
Definition: Tensor.cuh:31
__host__ __device__ const IndexT * sizes() const
Returns the size array.
Definition: Tensor.cuh:231
__host__ __device__ PtrTraits< U >::PtrType dataAs()
Cast to a different datatype.
Definition: Tensor.cuh:187
__host__ __device__ size_t getSizeInBytes() const
Definition: Tensor.cuh:226
__host__ __device__ DataPtrType end()
Definition: Tensor.cuh:168
Specialization for a view of a single value (0-dimensional)
Definition: Tensor.cuh:344
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastInner()
Definition: Tensor-inl.cuh:411
__host__ __device__ Tensor< T, SubDim, Contig, IndexT, PtrTraits > view()
Definition: Tensor-inl.cuh:546
TensorType & tensor_
The tensor we&#39;re referencing.
Definition: Tensor.cuh:586
__host__ __device__ DataPtrType end() const
Definition: Tensor.cuh:180
__host__ __device__ const TensorType::DataPtrType data() const
Returns a raw accessor to our slice (const).
Definition: Tensor.cuh:379
__host__ __device__ SubTensor< TensorType, SubDim-1, PtrTraits > operator[](typename TensorType::IndexType index)
Definition: Tensor.cuh:463
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:210
__host__ void copyFrom(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies a tensor into ourselves; sizes must match.
Definition: Tensor-inl.cuh:101
__host__ __device__ const T & as() const
Cast to a different datatype (const).
Definition: Tensor.cuh:520
A SubDim-rank slice of a parent Tensor.
Definition: Tensor.cuh:38
__host__ __device__ PtrTraits< T >::PtrType dataAs()
Cast to a different datatype.
Definition: Tensor.cuh:398
__host__ __device__ TensorType::DataPtrType data()
Returns a raw accessor to our slice.
Definition: Tensor.cuh:502
IndexT stride_[Dim]
Array of strides (in sizeof(T) terms) per each dimension.
Definition: Tensor.cuh:334
__host__ __device__ bool isContiguous() const
Definition: Tensor-inl.cuh:298
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastOuter()
Definition: Tensor-inl.cuh:438
IndexT size_[Dim]
Size per each dimension.
Definition: Tensor.cuh:337
__host__ __device__ TensorType::DataPtrType data()
Returns a raw accessor to our slice.
Definition: Tensor.cuh:373
__host__ __device__ detail::SubTensor< TensorType, Dim-1, PtrTraits > operator[](IndexT)
Returns a read/write view of a portion of our tensor.
Definition: Tensor.cuh:600
__host__ __device__ bool canCastResize() const
Returns true if we can castResize() this tensor to the new type.
Definition: Tensor-inl.cuh:259
__host__ __device__ const T & as() const
Cast to a different datatype (const).
Definition: Tensor.cuh:391
__host__ __device__ const DataPtrType data() const
Returns a raw pointer to the start of our data (const).
Definition: Tensor.cuh:174
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > castResize()
Definition: Tensor-inl.cuh:225