mmdeploy/csrc/execution/schedulers/static_thread_pool.h
lzhangzz 46bfe0ac87
[Feature] New pipeline & executor for SDK (#497)
* executor prototype

* add split/when_all

* fix GCC build

* WIP let_value

* fix let_value

* WIP ensure_started

* ensure_started & start_detached

* fix let_value + when_all combo on MSVC 142

* fix static thread pool

* generic just, then, let_value, sync_wait

* minor

* generic split and when_all

* fully generic sender adapters

* when_all: workaround for GCC7

* support legacy spdlog

* fix memleak

* bulk

* static detector

* fix bulk & first pipeline

* bulk for static thread pools

* fix on MSVC

* WIP async batch submission

* WIP collation

* async batch

* fix detector

* fix async detector

* fix

* fix

* debug

* fix cuda allocator

* WIP type erased executor

* better type erasure

* simplify C API impl

* Expand & type erase TC

* deduction guide for type erased senders

* fix GCC build

* when_all for arrays of Value senders

* WIP pipeline v2

* WIP pipeline parser

* WIP timed batch operation

* add registry

* experiment

* fix pipeline

* naming

* fix mem-leak

* fix deferred batch operation

* WIP

* WIP configurable scheduler

* WIP configurable scheduler

* add comment

* parse scheduler config

* force link schedulers

* WIP pipeable sender

* WIP CPO

* ADL isolation and dismantle headers

* type erase single thread context

* fix MSVC build

* CPO

* replace decay_t with remove_cvref_t

* structure adjustment

* structure adjustment

* apply CPOs & C API rework

* refine C API

* detector async C API

* adjust detector async C API

* # Conflicts:
#	csrc/apis/c/detector.cpp

* fix when_all for type erased senders

* support void return for Then

* async detector

* fix some CPOs

* minor

* WIP rework capture mechanism for type erased types

* minor fix

* fix MSVC build

* move expand.h to execution

* make `Expand` pipeable

* fix type erased

* un-templatize `_TypeErasedOperation`

* re-work C API

* remove async_detector C API

* fix pipeline

* add flatten & unflatten

* fix flatten & unflatten

* add aync OCR demo

* config executor for nodes & better executor API

* working async OCR example

* minor

* dynamic batch via scheduler

* dynamic batch on `Value`

* fix MSVC build

* type erase dynamic batch scheduler

* sender as Python Awaitable

* naming

* naming

* add docs

* minor

* merge tmp branch

* unify C APIs

* fix ocr

* unify APIs

* fix typo

* update async OCR demo

* add v3 API text recognizer

* fix v3 API

* fix lint

* add license info & reformat

* add demo async_ocr_v2

* revert files

* revert files

* resolve link issues

* fix scheduler linkage for shared libs

* fix license header

* add docs for `mmdeploy_executor_split`

* add missing `mmdeploy_executor_transfer_just` and `mmdeploy_executor_execute`

* make `TimedSingleThreadContext` header only

* fix lint

* simplify type-erased sender
2022-06-01 14:10:43 +08:00

362 lines
9.6 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/brycelelbach/wg21_p2300_std_execution/blob/main/examples/schedulers/static_thread_pool.hpp
#ifndef MMDEPLOY_CSRC_EXPERIMENTAL_EXECUTION_STATIC_THREAD_POOL_H_
#define MMDEPLOY_CSRC_EXPERIMENTAL_EXECUTION_STATIC_THREAD_POOL_H_
#include <atomic>
#include <condition_variable>
#include <mutex>
#include <thread>
#include <type_traits>
#include <vector>
#include "execution/execution.h"
#include "intrusive_queue.h"
namespace mmdeploy {
namespace __static_thread_pool {
struct TaskBase {
TaskBase* next_;
void (*execute_)(TaskBase*) noexcept;
};
template <typename Receiver>
struct _Operation {
struct type;
};
template <typename Receiver>
using operation_t = typename _Operation<remove_cvref_t<Receiver>>::type;
class StaticThreadPool;
struct Scheduler {
template <typename Receiver>
friend struct _Operation;
struct Sender {
using value_types = std::tuple<>;
template <typename Receiver>
operation_t<Receiver> MakeOperation(Receiver&& r) const {
return {pool_, (Receiver &&) r};
}
template <typename Receiver>
friend operation_t<Receiver> tag_invoke(connect_t, Sender s, Receiver&& r) {
return s.MakeOperation((Receiver &&) r);
}
friend auto tag_invoke(get_completion_scheduler_t, const Sender& sender) noexcept -> Scheduler {
return Scheduler{sender.pool_};
}
friend struct Scheduler;
explicit Sender(StaticThreadPool& pool) noexcept : pool_(pool) {}
StaticThreadPool& pool_;
};
Sender MakeSender_() const { return Sender{*pool_}; }
friend class StaticThreadPool;
public:
explicit Scheduler(StaticThreadPool& pool) noexcept : pool_(&pool) {}
friend bool operator==(Scheduler a, Scheduler b) noexcept { return a.pool_ == b.pool_; }
friend bool operator!=(Scheduler a, Scheduler b) noexcept { return a.pool_ != b.pool_; }
friend Sender tag_invoke(schedule_t, const Scheduler& self) noexcept {
return self.MakeSender_();
}
private:
StaticThreadPool* pool_{nullptr};
};
class StaticThreadPool {
template <typename Receiver>
friend struct _Operation;
public:
StaticThreadPool();
explicit StaticThreadPool(std::uint32_t thread_count);
~StaticThreadPool();
Scheduler GetScheduler() noexcept { return Scheduler{*this}; }
void RequestStop() noexcept;
private:
class ThreadState {
public:
TaskBase* try_pop();
TaskBase* pop();
bool try_push(TaskBase* task);
void push(TaskBase* task);
void request_stop();
private:
std::mutex mutex_;
std::condition_variable cv_;
intrusive_queue<&TaskBase::next_> queue_;
bool stop_requested_{false};
};
void Run(std::uint32_t index) noexcept;
void Join() noexcept;
void Enqueue(TaskBase* task) noexcept;
std::uint32_t thread_count_;
std::vector<std::thread> threads_;
std::vector<ThreadState> thread_states_;
std::atomic<std::uint32_t> next_thread_;
};
template <typename Receiver>
struct _Operation<Receiver>::type : TaskBase {
friend Scheduler::Sender;
StaticThreadPool& pool_;
Receiver receiver_;
type(StaticThreadPool& pool, Receiver&& r) : TaskBase{}, pool_(pool), receiver_((Receiver &&) r) {
this->execute_ = [](TaskBase* t) noexcept {
auto& op = *static_cast<type*>(t);
SetValue((Receiver &&) op.receiver_);
};
}
void enqueue_(TaskBase* op) const { return pool_.Enqueue(op); }
friend void tag_invoke(start_t, type& op) noexcept { op.enqueue_(&op); }
};
inline StaticThreadPool::StaticThreadPool()
: StaticThreadPool(std::thread::hardware_concurrency()) {}
inline StaticThreadPool::StaticThreadPool(std::uint32_t thread_count)
: thread_count_(thread_count), thread_states_(thread_count), next_thread_(0) {
assert(thread_count_ > 0);
threads_.reserve(thread_count_);
try {
for (std::uint32_t i = 0; i < thread_count_; ++i) {
threads_.emplace_back([this, i] { Run(i); });
}
} catch (...) {
RequestStop();
Join();
throw;
}
}
inline StaticThreadPool::~StaticThreadPool() {
RequestStop();
Join();
}
inline void StaticThreadPool::RequestStop() noexcept {
for (auto& state : thread_states_) {
state.request_stop();
}
}
inline void StaticThreadPool::Run(std::uint32_t index) noexcept {
while (true) {
TaskBase* task = nullptr;
for (std::uint32_t i = 0; i < thread_count_; ++i) {
auto queue_index = (index + i) < thread_count_ ? (index + i) : (index + i - thread_count_);
auto& state = thread_states_[queue_index];
task = state.try_pop();
if (task != nullptr) {
break;
}
}
if (task == nullptr) {
task = thread_states_[index].pop();
if (task == nullptr) {
return;
}
}
task->execute_(task);
}
}
inline void StaticThreadPool::Join() noexcept {
for (auto& t : threads_) {
t.join();
}
threads_.clear();
}
inline void StaticThreadPool::Enqueue(TaskBase* task) noexcept {
const auto thread_count = static_cast<std::uint32_t>(threads_.size());
const std::uint32_t start_index =
next_thread_.fetch_add(1, std::memory_order_relaxed) % thread_count;
for (std::uint32_t i = 0; i < thread_count; ++i) {
const auto index =
(start_index + i) < thread_count ? (start_index + i) : (start_index + i - thread_count);
if (thread_states_[index].try_push(task)) {
return;
}
}
thread_states_[start_index].push(task);
}
inline TaskBase* StaticThreadPool::ThreadState::try_pop() {
std::unique_lock lock{mutex_, std::try_to_lock};
if (!lock || queue_.empty()) {
return nullptr;
}
return queue_.pop_front();
}
inline TaskBase* StaticThreadPool::ThreadState::pop() {
std::unique_lock lock{mutex_};
while (queue_.empty()) {
if (stop_requested_) {
return nullptr;
}
cv_.wait(lock);
}
return queue_.pop_front();
}
inline bool StaticThreadPool::ThreadState::try_push(TaskBase* task) {
bool was_empty{};
{
std::unique_lock lock{mutex_, std::try_to_lock};
if (!lock) {
return false;
}
was_empty = queue_.empty();
queue_.push_back(task);
}
if (was_empty) {
cv_.notify_one();
}
return true;
}
inline void StaticThreadPool::ThreadState::push(TaskBase* task) {
bool was_empty{};
{
std::lock_guard lock{mutex_};
was_empty = queue_.empty();
queue_.push_back(task);
}
if (was_empty) {
cv_.notify_one();
}
}
inline void StaticThreadPool::ThreadState::request_stop() {
{
std::lock_guard lock{mutex_};
stop_requested_ = true;
}
cv_.notify_one();
}
namespace __bulk {
template <typename CvrefSender, typename Shape, typename Func, typename Receiver>
struct _Operation {
struct type;
};
template <typename CvrefSender, typename Shape, typename Func, typename Receiver>
using operation_t = typename _Operation<CvrefSender, Shape, Func, Receiver>::type;
template <typename Receiver, typename Shape, typename Func, typename Tuple>
struct _Receiver {
struct type;
};
template <typename Receiver, typename Shape, typename Func, typename Tuple>
using receiver_t = typename _Receiver<remove_cvref_t<Receiver>, Shape, Func, Tuple>::type;
template <typename Receiver, typename Shape, typename Func, typename Tuple>
struct _Receiver<Receiver, Shape, Func, Tuple>::type {
struct State {
Receiver receiver_;
Shape shape_;
Func func_;
std::optional<Tuple> values_;
Scheduler scheduler_;
std::atomic<Shape> count_;
};
std::shared_ptr<State> state_;
type(Receiver&& receiver, Shape shape, Func func, Scheduler scheduler)
: state_(new State{(Receiver &&) receiver, shape, (Func &&) func, std::nullopt, scheduler,
shape}) {}
template <typename... As>
friend void tag_invoke(set_value_t, type&& self, As&&... as) noexcept {
auto& state = self.state_;
state->values_.emplace((As &&) as...);
for (Shape index = {}; index < state->shape_; ++index) {
StartDetached(Then(Schedule(state->scheduler_), [state, index] {
std::apply([&](auto&... vals) { state->func_(index, vals...); }, state->values_.value());
if (0 == --state->count_) {
std::apply(
[&](auto&... vals) { SetValue(std::move(state->receiver_), std::move(vals)...); },
state->values_.value());
}
return 0;
}));
}
}
};
template <typename Sender, typename Shape, typename Func>
struct _Sender {
struct type;
};
template <typename Sender, typename Shape, typename Func>
using sender_t = typename _Sender<remove_cvref_t<Sender>, remove_cvref_t<Shape>, Func>::type;
template <typename Sender, typename Shape, typename Func>
struct _Sender<Sender, Shape, Func>::type {
using value_types = completion_signatures_of_t<Sender>;
template <typename Receiver>
using _receiver_t = receiver_t<Receiver, Shape, Func, value_types>;
Sender sender_;
Scheduler scheduler_;
Shape shape_;
Func func_;
template <typename Self, typename Receiver, _decays_to<Self, type, int> = 0>
friend auto tag_invoke(connect_t, Self&& self, Receiver&& receiver) {
return Connect(((Self &&) self).sender_,
_receiver_t<Receiver>{(Receiver &&) receiver, ((Self &&) self).shape_,
((Self &&) self).func_, ((Self &&) self).scheduler_});
}
};
} // namespace __bulk
template <typename Sender, typename Shape, typename Func>
__bulk::sender_t<Sender, Shape, Func> tag_invoke(bulk_t, Scheduler scheduler, Sender&& sender,
Shape&& shape, Func&& func) {
return {(Sender &&) sender, scheduler, (Shape &&) shape, (Func &&) func};
}
} // namespace __static_thread_pool
using __static_thread_pool::StaticThreadPool;
} // namespace mmdeploy
#endif // MMDEPLOY_CSRC_EXPERIMENTAL_EXECUTION_STATIC_THREAD_POOL_H_