// Copyright (c) OpenMMLab. All rights reserved. #ifndef MMDEPLOY_DEVICE_OPENCL_DEVICE_H_ #define MMDEPLOY_DEVICE_OPENCL_DEVICE_H_ #include #include "CL/cl.hpp" #include "core/device_impl.h" #include "core/logger.h" namespace mmdeploy { namespace detail { static inline cl::CommandQueue& Cast(cl_command_queue& queue) { return *reinterpret_cast(&queue); } static inline cl::Buffer& Cast(cl_mem& buffer) { return *reinterpret_cast(&buffer); } static inline cl::Event& Cast(cl_event& event) { return *reinterpret_cast(&event); } static inline cl::Context& Cast(cl_context& context) { return *reinterpret_cast(&context); } static inline cl::Platform& Cast(cl_platform_id& platform) { return *reinterpret_cast(&platform); } static inline cl::Device& Cast(cl_device_id& device) { return *reinterpret_cast(&device); } } // namespace detail class OclPlatformImpl : public PlatformImpl { public: explicit OclPlatformImpl(cl::Platform platform); const char* GetPlatformName() const noexcept override { return "opencl"; } shared_ptr CreateBuffer(Device device) override; shared_ptr CreateStream(Device device) override; shared_ptr CreateEvent(Device device) override; Result Copy(const void* host_ptr, Buffer dst, size_t size, size_t dst_offset, Stream stream) override; Result Copy(Buffer src, void* host_ptr, size_t size, size_t src_offset, Stream stream) override; Result Copy(Buffer src, Buffer dst, size_t size, size_t src_offset, size_t dst_offset, Stream stream) override; Result GetDefaultStream(int32_t device_id) override; Device GetDevice(int device_id) { return Device(platform_id_, device_id); } cl::Device& GetNativeDevice(int device_id) { return devices_[device_id]; } cl::Context& GetContext() { return ctx_; } private: cl::Platform platform_; std::vector devices_; std::vector queues_; std::vector> init_flag_; cl::Context ctx_; }; OclPlatformImpl& gOclPlatform(); class OclDeviceMemory { public: OclDeviceMemory() : size_(), data_(), owned_data_(false) {} Result Init(const cl::Context& ctx, size_t size, size_t alignment, uint64_t flags) { if (alignment != 1) { return Status(eNotSupported); } new (&data_) cl::Buffer(ctx, CL_MEM_READ_WRITE, size); owned_data_ = true; size_ = size; return success(); } Result Init(size_t size, shared_ptr data, uint64_t flags) { external_ = std::move(data); data_ = static_cast(external_.get()); size_ = size; return success(); } ~OclDeviceMemory() { if (owned_data_) { detail::Cast(data_).~Buffer(); owned_data_ = false; } size_ = 0; data_ = cl_mem{}; external_.reset(); } size_t size() const { return size_; } cl_mem& data() { return data_; } private: size_t size_; cl_mem data_; bool owned_data_; shared_ptr external_; }; class OclBufferImpl : public BufferImpl { public: explicit OclBufferImpl(Device device); Result Init(size_t size, Allocator allocator, size_t alignment, uint64_t flags) override; Result Init(size_t size, std::shared_ptr native, uint64_t flags) override; Result SubBuffer(size_t offset, size_t size, uint64_t flags) override { return Status(eNotSupported); } void* GetNative(ErrorCode* ec) override; size_t GetSize(ErrorCode* ec) override; cl::Buffer& buffer() { return detail::Cast(memory_->data()); } private: std::shared_ptr memory_; size_t size_{0}; }; class OclStreamImpl : public StreamImpl { public: explicit OclStreamImpl(Device device); ~OclStreamImpl() override; Result Init(uint64_t flags) override; Result Init(std::shared_ptr native, uint64_t flags) override; Result DependsOn(Event& event) override; Result Query() override; Result Wait() override; Result Submit(Kernel& kernel) override; void* GetNative(ErrorCode* ec) override; cl::CommandQueue& queue() { return detail::Cast(queue_); } private: cl_command_queue queue_; bool owned_queue_; std::shared_ptr external_; }; class OclEventImpl : public EventImpl { public: explicit OclEventImpl(Device device); ~OclEventImpl() override; Result Init(uint64_t flags) override; Result Init(std::shared_ptr native, uint64_t flags) override; Result Query() override; Result Record(Stream& stream) override; Result Wait() override; void* GetNative(ErrorCode* ec) override; cl::Event& event() { return detail::Cast(event_); } private: cl_event event_; bool owned_event_; std::shared_ptr external_; }; } // namespace mmdeploy #endif // MMDEPLOY_DEVICE_OPENCL_DEVICE_H_