mmdeploy/csrc/graph/task.cpp
lzhangzz 46bfe0ac87
[Feature] New pipeline & executor for SDK (#497)
* executor prototype

* add split/when_all

* fix GCC build

* WIP let_value

* fix let_value

* WIP ensure_started

* ensure_started & start_detached

* fix let_value + when_all combo on MSVC 142

* fix static thread pool

* generic just, then, let_value, sync_wait

* minor

* generic split and when_all

* fully generic sender adapters

* when_all: workaround for GCC7

* support legacy spdlog

* fix memleak

* bulk

* static detector

* fix bulk & first pipeline

* bulk for static thread pools

* fix on MSVC

* WIP async batch submission

* WIP collation

* async batch

* fix detector

* fix async detector

* fix

* fix

* debug

* fix cuda allocator

* WIP type erased executor

* better type erasure

* simplify C API impl

* Expand & type erase TC

* deduction guide for type erased senders

* fix GCC build

* when_all for arrays of Value senders

* WIP pipeline v2

* WIP pipeline parser

* WIP timed batch operation

* add registry

* experiment

* fix pipeline

* naming

* fix mem-leak

* fix deferred batch operation

* WIP

* WIP configurable scheduler

* WIP configurable scheduler

* add comment

* parse scheduler config

* force link schedulers

* WIP pipeable sender

* WIP CPO

* ADL isolation and dismantle headers

* type erase single thread context

* fix MSVC build

* CPO

* replace decay_t with remove_cvref_t

* structure adjustment

* structure adjustment

* apply CPOs & C API rework

* refine C API

* detector async C API

* adjust detector async C API

* # Conflicts:
#	csrc/apis/c/detector.cpp

* fix when_all for type erased senders

* support void return for Then

* async detector

* fix some CPOs

* minor

* WIP rework capture mechanism for type erased types

* minor fix

* fix MSVC build

* move expand.h to execution

* make `Expand` pipeable

* fix type erased

* un-templatize `_TypeErasedOperation`

* re-work C API

* remove async_detector C API

* fix pipeline

* add flatten & unflatten

* fix flatten & unflatten

* add aync OCR demo

* config executor for nodes & better executor API

* working async OCR example

* minor

* dynamic batch via scheduler

* dynamic batch on `Value`

* fix MSVC build

* type erase dynamic batch scheduler

* sender as Python Awaitable

* naming

* naming

* add docs

* minor

* merge tmp branch

* unify C APIs

* fix ocr

* unify APIs

* fix typo

* update async OCR demo

* add v3 API text recognizer

* fix v3 API

* fix lint

* add license info & reformat

* add demo async_ocr_v2

* revert files

* revert files

* resolve link issues

* fix scheduler linkage for shared libs

* fix license header

* add docs for `mmdeploy_executor_split`

* add missing `mmdeploy_executor_transfer_just` and `mmdeploy_executor_execute`

* make `TimedSingleThreadContext` header only

* fix lint

* simplify type-erased sender
2022-06-01 14:10:43 +08:00

84 lines
2.8 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include "graph/task.h"
#include "core/operator.h"
#include "graph/common.h"
namespace mmdeploy::graph {
Sender<Value> Task::Process(Sender<Value> input) {
return LetValue(std::move(input), [this](Value& v) -> Sender<Value> {
assert(v.is_array());
// handle empty input
if (v.front().empty()) {
return TransferJust(*sched_, Value(Value::Array(v.size(), Value::kArray)));
}
if (v.front().is_array() && !is_batched_) {
auto batch_size = v.front().size();
Value output = Value::Array(batch_size);
// clang-format off
return TransferJust(*sched_, std::move(output))
| Then([&](Value&& output) -> Value {
auto input = graph::DistribAA(v).value();
return Value{std::move(input), std::move(output)};
})
| Bulk(batch_size, [&](size_t index, Value& in_out) {
const auto& input = in_out[0];
auto& output = in_out[1];
output[index] = module_->Process(input[index]).value();
})
| Then([](const Value& in_out) {
return graph::DistribAA(in_out[1]).value();
});
// clang-format on
} else {
return DynamicBatch(TransferJust(*sched_, std::move(v)), batch_context_,
[&](const Value& u) { return module_->Process(u).value(); });
}
});
}
Result<unique_ptr<Task>> TaskParser::Parse(const Value& config) {
try {
auto task = std::make_unique<Task>();
OUTCOME_TRY(NodeParser::Parse(config, *task));
OUTCOME_TRY(task->module_, CreateFromRegistry<Module>(config, "module"));
bool sched_set = false;
if (config["context"].contains("executor")) {
auto& exec_info = config["context"]["executor"];
for (auto it = exec_info.begin(); it != exec_info.end(); ++it) {
if (it.key() == task->name()) {
task->sched_ = it->get<TypeErasedScheduler<Value>>();
sched_set = true;
MMDEPLOY_INFO("scheduler configured for task {}", task->name());
break;
}
}
}
if (!sched_set) {
task->sched_ =
TypeErasedScheduler<Value>{std::make_shared<TypeErasedScheduler<Value>::Impl>()};
}
task->is_batched_ = config.value("is_batched", false);
task->is_thread_safe_ = config.value("is_thread_safe", false);
return std::move(task);
} catch (const std::exception& e) {
MMDEPLOY_ERROR("error parsing config: {}", config);
return nullptr;
}
}
class TaskCreator : public Creator<Node> {
public:
const char* GetName() const override { return "Task"; }
int GetVersion() const override { return 0; }
std::unique_ptr<Node> Create(const Value& value) override {
return TaskParser::Parse(value).value();
}
};
REGISTER_MODULE(Node, TaskCreator);
} // namespace mmdeploy::graph