From 0da1ed2311307ed0a8084ce6fc616b30016bd7d9 Mon Sep 17 00:00:00 2001 From: Li Zhang <lzhang329@gmail.com> Date: Fri, 10 Mar 2023 19:16:13 +0800 Subject: [PATCH] add unified device guard (#1855) --- csrc/mmdeploy/core/device.h | 31 +++++++++++++++++++++++ csrc/mmdeploy/core/device_impl.cpp | 2 ++ csrc/mmdeploy/core/device_impl.h | 2 +- csrc/mmdeploy/device/cpu/cpu_device.cpp | 8 ++++++ csrc/mmdeploy/device/cpu/cpu_device.h | 2 ++ csrc/mmdeploy/device/cuda/cuda_device.cpp | 26 +++++++++++++++++++ csrc/mmdeploy/device/cuda/cuda_device.h | 6 ++++- csrc/mmdeploy/net/ort/ort_net.cpp | 8 +++++- csrc/mmdeploy/net/ort/ort_net.h | 2 +- csrc/mmdeploy/net/trt/trt_net.cpp | 15 ++++++----- csrc/mmdeploy/net/trt/trt_net.h | 1 + 11 files changed, 93 insertions(+), 10 deletions(-) diff --git a/csrc/mmdeploy/core/device.h b/csrc/mmdeploy/core/device.h index 328ccb310..5efc80813 100644 --- a/csrc/mmdeploy/core/device.h +++ b/csrc/mmdeploy/core/device.h @@ -7,12 +7,14 @@ #include <functional> #include <memory> #include <optional> +#include <ostream> #include <string> #include <vector> #include "mmdeploy/core/macro.h" #include "mmdeploy/core/mpl/type_traits.h" #include "mmdeploy/core/status_code.h" +#include "mmdeploy/core/utils/formatter.h" namespace mmdeploy { @@ -97,6 +99,11 @@ class Device { return PlatformId(platform_id_); } + friend std::ostream& operator<<(std::ostream& os, const Device& device) { + os << "(" << device.platform_id_ << ", " << device.device_id_ << ")"; + return os; + } + private: int platform_id_{0}; int device_id_{0}; @@ -112,6 +119,9 @@ class MMDEPLOY_API Platform { // throws if not found explicit Platform(int platform_id); + // bind device with the current thread + Result<void> Bind(Device device, Device* prev); + // -1 if invalid int GetPlatformId() const; @@ -135,6 +145,27 @@ class MMDEPLOY_API Platform { MMDEPLOY_API const char* GetPlatformName(PlatformId id); +class DeviceGuard { + public: + explicit DeviceGuard(Device device) : platform_(device.platform_id()) { + auto r = platform_.Bind(device, &prev_); + if (!r) { + MMDEPLOY_ERROR("failed to bind device {}: {}", device, r.error().message().c_str()); + } + } + + ~DeviceGuard() { + auto r = platform_.Bind(prev_, nullptr); + if (!r) { + MMDEPLOY_ERROR("failed to unbind device {}: {}", prev_, r.error().message().c_str()); + } + } + + private: + Platform platform_; + Device prev_; +}; + class MMDEPLOY_API Stream { public: Stream() = default; diff --git a/csrc/mmdeploy/core/device_impl.cpp b/csrc/mmdeploy/core/device_impl.cpp index 257df4d0d..b65b82be0 100644 --- a/csrc/mmdeploy/core/device_impl.cpp +++ b/csrc/mmdeploy/core/device_impl.cpp @@ -54,6 +54,8 @@ Platform::Platform(int platform_id) { } } +Result<void> Platform::Bind(Device device, Device* prev) { return impl_->BindDevice(device, prev); } + const char* GetPlatformName(PlatformId id) { if (auto impl = gPlatformRegistry().GetPlatformImpl(id); impl) { return impl->GetPlatformName(); diff --git a/csrc/mmdeploy/core/device_impl.h b/csrc/mmdeploy/core/device_impl.h index 196952b6d..8860c9610 100644 --- a/csrc/mmdeploy/core/device_impl.h +++ b/csrc/mmdeploy/core/device_impl.h @@ -27,7 +27,7 @@ class PlatformImpl { virtual void SetPlatformId(int id) { platform_id_ = id; } - virtual Result<void> SetDevice(Device device) { return success(); }; + virtual Result<void> BindDevice(Device device, Device* prev) = 0; virtual shared_ptr<BufferImpl> CreateBuffer(Device device) = 0; diff --git a/csrc/mmdeploy/device/cpu/cpu_device.cpp b/csrc/mmdeploy/device/cpu/cpu_device.cpp index 23c90d95f..9ce6ff1c1 100644 --- a/csrc/mmdeploy/device/cpu/cpu_device.cpp +++ b/csrc/mmdeploy/device/cpu/cpu_device.cpp @@ -70,6 +70,14 @@ class CpuHostMemory : public NonCopyable { //////////////////////////////////////////////////////////////////////////////// /// CpuPlatformImpl +Result<void> CpuPlatformImpl::BindDevice(Device device, Device* prev) { + // do nothing + if (prev) { + *prev = device; + } + return success(); +} + shared_ptr<BufferImpl> CpuPlatformImpl::CreateBuffer(Device device) { return std::make_shared<CpuBufferImpl>(device); } diff --git a/csrc/mmdeploy/device/cpu/cpu_device.h b/csrc/mmdeploy/device/cpu/cpu_device.h index 3b27c7fa1..c508e030d 100644 --- a/csrc/mmdeploy/device/cpu/cpu_device.h +++ b/csrc/mmdeploy/device/cpu/cpu_device.h @@ -17,6 +17,8 @@ class CpuPlatformImpl : public PlatformImpl { const char* GetPlatformName() const noexcept override; + Result<void> BindDevice(Device device, Device* prev) override; + shared_ptr<BufferImpl> CreateBuffer(Device device) override; shared_ptr<StreamImpl> CreateStream(Device device) override; diff --git a/csrc/mmdeploy/device/cuda/cuda_device.cpp b/csrc/mmdeploy/device/cuda/cuda_device.cpp index 7633b6166..70cac8802 100644 --- a/csrc/mmdeploy/device/cuda/cuda_device.cpp +++ b/csrc/mmdeploy/device/cuda/cuda_device.cpp @@ -127,6 +127,32 @@ shared_ptr<EventImpl> CudaPlatformImpl::CreateEvent(Device device) { return std::make_shared<CudaEventImpl>(device); } +Result<void> CudaPlatformImpl::BindDevice(Device device, Device* prev) { + if (device.platform_id() != platform_id_) { + return Status(eInvalidArgument); + } + // skip null device + if (device.device_id() == -1) { + return success(); + } + int prev_device_id = -1; + if (prev) { + CUcontext ctx{}; + cuCtxGetCurrent(&ctx); + if (ctx) { + cudaGetDevice(&prev_device_id); + *prev = Device(platform_id_, prev_device_id); + } else { + // cuda is not initialized return a null device as previous + *prev = Device(platform_id_, -1); + } + } + if (device.device_id() != prev_device_id) { + cudaSetDevice(device.device_id()); + } + return success(); +} + bool CudaPlatformImpl::CheckCopyDevice(const Device& src, const Device& dst, const Device& st) { return st.is_device() && (src.is_host() || src == st) && (dst.is_host() || dst == st); } diff --git a/csrc/mmdeploy/device/cuda/cuda_device.h b/csrc/mmdeploy/device/cuda/cuda_device.h index e7695a20d..20b894652 100644 --- a/csrc/mmdeploy/device/cuda/cuda_device.h +++ b/csrc/mmdeploy/device/cuda/cuda_device.h @@ -28,6 +28,8 @@ class CudaPlatformImpl : public PlatformImpl { const char* GetPlatformName() const noexcept override { return "cuda"; } + Result<void> BindDevice(Device device, Device* prev) override; + shared_ptr<BufferImpl> CreateBuffer(Device device) override; shared_ptr<StreamImpl> CreateStream(Device device) override; @@ -178,7 +180,9 @@ class CudaDeviceGuard { if (ctx) { cudaGetDevice(&prev_device_id_); } - if (prev_device_id_ != device_id_) cudaSetDevice(device_id_); + if (prev_device_id_ != device_id_) { + cudaSetDevice(device_id_); + } } ~CudaDeviceGuard() { if (prev_device_id_ >= 0 && prev_device_id_ != device_id_) { diff --git a/csrc/mmdeploy/net/ort/ort_net.cpp b/csrc/mmdeploy/net/ort/ort_net.cpp index f26abb7a0..38b04e70f 100644 --- a/csrc/mmdeploy/net/ort/ort_net.cpp +++ b/csrc/mmdeploy/net/ort/ort_net.cpp @@ -40,7 +40,7 @@ Result<void> OrtNet::Init(const Value& args) { auto& context = args["context"]; device_ = context["device"].get<Device>(); stream_ = context["stream"].get<Stream>(); - + DeviceGuard guard(device_); auto name = args["name"].get<std::string>(); auto model = context["model"].get<Model>(); @@ -150,6 +150,7 @@ static Result<Tensor> AsTensor(Ort::Value& value, const Device& device) { } Result<void> OrtNet::Forward() { + DeviceGuard guard(device_); try { OUTCOME_TRY(stream_.Wait()); Ort::IoBinding binding(session_); @@ -186,6 +187,11 @@ Result<void> OrtNet::Forward() { return success(); } +OrtNet::~OrtNet() { + DeviceGuard guard(device_); + session_ = Ort::Session{nullptr}; +} + static std::unique_ptr<Net> Create(const Value& args) { try { auto p = std::make_unique<OrtNet>(); diff --git a/csrc/mmdeploy/net/ort/ort_net.h b/csrc/mmdeploy/net/ort/ort_net.h index b325b4ad2..94f5095d8 100644 --- a/csrc/mmdeploy/net/ort/ort_net.h +++ b/csrc/mmdeploy/net/ort/ort_net.h @@ -11,7 +11,7 @@ namespace mmdeploy::framework { class OrtNet : public Net { public: - ~OrtNet() override = default; + ~OrtNet() override; Result<void> Init(const Value& cfg) override; Result<void> Deinit() override; Result<Span<Tensor>> GetInputTensors() override; diff --git a/csrc/mmdeploy/net/trt/trt_net.cpp b/csrc/mmdeploy/net/trt/trt_net.cpp index 359ece60b..8b9b98b5d 100644 --- a/csrc/mmdeploy/net/trt/trt_net.cpp +++ b/csrc/mmdeploy/net/trt/trt_net.cpp @@ -79,7 +79,11 @@ static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status #define TRT_TRY(...) OUTCOME_TRY(trt_try(__VA_ARGS__)) -TRTNet::~TRTNet() = default; +TRTNet::~TRTNet() { + CudaDeviceGuard guard(device_); + context_.reset(); + engine_.reset(); +} static Result<DataType> MapDataType(nvinfer1::DataType dtype) { switch (dtype) { @@ -106,6 +110,7 @@ Result<void> TRTNet::Init(const Value& args) { MMDEPLOY_ERROR("TRTNet: device must be a GPU!"); return Status(eNotSupported); } + CudaDeviceGuard guard(device_); stream_ = context["stream"].get<Stream>(); event_ = Event(device_); @@ -156,13 +161,10 @@ Result<void> TRTNet::Init(const Value& args) { return success(); } -Result<void> TRTNet::Deinit() { - context_.reset(); - engine_.reset(); - return success(); -} +Result<void> TRTNet::Deinit() { return success(); } Result<void> TRTNet::Reshape(Span<TensorShape> input_shapes) { + CudaDeviceGuard guard(device_); using namespace trt_detail; if (input_shapes.size() != input_tensors_.size()) { return Status(eInvalidArgument); @@ -190,6 +192,7 @@ Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; } Result<Span<Tensor>> TRTNet::GetOutputTensors() { return output_tensors_; } Result<void> TRTNet::Forward() { + CudaDeviceGuard guard(device_); using namespace trt_detail; std::vector<void*> bindings(engine_->getNbBindings()); diff --git a/csrc/mmdeploy/net/trt/trt_net.h b/csrc/mmdeploy/net/trt/trt_net.h index e037265b6..075f27d3f 100644 --- a/csrc/mmdeploy/net/trt/trt_net.h +++ b/csrc/mmdeploy/net/trt/trt_net.h @@ -6,6 +6,7 @@ #include "NvInferRuntime.h" #include "mmdeploy/core/mpl/span.h" #include "mmdeploy/core/net.h" +#include "mmdeploy/device/cuda/cuda_device.h" namespace mmdeploy::framework {