mmdeploy/csrc/net/trt/trt_net.cpp

229 lines
6.3 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include "trt_net.h"
#include <sstream>
#include "core/logger.h"
#include "core/model.h"
#include "core/module.h"
#include "core/utils/formatter.h"
namespace mmdeploy {
namespace trt_detail {
class TRTLogger : public nvinfer1::ILogger {
public:
void log(Severity severity, const char* msg) noexcept override {
switch (severity) {
case Severity::kINFO:
// INFO("TRTNet: {}", msg);
break;
case Severity::kWARNING:
WARN("TRTNet: {}", msg);
break;
case Severity::kERROR:
case Severity::kINTERNAL_ERROR:
ERROR("TRTNet: {}", msg);
break;
default:
break;
}
}
static TRTLogger& get() {
static TRTLogger trt_logger{};
return trt_logger;
}
};
nvinfer1::Dims to_dims(const TensorShape& shape) {
nvinfer1::Dims dims{};
dims.nbDims = shape.size();
for (size_t i = 0; i < shape.size(); ++i) {
dims.d[i] = shape[i];
}
return dims;
}
TensorShape to_shape(const nvinfer1::Dims& dims) {
TensorShape shape(dims.nbDims);
for (int i = 0; i < shape.size(); ++i) {
shape[i] = dims.d[i];
}
return shape;
}
} // namespace trt_detail
std::string to_string(const nvinfer1::Dims& dims) {
std::stringstream ss;
ss << "(";
for (int i = 0; i < dims.nbDims; ++i) {
if (i) ss << ", ";
ss << dims.d[i];
}
ss << ")";
return ss.str();
}
static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status e = Status(eFail)) {
if (code) {
return success();
}
if (msg) {
ERROR("{}", msg);
}
return e;
}
#define TRT_TRY(...) OUTCOME_TRY(trt_try(__VA_ARGS__))
TRTNet::~TRTNet() = default;
static Result<DataType> MapDataType(nvinfer1::DataType dtype) {
switch (dtype) {
case nvinfer1::DataType::kFLOAT:
return DataType::kFLOAT;
case nvinfer1::DataType::kHALF:
return DataType::kHALF;
case nvinfer1::DataType::kINT8:
return DataType::kINT8;
case nvinfer1::DataType::kINT32:
return DataType::kINT32;
default:
return Status(eNotSupported);
}
}
Result<void> TRTNet::Init(const Value& args) {
using namespace trt_detail;
auto& context = args["context"];
device_ = context["device"].get<Device>();
if (device_.is_host()) {
ERROR("TRTNet: device must be a GPU!");
return Status(eNotSupported);
}
stream_ = context["stream"].get<Stream>();
event_ = Event(device_);
auto name = args["name"].get<std::string>();
auto model = context["model"].get<Model>();
OUTCOME_TRY(auto config, model.GetModelConfig(name));
OUTCOME_TRY(auto plan, model.ReadFile(config.net));
TRTWrapper runtime = nvinfer1::createInferRuntime(TRTLogger::get());
TRT_TRY(!!runtime, "failed to create TRT infer runtime");
engine_ = runtime->deserializeCudaEngine(plan.data(), plan.size());
TRT_TRY(!!engine_, "failed to deserialize TRT CUDA engine");
TRT_TRY(engine_->getNbOptimizationProfiles() == 1, "only 1 optimization profile supported",
Status(eNotSupported));
auto n_bindings = engine_->getNbBindings();
for (int i = 0; i < n_bindings; ++i) {
auto binding_name = engine_->getBindingName(i);
auto dims = engine_->getBindingDimensions(i);
if (engine_->isShapeBinding(i)) {
ERROR("shape binding is not supported.");
return Status(eNotSupported);
}
OUTCOME_TRY(auto dtype, MapDataType(engine_->getBindingDataType(i)));
TensorDesc desc{
.device = device_, .data_type = dtype, .shape = to_shape(dims), .name = binding_name};
if (engine_->bindingIsInput(i)) {
DEBUG("input binding {} {} {}", i, binding_name, to_string(dims));
input_ids_.push_back(i);
input_names_.emplace_back(binding_name);
input_tensors_.emplace_back(desc, Buffer());
} else {
DEBUG("output binding {} {} {}", i, binding_name, to_string(dims));
output_ids_.push_back(i);
output_names_.emplace_back(binding_name);
output_tensors_.emplace_back(desc, Buffer());
}
}
context_ = engine_->createExecutionContext();
TRT_TRY(!!context_, "failed to create TRT execution context");
context_->setOptimizationProfileAsync(0, static_cast<cudaStream_t>(stream_.GetNative()));
OUTCOME_TRY(stream_.Wait());
return success();
}
Result<void> TRTNet::Deinit() {
context_.reset();
engine_.reset();
return success();
}
Result<void> TRTNet::Reshape(Span<TensorShape> input_shapes) {
using namespace trt_detail;
if (input_shapes.size() != input_tensors_.size()) {
return Status(eInvalidArgument);
}
for (int i = 0; i < input_tensors_.size(); ++i) {
auto dims = to_dims(input_shapes[i]);
// ERROR("input shape: {}", to_string(dims));
TRT_TRY(context_->setBindingDimensions(input_ids_[i], dims));
input_tensors_[i].Reshape(input_shapes[i]);
}
if (!context_->allInputDimensionsSpecified()) {
ERROR("not all input dimensions specified");
return Status(eFail);
}
for (int i = 0; i < output_tensors_.size(); ++i) {
auto dims = context_->getBindingDimensions(output_ids_[i]);
// ERROR("output shape: {}", to_string(dims));
output_tensors_[i].Reshape(to_shape(dims));
}
return success();
}
Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; }
Result<Span<Tensor>> TRTNet::GetOutputTensors() { return output_tensors_; }
Result<void> TRTNet::Forward() {
using namespace trt_detail;
std::vector<void*> bindings(engine_->getNbBindings());
for (int i = 0; i < input_tensors_.size(); ++i) {
bindings[input_ids_[i]] = input_tensors_[i].data();
}
for (int i = 0; i < output_tensors_.size(); ++i) {
bindings[output_ids_[i]] = output_tensors_[i].data();
}
auto event = GetNative<cudaEvent_t>(event_);
auto status = context_->enqueueV2(bindings.data(), GetNative<cudaStream_t>(stream_), &event);
TRT_TRY(status, "TRT forward failed", Status(eFail));
OUTCOME_TRY(event_.Wait());
return success();
}
Result<void> TRTNet::ForwardAsync(Event* event) { return Status(eNotSupported); }
class TRTNetCreator : public Creator<Net> {
public:
const char* GetName() const override { return "tensorrt"; }
int GetVersion() const override { return 0; }
std::unique_ptr<Net> Create(const Value& args) override {
auto p = std::make_unique<TRTNet>();
if (p->Init(args)) {
return p;
}
return nullptr;
}
};
REGISTER_MODULE(Net, TRTNetCreator);
} // namespace mmdeploy