mmdeploy/csrc/net/ort/ort_net.cpp

198 lines
6.3 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include "ort_net.h"
#include "core/logger.h"
#include "core/model.h"
#include "core/utils/formatter.h"
namespace mmdeploy {
static TensorShape to_shape(const Ort::TypeInfo& info) {
auto shape = info.GetTensorTypeAndShapeInfo().GetShape();
return {shape.begin(), shape.end()};
}
static Result<DataType> ConvertElementType(ONNXTensorElementDataType type) {
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return DataType::kFLOAT;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return DataType::kHALF;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return DataType::kINT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return DataType::kINT32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return DataType::kINT64;
default:
ERROR("unsupported ONNXTensorElementDataType: {}", static_cast<int>(type));
return Status(eNotSupported);
}
}
// TODO: handle datatype
Result<void> OrtNet::Init(const Value& args) {
auto& context = args["context"];
device_ = context["device"].get<Device>();
stream_ = context["stream"].get<Stream>();
auto name = args["name"].get<std::string>();
auto model = context["model"].get<Model>();
OUTCOME_TRY(auto config, model.GetModelConfig(name));
OUTCOME_TRY(auto onnx, model.ReadFile(config.net));
Ort::SessionOptions options;
options.SetLogSeverityLevel(3);
if (device_.is_device()) {
OrtCUDAProviderOptions cuda_options{};
cuda_options.device_id = device_.device_id();
// TODO set compute stream
options.AppendExecutionProvider_CUDA(cuda_options);
}
session_ = {env_, onnx.data(), onnx.size(), options};
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Allocator allocator(session_, memory_info);
auto n_inputs = session_.GetInputCount();
// force negative shape to be empty
auto filter_shape = [](TensorShape& shape) {
if (std::any_of(begin(shape), end(shape), [](auto x) { return x < 0; })) {
shape = {};
}
};
for (int i = 0; i < n_inputs; ++i) {
auto input_name = session_.GetInputName(i, allocator);
auto type_info = session_.GetInputTypeInfo(i);
auto shape = to_shape(type_info);
INFO("input {}, shape = {}", i, shape);
filter_shape(shape);
OUTCOME_TRY(auto data_type,
ConvertElementType(type_info.GetTensorTypeAndShapeInfo().GetElementType()));
input_tensors_.emplace_back(
TensorDesc{.device = device_, .data_type = data_type, .shape = shape, .name = input_name});
allocator.Free(input_name);
}
auto n_outputs = session_.GetOutputCount();
for (int i = 0; i < n_outputs; ++i) {
auto output_name = session_.GetOutputName(i, allocator);
auto type_info = session_.GetOutputTypeInfo(i);
auto shape = to_shape(type_info);
INFO("output {}, shape = {}", i, shape);
filter_shape(shape);
OUTCOME_TRY(auto data_type,
ConvertElementType(type_info.GetTensorTypeAndShapeInfo().GetElementType()));
output_tensors_.emplace_back(
TensorDesc{.device = device_, .data_type = data_type, .shape = shape, .name = output_name});
allocator.Free(output_name);
}
return success();
}
Result<void> OrtNet::ForwardAsync(Event* event) { return Status(eNotSupported); }
Result<void> OrtNet::Deinit() { return success(); }
Result<Span<Tensor>> OrtNet::GetInputTensors() { return input_tensors_; }
Result<Span<Tensor>> OrtNet::GetOutputTensors() { return output_tensors_; }
Result<void> OrtNet::Reshape(Span<TensorShape> input_shapes) {
for (size_t i = 0; i < input_shapes.size(); ++i) {
input_tensors_[i].Reshape(input_shapes[i]);
}
return success();
}
static Ort::MemoryInfo MemoryInfo(const TensorDesc& desc) {
const char* device_name = desc.device.is_host() ? "Cpu" : "Cuda";
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator, desc.device.device_id(),
OrtMemTypeDefault);
return memory_info;
}
static Ort::Value AsOrtValue(Tensor& tensor) {
auto memory_info = MemoryInfo(tensor.desc());
std::vector<int64_t> shape(begin(tensor.shape()), end(tensor.shape()));
return Ort::Value::CreateTensor(memory_info, tensor.data<float>(), tensor.size(), shape.data(),
shape.size());
}
static Result<Tensor> AsTensor(Ort::Value& value, const Device& device) {
auto info = value.GetTensorTypeAndShapeInfo();
TensorDesc desc;
desc.shape = info.GetShape();
desc.device = device;
OUTCOME_TRY(desc.data_type, ConvertElementType(info.GetElementType()));
std::shared_ptr<void> data(const_cast<void*>(value.GetTensorData<void>()), [](void*) {});
return Tensor(desc, data);
}
Result<void> OrtNet::Forward() {
try {
OUTCOME_TRY(stream_.Wait());
Ort::IoBinding binding(session_);
std::vector<Ort::Value> inputs;
std::vector<Ort::Value> outputs;
Ort::RunOptions options;
inputs.reserve(input_tensors_.size());
for (auto& t : input_tensors_) {
inputs.push_back(AsOrtValue(t));
binding.BindInput(t.name(), inputs.back());
}
// TODO: We are in the same situation as PPL.nn, the backend can't infer shapes
// without executing forward
for (auto& t : output_tensors_) {
binding.BindOutput(t.name(), MemoryInfo(t.desc()));
}
session_.Run({}, binding);
outputs = binding.GetOutputValues();
for (size_t i = 0; i < output_tensors_.size(); ++i) {
OUTCOME_TRY(auto tmp, AsTensor(outputs[i], output_tensors_[i].device()));
output_tensors_[i].Reshape(tmp.shape());
OUTCOME_TRY(tmp.CopyTo(output_tensors_[i], stream_));
}
OUTCOME_TRY(stream_.Wait());
} catch (const std::exception& e) {
ERROR(e.what());
return Status(eFail);
}
return success();
}
class OrtNetCreator : public Creator<Net> {
public:
const char* GetName() const override { return "onnxruntime"; }
int GetVersion() const override { return 0; }
std::unique_ptr<Net> Create(const Value& args) override {
try {
auto p = std::make_unique<OrtNet>();
if (auto r = p->Init(args)) {
return p;
} else {
ERROR("error creating OrtNet: {}", r.error().message().c_str());
return nullptr;
}
} catch (const std::exception& e) {
ERROR("unhandled exception when creating ORTNet: {}", e.what());
return nullptr;
}
}
};
REGISTER_MODULE(Net, OrtNetCreator);
} // namespace mmdeploy