mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
75 lines
1.6 KiB
C
75 lines
1.6 KiB
C
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||
|
|
||
|
#ifndef CORE_TENSOR_H
|
||
|
#define CORE_TENSOR_H
|
||
|
|
||
|
#include <string>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "device.h"
|
||
|
#include "types.h"
|
||
|
|
||
|
namespace mmdeploy {
|
||
|
|
||
|
using TensorShape = std::vector<int64_t>;
|
||
|
struct TensorDesc {
|
||
|
Device device;
|
||
|
DataType data_type{DataType::kFLOAT};
|
||
|
TensorShape shape;
|
||
|
std::string name;
|
||
|
};
|
||
|
|
||
|
class Tensor final {
|
||
|
public:
|
||
|
Tensor() = default;
|
||
|
Tensor(const TensorDesc& desc, Allocator allocator = {}); // NOLINT
|
||
|
Tensor(const TensorDesc& desc, Buffer buffer);
|
||
|
Tensor(const TensorDesc& desc, std::shared_ptr<void> data);
|
||
|
~Tensor() = default;
|
||
|
|
||
|
const TensorDesc& desc() const;
|
||
|
const TensorShape& shape() const;
|
||
|
TensorShape::value_type shape(int dim) const;
|
||
|
DataType data_type() const;
|
||
|
const char* name() const;
|
||
|
int64_t size() const;
|
||
|
int64_t byte_size() const;
|
||
|
|
||
|
const Buffer& buffer() const;
|
||
|
Buffer& buffer();
|
||
|
Device device() const;
|
||
|
|
||
|
void Reshape(const TensorShape& shape);
|
||
|
|
||
|
Tensor Slice(int index);
|
||
|
|
||
|
Result<void> CopyFrom(const Tensor& tensor, Stream stream = {});
|
||
|
Result<void> CopyTo(Tensor& tensor, Stream stream = {}) const;
|
||
|
|
||
|
Result<void> CopyFrom(void* host_ptr, Stream stream = {});
|
||
|
Result<void> CopyTo(void* host_ptr, Stream stream = {}) const;
|
||
|
|
||
|
template <typename T = void>
|
||
|
T* data() {
|
||
|
return GetNative<T*>(buffer());
|
||
|
}
|
||
|
|
||
|
template <typename T = void, typename U = std::add_const_t<T> >
|
||
|
U* data() const {
|
||
|
return GetNative<U*>(buffer());
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
void Allocate();
|
||
|
|
||
|
TensorDesc desc_;
|
||
|
Allocator allocator_;
|
||
|
Buffer buffer_;
|
||
|
};
|
||
|
|
||
|
// static_assert(sizeof(Tensor) == 80);
|
||
|
|
||
|
} // namespace mmdeploy
|
||
|
|
||
|
#endif // !CORE_TENSOR_H
|