mmdeploy/csrc/device/opencl/opencl_device.h

194 lines
4.9 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_DEVICE_OPENCL_DEVICE_H_
#define MMDEPLOY_DEVICE_OPENCL_DEVICE_H_
#include <mutex>
#include "CL/cl.hpp"
#include "core/device_impl.h"
#include "core/logger.h"
namespace mmdeploy {
namespace detail {
static inline cl::CommandQueue& Cast(cl_command_queue& queue) {
return *reinterpret_cast<cl::CommandQueue*>(&queue);
}
static inline cl::Buffer& Cast(cl_mem& buffer) { return *reinterpret_cast<cl::Buffer*>(&buffer); }
static inline cl::Event& Cast(cl_event& event) { return *reinterpret_cast<cl::Event*>(&event); }
static inline cl::Context& Cast(cl_context& context) {
return *reinterpret_cast<cl::Context*>(&context);
}
static inline cl::Platform& Cast(cl_platform_id& platform) {
return *reinterpret_cast<cl::Platform*>(&platform);
}
static inline cl::Device& Cast(cl_device_id& device) {
return *reinterpret_cast<cl::Device*>(&device);
}
} // namespace detail
class OclPlatformImpl : public PlatformImpl {
public:
explicit OclPlatformImpl(cl::Platform platform);
const char* GetPlatformName() const noexcept override { return "opencl"; }
shared_ptr<BufferImpl> CreateBuffer(Device device) override;
shared_ptr<StreamImpl> CreateStream(Device device) override;
shared_ptr<EventImpl> CreateEvent(Device device) override;
Result<void> Copy(const void* host_ptr, Buffer dst, size_t size, size_t dst_offset,
Stream stream) override;
Result<void> Copy(Buffer src, void* host_ptr, size_t size, size_t src_offset,
Stream stream) override;
Result<void> Copy(Buffer src, Buffer dst, size_t size, size_t src_offset, size_t dst_offset,
Stream stream) override;
Result<Stream> GetDefaultStream(int32_t device_id) override;
Device GetDevice(int device_id) { return Device(platform_id_, device_id); }
cl::Device& GetNativeDevice(int device_id) { return devices_[device_id]; }
cl::Context& GetContext() { return ctx_; }
private:
cl::Platform platform_;
std::vector<cl::Device> devices_;
std::vector<Stream> queues_;
std::vector<std::unique_ptr<std::once_flag>> init_flag_;
cl::Context ctx_;
};
OclPlatformImpl& gOclPlatform();
class OclDeviceMemory {
public:
OclDeviceMemory() : size_(), data_(), owned_data_(false) {}
Result<void> Init(const cl::Context& ctx, size_t size, size_t alignment, uint64_t flags) {
if (alignment != 1) {
return Status(eNotSupported);
}
new (&data_) cl::Buffer(ctx, CL_MEM_READ_WRITE, size);
owned_data_ = true;
size_ = size;
return success();
}
Result<void> Init(size_t size, shared_ptr<void> data, uint64_t flags) {
external_ = std::move(data);
data_ = static_cast<cl_mem>(external_.get());
size_ = size;
return success();
}
~OclDeviceMemory() {
if (owned_data_) {
detail::Cast(data_).~Buffer();
owned_data_ = false;
}
size_ = 0;
data_ = cl_mem{};
external_.reset();
}
size_t size() const { return size_; }
cl_mem& data() { return data_; }
private:
size_t size_;
cl_mem data_;
bool owned_data_;
shared_ptr<void> external_;
};
class OclBufferImpl : public BufferImpl {
public:
explicit OclBufferImpl(Device device);
Result<void> Init(size_t size, Allocator allocator, size_t alignment, uint64_t flags) override;
Result<void> Init(size_t size, std::shared_ptr<void> native, uint64_t flags) override;
Result<BufferImplPtr> SubBuffer(size_t offset, size_t size, uint64_t flags) override {
return Status(eNotSupported);
}
void* GetNative(ErrorCode* ec) override;
size_t GetSize(ErrorCode* ec) override;
cl::Buffer& buffer() { return detail::Cast(memory_->data()); }
private:
std::shared_ptr<OclDeviceMemory> memory_;
size_t size_{0};
};
class OclStreamImpl : public StreamImpl {
public:
explicit OclStreamImpl(Device device);
~OclStreamImpl() override;
Result<void> Init(uint64_t flags) override;
Result<void> Init(std::shared_ptr<void> native, uint64_t flags) override;
Result<void> DependsOn(Event& event) override;
Result<void> Query() override;
Result<void> Wait() override;
Result<void> Submit(Kernel& kernel) override;
void* GetNative(ErrorCode* ec) override;
cl::CommandQueue& queue() { return detail::Cast(queue_); }
private:
cl_command_queue queue_;
bool owned_queue_;
std::shared_ptr<void> external_;
};
class OclEventImpl : public EventImpl {
public:
explicit OclEventImpl(Device device);
~OclEventImpl() override;
Result<void> Init(uint64_t flags) override;
Result<void> Init(std::shared_ptr<void> native, uint64_t flags) override;
Result<void> Query() override;
Result<void> Record(Stream& stream) override;
Result<void> Wait() override;
void* GetNative(ErrorCode* ec) override;
cl::Event& event() { return detail::Cast(event_); }
private:
cl_event event_;
bool owned_event_;
std::shared_ptr<void> external_;
};
} // namespace mmdeploy
#endif // MMDEPLOY_DEVICE_OPENCL_DEVICE_H_