mmdeploy/csrc/core/device_impl.cpp

384 lines
9.9 KiB
C++
Raw Normal View History

Merge sdk (#251) * check in cmake * move backend_ops to csrc/backend_ops * check in preprocess, model, some codebase and their c-apis * check in CMakeLists.txt * check in parts of test_csrc * commit everything else * add readme * update core's BUILD_INTERFACE directory * skip codespell on third_party * update trt_net and ort_net's CMakeLists * ignore clion's build directory * check in pybind11 * add onnx.proto. Remove MMDeploy's dependency on ncnn's source code * export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON * remove useless message * target include directory is wrong * change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net * skip install directory * update project's cmake * remove useless code * set CMAKE_BUILD_TYPE to Release by force if it isn't set by user * update custom ops CMakeLists * pass object target's source lists * fix lint end-of-file * fix lint: trailing whitespace * fix codespell hook * remove bicubic_interpolate to csrc/backend_ops/ * set MMDEPLOY_BUILD_SDK OFF * change custom ops build command * add spdlog installation command * update docs on how to checkout pybind11 * move bicubic_interpolate to backend_ops/tensorrt directory * remove useless code * correct cmake * fix typo * fix typo * fix install directory * correct sdk's readme * set cub dir when cuda version < 11.0 * change directory where clang-format will apply to * fix build command * add .clang-format * change clang-format style from google to file * reformat csrc/backend_ops * format sdk's code * turn off clang-format for some files * add -Xcompiler=-fno-gnu-unique * fix trt topk initialize * check in config for sdk demo * update cmake script and csrc's readme * correct config's path * add cuda include directory, otherwise compile failed in case of tensorrt8.2 * clang-format onnx2ncnn.cpp Co-authored-by: zhangli <lzhang329@gmail.com> Co-authored-by: grimoire <yaoqian@sensetime.com>
2021-12-07 10:57:55 +08:00
// Copyright (c) OpenMMLab. All rights reserved.
#include "device_impl.h"
#include <cassert>
#include "core/device.h"
#include "core/logger.h"
namespace mmdeploy {
template <typename T>
T SetError(ErrorCode* ec, ErrorCode code, T ret) {
if (ec) {
*ec = code;
}
return ret;
}
////////////////////////////////////////////////////////////////////////////////
/// Device
Device::Device(const char* platform_name, int device_id) {
platform_id_ = gPlatformRegistry().GetPlatformId(platform_name);
device_id_ = device_id;
}
//////////////////////////////////////////////////
/// Platform
int Platform::GetPlatformId() const {
if (impl_) {
return impl_->GetPlatformId();
}
return -1;
}
const char* Platform::GetPlatformName() const {
if (impl_) {
return impl_->GetPlatformName();
}
return "";
}
Platform::Platform(const char* platform_name) {
if (-1 == gPlatformRegistry().GetPlatform(platform_name, this)) {
throw_exception(eInvalidArgument);
}
}
Platform::Platform(int platform_id) {
if (-1 == gPlatformRegistry().GetPlatform(platform_id, this)) {
throw_exception(eInvalidArgument);
}
}
////////////////////////////////////////////////////////////////////////////////
/// Buffer
Buffer::Buffer(Device device, size_t size, Allocator allocator, size_t alignment, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
impl_ = p->CreateBuffer(device);
if (auto r = impl_->Init(size, std::move(allocator), alignment, flags); r.has_error()) {
impl_.reset();
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Buffer::Buffer(Device device, size_t size, void* native, uint64_t flags)
: Buffer(device, size, std::shared_ptr<void>(native, [](void*) {}), flags) {}
Buffer::Buffer(Device device, size_t size, std::shared_ptr<void> native, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
impl_ = p->CreateBuffer(device);
if (auto r = impl_->Init(size, std::move(native), flags); r.has_error()) {
impl_.reset();
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Device Buffer::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
Allocator Buffer::GetAllocator() const { return impl_ ? impl_->GetAllocator() : Allocator{}; }
void* Buffer::GetNative(ErrorCode* ec) const {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
size_t Buffer::GetSize(ErrorCode* ec) const {
return impl_ ? impl_->GetSize(ec) : SetError(ec, eInvalidArgument, 0);
}
Buffer::Buffer(Buffer& buffer, size_t offset, size_t size, uint64_t flags) {
auto impl = buffer.impl_->SubBuffer(offset, size, flags);
if (!impl) {
impl.error().throw_exception();
}
impl_ = std::move(impl).value();
}
#if 0
int Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
Stream stream;
GetDefaultStream(dst.GetDevice(), &stream);
if (!stream) {
return Status(eFail);
}
return stream.Copy(host_ptr, dst, size, dst_offset);
}
int Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
Stream stream;
GetDefaultStream(src.GetDevice(), &stream);
if (!stream) {
return Status(eFail);
}
return stream.Copy(src, host_ptr, size, src_offset);
}
int Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
size_t dst_offset) {
Stream stream;
GetDefaultStream(src.GetDevice(), &stream);
if (!stream) {
return Status(eFail);
}
return stream.Copy(src, dst, size, src_offset, dst_offset);
}
#endif
//////////////////////////////////////////////////
/// Stream
Stream::Stream(Device device, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateStream(device);
if (auto r = impl->Init(flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
ERROR("{}, {}", device.device_id(), device.platform_id());
throw_exception(eInvalidArgument);
}
}
Stream::Stream(Device device, void* native, uint64_t flags)
: Stream(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}
Stream::Stream(Device device, std::shared_ptr<void> native, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateStream(device);
if (auto r = impl->Init(std::move(native), flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Result<void> Stream::Query() {
if (impl_) {
return impl_->Query();
}
return Status(eInvalidArgument);
}
Result<void> Stream::Wait() {
if (impl_) {
return impl_->Wait();
}
return Status(eInvalidArgument);
}
Result<void> Stream::DependsOn(Event& event) {
return impl_ ? impl_->DependsOn(event) : Status(eInvalidArgument);
}
void* Stream::GetNative(ErrorCode* ec) {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
Result<void> Stream::Submit(Kernel& kernel) {
return impl_ ? impl_->Submit(kernel) : Status(eInvalidArgument);
}
Result<void> Stream::Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
size_t dst_offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
if (size == static_cast<size_t>(-1)) {
size = src.GetSize();
}
if (auto p = GetPlatformImpl(GetDevice())) {
return p->Copy(src, dst, size, src_offset, dst_offset, *this);
}
return Status(eInvalidArgument);
}
Result<void> Stream::Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
if (size == static_cast<size_t>(-1)) {
size = dst.GetSize();
}
auto device = GetDevice();
if (auto p = GetPlatformImpl(device)) {
return p->Copy(host_ptr, dst, size, dst_offset, *this);
}
return Status(eInvalidArgument);
}
Result<void> Stream::Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
if (size == static_cast<size_t>(-1)) {
size = src.GetSize();
}
if (auto p = GetPlatformImpl(GetDevice())) {
return p->Copy(src, host_ptr, size, src_offset, *this);
}
return Status(eInvalidArgument);
}
Result<void> Stream::Fill(const Buffer& dst, void* pattern, size_t pattern_size, size_t size,
size_t offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
return Status(eNotSupported);
}
Device Stream::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
Stream Stream::GetDefault(Device device) {
Platform platform(device.platform_id());
assert(platform);
Stream stream = Access::get<PlatformImpl>(platform).GetDefaultStream(device.device_id()).value();
return stream;
}
/////////////////////////////////////////////////
/// Event
Event::Event(Device device, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateEvent(device);
if (auto r = impl->Init(flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Event::Event(Device device, void* native, uint64_t flags)
: Event(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}
Event::Event(Device device, std::shared_ptr<void> native, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateEvent(device);
if (auto r = impl->Init(std::move(native), flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Result<void> Event::Query() { return impl_ ? impl_->Query() : Status(eInvalidArgument); }
Result<void> Event::Wait() { return impl_ ? impl_->Wait() : Status(eInvalidArgument); }
Result<void> Event::Record(Stream& stream) {
return impl_ ? impl_->Record(stream) : Status(eInvalidArgument);
}
void* Event::GetNative(ErrorCode* ec) {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
Device Event::GetDevice() { return impl_ ? impl_->GetDevice() : Device{}; }
/////////////////////////////////////////////////
/// Kernel
Device Kernel::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
void* Kernel::GetNative(ErrorCode* ec) {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
/////////////////////////////////////////////////
/// PlatformRegistry
int PlatformRegistry::Register(Creator creator) {
Platform platform(creator());
auto proposed_id = platform.GetPlatformId();
std::string name = platform.GetPlatformName();
if (proposed_id == -1) {
proposed_id = GetNextId();
platform.impl_->SetPlatformId(proposed_id);
} else if (!IsAvailable(proposed_id)) {
return -1;
}
entries_.push_back({name, proposed_id, platform});
return 0;
}
int PlatformRegistry::GetNextId() {
for (int i = 1;; ++i) {
if (IsAvailable(i)) {
return i;
}
}
}
bool PlatformRegistry::IsAvailable(int id) {
for (const auto& entry : entries_) {
if (entry.id == id) {
return false;
}
}
return true;
}
int PlatformRegistry::GetPlatform(const char* name, Platform* platform) {
for (const auto& entry : entries_) {
if (entry.name == name) {
*platform = entry.platform;
return 0;
}
}
return -1;
}
int PlatformRegistry::GetPlatform(int id, Platform* platform) {
for (const auto& entry : entries_) {
if (entry.id == id) {
*platform = entry.platform;
return 0;
}
}
return -1;
}
int PlatformRegistry::GetPlatformId(const char* name) {
for (const auto& entry : entries_) {
if (entry.name == name) {
return entry.id;
}
}
return -1;
}
PlatformImpl* PlatformRegistry::GetPlatformImpl(PlatformId id) {
for (const auto& entry : entries_) {
if (entry.id == id) {
return entry.platform.impl_.get();
}
}
return nullptr;
}
PlatformRegistry& gPlatformRegistry() {
static PlatformRegistry instance;
return instance;
}
} // namespace mmdeploy