add unified device guard (#1855)

pull/1935/head
Li Zhang 2023-03-10 19:16:13 +08:00 committed by Xin Chen
parent c39438658f
commit 0da1ed2311
11 changed files with 93 additions and 10 deletions

View File

@ -7,12 +7,14 @@
#include <functional>
#include <memory>
#include <optional>
#include <ostream>
#include <string>
#include <vector>
#include "mmdeploy/core/macro.h"
#include "mmdeploy/core/mpl/type_traits.h"
#include "mmdeploy/core/status_code.h"
#include "mmdeploy/core/utils/formatter.h"
namespace mmdeploy {
@ -97,6 +99,11 @@ class Device {
return PlatformId(platform_id_);
}
friend std::ostream& operator<<(std::ostream& os, const Device& device) {
os << "(" << device.platform_id_ << ", " << device.device_id_ << ")";
return os;
}
private:
int platform_id_{0};
int device_id_{0};
@ -112,6 +119,9 @@ class MMDEPLOY_API Platform {
// throws if not found
explicit Platform(int platform_id);
// bind device with the current thread
Result<void> Bind(Device device, Device* prev);
// -1 if invalid
int GetPlatformId() const;
@ -135,6 +145,27 @@ class MMDEPLOY_API Platform {
MMDEPLOY_API const char* GetPlatformName(PlatformId id);
class DeviceGuard {
public:
explicit DeviceGuard(Device device) : platform_(device.platform_id()) {
auto r = platform_.Bind(device, &prev_);
if (!r) {
MMDEPLOY_ERROR("failed to bind device {}: {}", device, r.error().message().c_str());
}
}
~DeviceGuard() {
auto r = platform_.Bind(prev_, nullptr);
if (!r) {
MMDEPLOY_ERROR("failed to unbind device {}: {}", prev_, r.error().message().c_str());
}
}
private:
Platform platform_;
Device prev_;
};
class MMDEPLOY_API Stream {
public:
Stream() = default;

View File

@ -54,6 +54,8 @@ Platform::Platform(int platform_id) {
}
}
Result<void> Platform::Bind(Device device, Device* prev) { return impl_->BindDevice(device, prev); }
const char* GetPlatformName(PlatformId id) {
if (auto impl = gPlatformRegistry().GetPlatformImpl(id); impl) {
return impl->GetPlatformName();

View File

@ -27,7 +27,7 @@ class PlatformImpl {
virtual void SetPlatformId(int id) { platform_id_ = id; }
virtual Result<void> SetDevice(Device device) { return success(); };
virtual Result<void> BindDevice(Device device, Device* prev) = 0;
virtual shared_ptr<BufferImpl> CreateBuffer(Device device) = 0;

View File

@ -70,6 +70,14 @@ class CpuHostMemory : public NonCopyable {
////////////////////////////////////////////////////////////////////////////////
/// CpuPlatformImpl
Result<void> CpuPlatformImpl::BindDevice(Device device, Device* prev) {
// do nothing
if (prev) {
*prev = device;
}
return success();
}
shared_ptr<BufferImpl> CpuPlatformImpl::CreateBuffer(Device device) {
return std::make_shared<CpuBufferImpl>(device);
}

View File

@ -17,6 +17,8 @@ class CpuPlatformImpl : public PlatformImpl {
const char* GetPlatformName() const noexcept override;
Result<void> BindDevice(Device device, Device* prev) override;
shared_ptr<BufferImpl> CreateBuffer(Device device) override;
shared_ptr<StreamImpl> CreateStream(Device device) override;

View File

@ -127,6 +127,32 @@ shared_ptr<EventImpl> CudaPlatformImpl::CreateEvent(Device device) {
return std::make_shared<CudaEventImpl>(device);
}
Result<void> CudaPlatformImpl::BindDevice(Device device, Device* prev) {
if (device.platform_id() != platform_id_) {
return Status(eInvalidArgument);
}
// skip null device
if (device.device_id() == -1) {
return success();
}
int prev_device_id = -1;
if (prev) {
CUcontext ctx{};
cuCtxGetCurrent(&ctx);
if (ctx) {
cudaGetDevice(&prev_device_id);
*prev = Device(platform_id_, prev_device_id);
} else {
// cuda is not initialized return a null device as previous
*prev = Device(platform_id_, -1);
}
}
if (device.device_id() != prev_device_id) {
cudaSetDevice(device.device_id());
}
return success();
}
bool CudaPlatformImpl::CheckCopyDevice(const Device& src, const Device& dst, const Device& st) {
return st.is_device() && (src.is_host() || src == st) && (dst.is_host() || dst == st);
}

View File

@ -28,6 +28,8 @@ class CudaPlatformImpl : public PlatformImpl {
const char* GetPlatformName() const noexcept override { return "cuda"; }
Result<void> BindDevice(Device device, Device* prev) override;
shared_ptr<BufferImpl> CreateBuffer(Device device) override;
shared_ptr<StreamImpl> CreateStream(Device device) override;
@ -178,7 +180,9 @@ class CudaDeviceGuard {
if (ctx) {
cudaGetDevice(&prev_device_id_);
}
if (prev_device_id_ != device_id_) cudaSetDevice(device_id_);
if (prev_device_id_ != device_id_) {
cudaSetDevice(device_id_);
}
}
~CudaDeviceGuard() {
if (prev_device_id_ >= 0 && prev_device_id_ != device_id_) {

View File

@ -40,7 +40,7 @@ Result<void> OrtNet::Init(const Value& args) {
auto& context = args["context"];
device_ = context["device"].get<Device>();
stream_ = context["stream"].get<Stream>();
DeviceGuard guard(device_);
auto name = args["name"].get<std::string>();
auto model = context["model"].get<Model>();
@ -150,6 +150,7 @@ static Result<Tensor> AsTensor(Ort::Value& value, const Device& device) {
}
Result<void> OrtNet::Forward() {
DeviceGuard guard(device_);
try {
OUTCOME_TRY(stream_.Wait());
Ort::IoBinding binding(session_);
@ -186,6 +187,11 @@ Result<void> OrtNet::Forward() {
return success();
}
OrtNet::~OrtNet() {
DeviceGuard guard(device_);
session_ = Ort::Session{nullptr};
}
static std::unique_ptr<Net> Create(const Value& args) {
try {
auto p = std::make_unique<OrtNet>();

View File

@ -11,7 +11,7 @@ namespace mmdeploy::framework {
class OrtNet : public Net {
public:
~OrtNet() override = default;
~OrtNet() override;
Result<void> Init(const Value& cfg) override;
Result<void> Deinit() override;
Result<Span<Tensor>> GetInputTensors() override;

View File

@ -79,7 +79,11 @@ static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status
#define TRT_TRY(...) OUTCOME_TRY(trt_try(__VA_ARGS__))
TRTNet::~TRTNet() = default;
TRTNet::~TRTNet() {
CudaDeviceGuard guard(device_);
context_.reset();
engine_.reset();
}
static Result<DataType> MapDataType(nvinfer1::DataType dtype) {
switch (dtype) {
@ -106,6 +110,7 @@ Result<void> TRTNet::Init(const Value& args) {
MMDEPLOY_ERROR("TRTNet: device must be a GPU!");
return Status(eNotSupported);
}
CudaDeviceGuard guard(device_);
stream_ = context["stream"].get<Stream>();
event_ = Event(device_);
@ -156,13 +161,10 @@ Result<void> TRTNet::Init(const Value& args) {
return success();
}
Result<void> TRTNet::Deinit() {
context_.reset();
engine_.reset();
return success();
}
Result<void> TRTNet::Deinit() { return success(); }
Result<void> TRTNet::Reshape(Span<TensorShape> input_shapes) {
CudaDeviceGuard guard(device_);
using namespace trt_detail;
if (input_shapes.size() != input_tensors_.size()) {
return Status(eInvalidArgument);
@ -190,6 +192,7 @@ Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; }
Result<Span<Tensor>> TRTNet::GetOutputTensors() { return output_tensors_; }
Result<void> TRTNet::Forward() {
CudaDeviceGuard guard(device_);
using namespace trt_detail;
std::vector<void*> bindings(engine_->getNbBindings());

View File

@ -6,6 +6,7 @@
#include "NvInferRuntime.h"
#include "mmdeploy/core/mpl/span.h"
#include "mmdeploy/core/net.h"
#include "mmdeploy/device/cuda/cuda_device.h"
namespace mmdeploy::framework {