// Copyright (c) OpenMMLab. All rights reserved.

#include "device_impl.h"

#include <cassert>

#include "core/device.h"
#include "core/logger.h"

namespace mmdeploy {

template <typename T>
T SetError(ErrorCode* ec, ErrorCode code, T ret) {
  if (ec) {
    *ec = code;
  }
  return ret;
}

////////////////////////////////////////////////////////////////////////////////
/// Device

Device::Device(const char* platform_name, int device_id) {
  platform_id_ = gPlatformRegistry().GetPlatformId(platform_name);
  device_id_ = device_id;
}

//////////////////////////////////////////////////
/// Platform

int Platform::GetPlatformId() const {
  if (impl_) {
    return impl_->GetPlatformId();
  }
  return -1;
}

const char* Platform::GetPlatformName() const {
  if (impl_) {
    return impl_->GetPlatformName();
  }
  return "";
}

Platform::Platform(const char* platform_name) {
  if (-1 == gPlatformRegistry().GetPlatform(platform_name, this)) {
    throw_exception(eInvalidArgument);
  }
}

Platform::Platform(int platform_id) {
  if (-1 == gPlatformRegistry().GetPlatform(platform_id, this)) {
    throw_exception(eInvalidArgument);
  }
}

////////////////////////////////////////////////////////////////////////////////
/// Buffer

Buffer::Buffer(Device device, size_t size, Allocator allocator, size_t alignment, uint64_t flags) {
  if (auto p = GetPlatformImpl(device)) {
    impl_ = p->CreateBuffer(device);
    if (auto r = impl_->Init(size, std::move(allocator), alignment, flags); r.has_error()) {
      impl_.reset();
      r.error().throw_exception();
    }
  } else {
    throw_exception(eInvalidArgument);
  }
}

Buffer::Buffer(Device device, size_t size, void* native, uint64_t flags)
    : Buffer(device, size, std::shared_ptr<void>(native, [](void*) {}), flags) {}

Buffer::Buffer(Device device, size_t size, std::shared_ptr<void> native, uint64_t flags) {
  if (auto p = GetPlatformImpl(device)) {
    impl_ = p->CreateBuffer(device);
    if (auto r = impl_->Init(size, std::move(native), flags); r.has_error()) {
      impl_.reset();
      r.error().throw_exception();
    }
  } else {
    throw_exception(eInvalidArgument);
  }
}

Device Buffer::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }

Allocator Buffer::GetAllocator() const { return impl_ ? impl_->GetAllocator() : Allocator{}; }

void* Buffer::GetNative(ErrorCode* ec) const {
  return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}

size_t Buffer::GetSize(ErrorCode* ec) const {
  return impl_ ? impl_->GetSize(ec) : SetError(ec, eInvalidArgument, 0);
}

Buffer::Buffer(Buffer& buffer, size_t offset, size_t size, uint64_t flags) {
  auto impl = buffer.impl_->SubBuffer(offset, size, flags);
  if (!impl) {
    impl.error().throw_exception();
  }
  impl_ = std::move(impl).value();
}

#if 0
int Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
 Stream stream;
 GetDefaultStream(dst.GetDevice(), &stream);
 if (!stream) {
   return Status(eFail);
 }
 return stream.Copy(host_ptr, dst, size, dst_offset);
}
int Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
 Stream stream;
 GetDefaultStream(src.GetDevice(), &stream);
 if (!stream) {
   return Status(eFail);
 }
 return stream.Copy(src, host_ptr, size, src_offset);
}
int Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
        size_t dst_offset) {
 Stream stream;
 GetDefaultStream(src.GetDevice(), &stream);
 if (!stream) {
   return Status(eFail);
 }
 return stream.Copy(src, dst, size, src_offset, dst_offset);
}
#endif

//////////////////////////////////////////////////
/// Stream

Stream::Stream(Device device, uint64_t flags) {
  if (auto p = GetPlatformImpl(device)) {
    auto impl = p->CreateStream(device);
    if (auto r = impl->Init(flags)) {
      impl_ = std::move(impl);
    } else {
      r.error().throw_exception();
    }
  } else {
    ERROR("{}, {}", device.device_id(), device.platform_id());
    throw_exception(eInvalidArgument);
  }
}

Stream::Stream(Device device, void* native, uint64_t flags)
    : Stream(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}

Stream::Stream(Device device, std::shared_ptr<void> native, uint64_t flags) {
  if (auto p = GetPlatformImpl(device)) {
    auto impl = p->CreateStream(device);
    if (auto r = impl->Init(std::move(native), flags)) {
      impl_ = std::move(impl);
    } else {
      r.error().throw_exception();
    }
  } else {
    throw_exception(eInvalidArgument);
  }
}

Result<void> Stream::Query() {
  if (impl_) {
    return impl_->Query();
  }
  return Status(eInvalidArgument);
}

Result<void> Stream::Wait() {
  if (impl_) {
    return impl_->Wait();
  }
  return Status(eInvalidArgument);
}

Result<void> Stream::DependsOn(Event& event) {
  return impl_ ? impl_->DependsOn(event) : Status(eInvalidArgument);
}

void* Stream::GetNative(ErrorCode* ec) {
  return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}

