Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
Tensor.cuh
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #pragma once
13 
14 #include <assert.h>
15 #include <cuda.h>
16 #include <cuda_runtime.h>
17 #include <initializer_list>
18 
19 /// Multi-dimensional array class for CUDA device and host usage.
20 /// Originally from Facebook's fbcunn, since added to the Torch GPU
21 /// library cutorch as well.
22 
23 namespace faiss { namespace gpu {
24 
25 /// Our tensor type
26 template <typename T,
27  int Dim,
28  bool Contig,
29  typename IndexT,
30  template <typename U> class PtrTraits>
31 class Tensor;
32 
33 /// Type of a subspace of a tensor
34 namespace detail {
35 template <typename TensorType,
36  int SubDim,
37  template <typename U> class PtrTraits>
38 class SubTensor;
39 }
40 
41 namespace traits {
42 
43 template <typename T>
45  typedef T* __restrict__ PtrType;
46 };
47 
48 template <typename T>
50  typedef T* PtrType;
51 };
52 
53 }
54 
55 /**
56  Templated multi-dimensional array that supports strided access of
57  elements. Main access is through `operator[]`; e.g.,
58  `tensor[x][y][z]`.
59 
60  - `T` is the contained type (e.g., `float`)
61  - `Dim` is the tensor rank
62  - If `Contig` is true, then the tensor is assumed to be
63  - contiguous, and only operations that make sense on contiguous
64  - arrays are allowed (e.g., no transpose). Strides are still
65  - calculated, but innermost stride is assumed to be 1.
66  - `IndexT` is the integer type used for size/stride arrays, and for
67  - all indexing math. Default is `int`, but for large tensors, `long`
68  - can be used instead.
69  - `PtrTraits` are traits applied to our data pointer (T*). By default,
70  - this is just T*, but RestrictPtrTraits can be used to apply T*
71  - __restrict__ for alias-free analysis.
72 */
73 template <typename T,
74  int Dim,
75  bool Contig = false,
76  typename IndexT = int,
77  template <typename U> class PtrTraits = traits::DefaultPtrTraits>
78 class Tensor {
79  public:
80  enum { NumDim = Dim };
81  typedef T DataType;
82  typedef IndexT IndexType;
83  enum { IsContig = Contig };
84  typedef typename PtrTraits<T>::PtrType DataPtrType;
85  typedef Tensor<T, Dim, Contig, IndexT, PtrTraits> TensorType;
86 
87  /// Default constructor
88  __host__ __device__ Tensor();
89 
90  /// Copy constructor
91  __host__ __device__ Tensor(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t)
92  = default;
93 
94  /// Move constructor
95  __host__ __device__ Tensor(Tensor<T, Dim, Contig, IndexT, PtrTraits>&& t)
96  = default;
97 
98  /// Assignment
99  __host__ __device__ Tensor<T, Dim, Contig, IndexT, PtrTraits>&
100  operator=(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t) = default;
101 
102  /// Move assignment
103  __host__ __device__ Tensor<T, Dim, Contig, IndexT, PtrTraits>&
104  operator=(Tensor<T, Dim, Contig, IndexT, PtrTraits>&& t);
105 
106  /// Constructor that calculates strides with no padding
107  __host__ __device__ Tensor(DataPtrType data,
108  const IndexT sizes[Dim]);
109  __host__ __device__ Tensor(DataPtrType data,
110  std::initializer_list<IndexT> sizes);
111 
112  /// Constructor that takes arbitrary size/stride arrays.
113  /// Errors if you attempt to pass non-contiguous strides to a
114  /// contiguous tensor.
115  __host__ __device__ Tensor(DataPtrType data,
116  const IndexT sizes[Dim],
117  const IndexT strides[Dim]);
118 
119  /// Copies a tensor into ourselves; sizes must match
120  __host__ void copyFrom(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t,
121  cudaStream_t stream);
122 
123  /// Copies ourselves into a tensor; sizes must match
124  __host__ void copyTo(Tensor<T, Dim, Contig, IndexT, PtrTraits>& t,
125  cudaStream_t stream);
126 
127  /// Returns true if the two tensors are of the same dimensionality,
128  /// size and stride.
129  template <int OtherDim>
130  __host__ __device__ bool
131  isSame(const Tensor<T, OtherDim, Contig, IndexT, PtrTraits>& rhs) const;
132 
133  /// Cast to a tensor of a different type of the same size and
134  /// stride. U and our type T must be of the same size
135  template <typename U>
136  __host__ __device__ Tensor<U, Dim, Contig, IndexT, PtrTraits> cast();
137 
138  /// Const version of `cast`
139  template <typename U>
140  __host__ __device__
141  const Tensor<U, Dim, Contig, IndexT, PtrTraits> cast() const;
142 
143  /// Cast to a tensor of a different type which is potentially a
144  /// different size than our type T. Tensor must be aligned and the
145  /// innermost dimension must be a size that is a multiple of
146  /// sizeof(U) / sizeof(T), and the stride of the innermost dimension
147  /// must be contiguous. The stride of all outer dimensions must be a
148  /// multiple of sizeof(U) / sizeof(T) as well.
149  template <typename U>
150  __host__ __device__ Tensor<U, Dim, Contig, IndexT, PtrTraits> castResize();
151 
152  /// Const version of `castResize`
153  template <typename U>
154  __host__ __device__ const Tensor<U, Dim, Contig, IndexT, PtrTraits>
155  castResize() const;
156 
157  /// Returns true if we can castResize() this tensor to the new type
158  template <typename U>
159  __host__ __device__ bool canCastResize() const;
160 
161  /// Attempts to cast this tensor to a tensor of a different IndexT.
162  /// Fails if size or stride entries are not representable in the new
163  /// IndexT.
164  template <typename NewIndexT>
165  __host__ Tensor<T, Dim, Contig, NewIndexT, PtrTraits>
166  castIndexType() const;
167 
168  /// Returns true if we can castIndexType() this tensor to the new
169  /// index type
170  template <typename NewIndexT>
171  __host__ bool canCastIndexType() const;
172 
173  /// Returns a raw pointer to the start of our data.
174  __host__ __device__ inline DataPtrType data() {
175  return data_;
176  }
177 
178  /// Returns a raw pointer to the end of our data, assuming
179  /// continuity
180  __host__ __device__ inline DataPtrType end() {
181  return data() + numElements();
182  }
183 
184  /// Returns a raw pointer to the start of our data (const).
185  __host__ __device__ inline
186  const DataPtrType data() const {
187  return data_;
188  }
189 
190  /// Returns a raw pointer to the end of our data, assuming
191  /// continuity (const)
192  __host__ __device__ inline DataPtrType end() const {
193  return data() + numElements();
194  }
195 
196  /// Cast to a different datatype
197  template <typename U>
198  __host__ __device__ inline
199  typename PtrTraits<U>::PtrType dataAs() {
200  return reinterpret_cast<typename PtrTraits<U>::PtrType>(data_);
201  }
202 
203  /// Cast to a different datatype
204  template <typename U>
205  __host__ __device__ inline
206  const typename PtrTraits<const U>::PtrType dataAs() const {
207  return reinterpret_cast<typename PtrTraits<const U>::PtrType>(data_);
208  }
209 
210  /// Returns a read/write view of a portion of our tensor.
211  __host__ __device__ inline
212  detail::SubTensor<TensorType, Dim - 1, PtrTraits>
213  operator[](IndexT);
214 
215  /// Returns a read/write view of a portion of our tensor (const).
216  __host__ __device__ inline
217  const detail::SubTensor<TensorType, Dim - 1, PtrTraits>
218  operator[](IndexT) const;
219 
220  /// Returns the size of a given dimension, `[0, Dim - 1]`. No bounds
221  /// checking.
222  __host__ __device__ inline IndexT getSize(int i) const {
223  return size_[i];
224  }
225 
226  /// Returns the stride of a given dimension, `[0, Dim - 1]`. No bounds
227  /// checking.
228  __host__ __device__ inline IndexT getStride(int i) const {
229  return stride_[i];
230  }
231 
232  /// Returns the total number of elements contained within our data
233  /// (product of `getSize(i)`)
234  __host__ __device__ IndexT numElements() const;
235 
236  /// If we are contiguous, returns the total size in bytes of our
237  /// data
238  __host__ __device__ size_t getSizeInBytes() const {
239  return (size_t) numElements() * sizeof(T);
240  }
241 
242  /// Returns the size array.
243  __host__ __device__ inline const IndexT* sizes() const {
244  return size_;
245  }
246 
247  /// Returns the stride array.
248  __host__ __device__ inline const IndexT* strides() const {
249  return stride_;
250  }
251 
252  /// Returns true if there is no padding within the tensor and no
253  /// re-ordering of the dimensions.
254  /// ~~~
255  /// (stride(i) == size(i + 1) * stride(i + 1)) && stride(dim - 1) == 0
256  /// ~~~
257  __host__ __device__ bool isContiguous() const;
258 
259  /// Returns whether a given dimension has only increasing stride
260  /// from the previous dimension. A tensor that was permuted by
261  /// exchanging size and stride only will fail this check.
262  /// If `i == 0` just check `size > 0`. Returns `false` if `stride` is `<= 0`.
263  __host__ __device__ bool isConsistentlySized(int i) const;
264 
265  // Returns whether at each dimension `stride <= size`.
266  // If this is not the case then iterating once over the size space will
267  // touch the same memory locations multiple times.
268  __host__ __device__ bool isConsistentlySized() const;
269 
270  /// Returns true if the given dimension index has no padding
271  __host__ __device__ bool isContiguousDim(int i) const;
272 
273  /// Returns a tensor of the same dimension after transposing the two
274  /// dimensions given. Does not actually move elements; transposition
275  /// is made by permuting the size/stride arrays.
276  /// If the dimensions are not valid, asserts.
278  transpose(int dim1, int dim2) const;
279 
280  /// Upcast a tensor of dimension `D` to some tensor of dimension
281  /// D' > D by padding the leading dimensions by 1
282  /// e.g., upcasting a 2-d tensor `[2][3]` to a 4-d tensor `[1][1][2][3]`
283  template <int NewDim>
285  upcastOuter();
286 
287  /// Upcast a tensor of dimension `D` to some tensor of dimension
288  /// D' > D by padding the lowest/most varying dimensions by 1
289  /// e.g., upcasting a 2-d tensor `[2][3]` to a 4-d tensor `[2][3][1][1]`
290  template <int NewDim>
292  upcastInner();
293 
294  /// Downcast a tensor of dimension `D` to some tensor of dimension
295  /// D' < D by collapsing the leading dimensions. asserts if there is
296  /// padding on the leading dimensions.
297  template <int NewDim>
298  __host__ __device__
300 
301  /// Downcast a tensor of dimension `D` to some tensor of dimension
302  /// D' < D by collapsing the leading dimensions. asserts if there is
303  /// padding on the leading dimensions.
304  template <int NewDim>
305  __host__ __device__
307 
308  /// Returns a tensor that is a view of the `SubDim`-dimensional slice
309  /// of this tensor, starting at `at`.
310  template <int SubDim>
312  view(DataPtrType at);
313 
314  /// Returns a tensor that is a view of the `SubDim`-dimensional slice
315  /// of this tensor, starting where our data begins
316  template <int SubDim>
318  view();
319 
320  /// Returns a tensor of the same dimension that is a view of the
321  /// original tensor with the specified dimension restricted to the
322  /// elements in the range [start, start + size)
324  narrowOutermost(IndexT start, IndexT size);
325 
326  /// Returns a tensor of the same dimension that is a view of the
327  /// original tensor with the specified dimension restricted to the
328  /// elements in the range [start, start + size).
329  /// Can occur in an arbitrary dimension, and is possibly
330  /// non-contiguous
331  __host__ __device__ Tensor<T, Dim, false, IndexT, PtrTraits>
332  narrow(int dim, IndexT start, IndexT size);
333 
334  /// Returns a view of the given tensor expressed as a tensor of a
335  /// different number of dimensions.
336  /// Only works if we are contiguous.
337  template <int NewDim>
339  view(std::initializer_list<IndexT> sizes);
340 
341  protected:
342  /// Raw pointer to where the tensor data begins
343  DataPtrType data_;
344 
345  /// Array of strides (in sizeof(T) terms) per each dimension
346  IndexT stride_[Dim];
347 
348  /// Size per each dimension
349  IndexT size_[Dim];
350 };
351 
352 // Utilities for checking a collection of tensors
353 namespace detail {
354 
355 template <typename IndexType>
356 bool canCastIndexType() {
357  return true;
358 }
359 
360 template <typename IndexType, typename T, typename... U>
361 bool canCastIndexType(const T& arg, const U&... args) {
362  return arg.canCastIndexType<IndexType>() &&
363  canCastIndexType(args...);
364 }
365 
366 } // namespace detail
367 
368 template <typename IndexType, typename... T>
369 bool canCastIndexType(const T&... args) {
370  return detail::canCastIndexType(args...);
371 }
372 
373 namespace detail {
374 
375 /// Specialization for a view of a single value (0-dimensional)
376 template <typename TensorType, template <typename U> class PtrTraits>
377 class SubTensor<TensorType, 0, PtrTraits> {
378  public:
379  __host__ __device__ SubTensor<TensorType, 0, PtrTraits>
380  operator=(typename TensorType::DataType val) {
381  *data_ = val;
382  return *this;
383  }
384 
385  // operator T&
386  __host__ __device__ operator typename TensorType::DataType&() {
387  return *data_;
388  }
389 
390  // const operator T& returning const T&
391  __host__ __device__ operator const typename TensorType::DataType&() const {
392  return *data_;
393  }
394 
395  // operator& returning T*
396  __host__ __device__ typename TensorType::DataType* operator&() {
397  return data_;
398  }
399 
400  // const operator& returning const T*
401  __host__ __device__ const typename TensorType::DataType* operator&() const {
402  return data_;
403  }
404 
405  /// Returns a raw accessor to our slice.
406  __host__ __device__ inline typename TensorType::DataPtrType data() {
407  return data_;
408  }
409 
410  /// Returns a raw accessor to our slice (const).
411  __host__ __device__ inline
412  const typename TensorType::DataPtrType data() const {
413  return data_;
414  }
415 
416  /// Cast to a different datatype.
417  template <typename T>
418  __host__ __device__ T& as() {
419  return *dataAs<T>();
420  }
421 
422  /// Cast to a different datatype (const).
423  template <typename T>
424  __host__ __device__ const T& as() const {
425  return *dataAs<T>();
426  }
427 
428  /// Cast to a different datatype
429  template <typename T>
430  __host__ __device__ inline
431  typename PtrTraits<T>::PtrType dataAs() {
432  return reinterpret_cast<typename PtrTraits<T>::PtrType>(data_);
433  }
434 
435  /// Cast to a different datatype (const)
436  template <typename T>
437  __host__ __device__ inline
438  typename PtrTraits<const T>::PtrType dataAs() const {
439  return reinterpret_cast<typename PtrTraits<const T>::PtrType>(data_);
440  }
441 
442  /// Use the texture cache for reads
443  __device__ inline typename TensorType::DataType ldg() const {
444 #if __CUDA_ARCH__ >= 350
445  return __ldg(data_);
446 #else
447  return *data_;
448 #endif
449  }
450 
451  /// Use the texture cache for reads; cast as a particular type
452  template <typename T>
453  __device__ inline T ldgAs() const {
454 #if __CUDA_ARCH__ >= 350
455  return __ldg(dataAs<T>());
456 #else
457  return as<T>();
458 #endif
459  }
460 
461  protected:
462  /// One dimension greater can create us
463  friend class SubTensor<TensorType, 1, PtrTraits>;
464 
465  /// Our parent tensor can create us
466  friend class Tensor<typename TensorType::DataType,
467  1,
468  TensorType::IsContig,
469  typename TensorType::IndexType,
470  PtrTraits>;
471 
472  __host__ __device__ inline SubTensor(
473  TensorType& t,
474  typename TensorType::DataPtrType data)
475  : tensor_(t),
476  data_(data) {
477  }
478 
479  /// The tensor we're referencing
480  TensorType& tensor_;
481 
482  /// Where our value is located
483  typename TensorType::DataPtrType const data_;
484 };
485 
486 /// A `SubDim`-rank slice of a parent Tensor
487 template <typename TensorType,
488  int SubDim,
489  template <typename U> class PtrTraits>
490 class SubTensor {
491  public:
492  /// Returns a view of the data located at our offset (the dimension
493  /// `SubDim` - 1 tensor).
494  __host__ __device__ inline
495  SubTensor<TensorType, SubDim - 1, PtrTraits>
496  operator[](typename TensorType::IndexType index) {
497  if (TensorType::IsContig && SubDim == 1) {
498  // Innermost dimension is stride 1 for contiguous arrays
499  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
500  tensor_, data_ + index);
501  } else {
502  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
503  tensor_,
504  data_ + index * tensor_.getStride(TensorType::NumDim - SubDim));
505  }
506  }
507 
508  /// Returns a view of the data located at our offset (the dimension
509  /// `SubDim` - 1 tensor) (const).
510  __host__ __device__ inline
511  const SubTensor<TensorType, SubDim - 1, PtrTraits>
512  operator[](typename TensorType::IndexType index) const {
513  if (TensorType::IsContig && SubDim == 1) {
514  // Innermost dimension is stride 1 for contiguous arrays
515  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
516  tensor_, data_ + index);
517  } else {
518  return SubTensor<TensorType, SubDim - 1, PtrTraits>(
519  tensor_,
520  data_ + index * tensor_.getStride(TensorType::NumDim - SubDim));
521  }
522  }
523 
524  // operator& returning T*
525  __host__ __device__ typename TensorType::DataType* operator&() {
526  return data_;
527  }
528 
529  // const operator& returning const T*
530  __host__ __device__ const typename TensorType::DataType* operator&() const {
531  return data_;
532  }
533 
534  /// Returns a raw accessor to our slice.
535  __host__ __device__ inline typename TensorType::DataPtrType data() {
536  return data_;
537  }
538 
539  /// Returns a raw accessor to our slice (const).
540  __host__ __device__ inline
541  const typename TensorType::DataPtrType data() const {
542  return data_;
543  }
544 
545  /// Cast to a different datatype.
546  template <typename T>
547  __host__ __device__ T& as() {
548  return *dataAs<T>();
549  }
550 
551  /// Cast to a different datatype (const).
552  template <typename T>
553  __host__ __device__ const T& as() const {
554  return *dataAs<T>();
555  }
556 
557  /// Cast to a different datatype
558  template <typename T>
559  __host__ __device__ inline
560  typename PtrTraits<T>::PtrType dataAs() {
561  return reinterpret_cast<typename PtrTraits<T>::PtrType>(data_);
562  }
563 
564  /// Cast to a different datatype (const)
565  template <typename T>
566  __host__ __device__ inline
567  typename PtrTraits<const T>::PtrType dataAs() const {
568  return reinterpret_cast<typename PtrTraits<const T>::PtrType>(data_);
569  }
570 
571  /// Use the texture cache for reads
572  __device__ inline typename TensorType::DataType ldg() const {
573 #if __CUDA_ARCH__ >= 350
574  return __ldg(data_);
575 #else
576  return *data_;
577 #endif
578  }
579 
580  /// Use the texture cache for reads; cast as a particular type
581  template <typename T>
582  __device__ inline T ldgAs() const {
583 #if __CUDA_ARCH__ >= 350
584  return __ldg(dataAs<T>());
585 #else
586  return as<T>();
587 #endif
588  }
589 
590  /// Returns a tensor that is a view of the SubDim-dimensional slice
591  /// of this tensor, starting where our data begins
592  Tensor<typename TensorType::DataType,
593  SubDim,
594  TensorType::IsContig,
595  typename TensorType::IndexType,
596  PtrTraits> view() {
597  return tensor_.template view<SubDim>(data_);
598  }
599 
600  protected:
601  /// One dimension greater can create us
602  friend class SubTensor<TensorType, SubDim + 1, PtrTraits>;
603 
604  /// Our parent tensor can create us
605  friend class
606  Tensor<typename TensorType::DataType,
607  TensorType::NumDim,
608  TensorType::IsContig,
609  typename TensorType::IndexType,
610  PtrTraits>;
611 
612  __host__ __device__ inline SubTensor(
613  TensorType& t,
614  typename TensorType::DataPtrType data)
615  : tensor_(t),
616  data_(data) {
617  }
618 
619  /// The tensor we're referencing
620  TensorType& tensor_;
621 
622  /// The start of our sub-region
623  typename TensorType::DataPtrType const data_;
624 };
625 
626 } // namespace detail
627 
628 template <typename T, int Dim, bool Contig,
629  typename IndexT, template <typename U> class PtrTraits>
630 __host__ __device__ inline
632  Dim - 1, PtrTraits>
634  return detail::SubTensor<TensorType, Dim - 1, PtrTraits>(
636  *this, data_)[index]);
637 }
638 
639 template <typename T, int Dim, bool Contig,
640  typename IndexT, template <typename U> class PtrTraits>
641 __host__ __device__ inline
643  Dim - 1, PtrTraits>
645  return detail::SubTensor<TensorType, Dim - 1, PtrTraits>(
647  const_cast<TensorType&>(*this), data_)[index]);
648 }
649 
650 } } // namespace
651 
652 #include "Tensor-inl.cuh"
__host__ __device__ Tensor()
Default constructor.
Definition: Tensor-inl.cuh:21
__host__ __device__ const PtrTraits< const U >::PtrType dataAs() const
Cast to a different datatype.
Definition: Tensor.cuh:206
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastInner()
Definition: Tensor-inl.cuh:536
__host__ __device__ PtrTraits< const T >::PtrType dataAs() const
Cast to a different datatype (const)
Definition: Tensor.cuh:567
__host__ __device__ bool isContiguousDim(int i) const
Returns true if the given dimension index has no padding.
Definition: Tensor-inl.cuh:397
__device__ T ldgAs() const
Use the texture cache for reads; cast as a particular type.
Definition: Tensor.cuh:582
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > transpose(int dim1, int dim2) const
Definition: Tensor-inl.cuh:406
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data.
Definition: Tensor.cuh:174
__host__ __device__ PtrTraits< T >::PtrType dataAs()
Cast to a different datatype.
Definition: Tensor.cuh:560
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > narrowOutermost(IndexT start, IndexT size)
Definition: Tensor-inl.cuh:606
__host__ __device__ PtrTraits< const T >::PtrType dataAs() const
Cast to a different datatype (const)
Definition: Tensor.cuh:438
DataPtrType data_
Raw pointer to where the tensor data begins.
Definition: Tensor.cuh:343
__host__ __device__ Tensor< T, Dim, Contig, IndexT, PtrTraits > & operator=(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t)=default
Assignment.
__host__ __device__ const TensorType::DataPtrType data() const
Returns a raw accessor to our slice (const).
Definition: Tensor.cuh:541
__device__ TensorType::DataType ldg() const
Use the texture cache for reads.
Definition: Tensor.cuh:443
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastOuter()
Definition: Tensor-inl.cuh:435
__host__ bool canCastIndexType() const
Definition: Tensor-inl.cuh:307
__host__ __device__ T & as()
Cast to a different datatype.
Definition: Tensor.cuh:418
__host__ __device__ T & as()
Cast to a different datatype.
Definition: Tensor.cuh:547
TensorType & tensor_
The tensor we&#39;re referencing.
Definition: Tensor.cuh:480
__host__ __device__ Tensor< T, Dim, false, IndexT, PtrTraits > narrow(int dim, IndexT start, IndexT size)
Definition: Tensor-inl.cuh:630
__device__ TensorType::DataType ldg() const
Use the texture cache for reads.
Definition: Tensor.cuh:572
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > cast()
Definition: Tensor-inl.cuh:204
__host__ void copyTo(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies ourselves into a tensor; sizes must match.
Definition: Tensor-inl.cuh:140
TensorType::DataPtrType const data_
The start of our sub-region.
Definition: Tensor.cuh:622
__host__ __device__ bool isSame(const Tensor< T, OtherDim, Contig, IndexT, PtrTraits > &rhs) const
Definition: Tensor-inl.cuh:179
__host__ Tensor< T, Dim, Contig, NewIndexT, PtrTraits > castIndexType() const
Definition: Tensor-inl.cuh:287
__host__ __device__ IndexT numElements() const
Definition: Tensor-inl.cuh:338
__device__ T ldgAs() const
Use the texture cache for reads; cast as a particular type.
Definition: Tensor.cuh:453
__host__ __device__ const SubTensor< TensorType, SubDim-1, PtrTraits > operator[](typename TensorType::IndexType index) const
Definition: Tensor.cuh:512
Tensor< typename TensorType::DataType, SubDim, TensorType::IsContig, typename TensorType::IndexType, PtrTraits > view()
Definition: Tensor.cuh:596
TensorType::DataPtrType const data_
Where our value is located.
Definition: Tensor.cuh:483
__host__ __device__ const IndexT * strides() const
Returns the stride array.
Definition: Tensor.cuh:248
__host__ __device__ IndexT getStride(int i) const
Definition: Tensor.cuh:228
Our tensor type.
Definition: Tensor.cuh:31
__host__ __device__ const IndexT * sizes() const
Returns the size array.
Definition: Tensor.cuh:243
__host__ __device__ PtrTraits< U >::PtrType dataAs()
Cast to a different datatype.
Definition: Tensor.cuh:199
__host__ __device__ size_t getSizeInBytes() const
Definition: Tensor.cuh:238
__host__ __device__ DataPtrType end()
Definition: Tensor.cuh:180
Specialization for a view of a single value (0-dimensional)
Definition: Tensor.cuh:377
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > upcastInner()
Definition: Tensor-inl.cuh:464
__host__ __device__ Tensor< T, SubDim, Contig, IndexT, PtrTraits > view()
Definition: Tensor-inl.cuh:599
TensorType & tensor_
The tensor we&#39;re referencing.
Definition: Tensor.cuh:619
__host__ __device__ DataPtrType end() const
Definition: Tensor.cuh:192
__host__ __device__ const TensorType::DataPtrType data() const
Returns a raw accessor to our slice (const).
Definition: Tensor.cuh:412
__host__ __device__ SubTensor< TensorType, SubDim-1, PtrTraits > operator[](typename TensorType::IndexType index)
Definition: Tensor.cuh:496
__host__ __device__ IndexT getSize(int i) const
Definition: Tensor.cuh:222
__host__ void copyFrom(Tensor< T, Dim, Contig, IndexT, PtrTraits > &t, cudaStream_t stream)
Copies a tensor into ourselves; sizes must match.
Definition: Tensor-inl.cuh:102
__host__ __device__ const T & as() const
Cast to a different datatype (const).
Definition: Tensor.cuh:553
A SubDim-rank slice of a parent Tensor.
Definition: Tensor.cuh:38
__host__ __device__ PtrTraits< T >::PtrType dataAs()
Cast to a different datatype.
Definition: Tensor.cuh:431
__host__ __device__ TensorType::DataPtrType data()
Returns a raw accessor to our slice.
Definition: Tensor.cuh:535
IndexT stride_[Dim]
Array of strides (in sizeof(T) terms) per each dimension.
Definition: Tensor.cuh:346
__host__ __device__ bool isContiguous() const
Definition: Tensor-inl.cuh:351
__host__ __device__ Tensor< T, NewDim, Contig, IndexT, PtrTraits > downcastOuter()
Definition: Tensor-inl.cuh:491
IndexT size_[Dim]
Size per each dimension.
Definition: Tensor.cuh:349
__host__ __device__ TensorType::DataPtrType data()
Returns a raw accessor to our slice.
Definition: Tensor.cuh:406
__host__ __device__ detail::SubTensor< TensorType, Dim-1, PtrTraits > operator[](IndexT)
Returns a read/write view of a portion of our tensor.
Definition: Tensor.cuh:633
__host__ __device__ bool canCastResize() const
Returns true if we can castResize() this tensor to the new type.
Definition: Tensor-inl.cuh:260
__host__ __device__ const T & as() const
Cast to a different datatype (const).
Definition: Tensor.cuh:424
__host__ __device__ const DataPtrType data() const
Returns a raw pointer to the start of our data (const).
Definition: Tensor.cuh:186
__host__ __device__ Tensor< U, Dim, Contig, IndexT, PtrTraits > castResize()
Definition: Tensor-inl.cuh:226