343 lines
8.5 KiB
C++
343 lines
8.5 KiB
C++
// Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include <cstdint>
|
|
#include <cstdlib>
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <optional>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "core/macro.h"
|
|
#include "core/status_code.h"
|
|
|
|
namespace mmdeploy {
|
|
|
|
class Platform;
|
|
class Device;
|
|
class Stream;
|
|
class Event;
|
|
class Allocator;
|
|
class Buffer;
|
|
class Kernel;
|
|
|
|
class PlatformImpl;
|
|
class StreamImpl;
|
|
class EventImpl;
|
|
class AllocatorImpl;
|
|
class BufferImpl;
|
|
class KernelImpl;
|
|
|
|
template <typename T>
|
|
using optional = std::optional<T>;
|
|
|
|
class DeviceId {
|
|
public:
|
|
using ValueType = int32_t;
|
|
constexpr explicit DeviceId(ValueType value) : value_(value) {}
|
|
constexpr operator ValueType() const { return value_; } // NOLINT
|
|
constexpr ValueType get() const { return value_; }
|
|
|
|
private:
|
|
ValueType value_;
|
|
};
|
|
|
|
class PlatformId {
|
|
public:
|
|
using ValueType = int32_t;
|
|
constexpr explicit PlatformId(ValueType value) : value_(value) {}
|
|
constexpr operator ValueType() const { return value_; } // NOLINT
|
|
constexpr ValueType get() const { return value_; }
|
|
|
|
private:
|
|
ValueType value_;
|
|
};
|
|
|
|
class Device {
|
|
public:
|
|
constexpr Device() : platform_id_(-1), device_id_(-1) {}
|
|
|
|
constexpr explicit Device(DeviceId device_id, PlatformId platform_id = PlatformId(-1))
|
|
: Device(platform_id.get(), device_id.get()) {}
|
|
|
|
constexpr explicit Device(PlatformId platform_id, DeviceId device_id = DeviceId(-1))
|
|
: Device(platform_id.get(), device_id.get()) {}
|
|
|
|
constexpr explicit Device(int platform_id, int device_id = 0)
|
|
: platform_id_(platform_id), device_id_(device_id) {}
|
|
|
|
MMDEPLOY_API explicit Device(const char *platform_name, int device_id = 0);
|
|
|
|
constexpr int device_id() const noexcept { return device_id_; }
|
|
|
|
constexpr int platform_id() const noexcept { return platform_id_; }
|
|
|
|
constexpr bool is_host() const noexcept { return platform_id() == 0; }
|
|
|
|
constexpr bool is_device() const noexcept { return platform_id() > 0; }
|
|
|
|
constexpr bool operator==(const Device &other) const noexcept {
|
|
return platform_id_ == other.platform_id_ && device_id_ == other.device_id_;
|
|
}
|
|
|
|
constexpr bool operator!=(const Device &other) const noexcept { return !(*this == other); }
|
|
|
|
constexpr explicit operator bool() const noexcept { return platform_id_ >= 0 && device_id_ >= 0; }
|
|
|
|
constexpr operator DeviceId() const noexcept { // NOLINT
|
|
return DeviceId(device_id_);
|
|
}
|
|
|
|
constexpr operator PlatformId() const noexcept { // NOLINT
|
|
return PlatformId(platform_id_);
|
|
}
|
|
|
|
private:
|
|
int platform_id_{0};
|
|
int device_id_{0};
|
|
};
|
|
|
|
enum class MemcpyKind : int { HtoD, DtoH, DtoD };
|
|
|
|
class MMDEPLOY_API Platform {
|
|
public:
|
|
// throws if not found
|
|
explicit Platform(const char *platform_name);
|
|
|
|
// throws if not found
|
|
explicit Platform(int platform_id);
|
|
|
|
// -1 if invalid
|
|
int GetPlatformId() const;
|
|
|
|
// "" if invalid
|
|
const char *GetPlatformName() const;
|
|
|
|
bool operator==(const Platform &other) { return impl_ == other.impl_; }
|
|
|
|
bool operator!=(const Platform &other) { return !(*this == other); }
|
|
|
|
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
|
|
|
|
private:
|
|
explicit Platform(std::shared_ptr<PlatformImpl> impl) : impl_(std::move(impl)) {}
|
|
|
|
private:
|
|
friend class PlatformRegistry;
|
|
friend class Access;
|
|
std::shared_ptr<PlatformImpl> impl_;
|
|
};
|
|
|
|
Platform GetPlatform(int platform_id);
|
|
|
|
Platform GetPlatform(const char *platform_name);
|
|
|
|
class MMDEPLOY_API Stream {
|
|
public:
|
|
Stream() = default;
|
|
|
|
explicit Stream(Device device, uint64_t flags = 0);
|
|
|
|
explicit Stream(Device device, void *native, uint64_t flags = 0);
|
|
|
|
explicit Stream(Device device, std::shared_ptr<void> native, uint64_t flags = 0);
|
|
|
|
Device GetDevice() const;
|
|
|
|
Result<void> Query();
|
|
|
|
Result<void> Wait();
|
|
|
|
Result<void> DependsOn(Event &event);
|
|
|
|
Result<void> Submit(Kernel &kernel);
|
|
|
|
void *GetNative(ErrorCode *ec = nullptr);
|
|
|
|
Result<void> Copy(const Buffer &src, Buffer &dst, size_t size = -1, size_t src_offset = 0,
|
|
size_t dst_offset = 0);
|
|
|
|
Result<void> Copy(const void *host_ptr, Buffer &dst, size_t size = -1, size_t dst_offset = 0);
|
|
|
|
Result<void> Copy(const Buffer &src, void *host_ptr, size_t size = -1, size_t src_offset = 0);
|
|
|
|
Result<void> Fill(const Buffer &dst, void *pattern, size_t pattern_size, size_t size = -1,
|
|
size_t offset = 0);
|
|
|
|
bool operator==(const Stream &other) const { return impl_ == other.impl_; }
|
|
|
|
bool operator!=(const Stream &other) const { return !(*this == other); }
|
|
|
|
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
|
|
|
|
static Stream GetDefault(Device device);
|
|
|
|
private:
|
|
explicit Stream(std::shared_ptr<StreamImpl> impl) : impl_(std::move(impl)) {}
|
|
|
|
private:
|
|
friend class Access;
|
|
|
|
std::shared_ptr<StreamImpl> impl_;
|
|
};
|
|
|
|
template <typename T>
|
|
T GetNative(Stream &stream, ErrorCode *ec = nullptr) {
|
|
return reinterpret_cast<T>(stream.GetNative(ec));
|
|
}
|
|
|
|
class MMDEPLOY_API Event {
|
|
public:
|
|
Event() = default;
|
|
|
|
explicit Event(Device device, uint64_t flags = 0);
|
|
|
|
explicit Event(Device device, void *native, uint64_t flags = 0);
|
|
|
|
explicit Event(Device device, std::shared_ptr<void> native, uint64_t flags = 0);
|
|
|
|
Device GetDevice();
|
|
|
|
Result<void> Query();
|
|
|
|
Result<void> Wait();
|
|
|
|
Result<void> Record(Stream &stream);
|
|
|
|
void *GetNative(ErrorCode *ec = nullptr);
|
|
|
|
bool operator==(const Event &other) const { return impl_ == other.impl_; }
|
|
|
|
bool operator!=(const Event &other) const { return !(*this == other); }
|
|
|
|
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
|
|
|
|
private:
|
|
explicit Event(std::shared_ptr<EventImpl> impl) : impl_(std::move(impl)) {}
|
|
|
|
private:
|
|
friend class Access;
|
|
std::shared_ptr<EventImpl> impl_;
|
|
};
|
|
|
|
template <typename T>
|
|
T GetNative(Event &event, ErrorCode *ec = nullptr) {
|
|
return reinterpret_cast<T>(event.GetNative(ec));
|
|
}
|
|
|
|
class MMDEPLOY_API Kernel {
|
|
public:
|
|
Kernel() = default;
|
|
explicit Kernel(std::shared_ptr<KernelImpl> impl) : impl_(std::move(impl)) {}
|
|
|
|
Device GetDevice() const;
|
|
|
|
void *GetNative(ErrorCode *ec = nullptr);
|
|
|
|
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
|
|
|
|
private:
|
|
std::shared_ptr<KernelImpl> impl_;
|
|
};
|
|
|
|
template <typename T>
|
|
T GetNative(Kernel &kernel, ErrorCode *ec = nullptr) {
|
|
return reinterpret_cast<T>(kernel.GetNative(ec));
|
|
}
|
|
|
|
class MMDEPLOY_API Allocator {
|
|
friend class Access;
|
|
|
|
public:
|
|
Allocator() = default;
|
|
|
|
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
|
|
|
|
private:
|
|
explicit Allocator(std::shared_ptr<AllocatorImpl> impl) : impl_(std::move(impl)) {}
|
|
std::shared_ptr<AllocatorImpl> impl_;
|
|
};
|
|
|
|
class MMDEPLOY_API Buffer {
|
|
public:
|
|
Buffer() = default;
|
|
|
|
Buffer(Device device, size_t size, size_t alignment = 1, uint64_t flags = 0)
|
|
: Buffer(device, size, Allocator{}, alignment, flags) {}
|
|
|
|
Buffer(Device device, size_t size, Allocator allocator, size_t alignment = 1, uint64_t flags = 0);
|
|
|
|
Buffer(Device device, size_t size, void *native, uint64_t flags = 0);
|
|
|
|
Buffer(Device device, size_t size, std::shared_ptr<void> native, uint64_t flags = 0);
|
|
// create sub-buffer
|
|
Buffer(Buffer &buffer, size_t offset, size_t size, uint64_t flags = 0);
|
|
|
|
size_t GetSize(ErrorCode *ec = nullptr) const;
|
|
|
|
// bool IsSubBuffer(ErrorCode *ec = nullptr);
|
|
|
|
void *GetNative(ErrorCode *ec = nullptr) const;
|
|
|
|
Device GetDevice() const;
|
|
|
|
Allocator GetAllocator() const;
|
|
|
|
bool operator==(const Buffer &other) const { return impl_ == other.impl_; }
|
|
|
|
bool operator!=(const Buffer &other) const { return !(*this == other); }
|
|
|
|
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
|
|
|
|
private:
|
|
explicit Buffer(std::shared_ptr<BufferImpl> impl) : impl_(std::move(impl)) {}
|
|
|
|
private:
|
|
friend class Access;
|
|
std::shared_ptr<BufferImpl> impl_;
|
|
};
|
|
|
|
template <typename T>
|
|
T GetNative(Buffer &buffer, ErrorCode *ec = nullptr) {
|
|
return reinterpret_cast<T>(buffer.GetNative(ec));
|
|
}
|
|
|
|
template <typename T>
|
|
T GetNative(const Buffer &buffer, ErrorCode *ec = nullptr) {
|
|
return reinterpret_cast<T>(buffer.GetNative(ec));
|
|
}
|
|
|
|
class MMDEPLOY_API PlatformRegistry {
|
|
public:
|
|
using Creator = std::function<std::shared_ptr<PlatformImpl>()>;
|
|
|
|
int Register(Creator creator);
|
|
|
|
int GetPlatform(const char *name, Platform *platform);
|
|
|
|
int GetPlatform(int id, Platform *platform);
|
|
|
|
int GetPlatformId(const char *name);
|
|
|
|
PlatformImpl *GetPlatformImpl(PlatformId id);
|
|
|
|
private:
|
|
int GetNextId();
|
|
|
|
bool IsAvailable(int id);
|
|
|
|
private:
|
|
struct Entry {
|
|
std::string name;
|
|
int id;
|
|
Platform platform;
|
|
};
|
|
std::vector<Entry> entries_;
|
|
};
|
|
|
|
MMDEPLOY_API PlatformRegistry &gPlatformRegistry();
|
|
|
|
} // namespace mmdeploy
|