#include "str.h" #include #include "sir/util/cloning.h" #include "sir/util/irtools.h" namespace seq { namespace ir { namespace transform { namespace pythonic { namespace { struct InspectionResult { bool valid = true; std::vector args; }; bool isString(Value *v) { auto *M = v->getModule(); return v->getType()->is(M->getStringType()); } void inspect(Value *v, InspectionResult &r) { // check if add first then go from there if (isString(v)) { if (auto *c = cast(v)) { auto *func = util::getFunc(c->getCallee()); if (func && func->getUnmangledName() == "__add__" && std::distance(c->begin(), c->end()) == 2 && isString(c->front()) && isString(c->back())) { inspect(c->front(), r); inspect(c->back(), r); return; } } r.args.push_back(v); } else { r.valid = false; } } } // namespace const std::string StrAdditionOptimization::KEY = "core-pythonic-str-addition-opt"; void StrAdditionOptimization::handle(CallInstr *v) { auto *M = v->getModule(); auto *f = util::getFunc(v->getCallee()); if (!f || f->getUnmangledName() != "__add__") return; InspectionResult r; inspect(v, r); if (r.valid && r.args.size() > 2) { std::vector args; util::CloneVisitor cv(M); for (auto *arg : r.args) { args.push_back(cv.clone(arg)); } auto *arg = util::makeTuple(args, M); args = {arg}; auto *replacementFunc = M->getOrRealizeMethod(M->getStringType(), "cat", {arg->getType()}); seqassert(replacementFunc, "could not find cat function"); v->replaceAll(util::call(replacementFunc, args)); } } } // namespace pythonic } // namespace transform } // namespace ir } // namespace seq