mmdeploy/csrc/device/opencl/opencl_device.cpp

258 lines
7.2 KiB
C++
Raw Normal View History

Merge sdk (#251) * check in cmake * move backend_ops to csrc/backend_ops * check in preprocess, model, some codebase and their c-apis * check in CMakeLists.txt * check in parts of test_csrc * commit everything else * add readme * update core's BUILD_INTERFACE directory * skip codespell on third_party * update trt_net and ort_net's CMakeLists * ignore clion's build directory * check in pybind11 * add onnx.proto. Remove MMDeploy's dependency on ncnn's source code * export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON * remove useless message * target include directory is wrong * change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net * skip install directory * update project's cmake * remove useless code * set CMAKE_BUILD_TYPE to Release by force if it isn't set by user * update custom ops CMakeLists * pass object target's source lists * fix lint end-of-file * fix lint: trailing whitespace * fix codespell hook * remove bicubic_interpolate to csrc/backend_ops/ * set MMDEPLOY_BUILD_SDK OFF * change custom ops build command * add spdlog installation command * update docs on how to checkout pybind11 * move bicubic_interpolate to backend_ops/tensorrt directory * remove useless code * correct cmake * fix typo * fix typo * fix install directory * correct sdk's readme * set cub dir when cuda version < 11.0 * change directory where clang-format will apply to * fix build command * add .clang-format * change clang-format style from google to file * reformat csrc/backend_ops * format sdk's code * turn off clang-format for some files * add -Xcompiler=-fno-gnu-unique * fix trt topk initialize * check in config for sdk demo * update cmake script and csrc's readme * correct config's path * add cuda include directory, otherwise compile failed in case of tensorrt8.2 * clang-format onnx2ncnn.cpp Co-authored-by: zhangli <lzhang329@gmail.com> Co-authored-by: grimoire <yaoqian@sensetime.com>
2021-12-07 10:57:55 +08:00
// Copyright (c) OpenMMLab. All rights reserved.
#if ENABLE_OPENCL && 0
#include "opencl_device.h"
#include <iostream>
#include <mutex>
namespace mmdeploy {
////////////////////////////////////////////////////////////////////////////////
/// OclPlatformImpl
OclPlatformImpl::OclPlatformImpl(cl::Platform platform) : platform_(std::move(platform)) {
platform_.getDevices(CL_DEVICE_TYPE_ALL, &devices_);
queues_.resize(devices_.size());
for (int i = 0; i < devices_.size(); ++i) {
init_flag_.push_back(std::make_unique<std::once_flag>());
}
ctx_ = cl::Context(devices_);
}
shared_ptr<BufferImpl> OclPlatformImpl::CreateBuffer(Device device) {
return std::make_shared<OclBufferImpl>(device);
}
shared_ptr<StreamImpl> OclPlatformImpl::CreateStream(Device device) {
return std::make_shared<OclStreamImpl>(device);
}
shared_ptr<EventImpl> OclPlatformImpl::CreateEvent(Device device) {
return std::make_shared<OclEventImpl>(device);
}
Result<void> OclPlatformImpl::Copy(const void* host_ptr, Buffer dst, size_t size, size_t dst_offset,
Stream stream) {
if (!dst || !stream) {
return Status(eInvalidArgument);
}
auto device = dst.GetDevice();
if (device.platform_id() != GetPlatformId()) {
return Status(eInvalidArgument);
}
if (stream.GetDevice() != device) {
return Status(eInvalidArgument);
}
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto& to = Access::get<OclBufferImpl>(dst).buffer();
auto status = queue.enqueueWriteBuffer(to, false, dst_offset, size, host_ptr);
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclPlatformImpl::Copy(Buffer src, void* host_ptr, size_t size, size_t src_offset,
Stream stream) {
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto& from = Access::get<OclBufferImpl>(src).buffer();
auto status = queue.enqueueReadBuffer(from, false, src_offset, size, host_ptr);
if (status) {
fprintf(stderr, "status = %d\n", (int)status);
}
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclPlatformImpl::Copy(Buffer src, Buffer dst, size_t size, size_t src_offset,
size_t dst_offset, Stream stream) {
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto& from = Access::get<OclBufferImpl>(src).buffer();
auto& to = Access::get<OclBufferImpl>(dst).buffer();
auto status = queue.enqueueCopyBuffer(from, to, src_offset, dst_offset, size);
if (status) {
fprintf(stderr, "status = %d\n", (int)status);
}
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<Stream> OclPlatformImpl::GetDefaultStream(int32_t device_id) {
if (device_id >= queues_.size()) {
return Status(eInvalidArgument);
}
try {
std::call_once(*init_flag_[device_id],
[&] { queues_[device_id] = Stream(GetDevice((device_id))); });
return queues_[device_id];
} catch (...) {
return Status(eFail);
}
}
OclPlatformImpl& gOclPlatform() {
static Platform platform("opencl");
return Access::get<OclPlatformImpl>(platform);
}
////////////////////////////////////////////////////////////////////////////////
/// OclBufferImpl
OclBufferImpl::OclBufferImpl(Device device) : BufferImpl(device) {
memory_ = std::make_shared<OclDeviceMemory>();
}
Result<void> OclBufferImpl::Init(size_t size, Allocator allocator, size_t alignment,
uint64_t flags) {
auto& ctx = gOclPlatform().GetContext();
OUTCOME_TRY(memory_->Init(ctx, size, alignment, flags));
size_ = size;
return success();
}
Result<void> OclBufferImpl::Init(size_t size, std::shared_ptr<void> native, uint64_t flags) {
OUTCOME_TRY(memory_->Init(size, std::move(native), flags));
size_ = size;
return success();
}
void* OclBufferImpl::GetNative(ErrorCode* ec) { return memory_->data(); }
size_t OclBufferImpl::GetSize(ErrorCode* ec) { return size_; }
////////////////////////////////////////////////////////////////////////////////
/// OclStreamImpl
OclStreamImpl::OclStreamImpl(Device device) : StreamImpl(device), queue_(), owned_queue_(false) {}
OclStreamImpl::~OclStreamImpl() {
if (owned_queue_) {
detail::Cast(queue_).~CommandQueue();
queue_ = {};
owned_queue_ = false;
}
external_.reset();
}
Result<void> OclStreamImpl::Init(uint64_t flags) {
auto& platform = gOclPlatform();
auto& ctx = platform.GetContext();
auto& dev = platform.GetNativeDevice(device_.device_id());
new (&queue_) cl::CommandQueue(ctx, dev);
owned_queue_ = true;
return success();
}
Result<void> OclStreamImpl::Init(std::shared_ptr<void> native, uint64_t flags) {
external_ = std::move(native);
queue_ = static_cast<cl_command_queue>(external_.get());
owned_queue_ = false;
return success();
}
Result<void> OclStreamImpl::DependsOn(Event& event) { return Status(eNotSupported); }
Result<void> OclStreamImpl::Query() {
cl::Event event;
queue().enqueueMarkerWithWaitList(nullptr, &event);
auto status = event.getInfo<CL_EVENT_COMMAND_EXECUTION_STATUS>();
if (status == CL_COMPLETE) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclStreamImpl::Wait() {
auto status = queue().finish();
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclStreamImpl::Submit(Kernel& kernel) { return Status(eNotSupported); }
void* OclStreamImpl::GetNative(ErrorCode* ec) { return queue_; }
////////////////////////////////////////////////////////////////////////////////
/// OclEventImpl
OclEventImpl::OclEventImpl(Device device) : EventImpl(device), event_(), owned_event_() {}
OclEventImpl::~OclEventImpl() = default;
Result<void> OclEventImpl::Init(uint64_t flags) { return success(); }
Result<void> OclEventImpl::Init(std::shared_ptr<void> native, uint64_t flags) {
external_ = std::move(native);
event_ = static_cast<cl_event>(external_.get());
owned_event_ = false;
return success();
}
Result<void> OclEventImpl::Query() {
auto status = event().getInfo<CL_EVENT_COMMAND_EXECUTION_STATUS>();
if (status == CL_COMPLETE) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclEventImpl::Record(Stream& stream) {
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto status = queue.enqueueMarkerWithWaitList(nullptr, &event());
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclEventImpl::Wait() {
if (!event_) {
return success();
}
auto status = event().wait();
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
void* OclEventImpl::GetNative(ErrorCode* ec) { return event_; }
////////////////////////////////////////////////////////////////////////////////
/// OclPlatformRegisterer
class OclPlatformRegisterer {
public:
OclPlatformRegisterer() {
gPlatformRegistry().Register([] {
Logger::GetInstance().SetLogLevel(spdlog::level::debug);
return std::make_shared<OclPlatformImpl>(cl::Platform::getDefault());
});
}
};
OclPlatformRegisterer g_ocl_platform_registerer;
} // namespace mmdeploy
#endif