mirror of https://github.com/exaloop/codon.git
342 lines
9.7 KiB
C++
342 lines
9.7 KiB
C++
#include "canonical.h"
|
|
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <tuple>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
|
|
#include "sir/analyze/module/side_effect.h"
|
|
#include "sir/transform/rewrite.h"
|
|
#include "sir/util/irtools.h"
|
|
|
|
namespace codon {
|
|
namespace ir {
|
|
namespace transform {
|
|
namespace cleanup {
|
|
namespace {
|
|
struct NodeRanker : public util::Operator {
|
|
// Nodes are ranked lexicographically by:
|
|
// - Whether the node is constant (constants come last)
|
|
// - Max node depth (deeper nodes first)
|
|
// - Node hash
|
|
// The hash imposes an arbitrary but well-defined ordering
|
|
// to ensure a single canonical representation for (most)
|
|
// nodes.
|
|
using Rank = std::tuple<int, int, uint64_t>;
|
|
Node *root = nullptr;
|
|
int maxDepth = 0;
|
|
uint64_t hash = 0;
|
|
|
|
// boost's hash_combine
|
|
template <class T> void hash_combine(const T &v) {
|
|
std::hash<T> hasher;
|
|
hash ^= hasher(v) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
|
|
}
|
|
|
|
void preHook(Node *node) {
|
|
if (!root)
|
|
root = node;
|
|
maxDepth = std::max(maxDepth, depth());
|
|
for (auto *v : node->getUsedVariables()) {
|
|
hash_combine(v->getName());
|
|
}
|
|
for (auto *v : node->getUsedTypes()) {
|
|
hash_combine(v->getName());
|
|
}
|
|
}
|
|
|
|
Rank getRank() {
|
|
return std::make_tuple((isA<Const>(root) ? 1 : -1), -maxDepth, hash);
|
|
}
|
|
};
|
|
|
|
NodeRanker::Rank getRank(Node *node) {
|
|
NodeRanker ranker;
|
|
node->accept(ranker);
|
|
return ranker.getRank();
|
|
}
|
|
|
|
bool isCommutativeOp(Func *fn) {
|
|
return fn && util::hasAttribute(fn, "std.internal.attributes.commutative");
|
|
}
|
|
|
|
bool isAssociativeOp(Func *fn) {
|
|
return fn && util::hasAttribute(fn, "std.internal.attributes.associative");
|
|
}
|
|
|
|
bool isDistributiveOp(Func *fn) {
|
|
return fn && util::hasAttribute(fn, "std.internal.attributes.distributive");
|
|
}
|
|
|
|
bool isInequalityOp(Func *fn) {
|
|
static const std::unordered_set<std::string> ops = {
|
|
Module::EQ_MAGIC_NAME, Module::NE_MAGIC_NAME, Module::LT_MAGIC_NAME,
|
|
Module::LE_MAGIC_NAME, Module::GT_MAGIC_NAME, Module::GE_MAGIC_NAME};
|
|
return fn && ops.find(fn->getUnmangledName()) != ops.end();
|
|
}
|
|
|
|
// c + b + a --> a + b + c
|
|
struct CanonOpChain : public RewriteRule {
|
|
static void extractAssociativeOpChain(Value *v, const std::string &op,
|
|
types::Type *type,
|
|
std::vector<Value *> &result) {
|
|
if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) {
|
|
auto *call = cast<CallInstr>(v);
|
|
extractAssociativeOpChain(call->front(), op, type, result);
|
|
extractAssociativeOpChain(call->back(), op, type, result);
|
|
} else {
|
|
result.push_back(v);
|
|
}
|
|
}
|
|
|
|
static void orderOperands(std::vector<Value *> &operands) {
|
|
std::vector<std::pair<NodeRanker::Rank, Value *>> rankedOperands;
|
|
for (auto *v : operands) {
|
|
rankedOperands.push_back({getRank(v), v});
|
|
}
|
|
std::sort(rankedOperands.begin(), rankedOperands.end());
|
|
|
|
operands.clear();
|
|
for (auto &p : rankedOperands) {
|
|
operands.push_back(std::get<1>(p));
|
|
}
|
|
}
|
|
|
|
void visit(CallInstr *v) override {
|
|
auto *fn = util::getFunc(v->getCallee());
|
|
if (!fn)
|
|
return;
|
|
|
|
std::string op = fn->getUnmangledName();
|
|
types::Type *type = v->getType();
|
|
const bool isAssociative = isAssociativeOp(fn);
|
|
const bool isCommutative = isCommutativeOp(fn);
|
|
|
|
if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) {
|
|
std::vector<Value *> operands;
|
|
if (isAssociative) {
|
|
extractAssociativeOpChain(v, op, type, operands);
|
|
} else {
|
|
operands.push_back(v->front());
|
|
operands.push_back(v->back());
|
|
}
|
|
seqassert(operands.size() >= 2, "bad call canonicalization");
|
|
|
|
if (isCommutative)
|
|
orderOperands(operands);
|
|
|
|
Value *newCall = util::call(fn, {operands[0], operands[1]});
|
|
for (auto it = operands.begin() + 2; it != operands.end(); ++it) {
|
|
newCall = util::call(fn, {newCall, *it});
|
|
}
|
|
|
|
return setResult(newCall);
|
|
}
|
|
}
|
|
};
|
|
|
|
// b > a --> a < b (etc.)
|
|
struct CanonInequality : public RewriteRule {
|
|
void visit(CallInstr *v) override {
|
|
auto *fn = util::getFunc(v->getCallee());
|
|
if (!fn)
|
|
return;
|
|
|
|
std::string op = fn->getUnmangledName();
|
|
types::Type *type = v->getType();
|
|
|
|
// canonicalize inequalities
|
|
if (v->numArgs() == 2 && isInequalityOp(fn)) {
|
|
Value *newCall = nullptr;
|
|
auto *lhs = v->front();
|
|
auto *rhs = v->back();
|
|
if (getRank(lhs) > getRank(rhs)) { // are we out of order?
|
|
// re-order
|
|
if (op == Module::EQ_MAGIC_NAME) { // lhs == rhs
|
|
newCall = *rhs == *lhs;
|
|
} else if (op == Module::NE_MAGIC_NAME) { // lhs != rhs
|
|
newCall = *rhs != *lhs;
|
|
} else if (op == Module::LT_MAGIC_NAME) { // lhs < rhs
|
|
newCall = *rhs > *lhs;
|
|
} else if (op == Module::LE_MAGIC_NAME) { // lhs <= rhs
|
|
newCall = *rhs >= *lhs;
|
|
} else if (op == Module::GT_MAGIC_NAME) { // lhs > rhs
|
|
newCall = *rhs < *lhs;
|
|
} else if (op == Module::GE_MAGIC_NAME) { // lhs >= rhs
|
|
newCall = *rhs <= *lhs;
|
|
} else {
|
|
seqassert(false, "unknown comparison op: {}", op);
|
|
}
|
|
|
|
if (newCall && newCall->getType()->is(type))
|
|
return setResult(newCall);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
// a*x + b*x --> (a + b) * x
|
|
struct CanonAddMul : public RewriteRule {
|
|
static bool varMatch(Value *a, Value *b) {
|
|
auto *v1 = cast<VarValue>(a);
|
|
auto *v2 = cast<VarValue>(b);
|
|
return v1 && v2 && v1->getVar()->getId() == v2->getVar()->getId();
|
|
}
|
|
|
|
static Func *getOp(Value *v) {
|
|
return isA<CallInstr>(v) ? util::getFunc(cast<CallInstr>(v)->getCallee()) : nullptr;
|
|
}
|
|
|
|
// (a + b) * x, or null if invalid
|
|
static Value *addMul(Value *a, Value *b, Value *x) {
|
|
if (!a || !b || !x)
|
|
return nullptr;
|
|
|
|
auto *y = (*a + *b);
|
|
if (!y) {
|
|
y = (*b + *a);
|
|
if (y && !isCommutativeOp(getOp(y)))
|
|
return nullptr;
|
|
}
|
|
if (!y)
|
|
return nullptr;
|
|
|
|
auto *z = (*y) * (*x);
|
|
if (!z) {
|
|
z = (*x) * (*y);
|
|
if (z && !isCommutativeOp(getOp(z)))
|
|
return nullptr;
|
|
}
|
|
if (!z)
|
|
return nullptr;
|
|
|
|
return z;
|
|
}
|
|
|
|
void visit(CallInstr *v) override {
|
|
auto *M = v->getModule();
|
|
auto *fn = util::getFunc(v->getCallee());
|
|
if (!isCommutativeOp(fn) ||
|
|
!util::isCallOf(v, Module::ADD_MAGIC_NAME, 2, /*output=*/nullptr,
|
|
/*method=*/true))
|
|
return;
|
|
|
|
// decompose the operation
|
|
Value *lhs = v->front();
|
|
Value *rhs = v->back();
|
|
Value *lhs1 = nullptr, *lhs2 = nullptr, *rhs1 = nullptr, *rhs2 = nullptr;
|
|
|
|
if (util::isCallOf(lhs, Module::MUL_MAGIC_NAME, 2, /*output=*/nullptr,
|
|
/*method=*/true)) {
|
|
auto *lhsCall = cast<CallInstr>(lhs);
|
|
lhs1 = lhsCall->front();
|
|
lhs2 = lhsCall->back();
|
|
} else {
|
|
lhs1 = lhs;
|
|
lhs2 = M->getInt(1);
|
|
}
|
|
|
|
if (util::isCallOf(rhs, Module::MUL_MAGIC_NAME, 2, /*output=*/nullptr,
|
|
/*method=*/true)) {
|
|
auto *rhsCall = cast<CallInstr>(rhs);
|
|
rhs1 = rhsCall->front();
|
|
rhs2 = rhsCall->back();
|
|
} else {
|
|
rhs1 = rhs;
|
|
rhs2 = M->getInt(1);
|
|
}
|
|
|
|
Value *newCall = nullptr;
|
|
if (varMatch(lhs1, rhs1)) {
|
|
newCall = addMul(lhs2, rhs2, lhs1);
|
|
} else if (varMatch(lhs1, rhs2)) {
|
|
newCall = addMul(lhs2, rhs1, lhs1);
|
|
} else if (varMatch(lhs2, rhs1)) {
|
|
newCall = addMul(lhs1, rhs2, lhs2);
|
|
} else if (varMatch(lhs2, rhs2)) {
|
|
newCall = addMul(lhs1, rhs1, lhs2);
|
|
}
|
|
|
|
if (newCall && isDistributiveOp(getOp(newCall)) &&
|
|
newCall->getType()->is(v->getType()))
|
|
return setResult(newCall);
|
|
}
|
|
};
|
|
|
|
// x - c --> x + (-c)
|
|
struct CanonConstSub : public RewriteRule {
|
|
void visit(CallInstr *v) override {
|
|
auto *M = v->getModule();
|
|
auto *type = v->getType();
|
|
|
|
if (!util::isCallOf(v, Module::SUB_MAGIC_NAME, 2, /*output=*/nullptr,
|
|
/*method=*/true))
|
|
return;
|
|
|
|
Value *lhs = v->front();
|
|
Value *rhs = v->back();
|
|
|
|
Value *newCall = nullptr;
|
|
if (util::isConst<int64_t>(rhs)) {
|
|
auto c = util::getConst<int64_t>(rhs);
|
|
if (c != -(static_cast<int64_t>(1) << 63)) // ensure no overflow
|
|
newCall = *lhs + *(M->getInt(-c));
|
|
} else if (util::isConst<double>(rhs)) {
|
|
auto c = util::getConst<double>(rhs);
|
|
newCall = *lhs + *(M->getFloat(-c));
|
|
}
|
|
|
|
if (newCall && newCall->getType()->is(type))
|
|
return setResult(newCall);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
const std::string CanonicalizationPass::KEY = "core-cleanup-canon";
|
|
|
|
void CanonicalizationPass::run(Module *m) {
|
|
registerStandardRules(m);
|
|
Rewriter::reset();
|
|
OperatorPass::run(m);
|
|
}
|
|
|
|
void CanonicalizationPass::handle(CallInstr *v) {
|
|
auto *r = getAnalysisResult<analyze::module::SideEffectResult>(sideEffectsKey);
|
|
if (!r->hasSideEffect(v))
|
|
rewrite(v);
|
|
}
|
|
|
|
void CanonicalizationPass::handle(SeriesFlow *v) {
|
|
auto it = v->begin();
|
|
while (it != v->end()) {
|
|
if (auto *series = cast<SeriesFlow>(*it)) {
|
|
it = v->erase(it);
|
|
for (auto *x : *series) {
|
|
it = v->insert(it, x);
|
|
++it;
|
|
}
|
|
} else if (auto *flowInstr = cast<FlowInstr>(*it)) {
|
|
it = v->erase(it);
|
|
// inserting in reverse order causes [flow, value] to be added
|
|
it = v->insert(it, flowInstr->getValue());
|
|
it = v->insert(it, flowInstr->getFlow());
|
|
// don't increment; re-traverse in case a new series flow added
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
}
|
|
|
|
void CanonicalizationPass::registerStandardRules(Module *m) {
|
|
registerRule("op-chain", std::make_unique<CanonOpChain>());
|
|
registerRule("inequality", std::make_unique<CanonInequality>());
|
|
registerRule("add-mul", std::make_unique<CanonAddMul>());
|
|
registerRule("const-sub", std::make_unique<CanonConstSub>());
|
|
}
|
|
|
|
} // namespace cleanup
|
|
} // namespace transform
|
|
} // namespace ir
|
|
} // namespace codon
|