#include "flow.h" #include "util/iterators.h" #include "util/fmt/ostream.h" #include "module.h" namespace seq { namespace ir { namespace { int findAndReplace(id_t id, seq::ir::Value *newVal, std::list &values) { auto replacements = 0; for (auto &value : values) { if (value->getId() == id) { value = newVal; ++replacements; } } return replacements; } } // namespace const char Flow::NodeId = 0; types::Type *Flow::doGetType() const { return getModule()->getVoidType(); } const char SeriesFlow::NodeId = 0; int SeriesFlow::doReplaceUsedValue(id_t id, Value *newValue) { return findAndReplace(id, newValue, series); } const char WhileFlow::NodeId = 0; int WhileFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (cond->getId() == id) { cond = newValue; ++replacements; } if (body->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); body = f; ++replacements; } return replacements; } const char ForFlow::NodeId = 0; std::vector ForFlow::doGetUsedValues() const { std::vector ret; if (isParallel()) ret = getSchedule()->getUsedValues(); ret.push_back(iter); ret.push_back(body); return ret; } int ForFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto count = 0; if (isParallel()) count += getSchedule()->replaceUsedValue(id, newValue); if (iter->getId() == id) { iter = newValue; ++count; } if (body->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); body = f; ++count; } return count; } int ForFlow::doReplaceUsedVariable(id_t id, Var *newVar) { if (var->getId() == id) { var = newVar; return 1; } return 0; } const char ImperativeForFlow::NodeId = 0; std::vector ImperativeForFlow::doGetUsedValues() const { std::vector ret; if (isParallel()) ret = getSchedule()->getUsedValues(); ret.push_back(start); ret.push_back(end); ret.push_back(body); return ret; } int ImperativeForFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto count = 0; if (isParallel()) count += getSchedule()->replaceUsedValue(id, newValue); if (body->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); body = f; ++count; } if (start->getId() == id) { start = newValue; ++count; } if (end->getId() == id) { end = newValue; ++count; } return count; } int ImperativeForFlow::doReplaceUsedVariable(id_t id, Var *newVar) { if (var->getId() == id) { var = newVar; return 1; } return 0; } const char IfFlow::NodeId = 0; std::vector IfFlow::doGetUsedValues() const { std::vector ret = {cond, trueBranch}; if (falseBranch) ret.push_back(falseBranch); return ret; } int IfFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (cond->getId() == id) { cond = newValue; ++replacements; } if (trueBranch->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); trueBranch = f; ++replacements; } if (falseBranch && falseBranch->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); falseBranch = f; ++replacements; } return replacements; } const char TryCatchFlow::NodeId = 0; std::vector TryCatchFlow::doGetUsedValues() const { std::vector ret = {body}; if (finally) ret.push_back(finally); for (auto &c : catches) ret.push_back(const_cast(static_cast(c.getHandler()))); return ret; } int TryCatchFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; if (body->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); body = f; ++replacements; } if (finally && finally->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); finally = f; ++replacements; } for (auto &c : catches) { if (c.getHandler()->getId() == id) { auto *f = cast(newValue); seqassert(f, "{} is not a flow", *newValue); c.setHandler(f); ++replacements; } } return replacements; } std::vector TryCatchFlow::doGetUsedTypes() const { std::vector ret; for (auto &c : catches) { if (auto *t = c.getType()) ret.push_back(const_cast(t)); } return ret; } int TryCatchFlow::doReplaceUsedType(const std::string &name, types::Type *newType) { auto count = 0; for (auto &c : catches) { if (c.getType()->getName() == name) { c.setType(newType); ++count; } } return count; } std::vector TryCatchFlow::doGetUsedVariables() const { std::vector ret; for (auto &c : catches) { if (auto *t = c.getVar()) ret.push_back(const_cast(t)); } return ret; } int TryCatchFlow::doReplaceUsedVariable(id_t id, Var *newVar) { auto count = 0; for (auto &c : catches) { if (c.getVar()->getId() == id) { c.setVar(newVar); ++count; } } return count; } const char PipelineFlow::NodeId = 0; types::Type *PipelineFlow::Stage::getOutputType() const { if (args.empty()) { return callee->getType(); } else { auto *funcType = cast(callee->getType()); seqassert(funcType, "{} is not a function type", *callee->getType()); return funcType->getReturnType(); } } types::Type *PipelineFlow::Stage::getOutputElementType() const { if (isGenerator()) { types::GeneratorType *genType = nullptr; if (args.empty()) { genType = cast(callee->getType()); return genType->getBase(); } else { auto *funcType = cast(callee->getType()); seqassert(funcType, "{} is not a function type", *callee->getType()); genType = cast(funcType->getReturnType()); } seqassert(genType, "generator type not found"); return genType->getBase(); } else if (args.empty()) { return callee->getType(); } else { auto *funcType = cast(callee->getType()); seqassert(funcType, "{} is not a function type", *callee->getType()); return funcType->getReturnType(); } } std::vector PipelineFlow::doGetUsedValues() const { std::vector ret; for (auto &s : stages) { ret.push_back(const_cast(s.getCallee())); for (auto *arg : s.args) if (arg) ret.push_back(arg); } return ret; } int PipelineFlow::doReplaceUsedValue(id_t id, Value *newValue) { auto replacements = 0; for (auto &c : stages) { if (c.getCallee()->getId() == id) { c.setCallee(newValue); ++replacements; } for (auto &s : c.args) if (s && s->getId() == id) { s = newValue; ++replacements; } } return replacements; } } // namespace ir } // namespace seq