// Copyright (c) OpenMMLab. All rights reserved. // Modified from // https://github.com/brycelelbach/wg21_p2300_std_execution/blob/main/examples/schedulers/static_thread_pool.hpp #ifndef MMDEPLOY_CSRC_EXPERIMENTAL_EXECUTION_STATIC_THREAD_POOL_H_ #define MMDEPLOY_CSRC_EXPERIMENTAL_EXECUTION_STATIC_THREAD_POOL_H_ #include #include #include #include #include #include #include "execution/execution.h" #include "intrusive_queue.h" namespace mmdeploy { namespace __static_thread_pool { struct TaskBase { TaskBase* next_; void (*execute_)(TaskBase*) noexcept; }; template struct _Operation { struct type; }; template using operation_t = typename _Operation>::type; class StaticThreadPool; struct Scheduler { template friend struct _Operation; struct Sender { using value_types = std::tuple<>; template operation_t MakeOperation(Receiver&& r) const { return {pool_, (Receiver &&) r}; } template friend operation_t tag_invoke(connect_t, Sender s, Receiver&& r) { return s.MakeOperation((Receiver &&) r); } friend auto tag_invoke(get_completion_scheduler_t, const Sender& sender) noexcept -> Scheduler { return Scheduler{sender.pool_}; } friend struct Scheduler; explicit Sender(StaticThreadPool& pool) noexcept : pool_(pool) {} StaticThreadPool& pool_; }; Sender MakeSender_() const { return Sender{*pool_}; } friend class StaticThreadPool; public: explicit Scheduler(StaticThreadPool& pool) noexcept : pool_(&pool) {} friend bool operator==(Scheduler a, Scheduler b) noexcept { return a.pool_ == b.pool_; } friend bool operator!=(Scheduler a, Scheduler b) noexcept { return a.pool_ != b.pool_; } friend Sender tag_invoke(schedule_t, const Scheduler& self) noexcept { return self.MakeSender_(); } private: StaticThreadPool* pool_{nullptr}; }; class StaticThreadPool { template friend struct _Operation; public: StaticThreadPool(); explicit StaticThreadPool(std::uint32_t thread_count); ~StaticThreadPool(); Scheduler GetScheduler() noexcept { return Scheduler{*this}; } void RequestStop() noexcept; private: class ThreadState { public: TaskBase* try_pop(); TaskBase* pop(); bool try_push(TaskBase* task); void push(TaskBase* task); void request_stop(); private: std::mutex mutex_; std::condition_variable cv_; intrusive_queue<&TaskBase::next_> queue_; bool stop_requested_{false}; }; void Run(std::uint32_t index) noexcept; void Join() noexcept; void Enqueue(TaskBase* task) noexcept; std::uint32_t thread_count_; std::vector threads_; std::vector thread_states_; std::atomic next_thread_; }; template struct _Operation::type : TaskBase { friend Scheduler::Sender; StaticThreadPool& pool_; Receiver receiver_; type(StaticThreadPool& pool, Receiver&& r) : TaskBase{}, pool_(pool), receiver_((Receiver &&) r) { this->execute_ = [](TaskBase* t) noexcept { auto& op = *static_cast(t); SetValue((Receiver &&) op.receiver_); }; } void enqueue_(TaskBase* op) const { return pool_.Enqueue(op); } friend void tag_invoke(start_t, type& op) noexcept { op.enqueue_(&op); } }; inline StaticThreadPool::StaticThreadPool() : StaticThreadPool(std::thread::hardware_concurrency()) {} inline StaticThreadPool::StaticThreadPool(std::uint32_t thread_count) : thread_count_(thread_count), thread_states_(thread_count), next_thread_(0) { assert(thread_count_ > 0); threads_.reserve(thread_count_); try { for (std::uint32_t i = 0; i < thread_count_; ++i) { threads_.emplace_back([this, i] { Run(i); }); } } catch (...) { RequestStop(); Join(); throw; } } inline StaticThreadPool::~StaticThreadPool() { RequestStop(); Join(); } inline void StaticThreadPool::RequestStop() noexcept { for (auto& state : thread_states_) { state.request_stop(); } } inline void StaticThreadPool::Run(std::uint32_t index) noexcept { while (true) { TaskBase* task = nullptr; for (std::uint32_t i = 0; i < thread_count_; ++i) { auto queue_index = (index + i) < thread_count_ ? (index + i) : (index + i - thread_count_); auto& state = thread_states_[queue_index]; task = state.try_pop(); if (task != nullptr) { break; } } if (task == nullptr) { task = thread_states_[index].pop(); if (task == nullptr) { return; } } task->execute_(task); } } inline void StaticThreadPool::Join() noexcept { for (auto& t : threads_) { t.join(); } threads_.clear(); } inline void StaticThreadPool::Enqueue(TaskBase* task) noexcept { const auto thread_count = static_cast(threads_.size()); const std::uint32_t start_index = next_thread_.fetch_add(1, std::memory_order_relaxed) % thread_count; for (std::uint32_t i = 0; i < thread_count; ++i) { const auto index = (start_index + i) < thread_count ? (start_index + i) : (start_index + i - thread_count); if (thread_states_[index].try_push(task)) { return; } } thread_states_[start_index].push(task); } inline TaskBase* StaticThreadPool::ThreadState::try_pop() { std::unique_lock lock{mutex_, std::try_to_lock}; if (!lock || queue_.empty()) { return nullptr; } return queue_.pop_front(); } inline TaskBase* StaticThreadPool::ThreadState::pop() { std::unique_lock lock{mutex_}; while (queue_.empty()) { if (stop_requested_) { return nullptr; } cv_.wait(lock); } return queue_.pop_front(); } inline bool StaticThreadPool::ThreadState::try_push(TaskBase* task) { bool was_empty{}; { std::unique_lock lock{mutex_, std::try_to_lock}; if (!lock) { return false; } was_empty = queue_.empty(); queue_.push_back(task); } if (was_empty) { cv_.notify_one(); } return true; } inline void StaticThreadPool::ThreadState::push(TaskBase* task) { bool was_empty{}; { std::lock_guard lock{mutex_}; was_empty = queue_.empty(); queue_.push_back(task); } if (was_empty) { cv_.notify_one(); } } inline void StaticThreadPool::ThreadState::request_stop() { { std::lock_guard lock{mutex_}; stop_requested_ = true; } cv_.notify_one(); } namespace __bulk { template struct _Operation { struct type; }; template using operation_t = typename _Operation::type; template struct _Receiver { struct type; }; template using receiver_t = typename _Receiver, Shape, Func, Tuple>::type; template struct _Receiver::type { struct State { Receiver receiver_; Shape shape_; Func func_; std::optional values_; Scheduler scheduler_; std::atomic count_; }; std::shared_ptr state_; type(Receiver&& receiver, Shape shape, Func func, Scheduler scheduler) : state_(new State{(Receiver &&) receiver, shape, (Func &&) func, std::nullopt, scheduler, shape}) {} template friend void tag_invoke(set_value_t, type&& self, As&&... as) noexcept { auto& state = self.state_; state->values_.emplace((As &&) as...); for (Shape index = {}; index < state->shape_; ++index) { StartDetached(Then(Schedule(state->scheduler_), [state, index] { std::apply([&](auto&... vals) { state->func_(index, vals...); }, state->values_.value()); if (0 == --state->count_) { std::apply( [&](auto&... vals) { SetValue(std::move(state->receiver_), std::move(vals)...); }, state->values_.value()); } return 0; })); } } }; template struct _Sender { struct type; }; template using sender_t = typename _Sender, remove_cvref_t, Func>::type; template struct _Sender::type { using value_types = completion_signatures_of_t; template using _receiver_t = receiver_t; Sender sender_; Scheduler scheduler_; Shape shape_; Func func_; template = 0> friend auto tag_invoke(connect_t, Self&& self, Receiver&& receiver) { return Connect(((Self &&) self).sender_, _receiver_t{(Receiver &&) receiver, ((Self &&) self).shape_, ((Self &&) self).func_, ((Self &&) self).scheduler_}); } }; } // namespace __bulk template __bulk::sender_t tag_invoke(bulk_t, Scheduler scheduler, Sender&& sender, Shape&& shape, Func&& func) { return {(Sender &&) sender, scheduler, (Shape &&) shape, (Func &&) func}; } } // namespace __static_thread_pool using __static_thread_pool::StaticThreadPool; } // namespace mmdeploy #endif // MMDEPLOY_CSRC_EXPERIMENTAL_EXECUTION_STATIC_THREAD_POOL_H_