1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/compiler/sir/flow.cpp
2021-09-27 14:02:44 -04:00

307 lines
7.0 KiB
C++

#include "flow.h"
#include "util/iterators.h"
#include "util/fmt/ostream.h"
#include "module.h"
namespace seq {
namespace ir {
namespace {
int findAndReplace(id_t id, seq::ir::Value *newVal,
std::list<seq::ir::Value *> &values) {
auto replacements = 0;
for (auto &value : values) {
if (value->getId() == id) {
value = newVal;
++replacements;
}
}
return replacements;
}
} // namespace
const char Flow::NodeId = 0;
types::Type *Flow::doGetType() const { return getModule()->getVoidType(); }
const char SeriesFlow::NodeId = 0;
int SeriesFlow::doReplaceUsedValue(id_t id, Value *newValue) {
return findAndReplace(id, newValue, series);
}
const char WhileFlow::NodeId = 0;
int WhileFlow::doReplaceUsedValue(id_t id, Value *newValue) {
auto replacements = 0;
if (cond->getId() == id) {
cond = newValue;
++replacements;
}
if (body->getId() == id) {
auto *f = cast<Flow>(newValue);
seqassert(f, "{} is not a flow", *newValue);
body = f;
++replacements;
}
return replacements;
}
const char ForFlow::NodeId = 0;
std::vector<Value *> ForFlow::doGetUsedValues() const {
std::vector<Value *> ret;
if (isParallel())
ret = getSchedule()->getUsedValues();
ret.push_back(iter);
ret.push_back(body);
return ret;
}
int ForFlow::doReplaceUsedValue(id_t id, Value *newValue) {
auto count = 0;
if (isParallel())
count += getSchedule()->replaceUsedValue(id, newValue);
if (iter->getId() == id) {
iter = newValue;
++count;
}
if (body->getId() == id) {
auto *f = cast<Flow>(newValue);
seqassert(f, "{} is not a flow", *newValue);
body = f;
++count;
}
return count;
}
int ForFlow::doReplaceUsedVariable(id_t id, Var *newVar) {
if (var->getId() == id) {
var = newVar;
return 1;
}
return 0;
}
const char ImperativeForFlow::NodeId = 0;
std::vector<Value *> ImperativeForFlow::doGetUsedValues() const {
std::vector<Value *> ret;
if (isParallel())
ret = getSchedule()->getUsedValues();
ret.push_back(start);
ret.push_back(end);
ret.push_back(body);
return ret;
}
int ImperativeForFlow::doReplaceUsedValue(id_t id, Value *newValue) {
auto count = 0;
if (isParallel())
count += getSchedule()->replaceUsedValue(id, newValue);
if (body->getId() == id) {
auto *f = cast<Flow>(newValue);
seqassert(f, "{} is not a flow", *newValue);
body = f;
++count;
}
if (start->getId() == id) {
start = newValue;
++count;
}
if (end->getId() == id) {
end = newValue;
++count;
}
return count;
}
int ImperativeForFlow::doReplaceUsedVariable(id_t id, Var *newVar) {
if (var->getId() == id) {
var = newVar;
return 1;
}
return 0;
}
const char IfFlow::NodeId = 0;
std::vector<Value *> IfFlow::doGetUsedValues() const {
std::vector<Value *> ret = {cond, trueBranch};
if (falseBranch)
ret.push_back(falseBranch);
return ret;
}
int IfFlow::doReplaceUsedValue(id_t id, Value *newValue) {
auto replacements = 0;
if (cond->getId() == id) {
cond = newValue;
++replacements;
}
if (trueBranch->getId() == id) {
auto *f = cast<Flow>(newValue);
seqassert(f, "{} is not a flow", *newValue);
trueBranch = f;
++replacements;
}
if (falseBranch && falseBranch->getId() == id) {
auto *f = cast<Flow>(newValue);
seqassert(f, "{} is not a flow", *newValue);
falseBranch = f;
++replacements;
}
return replacements;
}
const char TryCatchFlow::NodeId = 0;
std::vector<Value *> TryCatchFlow::doGetUsedValues() const {
std::vector<Value *> ret = {body};
if (finally)
ret.push_back(finally);
for (auto &c : catches)
ret.push_back(const_cast<Value *>(static_cast<const Value *>(c.getHandler())));
return ret;
}
int TryCatchFlow::doReplaceUsedValue(id_t id, Value *newValue) {
auto replacements = 0;
if (body->getId() == id) {
auto *f = cast<Flow>(newValue);
seqassert(f, "{} is not a flow", *newValue);
body = f;
++replacements;
}
if (finally && finally->getId() == id) {
auto *f = cast<Flow>(newValue);
seqassert(f, "{} is not a flow", *newValue);
finally = f;
++replacements;
}
for (auto &c : catches) {
if (c.getHandler()->getId() == id) {
auto *f = cast<Flow>(newValue);
seqassert(f, "{} is not a flow", *newValue);
c.setHandler(f);
++replacements;
}
}
return replacements;
}
std::vector<types::Type *> TryCatchFlow::doGetUsedTypes() const {
std::vector<types::Type *> ret;
for (auto &c : catches) {
if (auto *t = c.getType())
ret.push_back(const_cast<types::Type *>(t));
}
return ret;
}
int TryCatchFlow::doReplaceUsedType(const std::string &name, types::Type *newType) {
auto count = 0;
for (auto &c : catches) {
if (c.getType()->getName() == name) {
c.setType(newType);
++count;
}
}
return count;
}
std::vector<Var *> TryCatchFlow::doGetUsedVariables() const {
std::vector<Var *> ret;
for (auto &c : catches) {
if (auto *t = c.getVar())
ret.push_back(const_cast<Var *>(t));
}
return ret;
}
int TryCatchFlow::doReplaceUsedVariable(id_t id, Var *newVar) {
auto count = 0;
for (auto &c : catches) {
if (c.getVar()->getId() == id) {
c.setVar(newVar);
++count;
}
}
return count;
}
const char PipelineFlow::NodeId = 0;
types::Type *PipelineFlow::Stage::getOutputType() const {
if (args.empty()) {
return callee->getType();
} else {
auto *funcType = cast<types::FuncType>(callee->getType());
seqassert(funcType, "{} is not a function type", *callee->getType());
return funcType->getReturnType();
}
}
types::Type *PipelineFlow::Stage::getOutputElementType() const {
if (isGenerator()) {
types::GeneratorType *genType = nullptr;
if (args.empty()) {
genType = cast<types::GeneratorType>(callee->getType());
return genType->getBase();
} else {
auto *funcType = cast<types::FuncType>(callee->getType());
seqassert(funcType, "{} is not a function type", *callee->getType());
genType = cast<types::GeneratorType>(funcType->getReturnType());
}
seqassert(genType, "generator type not found");
return genType->getBase();
} else if (args.empty()) {
return callee->getType();
} else {
auto *funcType = cast<types::FuncType>(callee->getType());
seqassert(funcType, "{} is not a function type", *callee->getType());
return funcType->getReturnType();
}
}
std::vector<Value *> PipelineFlow::doGetUsedValues() const {
std::vector<Value *> ret;
for (auto &s : stages) {
ret.push_back(const_cast<Value *>(s.getCallee()));
for (auto *arg : s.args)
if (arg)
ret.push_back(arg);
}
return ret;
}
int PipelineFlow::doReplaceUsedValue(id_t id, Value *newValue) {
auto replacements = 0;
for (auto &c : stages) {
if (c.getCallee()->getId() == id) {
c.setCallee(newValue);
++replacements;
}
for (auto &s : c.args)
if (s && s->getId() == id) {
s = newValue;
++replacements;
}
}
return replacements;
}
} // namespace ir
} // namespace seq