add unified device guard (#1855)
parent
c39438658f
commit
0da1ed2311
|
@ -7,12 +7,14 @@
|
|||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mmdeploy/core/macro.h"
|
||||
#include "mmdeploy/core/mpl/type_traits.h"
|
||||
#include "mmdeploy/core/status_code.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
|
||||
|
@ -97,6 +99,11 @@ class Device {
|
|||
return PlatformId(platform_id_);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const Device& device) {
|
||||
os << "(" << device.platform_id_ << ", " << device.device_id_ << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
int platform_id_{0};
|
||||
int device_id_{0};
|
||||
|
@ -112,6 +119,9 @@ class MMDEPLOY_API Platform {
|
|||
// throws if not found
|
||||
explicit Platform(int platform_id);
|
||||
|
||||
// bind device with the current thread
|
||||
Result<void> Bind(Device device, Device* prev);
|
||||
|
||||
// -1 if invalid
|
||||
int GetPlatformId() const;
|
||||
|
||||
|
@ -135,6 +145,27 @@ class MMDEPLOY_API Platform {
|
|||
|
||||
MMDEPLOY_API const char* GetPlatformName(PlatformId id);
|
||||
|
||||
class DeviceGuard {
|
||||
public:
|
||||
explicit DeviceGuard(Device device) : platform_(device.platform_id()) {
|
||||
auto r = platform_.Bind(device, &prev_);
|
||||
if (!r) {
|
||||
MMDEPLOY_ERROR("failed to bind device {}: {}", device, r.error().message().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
~DeviceGuard() {
|
||||
auto r = platform_.Bind(prev_, nullptr);
|
||||
if (!r) {
|
||||
MMDEPLOY_ERROR("failed to unbind device {}: {}", prev_, r.error().message().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Platform platform_;
|
||||
Device prev_;
|
||||
};
|
||||
|
||||
class MMDEPLOY_API Stream {
|
||||
public:
|
||||
Stream() = default;
|
||||
|
|
|
@ -54,6 +54,8 @@ Platform::Platform(int platform_id) {
|
|||
}
|
||||
}
|
||||
|
||||
Result<void> Platform::Bind(Device device, Device* prev) { return impl_->BindDevice(device, prev); }
|
||||
|
||||
const char* GetPlatformName(PlatformId id) {
|
||||
if (auto impl = gPlatformRegistry().GetPlatformImpl(id); impl) {
|
||||
return impl->GetPlatformName();
|
||||
|
|
|
@ -27,7 +27,7 @@ class PlatformImpl {
|
|||
|
||||
virtual void SetPlatformId(int id) { platform_id_ = id; }
|
||||
|
||||
virtual Result<void> SetDevice(Device device) { return success(); };
|
||||
virtual Result<void> BindDevice(Device device, Device* prev) = 0;
|
||||
|
||||
virtual shared_ptr<BufferImpl> CreateBuffer(Device device) = 0;
|
||||
|
||||
|
|
|
@ -70,6 +70,14 @@ class CpuHostMemory : public NonCopyable {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// CpuPlatformImpl
|
||||
|
||||
Result<void> CpuPlatformImpl::BindDevice(Device device, Device* prev) {
|
||||
// do nothing
|
||||
if (prev) {
|
||||
*prev = device;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
shared_ptr<BufferImpl> CpuPlatformImpl::CreateBuffer(Device device) {
|
||||
return std::make_shared<CpuBufferImpl>(device);
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@ class CpuPlatformImpl : public PlatformImpl {
|
|||
|
||||
const char* GetPlatformName() const noexcept override;
|
||||
|
||||
Result<void> BindDevice(Device device, Device* prev) override;
|
||||
|
||||
shared_ptr<BufferImpl> CreateBuffer(Device device) override;
|
||||
|
||||
shared_ptr<StreamImpl> CreateStream(Device device) override;
|
||||
|
|
|
@ -127,6 +127,32 @@ shared_ptr<EventImpl> CudaPlatformImpl::CreateEvent(Device device) {
|
|||
return std::make_shared<CudaEventImpl>(device);
|
||||
}
|
||||
|
||||
Result<void> CudaPlatformImpl::BindDevice(Device device, Device* prev) {
|
||||
if (device.platform_id() != platform_id_) {
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
// skip null device
|
||||
if (device.device_id() == -1) {
|
||||
return success();
|
||||
}
|
||||
int prev_device_id = -1;
|
||||
if (prev) {
|
||||
CUcontext ctx{};
|
||||
cuCtxGetCurrent(&ctx);
|
||||
if (ctx) {
|
||||
cudaGetDevice(&prev_device_id);
|
||||
*prev = Device(platform_id_, prev_device_id);
|
||||
} else {
|
||||
// cuda is not initialized return a null device as previous
|
||||
*prev = Device(platform_id_, -1);
|
||||
}
|
||||
}
|
||||
if (device.device_id() != prev_device_id) {
|
||||
cudaSetDevice(device.device_id());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
bool CudaPlatformImpl::CheckCopyDevice(const Device& src, const Device& dst, const Device& st) {
|
||||
return st.is_device() && (src.is_host() || src == st) && (dst.is_host() || dst == st);
|
||||
}
|
||||
|
|
|
@ -28,6 +28,8 @@ class CudaPlatformImpl : public PlatformImpl {
|
|||
|
||||
const char* GetPlatformName() const noexcept override { return "cuda"; }
|
||||
|
||||
Result<void> BindDevice(Device device, Device* prev) override;
|
||||
|
||||
shared_ptr<BufferImpl> CreateBuffer(Device device) override;
|
||||
|
||||
shared_ptr<StreamImpl> CreateStream(Device device) override;
|
||||
|
@ -178,7 +180,9 @@ class CudaDeviceGuard {
|
|||
if (ctx) {
|
||||
cudaGetDevice(&prev_device_id_);
|
||||
}
|
||||
if (prev_device_id_ != device_id_) cudaSetDevice(device_id_);
|
||||
if (prev_device_id_ != device_id_) {
|
||||
cudaSetDevice(device_id_);
|
||||
}
|
||||
}
|
||||
~CudaDeviceGuard() {
|
||||
if (prev_device_id_ >= 0 && prev_device_id_ != device_id_) {
|
||||
|
|
|
@ -40,7 +40,7 @@ Result<void> OrtNet::Init(const Value& args) {
|
|||
auto& context = args["context"];
|
||||
device_ = context["device"].get<Device>();
|
||||
stream_ = context["stream"].get<Stream>();
|
||||
|
||||
DeviceGuard guard(device_);
|
||||
auto name = args["name"].get<std::string>();
|
||||
auto model = context["model"].get<Model>();
|
||||
|
||||
|
@ -150,6 +150,7 @@ static Result<Tensor> AsTensor(Ort::Value& value, const Device& device) {
|
|||
}
|
||||
|
||||
Result<void> OrtNet::Forward() {
|
||||
DeviceGuard guard(device_);
|
||||
try {
|
||||
OUTCOME_TRY(stream_.Wait());
|
||||
Ort::IoBinding binding(session_);
|
||||
|
@ -186,6 +187,11 @@ Result<void> OrtNet::Forward() {
|
|||
return success();
|
||||
}
|
||||
|
||||
OrtNet::~OrtNet() {
|
||||
DeviceGuard guard(device_);
|
||||
session_ = Ort::Session{nullptr};
|
||||
}
|
||||
|
||||
static std::unique_ptr<Net> Create(const Value& args) {
|
||||
try {
|
||||
auto p = std::make_unique<OrtNet>();
|
||||
|
|
|
@ -11,7 +11,7 @@ namespace mmdeploy::framework {
|
|||
|
||||
class OrtNet : public Net {
|
||||
public:
|
||||
~OrtNet() override = default;
|
||||
~OrtNet() override;
|
||||
Result<void> Init(const Value& cfg) override;
|
||||
Result<void> Deinit() override;
|
||||
Result<Span<Tensor>> GetInputTensors() override;
|
||||
|
|
|
@ -79,7 +79,11 @@ static inline Result<void> trt_try(bool code, const char* msg = nullptr, Status
|
|||
|
||||
#define TRT_TRY(...) OUTCOME_TRY(trt_try(__VA_ARGS__))
|
||||
|
||||
TRTNet::~TRTNet() = default;
|
||||
TRTNet::~TRTNet() {
|
||||
CudaDeviceGuard guard(device_);
|
||||
context_.reset();
|
||||
engine_.reset();
|
||||
}
|
||||
|
||||
static Result<DataType> MapDataType(nvinfer1::DataType dtype) {
|
||||
switch (dtype) {
|
||||
|
@ -106,6 +110,7 @@ Result<void> TRTNet::Init(const Value& args) {
|
|||
MMDEPLOY_ERROR("TRTNet: device must be a GPU!");
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
CudaDeviceGuard guard(device_);
|
||||
stream_ = context["stream"].get<Stream>();
|
||||
|
||||
event_ = Event(device_);
|
||||
|
@ -156,13 +161,10 @@ Result<void> TRTNet::Init(const Value& args) {
|
|||
return success();
|
||||
}
|
||||
|
||||
Result<void> TRTNet::Deinit() {
|
||||
context_.reset();
|
||||
engine_.reset();
|
||||
return success();
|
||||
}
|
||||
Result<void> TRTNet::Deinit() { return success(); }
|
||||
|
||||
Result<void> TRTNet::Reshape(Span<TensorShape> input_shapes) {
|
||||
CudaDeviceGuard guard(device_);
|
||||
using namespace trt_detail;
|
||||
if (input_shapes.size() != input_tensors_.size()) {
|
||||
return Status(eInvalidArgument);
|
||||
|
@ -190,6 +192,7 @@ Result<Span<Tensor>> TRTNet::GetInputTensors() { return input_tensors_; }
|
|||
Result<Span<Tensor>> TRTNet::GetOutputTensors() { return output_tensors_; }
|
||||
|
||||
Result<void> TRTNet::Forward() {
|
||||
CudaDeviceGuard guard(device_);
|
||||
using namespace trt_detail;
|
||||
std::vector<void*> bindings(engine_->getNbBindings());
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include "NvInferRuntime.h"
|
||||
#include "mmdeploy/core/mpl/span.h"
|
||||
#include "mmdeploy/core/net.h"
|
||||
#include "mmdeploy/device/cuda/cuda_device.h"
|
||||
|
||||
namespace mmdeploy::framework {
|
||||
|
||||
|
|
Loading…
Reference in New Issue