// Copyright (c) OpenMMLab. All rights reserved. #ifndef MMDEPLOY_SRC_NET_TRT_TRT_NET_H_ #define MMDEPLOY_SRC_NET_TRT_TRT_NET_H_ #include #include "NvInferRuntime.h" #include "core/net.h" namespace mmdeploy { namespace trt_detail { template class TRTWrapper { public: TRTWrapper() : ptr_(nullptr) {} TRTWrapper(T* ptr) : ptr_(ptr) {} // NOLINT ~TRTWrapper() { reset(); } TRTWrapper(const TRTWrapper&) = delete; TRTWrapper& operator=(const TRTWrapper&) = delete; TRTWrapper(TRTWrapper&& other) noexcept { *this = std::move(other); } TRTWrapper& operator=(TRTWrapper&& other) noexcept { reset(std::exchange(other.ptr_, nullptr)); return *this; } T& operator*() { return *ptr_; } T* operator->() { return ptr_; } void reset(T* p = nullptr) { if (auto old = std::exchange(ptr_, p)) { // NOLINT #if NV_TENSORRT_MAJOR < 8 old->destroy(); #else delete old; #endif } } explicit operator bool() const noexcept { return ptr_ != nullptr; } private: T* ptr_; }; // clang-format off template explicit TRTWrapper(T*) -> TRTWrapper; // clang-format on } // namespace trt_detail class TRTNet : public Net { public: ~TRTNet() override; Result Init(const Value& cfg) override; Result Deinit() override; Result> GetInputTensors() override; Result> GetOutputTensors() override; Result Reshape(Span input_shapes) override; Result Forward() override; Result ForwardAsync(Event* event) override; private: private: trt_detail::TRTWrapper engine_; trt_detail::TRTWrapper context_; std::vector input_ids_; std::vector output_ids_; std::vector input_names_; std::vector output_names_; std::vector input_tensors_; std::vector output_tensors_; Device device_; Stream stream_; Event event_; }; } // namespace mmdeploy #endif // MMDEPLOY_SRC_NET_TRT_TRT_NET_H_