1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
2021-09-27 14:02:44 -04:00

262 lines
7.8 KiB
C++

#pragma once
#include <algorithm>
#include <utility>
#include "sir/transform/pass.h"
#include "sir/transform/rewrite.h"
#include "sir/util/irtools.h"
namespace seq {
namespace ir {
namespace transform {
namespace folding {
/// Commutative, binary rule that requires a single constant.
template <typename ConstantType>
class SingleConstantCommutativeRule : public RewriteRule {
public:
using Calculator = std::function<Value *(Value *)>;
enum Kind { LEFT, RIGHT, COMMUTATIVE };
private:
/// the value being matched against
ConstantType val;
/// the type being matched
types::Type *type;
/// the magic method name
std::string magic;
/// the calculator
Calculator calc;
/// left, right or commutative
Kind kind;
public:
/// Constructs a commutative rule.
/// @param val the matched value
/// @param newVal the result
/// @param magic the magic name
/// @param kind left, right, or commutative
/// @param type the matched type
SingleConstantCommutativeRule(ConstantType val, ConstantType newVal,
std::string magic, Kind kind, types::Type *type)
: val(val), type(type), magic(std::move(magic)), kind(kind) {
calc = [=](Value *v) -> Value * {
return v->getModule()->N<TemplatedConst<ConstantType>>(v->getSrcInfo(), val,
type);
};
}
/// Constructs a commutative rule.
/// @param val the matched value
/// @param newVal the result
/// @param magic the magic name
/// @param calc the calculator
/// @param kind left, right, or commutative
/// @param type the matched type
SingleConstantCommutativeRule(ConstantType val, Calculator calc, std::string magic,
Kind kind, types::Type *type)
: val(val), type(type), magic(std::move(magic)), calc(std::move(calc)),
kind(kind) {}
virtual ~SingleConstantCommutativeRule() noexcept = default;
void visit(CallInstr *v) override {
if (!util::isCallOf(v, magic, {type, type}, type, /*method=*/true))
return;
auto *left = v->front();
auto *right = v->back();
auto *leftConst = cast<TemplatedConst<ConstantType>>(left);
auto *rightConst = cast<TemplatedConst<ConstantType>>(right);
if ((kind == Kind::COMMUTATIVE || kind == Kind::LEFT) && leftConst &&
leftConst->getVal() == val)
return setResult(calc(right));
if ((kind == Kind::COMMUTATIVE || kind == Kind::RIGHT) && rightConst &&
rightConst->getVal() == val)
return setResult(calc(left));
}
};
/// Binary rule that requires two constants.
template <typename ConstantType, typename Func, typename OutputType = ConstantType>
class DoubleConstantBinaryRule : public RewriteRule {
private:
/// the calculator
Func f;
/// the input type
types::Type *inputType;
/// the output type
types::Type *resultType;
/// the magic method name
std::string magic;
public:
/// Constructs a binary rule.
/// @param f the calculator
/// @param magic the magic method name
/// @param inputType the input type
/// @param resultType the output type
DoubleConstantBinaryRule(Func f, std::string magic, types::Type *inputType,
types::Type *resultType)
: f(std::move(f)), inputType(inputType), resultType(resultType),
magic(std::move(magic)) {}
virtual ~DoubleConstantBinaryRule() noexcept = default;
void visit(CallInstr *v) override {
if (!util::isCallOf(v, magic, {inputType, inputType}, resultType, /*method=*/true))
return;
auto *left = v->front();
auto *right = v->back();
auto *leftConst = cast<TemplatedConst<ConstantType>>(left);
auto *rightConst = cast<TemplatedConst<ConstantType>>(right);
if (leftConst && rightConst)
return setResult(toValue(v, f(leftConst->getVal(), rightConst->getVal())));
}
private:
Value *toValue(Value *, Value *v) { return v; }
Value *toValue(Value *og, OutputType v) {
return og->getModule()->template N<TemplatedConst<OutputType>>(og->getSrcInfo(), v,
resultType);
}
};
/// Unary rule that requires one constant.
template <typename ConstantType, typename Func>
class SingleConstantUnaryRule : public RewriteRule {
private:
/// the calculator
Func f;
/// the input type
types::Type *inputType;
/// the output type
types::Type *resultType;
/// the magic method name
std::string magic;
public:
/// Constructs a unary rule.
/// @param f the calculator
/// @param magic the magic method name
/// @param inputType the input type
/// @param resultType the output type
SingleConstantUnaryRule(Func f, std::string magic, types::Type *inputType,
types::Type *resultType)
: f(std::move(f)), inputType(inputType), resultType(resultType),
magic(std::move(magic)) {}
virtual ~SingleConstantUnaryRule() noexcept = default;
void visit(CallInstr *v) override {
if (!util::isCallOf(v, magic, {inputType}, resultType, /*method=*/true))
return;
auto *arg = v->front();
auto *argConst = cast<TemplatedConst<ConstantType>>(arg);
if (argConst)
return setResult(toValue(v, f(argConst->getVal())));
}
private:
Value *toValue(Value *, Value *v) { return v; }
template <typename NewType> Value *toValue(Value *og, NewType v) {
return og->getModule()->template N<TemplatedConst<NewType>>(og->getSrcInfo(), v,
resultType);
}
};
/// Unary rule that requires no constant.
template <typename Func> class UnaryRule : public RewriteRule {
private:
/// the calculator
Func f;
/// the input type
types::Type *inputType;
/// the magic method name
std::string magic;
public:
/// Constructs a unary rule.
/// @param f the calculator
/// @param magic the magic method name
/// @param inputType the input type
UnaryRule(Func f, std::string magic, types::Type *inputType)
: f(std::move(f)), inputType(inputType), magic(std::move(magic)) {}
virtual ~UnaryRule() noexcept = default;
void visit(CallInstr *v) override {
if (!util::isCallOf(v, magic, {inputType}, inputType, /*method=*/true))
return;
auto *arg = v->front();
return setResult(f(arg));
}
};
/// Rule that eliminates an operation, like "+x".
class NoOpRule : public RewriteRule {
private:
/// the input type
types::Type *inputType;
/// the magic method name
std::string magic;
public:
/// Constructs a no-op rule.
/// @param magic the magic method name
/// @param inputType the input type
NoOpRule(std::string magic, types::Type *inputType)
: inputType(inputType), magic(std::move(magic)) {}
virtual ~NoOpRule() noexcept = default;
void visit(CallInstr *v) override {
if (!util::isCallOf(v, magic, {inputType}, inputType, /*method=*/true))
return;
auto *arg = v->front();
return setResult(arg);
}
};
/// Rule that eliminates a double-application of an operation, like "-(-x)".
class DoubleApplicationNoOpRule : public RewriteRule {
private:
/// the input type
types::Type *inputType;
/// the magic method name
std::string magic;
public:
/// Constructs a double-application no-op rule.
/// @param magic the magic method name
/// @param inputType the input type
DoubleApplicationNoOpRule(std::string magic, types::Type *inputType)
: inputType(inputType), magic(std::move(magic)) {}
virtual ~DoubleApplicationNoOpRule() noexcept = default;
void visit(CallInstr *v) override {
if (!util::isCallOf(v, magic, {inputType}, inputType, /*method=*/true))
return;
if (!util::isCallOf(v->front(), magic, {inputType}, inputType, /*method=*/true))
return;
auto *arg = v->front();
return setResult(cast<CallInstr>(arg)->front());
}
};
} // namespace folding
} // namespace transform
} // namespace ir
} // namespace seq