1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
2021-09-27 14:02:44 -04:00

103 lines
2.6 KiB
C++

#include "io.h"
#include <algorithm>
#include "sir/util/cloning.h"
#include "sir/util/irtools.h"
namespace seq {
namespace ir {
namespace transform {
namespace pythonic {
namespace {
void optimizePrint(CallInstr *v) {
auto *M = v->getModule();
auto *inner = cast<CallInstr>(v->front());
if (!inner)
return;
auto *innerFunc = util::getFunc(inner->getCallee());
if (!innerFunc || innerFunc->getUnmangledName() != "__new__" ||
std::distance(inner->begin(), inner->end()) != 1)
return;
auto *cat = cast<CallInstr>(inner->front());
if (!cat)
return;
auto *catFunc = util::getFunc(cat->getCallee());
if (!catFunc || catFunc->getUnmangledName() != "cat")
return;
auto *realCat =
M->getOrRealizeMethod(M->getStringType(), "cat", {cat->front()->getType()});
if (realCat->getId() != catFunc->getId())
return;
util::CloneVisitor cv(M);
std::vector<Value *> args;
std::vector<types::Type *> types;
for (auto *printArg : *v) {
args.push_back(cv.clone(printArg));
types.push_back(printArg->getType());
}
args[0] = cv.clone(cat->front());
types[0] = args[0]->getType();
args[1] = M->getString("");
auto *replacement = M->getOrRealizeFunc("print", types, {}, "std.internal.builtin");
if (!replacement)
return;
v->replaceAll(util::call(replacement, args));
}
void optimizeWrite(CallInstr *v) {
auto *M = v->getModule();
auto it = v->begin();
auto *file = *it++;
auto *cat = cast<CallInstr>(*it++);
if (!cat)
return;
auto *catFunc = util::getFunc(cat->getCallee());
if (!catFunc || catFunc->getUnmangledName() != "cat")
return;
auto *realCat =
M->getOrRealizeMethod(M->getStringType(), "cat", {cat->front()->getType()});
if (realCat->getId() != catFunc->getId())
return;
util::CloneVisitor cv(M);
auto *iter = cv.clone(cat->front())->iter();
if (!iter)
return;
std::vector<Value *> args = {cv.clone(file), iter};
auto *replacement = M->getOrRealizeMethod(file->getType(), "__file_write_gen__",
{args[0]->getType(), args[1]->getType()});
if (!replacement)
return;
v->replaceAll(util::call(replacement, args));
}
} // namespace
const std::string IOCatOptimization::KEY = "core-pythonic-io-cat-opt";
void IOCatOptimization::handle(CallInstr *v) {
if (util::getStdlibFunc(v->getCallee(), "print")) {
optimizePrint(v);
} else if (auto *f = cast<Func>(util::getFunc(v->getCallee()))) {
if (f->getUnmangledName() == "write")
optimizeWrite(v);
}
}
} // namespace pythonic
} // namespace transform
} // namespace ir
} // namespace seq