229 lines
6.3 KiB
C++
229 lines
6.3 KiB
C++
// Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
#include "trt_net.h"
|
|
|
|
#include <sstream>
|
|
|
|
#include "core/logger.h"
|
|
#include "core/model.h"
|
|
#include "core/module.h"
|
|
#include "core/utils/formatter.h"
|
|
|
|
namespace mmdeploy {
|
|
|
|
namespace trt_detail {
|
|
|
|
class TRTLogger : public nvinfer1::ILogger {
|
|
public:
|
|
void log(Severity severity, const char* msg) noexcept override {
|
|
switch (severity) {
|
|
case Severity::kINFO:
|
|
// INFO("TRTNet: {}", msg);
|
|
break;
|
|
case Severity::kWARNING:
|
|
WARN("TRTNet: {}", msg);
|
|
break;
|
|
case Severity::kERROR:
|
|
case Severity::kINTERNAL_ERROR:
|
|
ERROR("TRTNet: {}", msg);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
static TRTLogger& get() {
|
|
static TRTLogger trt_logger{};
|
|
return trt_logger;
|
|
}
|
|
};
|
|
|
|
nvinfer1::Dims to_dims(const TensorShape& shape) {
|
|
nvinfer1::Dims dims{};
|
|
dims.nbDims = shape.size();
|
|
for (size_t i = 0; i < shape.size(); ++i) {
|
|
dims.d[i] = shape[i];
|
|
}
|
|
return dims;
|
|
}
|
|
|
|
TensorShape to_shape(const nvinfer1::Dims& dims) {
|
|
TensorShape shape(dims.nbDims);
|
|
for (int i = 0; i < shape.size(); ++i) {
|
|
shape[i] = dims.d[i];
|
|
}
|
|
return shape;
|
|
}
|
|
|
|
} // namespace trt_detail
|
|
|
|
std::string to_string(const nvinfer1::Dims& dims) {
|
|
std::stringstream ss;
|
|
ss << "(";
|
|
for (int i = 0; i < dims.nbDims; ++i) {
|
|
if (i) ss << ", ";
|
|
ss << dims.d[i];
|
|
}
|
|
ss << ")";
|
|
return ss.str();
|
|
}
|
|
|
|
static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status e = Status(eFail)) {
|
|
if (code) {
|
|
return success();
|
|
}
|
|
if (msg) {
|
|
ERROR("{}", msg);
|
|
}
|
|
return e;
|
|
}
|
|
|
|
#define TRT_TRY(...) OUTCOME_TRY(trt_try(__VA_ARGS__))
|
|
|
|
TRTNet::~TRTNet() = default;
|
|
|
|
static Result<DataType> MapDataType(nvinfer1::DataType dtype) {
|
|
switch (dtype) {
|
|
case nvinfer1::DataType::kFLOAT:
|
|
return DataType::kFLOAT;
|
|
case nvinfer1::DataType::kHALF:
|
|
return DataType::kHALF;
|
|
case nvinfer1::DataType::kINT8:
|
|
return DataType::kINT8;
|
|
case nvinfer1::DataType::kINT32:
|
|
return DataType::kINT32;
|
|
default:
|
|
return Status(eNotSupported);
|
|
}
|
|
}
|
|
|
|
Result<void> TRTNet::Init(const Value& args) {
|
|
using namespace trt_detail;
|
|
|
|
auto& context = args["context"];
|
|
device_ = context["device"].get<Device>();
|
|
if (device_.is_host()) {
|
|
ERROR("TRTNet: device must be a GPU!");
|
|
return Status(eNotSupported);
|
|
}
|
|
stream_ = context["stream"].get<Stream>();
|
|
|
|
event_ = Event(device_);
|
|
|
|
auto name = args["name"].get<std::string>();
|
|
auto model = context["model"].get<Model>();
|
|
OUTCOME_TRY(auto config, model.GetModelConfig(name));
|
|
|
|
OUTCOME_TRY(auto plan, model.ReadFile(config.net));
|
|
|
|
TRTWrapper runtime = nvinfer1::createInferRuntime(TRTLogger::get());
|
|
TRT_TRY(!!runtime, "failed to create TRT infer runtime");
|
|
|
|
engine_ = runtime->deserializeCudaEngine(plan.data(), plan.size());
|
|
TRT_TRY(!!engine_, "failed to deserialize TRT CUDA engine");
|
|
|
|
TRT_TRY(engine_->getNbOptimizationProfiles() == 1, "only 1 optimization profile supported",
|
|
Status(eNotSupported));
|
|
|
|
auto n_bindings = engine_->getNbBindings();
|
|
for (int i = 0; i < n_bindings; ++i) {
|
|
auto binding_name = engine_->getBindingName(i);
|
|
auto dims = engine_->getBindingDimensions(i);
|
|
if (engine_->isShapeBinding(i)) {
|
|
ERROR("shape binding is not supported.");
|
|
return Status(eNotSupported);
|
|
}
|
|
OUTCOME_TRY(auto dtype, MapDataType(engine_->getBindingDataType(i)));
|
|
TensorDesc desc{
|
|
.device = device_, .data_type = dtype, .shape = to_shape(dims), .name = binding_name};
|
|
if (engine_->bindingIsInput(i)) {
|
|
DEBUG("input binding {} {} {}", i, binding_name, to_string(dims));
|
|
input_ids_.push_back(i);
|
|
input_names_.emplace_back(binding_name);
|
|
input_tensors_.emplace_back(desc, Buffer());
|
|
} else {
|
|
DEBUG("output binding {} {} {}", i, binding_name, to_string(dims));
|
|
output_ids_.push_back(i);
|
|
output_names_.emplace_back(binding_name);
|
|
output_tensors_.emplace_back(desc, Buffer());
|
|
}
|
|
}
|
|
context_ = engine_->createExecutionContext();
|
|
TRT_TRY(!!context_, "failed to create TRT execution context");
|
|
|
|
context_->setOptimizationProfileAsync(0, static_cast<cudaStream_t>(stream_.GetNative()));
|
|
OUTCOME_TRY(stream_.Wait());
|
|
|
|
return success();
|
|
}
|
|
|
|
Result<void> TRTNet::Deinit() {
|
|
context_.reset();
|
|
engine_.reset();
|
|
return success();
|
|
}
|
|
|
|
Result<void> TRTNet::Reshape(Span<TensorShape> input_shapes) {
|
|
using namespace trt_detail;
|
|
if (input_shapes.size() != input_tensors_.size()) {
|
|
return Status(eInvalidArgument);
|
|
}
|
|
for (int i = 0; i < input_tensors_.size(); ++i) {
|
|
auto dims = to_dims(input_shapes[i]);
|
|
// ERROR("input shape: {}", to_string(dims));
|
|
TRT_TRY(context_->setBindingDimensions(input_ids_[i], dims));
|
|
input_tensors_[i].Reshape(input_shapes[i]);
|
|
}
|
|
if (!context_->allInputDimensionsSpecified()) {
|
|
ERROR("not all input dimensions specified");
|
|
return Status(eFail);
|
|
}
|
|
for (int i = 0; i < output_tensors_.size(); ++i) {
|
|
auto dims = context_->getBindingDimensions(output_ids_[i]);
|
|
// ERROR("output shape: {}", to_string(dims));
|
|
output_tensors_[i].Reshape(to_shape(dims));
|
|
}
|
|
return success();
|
|
}
|
|
|
|
Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; }
|
|
|
|
Result<Span<Tensor>> TRTNet::GetOutputTensors() { return output_tensors_; }
|
|
|
|
Result<void> TRTNet::Forward() {
|
|
using namespace trt_detail;
|
|
std::vector<void*> bindings(engine_->getNbBindings());
|
|
|
|
for (int i = 0; i < input_tensors_.size(); ++i) {
|
|
bindings[input_ids_[i]] = input_tensors_[i].data();
|
|
}
|
|
for (int i = 0; i < output_tensors_.size(); ++i) {
|
|
bindings[output_ids_[i]] = output_tensors_[i].data();
|
|
}
|
|
|
|
auto event = GetNative<cudaEvent_t>(event_);
|
|
auto status = context_->enqueueV2(bindings.data(), GetNative<cudaStream_t>(stream_), &event);
|
|
TRT_TRY(status, "TRT forward failed", Status(eFail));
|
|
OUTCOME_TRY(event_.Wait());
|
|
|
|
return success();
|
|
}
|
|
|
|
Result<void> TRTNet::ForwardAsync(Event* event) { return Status(eNotSupported); }
|
|
|
|
class TRTNetCreator : public Creator<Net> {
|
|
public:
|
|
const char* GetName() const override { return "tensorrt"; }
|
|
int GetVersion() const override { return 0; }
|
|
std::unique_ptr<Net> Create(const Value& args) override {
|
|
auto p = std::make_unique<TRTNet>();
|
|
if (p->Init(args)) {
|
|
return p;
|
|
}
|
|
return nullptr;
|
|
}
|
|
};
|
|
|
|
REGISTER_MODULE(Net, TRTNetCreator);
|
|
|
|
} // namespace mmdeploy
|