// Copyright (c) OpenMMLab. All rights reserved. #if ENABLE_OPENCL && 0 #include "opencl_device.h" #include #include namespace mmdeploy { //////////////////////////////////////////////////////////////////////////////// /// OclPlatformImpl OclPlatformImpl::OclPlatformImpl(cl::Platform platform) : platform_(std::move(platform)) { platform_.getDevices(CL_DEVICE_TYPE_ALL, &devices_); queues_.resize(devices_.size()); for (int i = 0; i < devices_.size(); ++i) { init_flag_.push_back(std::make_unique()); } ctx_ = cl::Context(devices_); } shared_ptr OclPlatformImpl::CreateBuffer(Device device) { return std::make_shared(device); } shared_ptr OclPlatformImpl::CreateStream(Device device) { return std::make_shared(device); } shared_ptr OclPlatformImpl::CreateEvent(Device device) { return std::make_shared(device); } Result OclPlatformImpl::Copy(const void* host_ptr, Buffer dst, size_t size, size_t dst_offset, Stream stream) { if (!dst || !stream) { return Status(eInvalidArgument); } auto device = dst.GetDevice(); if (device.platform_id() != GetPlatformId()) { return Status(eInvalidArgument); } if (stream.GetDevice() != device) { return Status(eInvalidArgument); } auto& queue = Access::get(stream).queue(); auto& to = Access::get(dst).buffer(); auto status = queue.enqueueWriteBuffer(to, false, dst_offset, size, host_ptr); if (status == CL_SUCCESS) { return success(); } else { return Status(eFail); } } Result OclPlatformImpl::Copy(Buffer src, void* host_ptr, size_t size, size_t src_offset, Stream stream) { auto& queue = Access::get(stream).queue(); auto& from = Access::get(src).buffer(); auto status = queue.enqueueReadBuffer(from, false, src_offset, size, host_ptr); if (status) { fprintf(stderr, "status = %d\n", (int)status); } if (status == CL_SUCCESS) { return success(); } else { return Status(eFail); } } Result OclPlatformImpl::Copy(Buffer src, Buffer dst, size_t size, size_t src_offset, size_t dst_offset, Stream stream) { auto& queue = Access::get(stream).queue(); auto& from = Access::get(src).buffer(); auto& to = Access::get(dst).buffer(); auto status = queue.enqueueCopyBuffer(from, to, src_offset, dst_offset, size); if (status) { fprintf(stderr, "status = %d\n", (int)status); } if (status == CL_SUCCESS) { return success(); } else { return Status(eFail); } } Result OclPlatformImpl::GetDefaultStream(int32_t device_id) { if (device_id >= queues_.size()) { return Status(eInvalidArgument); } try { std::call_once(*init_flag_[device_id], [&] { queues_[device_id] = Stream(GetDevice((device_id))); }); return queues_[device_id]; } catch (...) { return Status(eFail); } } OclPlatformImpl& gOclPlatform() { static Platform platform("opencl"); return Access::get(platform); } //////////////////////////////////////////////////////////////////////////////// /// OclBufferImpl OclBufferImpl::OclBufferImpl(Device device) : BufferImpl(device) { memory_ = std::make_shared(); } Result OclBufferImpl::Init(size_t size, Allocator allocator, size_t alignment, uint64_t flags) { auto& ctx = gOclPlatform().GetContext(); OUTCOME_TRY(memory_->Init(ctx, size, alignment, flags)); size_ = size; return success(); } Result OclBufferImpl::Init(size_t size, std::shared_ptr native, uint64_t flags) { OUTCOME_TRY(memory_->Init(size, std::move(native), flags)); size_ = size; return success(); } void* OclBufferImpl::GetNative(ErrorCode* ec) { return memory_->data(); } size_t OclBufferImpl::GetSize(ErrorCode* ec) { return size_; } //////////////////////////////////////////////////////////////////////////////// /// OclStreamImpl OclStreamImpl::OclStreamImpl(Device device) : StreamImpl(device), queue_(), owned_queue_(false) {} OclStreamImpl::~OclStreamImpl() { if (owned_queue_) { detail::Cast(queue_).~CommandQueue(); queue_ = {}; owned_queue_ = false; } external_.reset(); } Result OclStreamImpl::Init(uint64_t flags) { auto& platform = gOclPlatform(); auto& ctx = platform.GetContext(); auto& dev = platform.GetNativeDevice(device_.device_id()); new (&queue_) cl::CommandQueue(ctx, dev); owned_queue_ = true; return success(); } Result OclStreamImpl::Init(std::shared_ptr native, uint64_t flags) { external_ = std::move(native); queue_ = static_cast(external_.get()); owned_queue_ = false; return success(); } Result OclStreamImpl::DependsOn(Event& event) { return Status(eNotSupported); } Result OclStreamImpl::Query() { cl::Event event; queue().enqueueMarkerWithWaitList(nullptr, &event); auto status = event.getInfo(); if (status == CL_COMPLETE) { return success(); } else { return Status(eFail); } } Result OclStreamImpl::Wait() { auto status = queue().finish(); if (status == CL_SUCCESS) { return success(); } else { return Status(eFail); } } Result OclStreamImpl::Submit(Kernel& kernel) { return Status(eNotSupported); } void* OclStreamImpl::GetNative(ErrorCode* ec) { return queue_; } //////////////////////////////////////////////////////////////////////////////// /// OclEventImpl OclEventImpl::OclEventImpl(Device device) : EventImpl(device), event_(), owned_event_() {} OclEventImpl::~OclEventImpl() = default; Result OclEventImpl::Init(uint64_t flags) { return success(); } Result OclEventImpl::Init(std::shared_ptr native, uint64_t flags) { external_ = std::move(native); event_ = static_cast(external_.get()); owned_event_ = false; return success(); } Result OclEventImpl::Query() { auto status = event().getInfo(); if (status == CL_COMPLETE) { return success(); } else { return Status(eFail); } } Result OclEventImpl::Record(Stream& stream) { auto& queue = Access::get(stream).queue(); auto status = queue.enqueueMarkerWithWaitList(nullptr, &event()); if (status == CL_SUCCESS) { return success(); } else { return Status(eFail); } } Result OclEventImpl::Wait() { if (!event_) { return success(); } auto status = event().wait(); if (status == CL_SUCCESS) { return success(); } else { return Status(eFail); } } void* OclEventImpl::GetNative(ErrorCode* ec) { return event_; } //////////////////////////////////////////////////////////////////////////////// /// OclPlatformRegisterer class OclPlatformRegisterer { public: OclPlatformRegisterer() { gPlatformRegistry().Register([] { Logger::GetInstance().SetLogLevel(spdlog::level::debug); return std::make_shared(cl::Platform::getDefault()); }); } }; OclPlatformRegisterer g_ocl_platform_registerer; } // namespace mmdeploy #endif