mmdeploy/csrc/device/opencl/opencl_device.cpp
lvhan028 36124f6205
Merge sdk (#251)
* check in cmake

* move backend_ops to csrc/backend_ops

* check in preprocess, model, some codebase and their c-apis

* check in CMakeLists.txt

* check in parts of test_csrc

* commit everything else

* add readme

* update core's BUILD_INTERFACE directory

* skip codespell on third_party

* update trt_net and ort_net's CMakeLists

* ignore clion's build directory

* check in pybind11

* add onnx.proto. Remove MMDeploy's dependency on ncnn's source code

* export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON

* remove useless message

* target include directory is wrong

* change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net

* skip install directory

* update project's cmake

* remove useless code

* set CMAKE_BUILD_TYPE to Release by force if it isn't set by user

* update custom ops CMakeLists

* pass object target's source lists

* fix lint end-of-file

* fix lint: trailing whitespace

* fix codespell hook

* remove bicubic_interpolate to csrc/backend_ops/

* set MMDEPLOY_BUILD_SDK OFF

* change custom ops build command

* add spdlog installation command

* update docs on how to checkout pybind11

* move bicubic_interpolate to backend_ops/tensorrt directory

* remove useless code

* correct cmake

* fix typo

* fix typo

* fix install directory

* correct sdk's readme

* set cub dir when cuda version < 11.0

* change directory where clang-format will apply to

* fix build command

* add .clang-format

* change clang-format style from google to file

* reformat csrc/backend_ops

* format sdk's code

* turn off clang-format for some files

* add -Xcompiler=-fno-gnu-unique

* fix trt topk initialize

* check in config for sdk demo

* update cmake script and csrc's readme

* correct config's path

* add cuda include directory, otherwise compile failed in case of tensorrt8.2

* clang-format onnx2ncnn.cpp

Co-authored-by: zhangli <lzhang329@gmail.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
2021-12-07 10:57:55 +08:00

258 lines
7.2 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#if ENABLE_OPENCL && 0
#include "opencl_device.h"
#include <iostream>
#include <mutex>
namespace mmdeploy {
////////////////////////////////////////////////////////////////////////////////
/// OclPlatformImpl
OclPlatformImpl::OclPlatformImpl(cl::Platform platform) : platform_(std::move(platform)) {
platform_.getDevices(CL_DEVICE_TYPE_ALL, &devices_);
queues_.resize(devices_.size());
for (int i = 0; i < devices_.size(); ++i) {
init_flag_.push_back(std::make_unique<std::once_flag>());
}
ctx_ = cl::Context(devices_);
}
shared_ptr<BufferImpl> OclPlatformImpl::CreateBuffer(Device device) {
return std::make_shared<OclBufferImpl>(device);
}
shared_ptr<StreamImpl> OclPlatformImpl::CreateStream(Device device) {
return std::make_shared<OclStreamImpl>(device);
}
shared_ptr<EventImpl> OclPlatformImpl::CreateEvent(Device device) {
return std::make_shared<OclEventImpl>(device);
}
Result<void> OclPlatformImpl::Copy(const void* host_ptr, Buffer dst, size_t size, size_t dst_offset,
Stream stream) {
if (!dst || !stream) {
return Status(eInvalidArgument);
}
auto device = dst.GetDevice();
if (device.platform_id() != GetPlatformId()) {
return Status(eInvalidArgument);
}
if (stream.GetDevice() != device) {
return Status(eInvalidArgument);
}
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto& to = Access::get<OclBufferImpl>(dst).buffer();
auto status = queue.enqueueWriteBuffer(to, false, dst_offset, size, host_ptr);
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclPlatformImpl::Copy(Buffer src, void* host_ptr, size_t size, size_t src_offset,
Stream stream) {
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto& from = Access::get<OclBufferImpl>(src).buffer();
auto status = queue.enqueueReadBuffer(from, false, src_offset, size, host_ptr);
if (status) {
fprintf(stderr, "status = %d\n", (int)status);
}
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclPlatformImpl::Copy(Buffer src, Buffer dst, size_t size, size_t src_offset,
size_t dst_offset, Stream stream) {
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto& from = Access::get<OclBufferImpl>(src).buffer();
auto& to = Access::get<OclBufferImpl>(dst).buffer();
auto status = queue.enqueueCopyBuffer(from, to, src_offset, dst_offset, size);
if (status) {
fprintf(stderr, "status = %d\n", (int)status);
}
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<Stream> OclPlatformImpl::GetDefaultStream(int32_t device_id) {
if (device_id >= queues_.size()) {
return Status(eInvalidArgument);
}
try {
std::call_once(*init_flag_[device_id],
[&] { queues_[device_id] = Stream(GetDevice((device_id))); });
return queues_[device_id];
} catch (...) {
return Status(eFail);
}
}
OclPlatformImpl& gOclPlatform() {
static Platform platform("opencl");
return Access::get<OclPlatformImpl>(platform);
}
////////////////////////////////////////////////////////////////////////////////
/// OclBufferImpl
OclBufferImpl::OclBufferImpl(Device device) : BufferImpl(device) {
memory_ = std::make_shared<OclDeviceMemory>();
}
Result<void> OclBufferImpl::Init(size_t size, Allocator allocator, size_t alignment,
uint64_t flags) {
auto& ctx = gOclPlatform().GetContext();
OUTCOME_TRY(memory_->Init(ctx, size, alignment, flags));
size_ = size;
return success();
}
Result<void> OclBufferImpl::Init(size_t size, std::shared_ptr<void> native, uint64_t flags) {
OUTCOME_TRY(memory_->Init(size, std::move(native), flags));
size_ = size;
return success();
}
void* OclBufferImpl::GetNative(ErrorCode* ec) { return memory_->data(); }
size_t OclBufferImpl::GetSize(ErrorCode* ec) { return size_; }
////////////////////////////////////////////////////////////////////////////////
/// OclStreamImpl
OclStreamImpl::OclStreamImpl(Device device) : StreamImpl(device), queue_(), owned_queue_(false) {}
OclStreamImpl::~OclStreamImpl() {
if (owned_queue_) {
detail::Cast(queue_).~CommandQueue();
queue_ = {};
owned_queue_ = false;
}
external_.reset();
}
Result<void> OclStreamImpl::Init(uint64_t flags) {
auto& platform = gOclPlatform();
auto& ctx = platform.GetContext();
auto& dev = platform.GetNativeDevice(device_.device_id());
new (&queue_) cl::CommandQueue(ctx, dev);
owned_queue_ = true;
return success();
}
Result<void> OclStreamImpl::Init(std::shared_ptr<void> native, uint64_t flags) {
external_ = std::move(native);
queue_ = static_cast<cl_command_queue>(external_.get());
owned_queue_ = false;
return success();
}
Result<void> OclStreamImpl::DependsOn(Event& event) { return Status(eNotSupported); }
Result<void> OclStreamImpl::Query() {
cl::Event event;
queue().enqueueMarkerWithWaitList(nullptr, &event);
auto status = event.getInfo<CL_EVENT_COMMAND_EXECUTION_STATUS>();
if (status == CL_COMPLETE) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclStreamImpl::Wait() {
auto status = queue().finish();
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclStreamImpl::Submit(Kernel& kernel) { return Status(eNotSupported); }
void* OclStreamImpl::GetNative(ErrorCode* ec) { return queue_; }
////////////////////////////////////////////////////////////////////////////////
/// OclEventImpl
OclEventImpl::OclEventImpl(Device device) : EventImpl(device), event_(), owned_event_() {}
OclEventImpl::~OclEventImpl() = default;
Result<void> OclEventImpl::Init(uint64_t flags) { return success(); }
Result<void> OclEventImpl::Init(std::shared_ptr<void> native, uint64_t flags) {
external_ = std::move(native);
event_ = static_cast<cl_event>(external_.get());
owned_event_ = false;
return success();
}
Result<void> OclEventImpl::Query() {
auto status = event().getInfo<CL_EVENT_COMMAND_EXECUTION_STATUS>();
if (status == CL_COMPLETE) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclEventImpl::Record(Stream& stream) {
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto status = queue.enqueueMarkerWithWaitList(nullptr, &event());
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclEventImpl::Wait() {
if (!event_) {
return success();
}
auto status = event().wait();
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
void* OclEventImpl::GetNative(ErrorCode* ec) { return event_; }
////////////////////////////////////////////////////////////////////////////////
/// OclPlatformRegisterer
class OclPlatformRegisterer {
public:
OclPlatformRegisterer() {
gPlatformRegistry().Register([] {
Logger::GetInstance().SetLogLevel(spdlog::level::debug);
return std::make_shared<OclPlatformImpl>(cl::Platform::getDefault());
});
}
};
OclPlatformRegisterer g_ocl_platform_registerer;
} // namespace mmdeploy
#endif