#include "const_fold.h" #include #include "sir/util/cloning.h" #include "sir/util/irtools.h" #define BINOP(o) \ [](auto x, auto y) -> auto { return x o y; } #define UNOP(o) \ [](auto x) -> auto { return o x; } namespace seq { namespace ir { namespace transform { namespace folding { namespace { template class IntFloatBinaryRule : public RewriteRule { private: Func f; std::string magic; types::Type *out; public: IntFloatBinaryRule(Func f, std::string magic, types::Type *out) : f(std::move(f)), magic(std::move(magic)), out(out) {} virtual ~IntFloatBinaryRule() noexcept = default; void visit(CallInstr *v) override { if (!util::isCallOf(v, magic, 2, /*output=*/nullptr, /*method=*/true)) return; auto *leftConst = cast(v->front()); auto *rightConst = cast(v->back()); if (!leftConst || !rightConst) return; auto *M = v->getModule(); if (isA(leftConst) && isA(rightConst)) { auto left = cast(leftConst)->getVal(); auto right = cast(rightConst)->getVal(); return setResult(M->template N>(v->getSrcInfo(), f(left, (double)right), out)); } else if (isA(leftConst) && isA(rightConst)) { auto left = cast(leftConst)->getVal(); auto right = cast(rightConst)->getVal(); return setResult(M->template N>(v->getSrcInfo(), f((double)left, right), out)); } } }; /// Binary rule that requires two constants. template class DoubleConstantBinaryRuleExcludeRHSZero : public DoubleConstantBinaryRule { public: DoubleConstantBinaryRuleExcludeRHSZero(Func f, std::string magic, types::Type *inputType, types::Type *resultType) : DoubleConstantBinaryRule(f, magic, inputType, resultType) {} virtual ~DoubleConstantBinaryRuleExcludeRHSZero() noexcept = default; void visit(CallInstr *v) override { if (v->numArgs() == 2) { auto *rightConst = cast>(v->back()); if (rightConst && rightConst->getVal() == ConstantType()) return; } DoubleConstantBinaryRule::visit(v); } }; auto id_val(Module *m) { return [=](Value *v) -> Value * { util::CloneVisitor cv(m); return cv.clone(v); }; } int64_t int_pow(int64_t base, int64_t exp) { if (exp < 0) return 0; int64_t result = 1; while (true) { if (exp & 1) { result *= base; } exp = exp >> 1; if (!exp) break; base = base * base; } return result; } template To convert(From x) { return To(x); } template auto intSingleRule(Module *m, Args &&...args) { return std::make_unique>( std::forward(args)..., m->getIntType()); } auto intNoOp(Module *m, std::string magic) { return std::make_unique(std::move(magic), m->getIntType()); } auto intDoubleApplyNoOp(Module *m, std::string magic) { return std::make_unique(std::move(magic), m->getIntType()); } template auto intToIntBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getIntType(), m->getIntType()); } template auto intToIntBinaryNoZeroRHS(Module *m, Func f, std::string magic) { return std::make_unique< DoubleConstantBinaryRuleExcludeRHSZero>( std::move(f), std::move(magic), m->getIntType(), m->getIntType()); } template auto intToBoolBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getIntType(), m->getBoolType()); } template auto boolToBoolBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getBoolType(), m->getBoolType()); } template auto floatToFloatBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType(), m->getFloatType()); } template auto floatToBoolBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType(), m->getBoolType()); } template auto intFloatToFloatBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType()); } template auto intFloatToBoolBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getBoolType()); } template auto intToIntUnary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getIntType(), m->getIntType()); } template auto floatToFloatUnary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType(), m->getFloatType()); } template auto boolToBoolUnary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getBoolType(), m->getBoolType()); } auto identityConvert(Module *m, std::string magic, types::Type *type) { return std::make_unique>(id_val(m), std::move(magic), type); } template auto typeConvert(Module *m, std::string magic, types::Type *fromType, types::Type *toType) { return std::make_unique< SingleConstantUnaryRule)>>>( convert, std::move(magic), fromType, toType); } } // namespace const std::string FoldingPass::KEY = "core-folding-const-fold"; void FoldingPass::run(Module *m) { registerStandardRules(m); Rewriter::reset(); OperatorPass::run(m); } void FoldingPass::handle(CallInstr *v) { rewrite(v); } void FoldingPass::registerStandardRules(Module *m) { // binary, single constant, int->int using Kind = SingleConstantCommutativeRule::Kind; registerRule("int-multiply-by-zero", intSingleRule(m, 0, 0, Module::MUL_MAGIC_NAME, Kind::COMMUTATIVE)); registerRule( "int-multiply-by-one", intSingleRule(m, 1, id_val(m), Module::MUL_MAGIC_NAME, Kind::COMMUTATIVE)); registerRule("int-subtract-zero", intSingleRule(m, 0, id_val(m), Module::SUB_MAGIC_NAME, Kind::RIGHT)); registerRule("int-add-zero", intSingleRule(m, 0, id_val(m), Module::ADD_MAGIC_NAME, Kind::COMMUTATIVE)); registerRule( "int-floor-div-by-one", intSingleRule(m, 1, id_val(m), Module::FLOOR_DIV_MAGIC_NAME, Kind::RIGHT)); registerRule("int-zero-floor-div", intSingleRule(m, 0, 0, Module::FLOOR_DIV_MAGIC_NAME, Kind::LEFT)); registerRule("int-pos", intNoOp(m, Module::POS_MAGIC_NAME)); registerRule("int-double-neg", intDoubleApplyNoOp(m, Module::NEG_MAGIC_NAME)); registerRule("int-double-inv", intDoubleApplyNoOp(m, Module::INVERT_MAGIC_NAME)); // binary, double constant, int->int registerRule("int-constant-addition", intToIntBinary(m, BINOP(+), Module::ADD_MAGIC_NAME)); registerRule("int-constant-subtraction", intToIntBinary(m, BINOP(-), Module::SUB_MAGIC_NAME)); registerRule("int-constant-floor-div", intToIntBinaryNoZeroRHS(m, BINOP(/), Module::FLOOR_DIV_MAGIC_NAME)); registerRule("int-constant-mul", intToIntBinary(m, BINOP(*), Module::MUL_MAGIC_NAME)); registerRule("int-constant-lshift", intToIntBinary(m, BINOP(<<), Module::LSHIFT_MAGIC_NAME)); registerRule("int-constant-rshift", intToIntBinary(m, BINOP(>>), Module::RSHIFT_MAGIC_NAME)); registerRule("int-constant-pow", intToIntBinary(m, int_pow, Module::POW_MAGIC_NAME)); registerRule("int-constant-xor", intToIntBinary(m, BINOP(^), Module::XOR_MAGIC_NAME)); registerRule("int-constant-or", intToIntBinary(m, BINOP(|), Module::OR_MAGIC_NAME)); registerRule("int-constant-and", intToIntBinary(m, BINOP(&), Module::AND_MAGIC_NAME)); registerRule("int-constant-mod", intToIntBinaryNoZeroRHS(m, BINOP(%), Module::MOD_MAGIC_NAME)); // binary, double constant, int->bool registerRule("int-constant-eq", intToBoolBinary(m, BINOP(==), Module::EQ_MAGIC_NAME)); registerRule("int-constant-ne", intToBoolBinary(m, BINOP(!=), Module::NE_MAGIC_NAME)); registerRule("int-constant-gt", intToBoolBinary(m, BINOP(>), Module::GT_MAGIC_NAME)); registerRule("int-constant-ge", intToBoolBinary(m, BINOP(>=), Module::GE_MAGIC_NAME)); registerRule("int-constant-lt", intToBoolBinary(m, BINOP(<), Module::LT_MAGIC_NAME)); registerRule("int-constant-le", intToBoolBinary(m, BINOP(<=), Module::LE_MAGIC_NAME)); // binary, double constant, bool->bool registerRule("bool-constant-xor", boolToBoolBinary(m, BINOP(^), Module::XOR_MAGIC_NAME)); registerRule("bool-constant-or", boolToBoolBinary(m, BINOP(|), Module::OR_MAGIC_NAME)); registerRule("bool-constant-and", boolToBoolBinary(m, BINOP(&), Module::AND_MAGIC_NAME)); // unary, single constant, int->int registerRule("int-constant-pos", intToIntUnary(m, UNOP(+), Module::POS_MAGIC_NAME)); registerRule("int-constant-neg", intToIntUnary(m, UNOP(-), Module::NEG_MAGIC_NAME)); registerRule("int-constant-inv", intToIntUnary(m, UNOP(~), Module::INVERT_MAGIC_NAME)); // unary, singe constant, float->float registerRule("float-constant-pos", floatToFloatUnary(m, UNOP(+), Module::POS_MAGIC_NAME)); registerRule("float-constant-neg", floatToFloatUnary(m, UNOP(-), Module::NEG_MAGIC_NAME)); // unary, single constant, bool->bool registerRule("bool-constant-inv", boolToBoolUnary(m, UNOP(!), Module::INVERT_MAGIC_NAME)); // binary, double constant, float->float registerRule("float-constant-addition", floatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME)); registerRule("float-constant-subtraction", floatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME)); registerRule("float-constant-floor-div", floatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME)); registerRule("float-constant-mul", floatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME)); registerRule( "float-constant-pow", floatToFloatBinary( m, [](auto a, auto b) { return std::pow(a, b); }, Module::POW_MAGIC_NAME)); // binary, double constant, float->bool registerRule("float-constant-eq", floatToBoolBinary(m, BINOP(==), Module::EQ_MAGIC_NAME)); registerRule("float-constant-ne", floatToBoolBinary(m, BINOP(!=), Module::NE_MAGIC_NAME)); registerRule("float-constant-gt", floatToBoolBinary(m, BINOP(>), Module::GT_MAGIC_NAME)); registerRule("float-constant-ge", floatToBoolBinary(m, BINOP(>=), Module::GE_MAGIC_NAME)); registerRule("float-constant-lt", floatToBoolBinary(m, BINOP(<), Module::LT_MAGIC_NAME)); registerRule("float-constant-le", floatToBoolBinary(m, BINOP(<=), Module::LE_MAGIC_NAME)); // binary, double constant, int,float->float registerRule("int-float-constant-addition", intFloatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME)); registerRule("int-float-constant-subtraction", intFloatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME)); registerRule("int-float-constant-floor-div", intFloatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME)); registerRule("int-float-constant-mul", intFloatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME)); // binary, double constant, int,float->bool registerRule("int-float-constant-eq", intFloatToBoolBinary(m, BINOP(==), Module::EQ_MAGIC_NAME)); registerRule("int-float-constant-ne", intFloatToBoolBinary(m, BINOP(!=), Module::NE_MAGIC_NAME)); registerRule("int-float-constant-gt", intFloatToBoolBinary(m, BINOP(>), Module::GT_MAGIC_NAME)); registerRule("int-float-constant-ge", intFloatToBoolBinary(m, BINOP(>=), Module::GE_MAGIC_NAME)); registerRule("int-float-constant-lt", intFloatToBoolBinary(m, BINOP(<), Module::LT_MAGIC_NAME)); registerRule("int-float-constant-le", intFloatToBoolBinary(m, BINOP(<=), Module::LE_MAGIC_NAME)); // type conversions, identity registerRule("int-constant-int", identityConvert(m, Module::INT_MAGIC_NAME, m->getIntType())); registerRule("float-constant-float", identityConvert(m, Module::FLOAT_MAGIC_NAME, m->getFloatType())); registerRule("bool-constant-bool", identityConvert(m, Module::BOOL_MAGIC_NAME, m->getBoolType())); // type conversions, distinct registerRule("float-constant-int", typeConvert(m, Module::INT_MAGIC_NAME, m->getFloatType(), m->getIntType())); registerRule("bool-constant-int", typeConvert(m, Module::INT_MAGIC_NAME, m->getBoolType(), m->getIntType())); registerRule("int-constant-float", typeConvert(m, Module::FLOAT_MAGIC_NAME, m->getIntType(), m->getFloatType())); registerRule("bool-constant-float", typeConvert(m, Module::FLOAT_MAGIC_NAME, m->getBoolType(), m->getFloatType())); registerRule("int-constant-bool", typeConvert(m, Module::BOOL_MAGIC_NAME, m->getIntType(), m->getBoolType())); registerRule("float-constant-bool", typeConvert(m, Module::BOOL_MAGIC_NAME, m->getFloatType(), m->getBoolType())); } } // namespace folding } // namespace transform } // namespace ir } // namespace seq #undef BINOP #undef UNOP