1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/compiler/sir/util/matching.cpp
2021-09-27 14:02:44 -04:00

332 lines
12 KiB
C++

#include "matching.h"
#include <algorithm>
#include "sir/sir.h"
#include "visitor.h"
#define VISIT(x) \
void visit(const x *v) override { \
if (matchAny || dynamic_cast<const util::Any *>(v)) { \
result = true; \
matchAny = true; \
} else if (!nodeId) { \
nodeId = &x::NodeId; \
other = v; \
} else if (nodeId != &x::NodeId || \
(!checkName && v->getName() != other->getName())) \
result = false; \
else \
handle(v, static_cast<const x *>(other)); \
}
namespace seq {
namespace ir {
namespace util {
namespace {
class MatchVisitor : public util::ConstVisitor {
private:
bool matchAny = false;
bool checkName;
const char *nodeId = nullptr;
bool result = false;
const Node *other = nullptr;
bool varIdMatch;
public:
explicit MatchVisitor(bool checkName = false, bool varIdMatch = false)
: checkName(checkName), varIdMatch(varIdMatch) {}
VISIT(Var);
void handle(const Var *x, const Var *y) { result = compareVars(x, y); }
VISIT(Func);
void handle(const Func *x, const Func *y) {}
VISIT(BodiedFunc);
void handle(const BodiedFunc *x, const BodiedFunc *y) {
result = compareFuncs(x, y) &&
std::equal(x->begin(), x->end(), y->begin(), y->end(),
[this](auto *x, auto *y) { return process(x, y); }) &&
process(x->getBody(), y->getBody()) && x->isBuiltin() == y->isBuiltin();
}
VISIT(ExternalFunc);
void handle(const ExternalFunc *x, const ExternalFunc *y) {
result = x->getUnmangledName() == y->getUnmangledName() && compareFuncs(x, y);
}
VISIT(InternalFunc);
void handle(const InternalFunc *x, const InternalFunc *y) {
result = x->getParentType() == y->getParentType() && compareFuncs(x, y);
}
VISIT(LLVMFunc);
void handle(const LLVMFunc *x, const LLVMFunc *y) {
result = std::equal(x->literal_begin(), x->literal_end(), y->literal_begin(),
y->literal_end(),
[this](auto &x, auto &y) {
if (x.isStatic() && y.isStatic())
return x.getStaticValue() == y.getStaticValue();
else if (x.isType() && y.isType())
return process(x.getTypeValue(), y.getTypeValue());
return false;
}) &&
x->getLLVMDeclarations() == y->getLLVMDeclarations() &&
x->getLLVMBody() == y->getLLVMBody() && compareFuncs(x, y);
}
VISIT(Value);
void handle(const Value *x, const Value *y) {}
VISIT(VarValue);
void handle(const VarValue *x, const VarValue *y) {
result = compareVars(x->getVar(), y->getVar());
}
VISIT(PointerValue);
void handle(const PointerValue *x, const PointerValue *y) {
result = compareVars(x->getVar(), y->getVar());
}
VISIT(Flow);
void handle(const Flow *x, const Flow *y) {}
VISIT(SeriesFlow);
void handle(const SeriesFlow *x, const SeriesFlow *y) {
result = std::equal(x->begin(), x->end(), y->begin(), y->end(),
[this](auto *x, auto *y) { return process(x, y); });
}
VISIT(IfFlow);
void handle(const IfFlow *x, const IfFlow *y) {
result = process(x->getCond(), y->getCond()) &&
process(x->getTrueBranch(), y->getTrueBranch()) &&
process(x->getFalseBranch(), y->getFalseBranch());
}
VISIT(WhileFlow);
void handle(const WhileFlow *x, const WhileFlow *y) {
result = process(x->getCond(), y->getCond()) && process(x->getBody(), y->getBody());
}
VISIT(ForFlow);
void handle(const ForFlow *x, const ForFlow *y) {
result = process(x->getIter(), y->getIter()) &&
process(x->getBody(), y->getBody()) && process(x->getVar(), y->getVar());
}
VISIT(ImperativeForFlow);
void handle(const ImperativeForFlow *x, const ImperativeForFlow *y) {
result = process(x->getVar(), y->getVar()) && process(x->getBody(), y->getBody()) &&
process(x->getStart(), y->getStart()) && x->getStep() == y->getStep() &&
process(x->getEnd(), y->getEnd());
}
VISIT(TryCatchFlow);
void handle(const TryCatchFlow *x, const TryCatchFlow *y) {
result = result && process(x->getFinally(), y->getFinally()) &&
process(x->getBody(), y->getBody()) &&
std::equal(x->begin(), x->end(), y->begin(), y->end(),
[this](auto &x, auto &y) {
return process(x.getHandler(), y.getHandler()) &&
process(x.getType(), y.getType()) &&
process(x.getVar(), y.getVar());
});
}
VISIT(PipelineFlow);
void handle(const PipelineFlow *x, const PipelineFlow *y) {
result = std::equal(
x->begin(), x->end(), y->begin(), y->end(), [this](auto &x, auto &y) {
return process(x.getCallee(), y.getCallee()) &&
std::equal(x.begin(), x.end(), y.begin(), y.end(),
[this](auto *x, auto *y) { return process(x, y); }) &&
x.isGenerator() == y.isGenerator() && x.isParallel() == y.isParallel();
});
}
VISIT(dsl::CustomFlow);
void handle(const dsl::CustomFlow *x, const dsl::CustomFlow *y) {
result = x->match(y);
}
VISIT(IntConst);
void handle(const IntConst *x, const IntConst *y) {
result = process(x->getType(), y->getType()) && x->getVal() == y->getVal();
}
VISIT(FloatConst);
void handle(const FloatConst *x, const FloatConst *y) {
result = process(x->getType(), y->getType()) && x->getVal() == y->getVal();
}
VISIT(BoolConst);
void handle(const BoolConst *x, const BoolConst *y) {
result = process(x->getType(), y->getType()) && x->getVal() == y->getVal();
}
VISIT(StringConst);
void handle(const StringConst *x, const StringConst *y) {
result = process(x->getType(), y->getType()) && x->getVal() == y->getVal();
}
VISIT(dsl::CustomConst);
void handle(const dsl::CustomConst *x, const dsl::CustomConst *y) {
result = x->match(y);
}
VISIT(AssignInstr);
void handle(const AssignInstr *x, const AssignInstr *y) {
result = process(x->getLhs(), y->getLhs()) && process(x->getRhs(), y->getRhs());
}
VISIT(ExtractInstr);
void handle(const ExtractInstr *x, const ExtractInstr *y) {
result = process(x->getVal(), y->getVal()) && x->getField() == y->getField();
}
VISIT(InsertInstr);
void handle(const InsertInstr *x, const InsertInstr *y) {
result = process(x->getLhs(), y->getLhs()) && x->getField() == y->getField() &&
process(x->getRhs(), y->getRhs());
}
VISIT(CallInstr);
void handle(const CallInstr *x, const CallInstr *y) {
result = process(x->getCallee(), y->getCallee()) &&
std::equal(x->begin(), x->end(), y->begin(), y->end(),
[this](auto *x, auto *y) { return process(x, y); });
}
VISIT(StackAllocInstr);
void handle(const StackAllocInstr *x, const StackAllocInstr *y) {
result = x->getCount() == y->getCount() && process(x->getType(), y->getType());
}
VISIT(TypePropertyInstr);
void handle(const TypePropertyInstr *x, const TypePropertyInstr *y) {
result = x->getProperty() == y->getProperty() &&
process(x->getInspectType(), y->getInspectType());
}
VISIT(YieldInInstr);
void handle(const YieldInInstr *x, const YieldInInstr *y) {
result = process(x->getType(), y->getType());
}
VISIT(TernaryInstr);
void handle(const TernaryInstr *x, const TernaryInstr *y) {
result = process(x->getCond(), y->getCond()) &&
process(x->getTrueValue(), y->getTrueValue()) &&
process(x->getFalseValue(), y->getFalseValue());
}
VISIT(BreakInstr);
void handle(const BreakInstr *x, const BreakInstr *y) {
result = process(x->getLoop(), y->getLoop());
}
VISIT(ContinueInstr);
void handle(const ContinueInstr *x, const ContinueInstr *y) {
result = process(x->getLoop(), y->getLoop());
}
VISIT(ReturnInstr);
void handle(const ReturnInstr *x, const ReturnInstr *y) {
result = process(x->getValue(), y->getValue());
}
VISIT(YieldInstr);
void handle(const YieldInstr *x, const YieldInstr *y) {
result = process(x->getValue(), y->getValue());
}
VISIT(ThrowInstr);
void handle(const ThrowInstr *x, const ThrowInstr *y) {
result = process(x->getValue(), y->getValue());
}
VISIT(FlowInstr);
void handle(const FlowInstr *x, const FlowInstr *y) {
result =
process(x->getFlow(), y->getFlow()) && process(x->getValue(), y->getValue());
}
VISIT(dsl::CustomInstr);
void handle(const dsl::CustomInstr *x, const dsl::CustomInstr *y) {
result = x->match(y);
}
VISIT(types::Type);
void handle(const types::Type *x, const types::Type *y) {}
VISIT(types::IntType);
void handle(const types::IntType *, const types::IntType *) { result = true; }
VISIT(types::FloatType);
void handle(const types::FloatType *, const types::FloatType *) { result = true; }
VISIT(types::BoolType);
void handle(const types::BoolType *, const types::BoolType *) { result = true; }
VISIT(types::ByteType);
void handle(const types::ByteType *, const types::ByteType *) { result = true; }
VISIT(types::VoidType);
void handle(const types::VoidType *, const types::VoidType *) { result = true; }
VISIT(types::RecordType);
void handle(const types::RecordType *x, const types::RecordType *y) {
result = std::equal(
x->begin(), x->end(), y->begin(), y->end(), [this](auto &x, auto &y) {
return x.getName() == y.getName() && process(x.getType(), y.getType());
});
}
VISIT(types::RefType);
void handle(const types::RefType *x, const types::RefType *y) {
result = process(x->getContents(), y->getContents());
}
VISIT(types::FuncType);
void handle(const types::FuncType *x, const types::FuncType *y) {
result = process(x->getReturnType(), y->getReturnType()) &&
std::equal(x->begin(), x->end(), y->begin(), y->end(),
[this](auto *x, auto *y) { return process(x, y); });
}
VISIT(types::OptionalType);
void handle(const types::OptionalType *x, const types::OptionalType *y) {
result = process(x->getBase(), y->getBase());
}
VISIT(types::PointerType);
void handle(const types::PointerType *x, const types::PointerType *y) {
result = process(x->getBase(), y->getBase());
}
VISIT(types::GeneratorType);
void handle(const types::GeneratorType *x, const types::GeneratorType *y) {
result = process(x->getBase(), y->getBase());
}
VISIT(types::IntNType);
void handle(const types::IntNType *x, const types::IntNType *y) {
result = x->getLen() == y->getLen() && x->isSigned() == y->isSigned();
}
VISIT(dsl::types::CustomType);
void handle(const dsl::types::CustomType *x, const dsl::types::CustomType *y) {
result = x->match(y);
}
bool process(const Node *x, const Node *y) const {
if (!x && !y)
return true;
else if ((!x && y) || (x && !y))
return false;
MatchVisitor v(checkName);
x->accept(v);
y->accept(v);
return v.result;
}
private:
bool compareVars(const Var *x, const Var *y) const {
return process(x->getType(), y->getType()) &&
(!varIdMatch || x->getId() == y->getId());
}
bool compareFuncs(const Func *x, const Func *y) const {
if (!compareVars(x, y))
return false;
if (!std::equal(x->arg_begin(), x->arg_end(), y->arg_begin(), y->arg_end(),
[this](auto *x, auto *y) { return process(x, y); }))
return false;
return true;
}
};
} // namespace
const char AnyType::NodeId = 0;
const char AnyValue::NodeId = 0;
const char AnyFlow::NodeId = 0;
const char AnyVar::NodeId = 0;
const char AnyFunc::NodeId = 0;
bool match(Node *a, Node *b, bool checkNames, bool varIdMatch) {
return MatchVisitor(checkNames).process(a, b);
}
} // namespace util
} // namespace ir
} // namespace seq
#undef VISIT