mmdeploy/csrc/mmdeploy/core/device_impl.cpp

409 lines
10 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/core/device_impl.h"
#include <cassert>
#include "mmdeploy/core/device.h"
#include "mmdeploy/core/logger.h"
namespace mmdeploy::framework {
template <typename T>
T SetError(ErrorCode* ec, ErrorCode code, T ret) {
if (ec) {
*ec = code;
}
return ret;
}
////////////////////////////////////////////////////////////////////////////////
/// Device
Device::Device(const char* platform_name, int device_id) {
platform_id_ = gPlatformRegistry().GetPlatformId(platform_name);
device_id_ = device_id;
}
//////////////////////////////////////////////////
/// Platform
int Platform::GetPlatformId() const {
if (impl_) {
return impl_->GetPlatformId();
}
return -1;
}
const char* Platform::GetPlatformName() const {
if (impl_) {
return impl_->GetPlatformName();
}
return "";
}
Platform::Platform(const char* platform_name) {
if (-1 == gPlatformRegistry().GetPlatform(platform_name, this)) {
throw_exception(eInvalidArgument);
}
}
Platform::Platform(int platform_id) {
if (-1 == gPlatformRegistry().GetPlatform(platform_id, this)) {
throw_exception(eInvalidArgument);
}
}
const char* GetPlatformName(PlatformId id) {
if (auto impl = gPlatformRegistry().GetPlatformImpl(id); impl) {
return impl->GetPlatformName();
}
return nullptr;
}
////////////////////////////////////////////////////////////////////////////////
/// Buffer
Buffer::Buffer(Device device, size_t size, Allocator allocator, size_t alignment, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
impl_ = p->CreateBuffer(device);
if (auto r = impl_->Init(size, std::move(allocator), alignment, flags); r.has_error()) {
impl_.reset();
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Buffer::Buffer(Device device, size_t size, void* native, uint64_t flags)
: Buffer(device, size, std::shared_ptr<void>(native, [](void*) {}), flags) {}
Buffer::Buffer(Device device, size_t size, std::shared_ptr<void> native, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
impl_ = p->CreateBuffer(device);
if (auto r = impl_->Init(size, std::move(native), flags); r.has_error()) {
impl_.reset();
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Device Buffer::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
Allocator Buffer::GetAllocator() const { return impl_ ? impl_->GetAllocator() : Allocator{}; }
void* Buffer::GetNative(ErrorCode* ec) const {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
size_t Buffer::GetSize(ErrorCode* ec) const {
return impl_ ? impl_->GetSize(ec) : SetError(ec, eInvalidArgument, 0);
}
Buffer::Buffer(Buffer& buffer, size_t offset, size_t size, uint64_t flags) {
auto impl = buffer.impl_->SubBuffer(offset, size, flags);
if (!impl) {
impl.error().throw_exception();
}
impl_ = std::move(impl).value();
}
#if 0
int Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
Stream stream;
GetDefaultStream(dst.GetDevice(), &stream);
if (!stream) {
return Status(eFail);
}
return stream.Copy(host_ptr, dst, size, dst_offset);
}
int Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
Stream stream;
GetDefaultStream(src.GetDevice(), &stream);
if (!stream) {
return Status(eFail);
}
return stream.Copy(src, host_ptr, size, src_offset);
}
int Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
size_t dst_offset) {
Stream stream;
GetDefaultStream(src.GetDevice(), &stream);
if (!stream) {
return Status(eFail);
}
return stream.Copy(src, dst, size, src_offset, dst_offset);
}
#endif
//////////////////////////////////////////////////
/// Stream
Stream::Stream(Device device, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateStream(device);
if (auto r = impl->Init(flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
MMDEPLOY_ERROR("{}, {}", device.device_id(), device.platform_id());
throw_exception(eInvalidArgument);
}
}
Stream::Stream(Device device, void* native, uint64_t flags)
: Stream(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}
Stream::Stream(Device device, std::shared_ptr<void> native, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateStream(device);
if (auto r = impl->Init(std::move(native), flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Result<void> Stream::Query() {
if (impl_) {
return impl_->Query();
}
return Status(eInvalidArgument);
}
Result<void> Stream::Wait() {
if (impl_) {
return impl_->Wait();
}
return Status(eInvalidArgument);
}
Result<void> Stream::DependsOn(Event& event) {
return impl_ ? impl_->DependsOn(event) : Status(eInvalidArgument);
}
void* Stream::GetNative(ErrorCode* ec) {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
Result<void> Stream::Submit(Kernel& kernel) {
return impl_ ? impl_->Submit(kernel) : Status(eInvalidArgument);
}
Result<void> Stream::Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
size_t dst_offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
if (size == static_cast<size_t>(-1)) {
size = src.GetSize();
}
if (auto p = GetPlatformImpl(GetDevice())) {
return p->Copy(src, dst, size, src_offset, dst_offset, *this);
}
return Status(eInvalidArgument);
}
Result<void> Stream::Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
if (size == static_cast<size_t>(-1)) {
size = dst.GetSize();
}
auto device = GetDevice();
if (auto p = GetPlatformImpl(device)) {
return p->Copy(host_ptr, dst, size, dst_offset, *this);
}
return Status(eInvalidArgument);
}
Result<void> Stream::Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
if (size == static_cast<size_t>(-1)) {
size = src.GetSize();
}
if (auto p = GetPlatformImpl(GetDevice())) {
return p->Copy(src, host_ptr, size, src_offset, *this);
}
return Status(eInvalidArgument);
}
Result<void> Stream::Fill(const Buffer& dst, void* pattern, size_t pattern_size, size_t size,
size_t offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
return Status(eNotSupported);
}
Device Stream::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
Stream Stream::GetDefault(Device device) {
Platform platform(device.platform_id());
assert(platform);
Stream stream = Access::get<PlatformImpl>(platform).GetDefaultStream(device.device_id()).value();
return stream;
}
/////////////////////////////////////////////////
/// Event
Event::Event(Device device, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateEvent(device);
if (auto r = impl->Init(flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Event::Event(Device device, void* native, uint64_t flags)
: Event(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}
Event::Event(Device device, std::shared_ptr<void> native, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateEvent(device);
if (auto r = impl->Init(std::move(native), flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Result<void> Event::Query() { return impl_ ? impl_->Query() : Status(eInvalidArgument); }
Result<void> Event::Wait() { return impl_ ? impl_->Wait() : Status(eInvalidArgument); }
Result<void> Event::Record(Stream& stream) {
return impl_ ? impl_->Record(stream) : Status(eInvalidArgument);
}
void* Event::GetNative(ErrorCode* ec) {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
Device Event::GetDevice() { return impl_ ? impl_->GetDevice() : Device{}; }
/////////////////////////////////////////////////
/// Kernel
Device Kernel::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
void* Kernel::GetNative(ErrorCode* ec) {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
/////////////////////////////////////////////////
/// PlatformRegistry
int PlatformRegistry::Register(Creator creator) {
Platform platform(creator());
auto proposed_id = platform.GetPlatformId();
std::string name = platform.GetPlatformName();
if (proposed_id == -1) {
proposed_id = GetNextId();
platform.impl_->SetPlatformId(proposed_id);
} else if (!IsAvailable(proposed_id)) {
return -1;
}
entries_.push_back({name, proposed_id, platform});
return 0;
}
int PlatformRegistry::AddAlias(const char* name, const char* target) {
aliases_.emplace_back(name, target);
return 0;
}
int PlatformRegistry::GetNextId() {
for (int i = 1;; ++i) {
if (IsAvailable(i)) {
return i;
}
}
}
bool PlatformRegistry::IsAvailable(int id) {
for (const auto& entry : entries_) {
if (entry.id == id) {
return false;
}
}
return true;
}
int PlatformRegistry::GetPlatform(const char* name, Platform* platform) {
for (const auto& alias : aliases_) {
if (name == alias.first) {
name = alias.second.c_str();
break;
}
}
for (const auto& entry : entries_) {
if (entry.name == name) {
*platform = entry.platform;
return 0;
}
}
return -1;
}
int PlatformRegistry::GetPlatform(int id, Platform* platform) {
for (const auto& entry : entries_) {
if (entry.id == id) {
*platform = entry.platform;
return 0;
}
}
return -1;
}
int PlatformRegistry::GetPlatformId(const char* name) {
for (const auto& alias : aliases_) {
if (name == alias.first) {
name = alias.second.c_str();
break;
}
}
for (const auto& entry : entries_) {
if (entry.name == name) {
return entry.id;
}
}
return -1;
}
PlatformImpl* PlatformRegistry::GetPlatformImpl(PlatformId id) {
for (const auto& entry : entries_) {
if (entry.id == id) {
return entry.platform.impl_.get();
}
}
return nullptr;
}
PlatformRegistry& gPlatformRegistry() {
static PlatformRegistry instance;
return instance;
}
} // namespace mmdeploy::framework