#include "canonical.h" #include #include #include #include #include #include "sir/analyze/module/side_effect.h" #include "sir/transform/rewrite.h" #include "sir/util/irtools.h" namespace seq { namespace ir { namespace transform { namespace cleanup { namespace { struct NodeRanker : public util::Operator { // Nodes are ranked lexicographically by: // - Whether the node is constant (constants come last) // - Max node depth (deeper nodes first) // - Node hash // The hash imposes an arbitrary but well-defined ordering // to ensure a single canonical representation for (most) // nodes. using Rank = std::tuple; Node *root = nullptr; int maxDepth = 0; uint64_t hash = 0; // boost's hash_combine template void hash_combine(const T &v) { std::hash hasher; hash ^= hasher(v) + 0x9e3779b9 + (hash << 6) + (hash >> 2); } void preHook(Node *node) { if (!root) root = node; maxDepth = std::max(maxDepth, depth()); for (auto *v : node->getUsedVariables()) { hash_combine(v->getName()); } for (auto *v : node->getUsedTypes()) { hash_combine(v->getName()); } } Rank getRank() { return std::make_tuple((isA(root) ? 1 : -1), -maxDepth, hash); } }; NodeRanker::Rank getRank(Node *node) { NodeRanker ranker; node->accept(ranker); return ranker.getRank(); } bool isCommutativeOp(Func *fn) { return fn && util::hasAttribute(fn, "std.internal.attributes.commutative"); } bool isAssociativeOp(Func *fn) { return fn && util::hasAttribute(fn, "std.internal.attributes.associative"); } bool isDistributiveOp(Func *fn) { return fn && util::hasAttribute(fn, "std.internal.attributes.distributive"); } bool isInequalityOp(Func *fn) { static const std::unordered_set ops = { Module::EQ_MAGIC_NAME, Module::NE_MAGIC_NAME, Module::LT_MAGIC_NAME, Module::LE_MAGIC_NAME, Module::GT_MAGIC_NAME, Module::GE_MAGIC_NAME}; return fn && ops.find(fn->getUnmangledName()) != ops.end(); } // c + b + a --> a + b + c struct CanonOpChain : public RewriteRule { static void extractAssociativeOpChain(Value *v, const std::string &op, types::Type *type, std::vector &result) { if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) { auto *call = cast(v); extractAssociativeOpChain(call->front(), op, type, result); extractAssociativeOpChain(call->back(), op, type, result); } else { result.push_back(v); } } static void orderOperands(std::vector &operands) { std::vector> rankedOperands; for (auto *v : operands) { rankedOperands.push_back({getRank(v), v}); } std::sort(rankedOperands.begin(), rankedOperands.end()); operands.clear(); for (auto &p : rankedOperands) { operands.push_back(std::get<1>(p)); } } void visit(CallInstr *v) override { auto *fn = util::getFunc(v->getCallee()); if (!fn) return; std::string op = fn->getUnmangledName(); types::Type *type = v->getType(); const bool isAssociative = isAssociativeOp(fn); const bool isCommutative = isCommutativeOp(fn); if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) { std::vector operands; if (isAssociative) { extractAssociativeOpChain(v, op, type, operands); } else { operands.push_back(v->front()); operands.push_back(v->back()); } seqassert(operands.size() >= 2, "bad call canonicalization"); if (isCommutative) orderOperands(operands); Value *newCall = util::call(fn, {operands[0], operands[1]}); for (auto it = operands.begin() + 2; it != operands.end(); ++it) { newCall = util::call(fn, {newCall, *it}); } return setResult(newCall); } } }; // b > a --> a < b (etc.) struct CanonInequality : public RewriteRule { void visit(CallInstr *v) override { auto *fn = util::getFunc(v->getCallee()); if (!fn) return; std::string op = fn->getUnmangledName(); types::Type *type = v->getType(); // canonicalize inequalities if (v->numArgs() == 2 && isInequalityOp(fn)) { Value *newCall = nullptr; auto *lhs = v->front(); auto *rhs = v->back(); if (getRank(lhs) > getRank(rhs)) { // are we out of order? // re-order if (op == Module::EQ_MAGIC_NAME) { // lhs == rhs newCall = *rhs == *lhs; } else if (op == Module::NE_MAGIC_NAME) { // lhs != rhs newCall = *rhs != *lhs; } else if (op == Module::LT_MAGIC_NAME) { // lhs < rhs newCall = *rhs > *lhs; } else if (op == Module::LE_MAGIC_NAME) { // lhs <= rhs newCall = *rhs >= *lhs; } else if (op == Module::GT_MAGIC_NAME) { // lhs > rhs newCall = *rhs < *lhs; } else if (op == Module::GE_MAGIC_NAME) { // lhs >= rhs newCall = *rhs <= *lhs; } else { seqassert(false, "unknown comparison op: {}", op); } if (newCall && newCall->getType()->is(type)) return setResult(newCall); } } } }; // a*x + b*x --> (a + b) * x struct CanonAddMul : public RewriteRule { static bool varMatch(Value *a, Value *b) { auto *v1 = cast(a); auto *v2 = cast(b); return v1 && v2 && v1->getVar()->getId() == v2->getVar()->getId(); } static Func *getOp(Value *v) { return isA(v) ? util::getFunc(cast(v)->getCallee()) : nullptr; } // (a + b) * x, or null if invalid static Value *addMul(Value *a, Value *b, Value *x) { if (!a || !b || !x) return nullptr; auto *y = (*a + *b); if (!y) { y = (*b + *a); if (y && !isCommutativeOp(getOp(y))) return nullptr; } if (!y) return nullptr; auto *z = (*y) * (*x); if (!z) { z = (*x) * (*y); if (z && !isCommutativeOp(getOp(z))) return nullptr; } if (!z) return nullptr; return z; } void visit(CallInstr *v) override { auto *M = v->getModule(); auto *fn = util::getFunc(v->getCallee()); if (!isCommutativeOp(fn) || !util::isCallOf(v, Module::ADD_MAGIC_NAME, 2, /*output=*/nullptr, /*method=*/true)) return; // decompose the operation Value *lhs = v->front(); Value *rhs = v->back(); Value *lhs1 = nullptr, *lhs2 = nullptr, *rhs1 = nullptr, *rhs2 = nullptr; if (util::isCallOf(lhs, Module::MUL_MAGIC_NAME, 2, /*output=*/nullptr, /*method=*/true)) { auto *lhsCall = cast(lhs); lhs1 = lhsCall->front(); lhs2 = lhsCall->back(); } else { lhs1 = lhs; lhs2 = M->getInt(1); } if (util::isCallOf(rhs, Module::MUL_MAGIC_NAME, 2, /*output=*/nullptr, /*method=*/true)) { auto *rhsCall = cast(rhs); rhs1 = rhsCall->front(); rhs2 = rhsCall->back(); } else { rhs1 = rhs; rhs2 = M->getInt(1); } Value *newCall = nullptr; if (varMatch(lhs1, rhs1)) { newCall = addMul(lhs2, rhs2, lhs1); } else if (varMatch(lhs1, rhs2)) { newCall = addMul(lhs2, rhs1, lhs1); } else if (varMatch(lhs2, rhs1)) { newCall = addMul(lhs1, rhs2, lhs2); } else if (varMatch(lhs2, rhs2)) { newCall = addMul(lhs1, rhs1, lhs2); } if (newCall && isDistributiveOp(getOp(newCall)) && newCall->getType()->is(v->getType())) return setResult(newCall); } }; // x - c --> x + (-c) struct CanonConstSub : public RewriteRule { void visit(CallInstr *v) override { auto *M = v->getModule(); auto *type = v->getType(); if (!util::isCallOf(v, Module::SUB_MAGIC_NAME, 2, /*output=*/nullptr, /*method=*/true)) return; Value *lhs = v->front(); Value *rhs = v->back(); Value *newCall = nullptr; if (util::isConst(rhs)) { auto c = util::getConst(rhs); if (c != -(static_cast(1) << 63)) // ensure no overflow newCall = *lhs + *(M->getInt(-c)); } else if (util::isConst(rhs)) { auto c = util::getConst(rhs); newCall = *lhs + *(M->getFloat(-c)); } if (newCall && newCall->getType()->is(type)) return setResult(newCall); } }; } // namespace const std::string CanonicalizationPass::KEY = "core-cleanup-canon"; void CanonicalizationPass::run(Module *m) { registerStandardRules(m); Rewriter::reset(); OperatorPass::run(m); } void CanonicalizationPass::handle(CallInstr *v) { auto *r = getAnalysisResult(sideEffectsKey); if (!r->hasSideEffect(v)) rewrite(v); } void CanonicalizationPass::handle(SeriesFlow *v) { auto it = v->begin(); while (it != v->end()) { if (auto *series = cast(*it)) { it = v->erase(it); for (auto *x : *series) { it = v->insert(it, x); ++it; } } else if (auto *flowInstr = cast(*it)) { it = v->erase(it); // inserting in reverse order causes [flow, value] to be added it = v->insert(it, flowInstr->getValue()); it = v->insert(it, flowInstr->getFlow()); // don't increment; re-traverse in case a new series flow added } else { ++it; } } } void CanonicalizationPass::registerStandardRules(Module *m) { registerRule("op-chain", std::make_unique()); registerRule("inequality", std::make_unique()); registerRule("add-mul", std::make_unique()); registerRule("const-sub", std::make_unique()); } } // namespace cleanup } // namespace transform } // namespace ir } // namespace seq