Result<void> Stream::Submit(Kernel& kernel) {
  return impl_ ? impl_->Submit(kernel) : Status(eInvalidArgument);
}

Result<void> Stream::Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
                          size_t dst_offset) {
  if (!impl_) {
    return Status(eInvalidArgument);
  }
  if (size == static_cast<size_t>(-1)) {
    size = src.GetSize();
  }
  if (auto p = GetPlatformImpl(GetDevice())) {
    return p->Copy(src, dst, size, src_offset, dst_offset, *this);
  }
  return Status(eInvalidArgument);
}

Result<void> Stream::Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
  if (!impl_) {
    return Status(eInvalidArgument);
  }
  if (size == static_cast<size_t>(-1)) {
    size = dst.GetSize();
  }
  auto device = GetDevice();
  if (auto p = GetPlatformImpl(device)) {
    return p->Copy(host_ptr, dst, size, dst_offset, *this);
  }
  return Status(eInvalidArgument);
}

Result<void> Stream::Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
  if (!impl_) {
    return Status(eInvalidArgument);
  }
  if (size == static_cast<size_t>(-1)) {
    size = src.GetSize();
  }
  if (auto p = GetPlatformImpl(GetDevice())) {
    return p->Copy(src, host_ptr, size, src_offset, *this);
  }
  return Status(eInvalidArgument);
}

Result<void> Stream::Fill(const Buffer& dst, void* pattern, size_t pattern_size, size_t size,
                          size_t offset) {
  if (!impl_) {
    return Status(eInvalidArgument);
  }
  return Status(eNotSupported);
}

Device Stream::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }

Stream Stream::GetDefault(Device device) {
  Platform platform(device.platform_id());
  assert(platform);
  Stream stream = Access::get<PlatformImpl>(platform).GetDefaultStream(device.device_id()).value();
  return stream;
}

/////////////////////////////////////////////////
/// Event

Event::Event(Device device, uint64_t flags) {
  if (auto p = GetPlatformImpl(device)) {
    auto impl = p->CreateEvent(device);
    if (auto r = impl->Init(flags)) {
      impl_ = std::move(impl);
    } else {
      r.error().throw_exception();
    }
  } else {
    throw_exception(eInvalidArgument);
  }
}

Event::Event(Device device, void* native, uint64_t flags)
    : Event(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}

Event::Event(Device device, std::shared_ptr<void> native, uint64_t flags) {
  if (auto p = GetPlatformImpl(device)) {
    auto impl = p->CreateEvent(device);
    if (auto r = impl->Init(std::move(native), flags)) {
      impl_ = std::move(impl);
    } else {
      r.error().throw_exception();
    }
  } else {
    throw_exception(eInvalidArgument);
  }
}

Result<void> Event::Query() { return impl_ ? impl_->Query() : Status(eInvalidArgument); }

Result<void> Event::Wait() { return impl_ ? impl_->Wait() : Status(eInvalidArgument); }

Result<void> Event::Record(Stream& stream) {
  return impl_ ? impl_->Record(stream) : Status(eInvalidArgument);
}

void* Event::GetNative(ErrorCode* ec) {
  return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}

Device Event::GetDevice() { return impl_ ? impl_->GetDevice() : Device{}; }

/////////////////////////////////////////////////
/// Kernel

Device Kernel::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }

void* Kernel::GetNative(ErrorCode* ec) {
  return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}

/////////////////////////////////////////////////
/// PlatformRegistry

int PlatformRegistry::Register(Creator creator) {
  Platform platform(creator());
  auto proposed_id = platform.GetPlatformId();
  std::string name = platform.GetPlatformName();
  if (proposed_id == -1) {
    proposed_id = GetNextId();
    platform.impl_->SetPlatformId(proposed_id);
  } else if (!IsAvailable(proposed_id)) {
    return -1;
  }
  entries_.push_back({name, proposed_id, platform});
  return 0;
}

int PlatformRegistry::GetNextId() {
  for (int i = 1;; ++i) {
    if (IsAvailable(i)) {
      return i;
    }
  }
}

bool PlatformRegistry::IsAvailable(int id) {
  for (const auto& entry : entries_) {
    if (entry.id == id) {
      return false;
    }
  }
  return true;
}

int PlatformRegistry::GetPlatform(const char* name, Platform* platform) {
  for (const auto& entry : entries_) {
    if (entry.name == name) {
      *platform = entry.platform;
      return 0;
    }
  }
  return -1;
}

int PlatformRegistry::GetPlatform(int id, Platform* platform) {
  for (const auto& entry : entries_) {
    if (entry.id == id) {
      *platform = entry.platform;
      return 0;
    }
  }
  return -1;
}
int PlatformRegistry::GetPlatformId(const char* name) {
  for (const auto& entry : entries_) {
    if (entry.name == name) {
      return entry.id;
    }
  }
  return -1;
}

PlatformImpl* PlatformRegistry::GetPlatformImpl(PlatformId id) {
  for (const auto& entry : entries_) {
    if (entry.id == id) {
      return entry.platform.impl_.get();
    }
  }
  return nullptr;
}

PlatformRegistry& gPlatformRegistry() {
  static PlatformRegistry instance;
  return instance;
}

}  // namespace mmdeploy