mmdeploy/csrc/device/opencl/opencl_device.cpp

258 lines
7.2 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#if ENABLE_OPENCL && 0
#include "opencl_device.h"
#include <iostream>
#include <mutex>
namespace mmdeploy {
////////////////////////////////////////////////////////////////////////////////
/// OclPlatformImpl
OclPlatformImpl::OclPlatformImpl(cl::Platform platform) : platform_(std::move(platform)) {
platform_.getDevices(CL_DEVICE_TYPE_ALL, &devices_);
queues_.resize(devices_.size());
for (int i = 0; i < devices_.size(); ++i) {
init_flag_.push_back(std::make_unique<std::once_flag>());
}
ctx_ = cl::Context(devices_);
}
shared_ptr<BufferImpl> OclPlatformImpl::CreateBuffer(Device device) {
return std::make_shared<OclBufferImpl>(device);
}
shared_ptr<StreamImpl> OclPlatformImpl::CreateStream(Device device) {
return std::make_shared<OclStreamImpl>(device);
}
shared_ptr<EventImpl> OclPlatformImpl::CreateEvent(Device device) {
return std::make_shared<OclEventImpl>(device);
}
Result<void> OclPlatformImpl::Copy(const void* host_ptr, Buffer dst, size_t size, size_t dst_offset,
Stream stream) {
if (!dst || !stream) {
return Status(eInvalidArgument);
}
auto device = dst.GetDevice();
if (device.platform_id() != GetPlatformId()) {
return Status(eInvalidArgument);
}
if (stream.GetDevice() != device) {
return Status(eInvalidArgument);
}
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto& to = Access::get<OclBufferImpl>(dst).buffer();
auto status = queue.enqueueWriteBuffer(to, false, dst_offset, size, host_ptr);
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclPlatformImpl::Copy(Buffer src, void* host_ptr, size_t size, size_t src_offset,
Stream stream) {
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto& from = Access::get<OclBufferImpl>(src).buffer();
auto status = queue.enqueueReadBuffer(from, false, src_offset, size, host_ptr);
if (status) {
fprintf(stderr, "status = %d\n", (int)status);
}
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclPlatformImpl::Copy(Buffer src, Buffer dst, size_t size, size_t src_offset,
size_t dst_offset, Stream stream) {
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto& from = Access::get<OclBufferImpl>(src).buffer();
auto& to = Access::get<OclBufferImpl>(dst).buffer();
auto status = queue.enqueueCopyBuffer(from, to, src_offset, dst_offset, size);
if (status) {
fprintf(stderr, "status = %d\n", (int)status);
}
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<Stream> OclPlatformImpl::GetDefaultStream(int32_t device_id) {
if (device_id >= queues_.size()) {
return Status(eInvalidArgument);
}
try {
std::call_once(*init_flag_[device_id],
[&] { queues_[device_id] = Stream(GetDevice((device_id))); });
return queues_[device_id];
} catch (...) {
return Status(eFail);
}
}
OclPlatformImpl& gOclPlatform() {
static Platform platform("opencl");
return Access::get<OclPlatformImpl>(platform);
}
////////////////////////////////////////////////////////////////////////////////
/// OclBufferImpl
OclBufferImpl::OclBufferImpl(Device device) : BufferImpl(device) {
memory_ = std::make_shared<OclDeviceMemory>();
}
Result<void> OclBufferImpl::Init(size_t size, Allocator allocator, size_t alignment,
uint64_t flags) {
auto& ctx = gOclPlatform().GetContext();
OUTCOME_TRY(memory_->Init(ctx, size, alignment, flags));
size_ = size;
return success();
}
Result<void> OclBufferImpl::Init(size_t size, std::shared_ptr<void> native, uint64_t flags) {
OUTCOME_TRY(memory_->Init(size, std::move(native), flags));
size_ = size;
return success();
}
void* OclBufferImpl::GetNative(ErrorCode* ec) { return memory_->data(); }
size_t OclBufferImpl::GetSize(ErrorCode* ec) { return size_; }
////////////////////////////////////////////////////////////////////////////////
/// OclStreamImpl
OclStreamImpl::OclStreamImpl(Device device) : StreamImpl(device), queue_(), owned_queue_(false) {}
OclStreamImpl::~OclStreamImpl() {
if (owned_queue_) {
detail::Cast(queue_).~CommandQueue();
queue_ = {};
owned_queue_ = false;
}
external_.reset();
}
Result<void> OclStreamImpl::Init(uint64_t flags) {
auto& platform = gOclPlatform();
auto& ctx = platform.GetContext();
auto& dev = platform.GetNativeDevice(device_.device_id());
new (&queue_) cl::CommandQueue(ctx, dev);
owned_queue_ = true;
return success();
}
Result<void> OclStreamImpl::Init(std::shared_ptr<void> native, uint64_t flags) {
external_ = std::move(native);
queue_ = static_cast<cl_command_queue>(external_.get());
owned_queue_ = false;
return success();
}
Result<void> OclStreamImpl::DependsOn(Event& event) { return Status(eNotSupported); }
Result<void> OclStreamImpl::Query() {
cl::Event event;
queue().enqueueMarkerWithWaitList(nullptr, &event);
auto status = event.getInfo<CL_EVENT_COMMAND_EXECUTION_STATUS>();
if (status == CL_COMPLETE) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclStreamImpl::Wait() {
auto status = queue().finish();
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclStreamImpl::Submit(Kernel& kernel) { return Status(eNotSupported); }
void* OclStreamImpl::GetNative(ErrorCode* ec) { return queue_; }
////////////////////////////////////////////////////////////////////////////////
/// OclEventImpl
OclEventImpl::OclEventImpl(Device device) : EventImpl(device), event_(), owned_event_() {}
OclEventImpl::~OclEventImpl() = default;
Result<void> OclEventImpl::Init(uint64_t flags) { return success(); }
Result<void> OclEventImpl::Init(std::shared_ptr<void> native, uint64_t flags) {
external_ = std::move(native);
event_ = static_cast<cl_event>(external_.get());
owned_event_ = false;
return success();
}
Result<void> OclEventImpl::Query() {
auto status = event().getInfo<CL_EVENT_COMMAND_EXECUTION_STATUS>();
if (status == CL_COMPLETE) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclEventImpl::Record(Stream& stream) {
auto& queue = Access::get<OclStreamImpl>(stream).queue();
auto status = queue.enqueueMarkerWithWaitList(nullptr, &event());
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
Result<void> OclEventImpl::Wait() {
if (!event_) {
return success();
}
auto status = event().wait();
if (status == CL_SUCCESS) {
return success();
} else {
return Status(eFail);
}
}
void* OclEventImpl::GetNative(ErrorCode* ec) { return event_; }
////////////////////////////////////////////////////////////////////////////////
/// OclPlatformRegisterer
class OclPlatformRegisterer {
public:
OclPlatformRegisterer() {
gPlatformRegistry().Register([] {
Logger::GetInstance().SetLogLevel(spdlog::level::debug);
return std::make_shared<OclPlatformImpl>(cl::Platform::getDefault());
});
}
};
OclPlatformRegisterer g_ocl_platform_registerer;
} // namespace mmdeploy
#endif