82 lines
1.8 KiB
C++
82 lines
1.8 KiB
C++
// Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
#ifndef CORE_TENSOR_H
|
|
#define CORE_TENSOR_H
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "device.h"
|
|
#include "types.h"
|
|
|
|
namespace mmdeploy {
|
|
|
|
using TensorShape = std::vector<int64_t>;
|
|
struct TensorDesc {
|
|
Device device;
|
|
DataType data_type{DataType::kFLOAT};
|
|
TensorShape shape;
|
|
std::string name;
|
|
};
|
|
|
|
class MMDEPLOY_API Tensor {
|
|
public:
|
|
Tensor() = default;
|
|
Tensor(const Tensor&) = default;
|
|
Tensor(Tensor&&) noexcept = default;
|
|
Tensor& operator=(const Tensor&) = default;
|
|
Tensor& operator=(Tensor&&) noexcept = default;
|
|
|
|
Tensor(const TensorDesc& desc, Allocator allocator = {}); // NOLINT
|
|
Tensor(const TensorDesc& desc, Buffer buffer);
|
|
Tensor(const TensorDesc& desc, std::shared_ptr<void> data);
|
|
~Tensor() = default;
|
|
|
|
const TensorDesc& desc() const;
|
|
const TensorShape& shape() const;
|
|
TensorShape::value_type shape(int dim) const;
|
|
DataType data_type() const;
|
|
const char* name() const;
|
|
int64_t size() const;
|
|
int64_t byte_size() const;
|
|
|
|
const Buffer& buffer() const;
|
|
Buffer& buffer();
|
|
Device device() const;
|
|
|
|
void Reshape(const TensorShape& shape);
|
|
|
|
Tensor Slice(int index);
|
|
|
|
Result<void> CopyFrom(const Tensor& tensor, Stream stream = {});
|
|
Result<void> CopyTo(Tensor& tensor, Stream stream = {}) const;
|
|
|
|
Result<void> CopyFrom(void* host_ptr, Stream stream = {});
|
|
Result<void> CopyTo(void* host_ptr, Stream stream = {}) const;
|
|
|
|
Allocator allocator() { return allocator_; }
|
|
|
|
template <typename T = void>
|
|
T* data() {
|
|
return GetNative<T*>(buffer());
|
|
}
|
|
|
|
template <typename T = void, typename U = std::add_const_t<T> >
|
|
U* data() const {
|
|
return GetNative<U*>(buffer());
|
|
}
|
|
|
|
private:
|
|
void Allocate();
|
|
|
|
TensorDesc desc_;
|
|
Allocator allocator_;
|
|
Buffer buffer_;
|
|
};
|
|
|
|
// static_assert(sizeof(Tensor) == 80);
|
|
|
|
} // namespace mmdeploy
|
|
|
|
#endif // !CORE_TENSOR_H
|