From 0da1ed2311307ed0a8084ce6fc616b30016bd7d9 Mon Sep 17 00:00:00 2001
From: Li Zhang <lzhang329@gmail.com>
Date: Fri, 10 Mar 2023 19:16:13 +0800
Subject: [PATCH] add unified device guard (#1855)

---
 csrc/mmdeploy/core/device.h               | 31 +++++++++++++++++++++++
 csrc/mmdeploy/core/device_impl.cpp        |  2 ++
 csrc/mmdeploy/core/device_impl.h          |  2 +-
 csrc/mmdeploy/device/cpu/cpu_device.cpp   |  8 ++++++
 csrc/mmdeploy/device/cpu/cpu_device.h     |  2 ++
 csrc/mmdeploy/device/cuda/cuda_device.cpp | 26 +++++++++++++++++++
 csrc/mmdeploy/device/cuda/cuda_device.h   |  6 ++++-
 csrc/mmdeploy/net/ort/ort_net.cpp         |  8 +++++-
 csrc/mmdeploy/net/ort/ort_net.h           |  2 +-
 csrc/mmdeploy/net/trt/trt_net.cpp         | 15 ++++++-----
 csrc/mmdeploy/net/trt/trt_net.h           |  1 +
 11 files changed, 93 insertions(+), 10 deletions(-)

diff --git a/csrc/mmdeploy/core/device.h b/csrc/mmdeploy/core/device.h
index 328ccb310..5efc80813 100644
--- a/csrc/mmdeploy/core/device.h
+++ b/csrc/mmdeploy/core/device.h
@@ -7,12 +7,14 @@
 #include <functional>
 #include <memory>
 #include <optional>
+#include <ostream>
 #include <string>
 #include <vector>
 
 #include "mmdeploy/core/macro.h"
 #include "mmdeploy/core/mpl/type_traits.h"
 #include "mmdeploy/core/status_code.h"
+#include "mmdeploy/core/utils/formatter.h"
 
 namespace mmdeploy {
 
@@ -97,6 +99,11 @@ class Device {
     return PlatformId(platform_id_);
   }
 
+  friend std::ostream& operator<<(std::ostream& os, const Device& device) {
+    os << "(" << device.platform_id_ << ", " << device.device_id_ << ")";
+    return os;
+  }
+
  private:
   int platform_id_{0};
   int device_id_{0};
@@ -112,6 +119,9 @@ class MMDEPLOY_API Platform {
   // throws if not found
   explicit Platform(int platform_id);
 
+  // bind device with the current thread
+  Result<void> Bind(Device device, Device* prev);
+
   // -1 if invalid
   int GetPlatformId() const;
 
@@ -135,6 +145,27 @@ class MMDEPLOY_API Platform {
 
 MMDEPLOY_API const char* GetPlatformName(PlatformId id);
 
+class DeviceGuard {
+ public:
+  explicit DeviceGuard(Device device) : platform_(device.platform_id()) {
+    auto r = platform_.Bind(device, &prev_);
+    if (!r) {
+      MMDEPLOY_ERROR("failed to bind device {}: {}", device, r.error().message().c_str());
+    }
+  }
+
+  ~DeviceGuard() {
+    auto r = platform_.Bind(prev_, nullptr);
+    if (!r) {
+      MMDEPLOY_ERROR("failed to unbind device {}: {}", prev_, r.error().message().c_str());
+    }
+  }
+
+ private:
+  Platform platform_;
+  Device prev_;
+};
+
 class MMDEPLOY_API Stream {
  public:
   Stream() = default;
diff --git a/csrc/mmdeploy/core/device_impl.cpp b/csrc/mmdeploy/core/device_impl.cpp
index 257df4d0d..b65b82be0 100644
--- a/csrc/mmdeploy/core/device_impl.cpp
+++ b/csrc/mmdeploy/core/device_impl.cpp
@@ -54,6 +54,8 @@ Platform::Platform(int platform_id) {
   }
 }
 
+Result<void> Platform::Bind(Device device, Device* prev) { return impl_->BindDevice(device, prev); }
+
 const char* GetPlatformName(PlatformId id) {
   if (auto impl = gPlatformRegistry().GetPlatformImpl(id); impl) {
     return impl->GetPlatformName();
diff --git a/csrc/mmdeploy/core/device_impl.h b/csrc/mmdeploy/core/device_impl.h
index 196952b6d..8860c9610 100644
--- a/csrc/mmdeploy/core/device_impl.h
+++ b/csrc/mmdeploy/core/device_impl.h
@@ -27,7 +27,7 @@ class PlatformImpl {
 
   virtual void SetPlatformId(int id) { platform_id_ = id; }
 
-  virtual Result<void> SetDevice(Device device) { return success(); };
+  virtual Result<void> BindDevice(Device device, Device* prev) = 0;
 
   virtual shared_ptr<BufferImpl> CreateBuffer(Device device) = 0;
 
diff --git a/csrc/mmdeploy/device/cpu/cpu_device.cpp b/csrc/mmdeploy/device/cpu/cpu_device.cpp
index 23c90d95f..9ce6ff1c1 100644
--- a/csrc/mmdeploy/device/cpu/cpu_device.cpp
+++ b/csrc/mmdeploy/device/cpu/cpu_device.cpp
@@ -70,6 +70,14 @@ class CpuHostMemory : public NonCopyable {
 ////////////////////////////////////////////////////////////////////////////////
 /// CpuPlatformImpl
 
+Result<void> CpuPlatformImpl::BindDevice(Device device, Device* prev) {
+  // do nothing
+  if (prev) {
+    *prev = device;
+  }
+  return success();
+}
+
 shared_ptr<BufferImpl> CpuPlatformImpl::CreateBuffer(Device device) {
   return std::make_shared<CpuBufferImpl>(device);
 }
diff --git a/csrc/mmdeploy/device/cpu/cpu_device.h b/csrc/mmdeploy/device/cpu/cpu_device.h
index 3b27c7fa1..c508e030d 100644
--- a/csrc/mmdeploy/device/cpu/cpu_device.h
+++ b/csrc/mmdeploy/device/cpu/cpu_device.h
@@ -17,6 +17,8 @@ class CpuPlatformImpl : public PlatformImpl {
 
   const char* GetPlatformName() const noexcept override;
 
+  Result<void> BindDevice(Device device, Device* prev) override;
+
   shared_ptr<BufferImpl> CreateBuffer(Device device) override;
 
   shared_ptr<StreamImpl> CreateStream(Device device) override;
diff --git a/csrc/mmdeploy/device/cuda/cuda_device.cpp b/csrc/mmdeploy/device/cuda/cuda_device.cpp
index 7633b6166..70cac8802 100644
--- a/csrc/mmdeploy/device/cuda/cuda_device.cpp
+++ b/csrc/mmdeploy/device/cuda/cuda_device.cpp
@@ -127,6 +127,32 @@ shared_ptr<EventImpl> CudaPlatformImpl::CreateEvent(Device device) {
   return std::make_shared<CudaEventImpl>(device);
 }
 
+Result<void> CudaPlatformImpl::BindDevice(Device device, Device* prev) {
+  if (device.platform_id() != platform_id_) {
+    return Status(eInvalidArgument);
+  }
+  // skip null device
+  if (device.device_id() == -1) {
+    return success();
+  }
+  int prev_device_id = -1;
+  if (prev) {
+    CUcontext ctx{};
+    cuCtxGetCurrent(&ctx);
+    if (ctx) {
+      cudaGetDevice(&prev_device_id);
+      *prev = Device(platform_id_, prev_device_id);
+    } else {
+      // cuda is not initialized return a null device as previous
+      *prev = Device(platform_id_, -1);
+    }
+  }
+  if (device.device_id() != prev_device_id) {
+    cudaSetDevice(device.device_id());
+  }
+  return success();
+}
+
 bool CudaPlatformImpl::CheckCopyDevice(const Device& src, const Device& dst, const Device& st) {
   return st.is_device() && (src.is_host() || src == st) && (dst.is_host() || dst == st);
 }
diff --git a/csrc/mmdeploy/device/cuda/cuda_device.h b/csrc/mmdeploy/device/cuda/cuda_device.h
index e7695a20d..20b894652 100644
--- a/csrc/mmdeploy/device/cuda/cuda_device.h
+++ b/csrc/mmdeploy/device/cuda/cuda_device.h
@@ -28,6 +28,8 @@ class CudaPlatformImpl : public PlatformImpl {
 
   const char* GetPlatformName() const noexcept override { return "cuda"; }
 
+  Result<void> BindDevice(Device device, Device* prev) override;
+
   shared_ptr<BufferImpl> CreateBuffer(Device device) override;
 
   shared_ptr<StreamImpl> CreateStream(Device device) override;
@@ -178,7 +180,9 @@ class CudaDeviceGuard {
     if (ctx) {
       cudaGetDevice(&prev_device_id_);
     }
-    if (prev_device_id_ != device_id_) cudaSetDevice(device_id_);
+    if (prev_device_id_ != device_id_) {
+      cudaSetDevice(device_id_);
+    }
   }
   ~CudaDeviceGuard() {
     if (prev_device_id_ >= 0 && prev_device_id_ != device_id_) {
diff --git a/csrc/mmdeploy/net/ort/ort_net.cpp b/csrc/mmdeploy/net/ort/ort_net.cpp
index f26abb7a0..38b04e70f 100644
--- a/csrc/mmdeploy/net/ort/ort_net.cpp
+++ b/csrc/mmdeploy/net/ort/ort_net.cpp
@@ -40,7 +40,7 @@ Result<void> OrtNet::Init(const Value& args) {
   auto& context = args["context"];
   device_ = context["device"].get<Device>();
   stream_ = context["stream"].get<Stream>();
-
+  DeviceGuard guard(device_);
   auto name = args["name"].get<std::string>();
   auto model = context["model"].get<Model>();
 
@@ -150,6 +150,7 @@ static Result<Tensor> AsTensor(Ort::Value& value, const Device& device) {
 }
 
 Result<void> OrtNet::Forward() {
+  DeviceGuard guard(device_);
   try {
     OUTCOME_TRY(stream_.Wait());
     Ort::IoBinding binding(session_);
@@ -186,6 +187,11 @@ Result<void> OrtNet::Forward() {
   return success();
 }
 
+OrtNet::~OrtNet() {
+  DeviceGuard guard(device_);
+  session_ = Ort::Session{nullptr};
+}
+
 static std::unique_ptr<Net> Create(const Value& args) {
   try {
     auto p = std::make_unique<OrtNet>();
diff --git a/csrc/mmdeploy/net/ort/ort_net.h b/csrc/mmdeploy/net/ort/ort_net.h
index b325b4ad2..94f5095d8 100644
--- a/csrc/mmdeploy/net/ort/ort_net.h
+++ b/csrc/mmdeploy/net/ort/ort_net.h
@@ -11,7 +11,7 @@ namespace mmdeploy::framework {
 
 class OrtNet : public Net {
  public:
-  ~OrtNet() override = default;
+  ~OrtNet() override;
   Result<void> Init(const Value& cfg) override;
   Result<void> Deinit() override;
   Result<Span<Tensor>> GetInputTensors() override;
diff --git a/csrc/mmdeploy/net/trt/trt_net.cpp b/csrc/mmdeploy/net/trt/trt_net.cpp
index 359ece60b..8b9b98b5d 100644
--- a/csrc/mmdeploy/net/trt/trt_net.cpp
+++ b/csrc/mmdeploy/net/trt/trt_net.cpp
@@ -79,7 +79,11 @@ static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status
 
 #define TRT_TRY(...) OUTCOME_TRY(trt_try(__VA_ARGS__))
 
-TRTNet::~TRTNet() = default;
+TRTNet::~TRTNet() {
+  CudaDeviceGuard guard(device_);
+  context_.reset();
+  engine_.reset();
+}
 
 static Result<DataType> MapDataType(nvinfer1::DataType dtype) {
   switch (dtype) {
@@ -106,6 +110,7 @@ Result<void> TRTNet::Init(const Value& args) {
     MMDEPLOY_ERROR("TRTNet: device must be a GPU!");
     return Status(eNotSupported);
   }
+  CudaDeviceGuard guard(device_);
   stream_ = context["stream"].get<Stream>();
 
   event_ = Event(device_);
@@ -156,13 +161,10 @@ Result<void> TRTNet::Init(const Value& args) {
   return success();
 }
 
-Result<void> TRTNet::Deinit() {
-  context_.reset();
-  engine_.reset();
-  return success();
-}
+Result<void> TRTNet::Deinit() { return success(); }
 
 Result<void> TRTNet::Reshape(Span<TensorShape> input_shapes) {
+  CudaDeviceGuard guard(device_);
   using namespace trt_detail;
   if (input_shapes.size() != input_tensors_.size()) {
     return Status(eInvalidArgument);
@@ -190,6 +192,7 @@ Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; }
 Result<Span<Tensor>> TRTNet::GetOutputTensors() { return output_tensors_; }
 
 Result<void> TRTNet::Forward() {
+  CudaDeviceGuard guard(device_);
   using namespace trt_detail;
   std::vector<void*> bindings(engine_->getNbBindings());
 
diff --git a/csrc/mmdeploy/net/trt/trt_net.h b/csrc/mmdeploy/net/trt/trt_net.h
index e037265b6..075f27d3f 100644
--- a/csrc/mmdeploy/net/trt/trt_net.h
+++ b/csrc/mmdeploy/net/trt/trt_net.h
@@ -6,6 +6,7 @@
 #include "NvInferRuntime.h"
 #include "mmdeploy/core/mpl/span.h"
 #include "mmdeploy/core/net.h"
+#include "mmdeploy/device/cuda/cuda_device.h"
 
 namespace mmdeploy::framework {