// Copyright (c) OpenMMLab. All rights reserved. #include "trt_net.h" #include <sstream> #include "core/logger.h" #include "core/model.h" #include "core/module.h" #include "core/utils/formatter.h" namespace mmdeploy { namespace trt_detail { class TRTLogger : public nvinfer1::ILogger { public: void log(Severity severity, const char* msg) noexcept override { switch (severity) { case Severity::kINFO: // INFO("TRTNet: {}", msg); break; case Severity::kWARNING: WARN("TRTNet: {}", msg); break; case Severity::kERROR: case Severity::kINTERNAL_ERROR: ERROR("TRTNet: {}", msg); break; default: break; } } static TRTLogger& get() { static TRTLogger trt_logger{}; return trt_logger; } }; nvinfer1::Dims to_dims(const TensorShape& shape) { nvinfer1::Dims dims{}; dims.nbDims = shape.size(); for (size_t i = 0; i < shape.size(); ++i) { dims.d[i] = shape[i]; } return dims; } TensorShape to_shape(const nvinfer1::Dims& dims) { TensorShape shape(dims.nbDims); for (int i = 0; i < shape.size(); ++i) { shape[i] = dims.d[i]; } return shape; } } // namespace trt_detail std::string to_string(const nvinfer1::Dims& dims) { std::stringstream ss; ss << "("; for (int i = 0; i < dims.nbDims; ++i) { if (i) ss << ", "; ss << dims.d[i]; } ss << ")"; return ss.str(); } static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status e = Status(eFail)) { if (code) { return success(); } if (msg) { ERROR("{}", msg); } return e; } #define TRT_TRY(...) OUTCOME_TRY(trt_try(__VA_ARGS__)) TRTNet::~TRTNet() = default; static Result<DataType> MapDataType(nvinfer1::DataType dtype) { switch (dtype) { case nvinfer1::DataType::kFLOAT: return DataType::kFLOAT; case nvinfer1::DataType::kHALF: return DataType::kHALF; case nvinfer1::DataType::kINT8: return DataType::kINT8; case nvinfer1::DataType::kINT32: return DataType::kINT32; default: return Status(eNotSupported); } } Result<void> TRTNet::Init(const Value& args) { using namespace trt_detail; auto& context = args["context"]; device_ = context["device"].get<Device>(); if (device_.is_host()) { ERROR("TRTNet: device must be a GPU!"); return Status(eNotSupported); } stream_ = context["stream"].get<Stream>(); event_ = Event(device_); auto name = args["name"].get<std::string>(); auto model = context["model"].get<Model>(); OUTCOME_TRY(auto config, model.GetModelConfig(name)); OUTCOME_TRY(auto plan, model.ReadFile(config.net)); TRTWrapper runtime = nvinfer1::createInferRuntime(TRTLogger::get()); TRT_TRY(!!runtime, "failed to create TRT infer runtime"); engine_ = runtime->deserializeCudaEngine(plan.data(), plan.size()); TRT_TRY(!!engine_, "failed to deserialize TRT CUDA engine"); TRT_TRY(engine_->getNbOptimizationProfiles() == 1, "only 1 optimization profile supported", Status(eNotSupported)); auto n_bindings = engine_->getNbBindings(); for (int i = 0; i < n_bindings; ++i) { auto binding_name = engine_->getBindingName(i); auto dims = engine_->getBindingDimensions(i); if (engine_->isShapeBinding(i)) { ERROR("shape binding is not supported."); return Status(eNotSupported); } OUTCOME_TRY(auto dtype, MapDataType(engine_->getBindingDataType(i))); TensorDesc desc{ .device = device_, .data_type = dtype, .shape = to_shape(dims), .name = binding_name}; if (engine_->bindingIsInput(i)) { DEBUG("input binding {} {} {}", i, binding_name, to_string(dims)); input_ids_.push_back(i); input_names_.emplace_back(binding_name); input_tensors_.emplace_back(desc, Buffer()); } else { DEBUG("output binding {} {} {}", i, binding_name, to_string(dims)); output_ids_.push_back(i); output_names_.emplace_back(binding_name); output_tensors_.emplace_back(desc, Buffer()); } } context_ = engine_->createExecutionContext(); TRT_TRY(!!context_, "failed to create TRT execution context"); context_->setOptimizationProfileAsync(0, static_cast<cudaStream_t>(stream_.GetNative())); OUTCOME_TRY(stream_.Wait()); return success(); } Result<void> TRTNet::Deinit() { context_.reset(); engine_.reset(); return success(); } Result<void> TRTNet::Reshape(Span<TensorShape> input_shapes) { using namespace trt_detail; if (input_shapes.size() != input_tensors_.size()) { return Status(eInvalidArgument); } for (int i = 0; i < input_tensors_.size(); ++i) { auto dims = to_dims(input_shapes[i]); // ERROR("input shape: {}", to_string(dims)); TRT_TRY(context_->setBindingDimensions(input_ids_[i], dims)); input_tensors_[i].Reshape(input_shapes[i]); } if (!context_->allInputDimensionsSpecified()) { ERROR("not all input dimensions specified"); return Status(eFail); } for (int i = 0; i < output_tensors_.size(); ++i) { auto dims = context_->getBindingDimensions(output_ids_[i]); // ERROR("output shape: {}", to_string(dims)); output_tensors_[i].Reshape(to_shape(dims)); } return success(); } Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; } Result<Span<Tensor>> TRTNet::GetOutputTensors() { return output_tensors_; } Result<void> TRTNet::Forward() { using namespace trt_detail; std::vector<void*> bindings(engine_->getNbBindings()); for (int i = 0; i < input_tensors_.size(); ++i) { bindings[input_ids_[i]] = input_tensors_[i].data(); } for (int i = 0; i < output_tensors_.size(); ++i) { bindings[output_ids_[i]] = output_tensors_[i].data(); } auto event = GetNative<cudaEvent_t>(event_); auto status = context_->enqueueV2(bindings.data(), GetNative<cudaStream_t>(stream_), &event); TRT_TRY(status, "TRT forward failed", Status(eFail)); OUTCOME_TRY(event_.Wait()); return success(); } Result<void> TRTNet::ForwardAsync(Event* event) { return Status(eNotSupported); } class TRTNetCreator : public Creator<Net> { public: const char* GetName() const override { return "tensorrt"; } int GetVersion() const override { return 0; } std::unique_ptr<Net> Create(const Value& args) override { auto p = std::make_unique<TRTNet>(); if (p->Init(args)) { return p; } return nullptr; } }; REGISTER_MODULE(Net, TRTNetCreator); } // namespace mmdeploy