mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* check in cmake * move backend_ops to csrc/backend_ops * check in preprocess, model, some codebase and their c-apis * check in CMakeLists.txt * check in parts of test_csrc * commit everything else * add readme * update core's BUILD_INTERFACE directory * skip codespell on third_party * update trt_net and ort_net's CMakeLists * ignore clion's build directory * check in pybind11 * add onnx.proto. Remove MMDeploy's dependency on ncnn's source code * export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON * remove useless message * target include directory is wrong * change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net * skip install directory * update project's cmake * remove useless code * set CMAKE_BUILD_TYPE to Release by force if it isn't set by user * update custom ops CMakeLists * pass object target's source lists * fix lint end-of-file * fix lint: trailing whitespace * fix codespell hook * remove bicubic_interpolate to csrc/backend_ops/ * set MMDEPLOY_BUILD_SDK OFF * change custom ops build command * add spdlog installation command * update docs on how to checkout pybind11 * move bicubic_interpolate to backend_ops/tensorrt directory * remove useless code * correct cmake * fix typo * fix typo * fix install directory * correct sdk's readme * set cub dir when cuda version < 11.0 * change directory where clang-format will apply to * fix build command * add .clang-format * change clang-format style from google to file * reformat csrc/backend_ops * format sdk's code * turn off clang-format for some files * add -Xcompiler=-fno-gnu-unique * fix trt topk initialize * check in config for sdk demo * update cmake script and csrc's readme * correct config's path * add cuda include directory, otherwise compile failed in case of tensorrt8.2 * clang-format onnx2ncnn.cpp Co-authored-by: zhangli <lzhang329@gmail.com> Co-authored-by: grimoire <yaoqian@sensetime.com>
384 lines
9.9 KiB
C++
384 lines
9.9 KiB
C++
// Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
#include "device_impl.h"
|
|
|
|
#include <cassert>
|
|
|
|
#include "core/device.h"
|
|
#include "core/logger.h"
|
|
|
|
namespace mmdeploy {
|
|
|
|
template <typename T>
|
|
T SetError(ErrorCode* ec, ErrorCode code, T ret) {
|
|
if (ec) {
|
|
*ec = code;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
/// Device
|
|
|
|
Device::Device(const char* platform_name, int device_id) {
|
|
platform_id_ = gPlatformRegistry().GetPlatformId(platform_name);
|
|
device_id_ = device_id;
|
|
}
|
|
|
|
//////////////////////////////////////////////////
|
|
/// Platform
|
|
|
|
int Platform::GetPlatformId() const {
|
|
if (impl_) {
|
|
return impl_->GetPlatformId();
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
const char* Platform::GetPlatformName() const {
|
|
if (impl_) {
|
|
return impl_->GetPlatformName();
|
|
}
|
|
return "";
|
|
}
|
|
|
|
Platform::Platform(const char* platform_name) {
|
|
if (-1 == gPlatformRegistry().GetPlatform(platform_name, this)) {
|
|
throw_exception(eInvalidArgument);
|
|
}
|
|
}
|
|
|
|
Platform::Platform(int platform_id) {
|
|
if (-1 == gPlatformRegistry().GetPlatform(platform_id, this)) {
|
|
throw_exception(eInvalidArgument);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
/// Buffer
|
|
|
|
Buffer::Buffer(Device device, size_t size, Allocator allocator, size_t alignment, uint64_t flags) {
|
|
if (auto p = GetPlatformImpl(device)) {
|
|
impl_ = p->CreateBuffer(device);
|
|
if (auto r = impl_->Init(size, std::move(allocator), alignment, flags); r.has_error()) {
|
|
impl_.reset();
|
|
r.error().throw_exception();
|
|
}
|
|
} else {
|
|
throw_exception(eInvalidArgument);
|
|
}
|
|
}
|
|
|
|
Buffer::Buffer(Device device, size_t size, void* native, uint64_t flags)
|
|
: Buffer(device, size, std::shared_ptr<void>(native, [](void*) {}), flags) {}
|
|
|
|
Buffer::Buffer(Device device, size_t size, std::shared_ptr<void> native, uint64_t flags) {
|
|
if (auto p = GetPlatformImpl(device)) {
|
|
impl_ = p->CreateBuffer(device);
|
|
if (auto r = impl_->Init(size, std::move(native), flags); r.has_error()) {
|
|
impl_.reset();
|
|
r.error().throw_exception();
|
|
}
|
|
} else {
|
|
throw_exception(eInvalidArgument);
|
|
}
|
|
}
|
|
|
|
Device Buffer::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
|
|
|
|
Allocator Buffer::GetAllocator() const { return impl_ ? impl_->GetAllocator() : Allocator{}; }
|
|
|
|
void* Buffer::GetNative(ErrorCode* ec) const {
|
|
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
|
|
}
|
|
|
|
size_t Buffer::GetSize(ErrorCode* ec) const {
|
|
return impl_ ? impl_->GetSize(ec) : SetError(ec, eInvalidArgument, 0);
|
|
}
|
|
|
|
Buffer::Buffer(Buffer& buffer, size_t offset, size_t size, uint64_t flags) {
|
|
auto impl = buffer.impl_->SubBuffer(offset, size, flags);
|
|
if (!impl) {
|
|
impl.error().throw_exception();
|
|
}
|
|
impl_ = std::move(impl).value();
|
|
}
|
|
|
|
#if 0
|
|
int Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
|
|
Stream stream;
|
|
GetDefaultStream(dst.GetDevice(), &stream);
|
|
if (!stream) {
|
|
return Status(eFail);
|
|
}
|
|
return stream.Copy(host_ptr, dst, size, dst_offset);
|
|
}
|
|
int Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
|
|
Stream stream;
|
|
GetDefaultStream(src.GetDevice(), &stream);
|
|
if (!stream) {
|
|
return Status(eFail);
|
|
}
|
|
return stream.Copy(src, host_ptr, size, src_offset);
|
|
}
|
|
int Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
|
|
size_t dst_offset) {
|
|
Stream stream;
|
|
GetDefaultStream(src.GetDevice(), &stream);
|
|
if (!stream) {
|
|
return Status(eFail);
|
|
}
|
|
return stream.Copy(src, dst, size, src_offset, dst_offset);
|
|
}
|
|
#endif
|
|
|
|
//////////////////////////////////////////////////
|
|
/// Stream
|
|
|
|
Stream::Stream(Device device, uint64_t flags) {
|
|
if (auto p = GetPlatformImpl(device)) {
|
|
auto impl = p->CreateStream(device);
|
|
if (auto r = impl->Init(flags)) {
|
|
impl_ = std::move(impl);
|
|
} else {
|
|
r.error().throw_exception();
|
|
}
|
|
} else {
|
|
ERROR("{}, {}", device.device_id(), device.platform_id());
|
|
throw_exception(eInvalidArgument);
|
|
}
|
|
}
|
|
|
|
Stream::Stream(Device device, void* native, uint64_t flags)
|
|
: Stream(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}
|
|
|
|
Stream::Stream(Device device, std::shared_ptr<void> native, uint64_t flags) {
|
|
if (auto p = GetPlatformImpl(device)) {
|
|
auto impl = p->CreateStream(device);
|
|
if (auto r = impl->Init(std::move(native), flags)) {
|
|
impl_ = std::move(impl);
|
|
} else {
|
|
r.error().throw_exception();
|
|
}
|
|
} else {
|
|
throw_exception(eInvalidArgument);
|
|
}
|
|
}
|
|
|
|
Result<void> Stream::Query() {
|
|
if (impl_) {
|
|
return impl_->Query();
|
|
}
|
|
return Status(eInvalidArgument);
|
|
}
|
|
|
|
Result<void> Stream::Wait() {
|
|
if (impl_) {
|
|
return impl_->Wait();
|
|
}
|
|
return Status(eInvalidArgument);
|
|
}
|
|
|
|
Result<void> Stream::DependsOn(Event& event) {
|
|
return impl_ ? impl_->DependsOn(event) : Status(eInvalidArgument);
|
|
}
|
|
|
|
void* Stream::GetNative(ErrorCode* ec) {
|
|
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
|
|
}
|
|
|
|
Result<void> Stream::Submit(Kernel& kernel) {
|
|
return impl_ ? impl_->Submit(kernel) : Status(eInvalidArgument);
|
|
}
|
|
|
|
Result<void> Stream::Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
|
|
size_t dst_offset) {
|
|
if (!impl_) {
|
|
return Status(eInvalidArgument);
|
|
}
|
|
if (size == static_cast<size_t>(-1)) {
|
|
size = src.GetSize();
|
|
}
|
|
if (auto p = GetPlatformImpl(GetDevice())) {
|
|
return p->Copy(src, dst, size, src_offset, dst_offset, *this);
|
|
}
|
|
return Status(eInvalidArgument);
|
|
}
|
|
|
|
Result<void> Stream::Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
|
|
if (!impl_) {
|
|
return Status(eInvalidArgument);
|
|
}
|
|
if (size == static_cast<size_t>(-1)) {
|
|
size = dst.GetSize();
|
|
}
|
|
auto device = GetDevice();
|
|
if (auto p = GetPlatformImpl(device)) {
|
|
return p->Copy(host_ptr, dst, size, dst_offset, *this);
|
|
}
|
|
return Status(eInvalidArgument);
|
|
}
|
|
|
|
Result<void> Stream::Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
|
|
if (!impl_) {
|
|
return Status(eInvalidArgument);
|
|
}
|
|
if (size == static_cast<size_t>(-1)) {
|
|
size = src.GetSize();
|
|
}
|
|
if (auto p = GetPlatformImpl(GetDevice())) {
|
|
return p->Copy(src, host_ptr, size, src_offset, *this);
|
|
}
|
|
return Status(eInvalidArgument);
|
|
}
|
|
|
|
Result<void> Stream::Fill(const Buffer& dst, void* pattern, size_t pattern_size, size_t size,
|
|
size_t offset) {
|
|
if (!impl_) {
|
|
return Status(eInvalidArgument);
|
|
}
|
|
return Status(eNotSupported);
|
|
}
|
|
|
|
Device Stream::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
|
|
|
|
Stream Stream::GetDefault(Device device) {
|
|
Platform platform(device.platform_id());
|
|
assert(platform);
|
|
Stream stream = Access::get<PlatformImpl>(platform).GetDefaultStream(device.device_id()).value();
|
|
return stream;
|
|
}
|
|
|
|
/////////////////////////////////////////////////
|
|
/// Event
|
|
|
|
Event::Event(Device device, uint64_t flags) {
|
|
if (auto p = GetPlatformImpl(device)) {
|
|
auto impl = p->CreateEvent(device);
|
|
if (auto r = impl->Init(flags)) {
|
|
impl_ = std::move(impl);
|
|
} else {
|
|
r.error().throw_exception();
|
|
}
|
|
} else {
|
|
throw_exception(eInvalidArgument);
|
|
}
|
|
}
|
|
|
|
Event::Event(Device device, void* native, uint64_t flags)
|
|
: Event(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}
|
|
|
|
Event::Event(Device device, std::shared_ptr<void> native, uint64_t flags) {
|
|
if (auto p = GetPlatformImpl(device)) {
|
|
auto impl = p->CreateEvent(device);
|
|
if (auto r = impl->Init(std::move(native), flags)) {
|
|
impl_ = std::move(impl);
|
|
} else {
|
|
r.error().throw_exception();
|
|
}
|
|
} else {
|
|
throw_exception(eInvalidArgument);
|
|
}
|
|
}
|
|
|
|
Result<void> Event::Query() { return impl_ ? impl_->Query() : Status(eInvalidArgument); }
|
|
|
|
Result<void> Event::Wait() { return impl_ ? impl_->Wait() : Status(eInvalidArgument); }
|
|
|
|
Result<void> Event::Record(Stream& stream) {
|
|
return impl_ ? impl_->Record(stream) : Status(eInvalidArgument);
|
|
}
|
|
|
|
void* Event::GetNative(ErrorCode* ec) {
|
|
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
|
|
}
|
|
|
|
Device Event::GetDevice() { return impl_ ? impl_->GetDevice() : Device{}; }
|
|
|
|
/////////////////////////////////////////////////
|
|
/// Kernel
|
|
|
|
Device Kernel::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
|
|
|
|
void* Kernel::GetNative(ErrorCode* ec) {
|
|
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
|
|
}
|
|
|
|
/////////////////////////////////////////////////
|
|
/// PlatformRegistry
|
|
|
|
int PlatformRegistry::Register(Creator creator) {
|
|
Platform platform(creator());
|
|
auto proposed_id = platform.GetPlatformId();
|
|
std::string name = platform.GetPlatformName();
|
|
if (proposed_id == -1) {
|
|
proposed_id = GetNextId();
|
|
platform.impl_->SetPlatformId(proposed_id);
|
|
} else if (!IsAvailable(proposed_id)) {
|
|
return -1;
|
|
}
|
|
entries_.push_back({name, proposed_id, platform});
|
|
return 0;
|
|
}
|
|
|
|
int PlatformRegistry::GetNextId() {
|
|
for (int i = 1;; ++i) {
|
|
if (IsAvailable(i)) {
|
|
return i;
|
|
}
|
|
}
|
|
}
|
|
|
|
bool PlatformRegistry::IsAvailable(int id) {
|
|
for (const auto& entry : entries_) {
|
|
if (entry.id == id) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
int PlatformRegistry::GetPlatform(const char* name, Platform* platform) {
|
|
for (const auto& entry : entries_) {
|
|
if (entry.name == name) {
|
|
*platform = entry.platform;
|
|
return 0;
|
|
}
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
int PlatformRegistry::GetPlatform(int id, Platform* platform) {
|
|
for (const auto& entry : entries_) {
|
|
if (entry.id == id) {
|
|
*platform = entry.platform;
|
|
return 0;
|
|
}
|
|
}
|
|
return -1;
|
|
}
|
|
int PlatformRegistry::GetPlatformId(const char* name) {
|
|
for (const auto& entry : entries_) {
|
|
if (entry.name == name) {
|
|
return entry.id;
|
|
}
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
PlatformImpl* PlatformRegistry::GetPlatformImpl(PlatformId id) {
|
|
for (const auto& entry : entries_) {
|
|
if (entry.id == id) {
|
|
return entry.platform.impl_.get();
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
PlatformRegistry& gPlatformRegistry() {
|
|
static PlatformRegistry instance;
|
|
return instance;
|
|
}
|
|
|
|
} // namespace mmdeploy
|