// Copyright (c) OpenMMLab. All rights reserved. #include "graph/inference.h" #include "archive/json_archive.h" #include "archive/value_archive.h" #include "core/operator.h" #include "graph/common.h" namespace mmdeploy::graph { Inference::Inference(const Value& cfg) : BaseNode(cfg) { auto& model_value = cfg["params"]["model"]; if (model_value.is_any()) { model_ = model_value.get(); } else if (model_value.is_string()) { auto model_path = model_value.get(); model_ = Model(model_path); } else { MMDEPLOY_ERROR("unsupported model specification"); throw_exception(eInvalidArgument); } auto pipeline_json = model_.ReadFile("pipeline.json").value(); auto json = nlohmann::json::parse(pipeline_json); auto context = cfg.value("context", Value(ValueType::kObject)); context["model"] = model_; auto value = from_json(json); value["context"] = context; pipeline_ = std::make_unique(value); if (!pipeline_) { MMDEPLOY_ERROR("failed to create pipeline"); throw_exception(eFail); } } void Inference::Build(TaskGraph& graph) { pipeline_->Build(graph); } class InferenceNodeCreator : public Creator { public: const char* GetName() const override { return "Inference"; } int GetVersion() const override { return 0; } std::unique_ptr Create(const Value& value) override { return std::make_unique(value); } }; REGISTER_MODULE(Node, InferenceNodeCreator); } // namespace mmdeploy::graph