pyextension.h support

pull/335/head
Ibrahim Numanagić 2023-02-09 16:48:49 -08:00
parent d6aa7e5142
commit 14ea7127c7
3 changed files with 111 additions and 92 deletions

View File

@ -230,10 +230,11 @@ std::vector<ExprPtr> Cache::mergeC3(std::vector<std::vector<ExprPtr>> &seqs) {
return result; return result;
} }
std::shared_ptr<ir::PyModule> Cache::getPythonModule() { void Cache::populatePythonModule() {
LOG("====== module generation ======="); LOG("[py] ====== module generation =======");
auto mod = std::make_shared<ir::PyModule>(); if (!pyModule)
pyModule = std::make_shared<ir::PyModule>();
#define N std::make_shared #define N std::make_shared
using namespace ast; using namespace ast;
@ -243,46 +244,48 @@ std::shared_ptr<ir::PyModule> Cache::getPythonModule() {
if (c.module.empty() && startswith(cn, "Pyx")) { if (c.module.empty() && startswith(cn, "Pyx")) {
ir::PyType py{rev(cn), c.ast->getDocstr()}; ir::PyType py{rev(cn), c.ast->getDocstr()};
for (const auto &[n, ofnn] : c.methods) { for (const auto &[n, ofnn] : c.methods) {
// get the last auto fnn = overloads[ofnn].back().name; // last overload
auto fnn = overloads[ofnn].back().name;
LOG("{} : {} -> {}", cn, n, fnn);
// get the types
auto &fna = functions[fnn].ast; auto &fna = functions[fnn].ast;
// std::vector<TypeExpr> t;
if (fna->hasAttr("autogenerated")) if (fna->hasAttr("autogenerated"))
continue; continue;
std::vector<Param> params; std::vector<Param> params;
std::vector<ExprPtr> args; std::vector<ExprPtr> args;
auto sctx = imports[MAIN_IMPORT].ctx; auto sctx = imports[MAIN_IMPORT].ctx;
if (true) { // assume these are methods!
params = {Param{sctx->generateCanonicalName("self"), N<IdExpr>("cobj")}, bool isMethod = fna->hasAttr(Attr::Method);
Param{sctx->generateCanonicalName("args"), N<IdExpr>("cobj")}}; LOG("[py] {}: {}.{} => {}", isMethod ? "method" : "classm", cn, n, fnn);
if (fna->args.size() > 2) { params = {Param{sctx->generateCanonicalName("self"), N<IdExpr>("cobj")},
params.back().type = Param{sctx->generateCanonicalName("args"), N<IdExpr>("cobj")}};
N<InstantiateExpr>(N<IdExpr>("Ptr"), params.back().type); if (fna->args.size() > 2) {
params.push_back( params.back().type = N<InstantiateExpr>(N<IdExpr>("Ptr"), params.back().type);
Param{sctx->generateCanonicalName("nargs"), N<IdExpr>("int")}); params.push_back(
} Param{sctx->generateCanonicalName("nargs"), N<IdExpr>("int")});
ExprPtr po = N<IdExpr>(params[0].name); }
po = N<CallExpr>(N<DotExpr>(N<IdExpr>(cn), "__from_py__"), po); ExprPtr po = N<IdExpr>(params[0].name);
po = N<CallExpr>(N<DotExpr>(N<IdExpr>(cn), "__from_py__"), po);
if (isMethod)
args.push_back(po); args.push_back(po);
if (fna->args.size() > 2) { if (fna->args.size() > 1 + isMethod) {
for (size_t ai = 1; ai < fna->args.size(); ai++) { for (size_t ai = isMethod; ai < fna->args.size(); ai++) {
ExprPtr po = N<IndexExpr>(N<IdExpr>(params[1].name), N<IntExpr>(ai - 1)); ExprPtr po =
if (fna->args[ai].type) N<IndexExpr>(N<IdExpr>(params[1].name), N<IntExpr>(ai - isMethod));
po = N<CallExpr>(N<DotExpr>(fna->args[ai].type->clone(), "__from_py__"), if (fna->args[ai].type) {
po); po = N<CallExpr>(N<DotExpr>(fna->args[ai].type->clone(), "__from_py__"),
args.push_back(po);
}
} else if (fna->args.size() == 2) {
ExprPtr po = N<IdExpr>(params[1].name);
if (fna->args[1].type)
po = N<CallExpr>(N<DotExpr>(fna->args[1].type->clone(), "__from_py__"),
po); po);
} else {
po = N<CallExpr>(N<IdExpr>("pyobj"), po);
}
args.push_back(po); args.push_back(po);
} }
} else if (fna->args.size() == 1 + isMethod) {
ExprPtr po = N<IdExpr>(params[1].name);
if (fna->args[1].type) {
po = N<CallExpr>(N<DotExpr>(fna->args[1].type->clone(), "__from_py__"), po);
} else {
po = N<CallExpr>(N<IdExpr>("pyobj"), po);
}
args.push_back(po);
} }
auto stubName = sctx->generateCanonicalName(fmt::format("_py.{}.{}", cn, n)); auto stubName = sctx->generateCanonicalName(fmt::format("_py.{}.{}", cn, n));
auto node = auto node =
@ -297,119 +300,134 @@ std::shared_ptr<ir::PyModule> Cache::getPythonModule() {
auto tnode = tv.transform(node); auto tnode = tv.transform(node);
auto rtv = tv.realize(typeCtx->forceFind(stubName)->type); auto rtv = tv.realize(typeCtx->forceFind(stubName)->type);
seqassertn(rtv, "realization of {} failed", stubName); seqassertn(rtv, "realization of {} failed", stubName);
TranslateVisitor(codegenCtx).transform(tnode);
auto pr = pendingRealizations; // copy it as it might be modified
for (auto &fn : pr)
TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone());
auto f = auto f =
functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir; functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir;
typeCtx->age = oldAge; typeCtx->age = oldAge;
if (n == "__repr__") if (n == "__repr__") {
py.repr = f; py.repr = f;
else if (n == "__add__") } else if (n == "__add__") {
py.add = f; py.add = f;
else if (n == "__iadd__") } else if (n == "__iadd__") {
py.iadd = f; py.iadd = f;
else if (n == "__sub__") } else if (n == "__sub__") {
py.sub = f; py.sub = f;
else if (n == "__isub__") } else if (n == "__isub__") {
py.isub = f; py.isub = f;
else if (n == "__mul__") } else if (n == "__mul__") {
py.mul = f; py.mul = f;
else if (n == "__imul__") } else if (n == "__imul__") {
py.imul = f; py.imul = f;
else if (n == "__mod__") } else if (n == "__mod__") {
py.mod = f; py.mod = f;
else if (n == "__imod__") } else if (n == "__imod__") {
py.imod = f; py.imod = f;
else if (n == "__divmod__") } else if (n == "__divmod__") {
py.divmod = f; py.divmod = f;
else if (n == "__pow__") } else if (n == "__pow__") {
py.pow = f; py.pow = f;
else if (n == "__ipow__") } else if (n == "__ipow__") {
py.ipow = f; py.ipow = f;
else if (n == "__neg__") } else if (n == "__neg__") {
py.neg = f; py.neg = f;
else if (n == "__pos__") } else if (n == "__pos__") {
py.pos = f; py.pos = f;
else if (n == "__abs__") } else if (n == "__abs__") {
py.abs = f; py.abs = f;
else if (n == "__bool__") } else if (n == "__bool__") {
py.bool_ = f; py.bool_ = f;
else if (n == "__invert__") } else if (n == "__invert__") {
py.invert = f; py.invert = f;
else if (n == "__lshift__") } else if (n == "__lshift__") {
py.lshift = f; py.lshift = f;
else if (n == "__ilshift__") } else if (n == "__ilshift__") {
py.ilshift = f; py.ilshift = f;
else if (n == "__rshift__") } else if (n == "__rshift__") {
py.rshift = f; py.rshift = f;
else if (n == "__irshift__") } else if (n == "__irshift__") {
py.irshift = f; py.irshift = f;
else if (n == "__and__") } else if (n == "__and__") {
py.and_ = f; py.and_ = f;
else if (n == "__iand__") } else if (n == "__iand__") {
py.iand = f; py.iand = f;
else if (n == "__xor__") } else if (n == "__xor__") {
py.xor_ = f; py.xor_ = f;
else if (n == "__ixor__") } else if (n == "__ixor__") {
py.ixor = f; py.ixor = f;
else if (n == "__or__") } else if (n == "__or__") {
py.or_ = f; py.or_ = f;
else if (n == "__ior__") } else if (n == "__ior__") {
py.ior = f; py.ior = f;
else if (n == "__int__") } else if (n == "__int__") {
py.int_ = f; py.int_ = f;
else if (n == "__float__") } else if (n == "__float__") {
py.float_ = f; py.float_ = f;
else if (n == "__floordiv__") } else if (n == "__floordiv__") {
py.floordiv = f; py.floordiv = f;
else if (n == "__ifloordiv__") } else if (n == "__ifloordiv__") {
py.ifloordiv = f; py.ifloordiv = f;
else if (n == "__truediv__") } else if (n == "__truediv__") {
py.truediv = f; py.truediv = f;
else if (n == "__itruediv__") } else if (n == "__itruediv__") {
py.itruediv = f; py.itruediv = f;
else if (n == "__index__") } else if (n == "__index__") {
py.index = f; py.index = f;
else if (n == "__matmul__") } else if (n == "__matmul__") {
py.matmul = f; py.matmul = f;
else if (n == "__imatmul__") } else if (n == "__imatmul__") {
py.imatmul = f; py.imatmul = f;
else if (n == "__len__") } else if (n == "__len__") {
py.len = f; py.len = f;
else if (n == "__getitem__") } else if (n == "__getitem__") {
py.getitem = f; py.getitem = f;
else if (n == "__setitem__") } else if (n == "__setitem__") {
py.setitem = f; py.setitem = f;
else if (n == "__contains__") } else if (n == "__contains__") {
py.contains = f; py.contains = f;
else if (n == "__hash__") } else if (n == "__hash__") {
py.hash = f; py.hash = f;
else if (n == "__call__") } else if (n == "__call__") {
py.call = f; py.call = f;
else if (n == "__str__") } else if (n == "__str__") {
py.str = f; py.str = f;
else if (n == "__cmp__") } else if (n == "__cmp__") {
py.cmp = f; py.cmp = f;
else if (n == "__iter__") } else if (n == "__iter__") {
py.iter = f; py.iter = f;
else if (n == "__del__") } else if (n == "__del__") {
py.del = f; py.del = f;
else if (n == "__new__") } else if (n == "__new__") {
py.new_ = f; py.new_ = f;
else if (n == "__init__") } else if (n == "__init__") {
py.init = f; py.init = f;
else } else {
py.methods.push_back( py.methods.push_back(ir::PyFunction{n, fna->getDocstr(), f,
ir::PyFunction{n, fna->getDocstr(), f, ir::PyFunction::Type::METHOD}); isMethod ? ir::PyFunction::Type::METHOD
: ir::PyFunction::Type::CLASS});
LOG(">| [{}] {}", functions[stubName].realizations.size(), *f); }
// LOG(">| [{}] {}", functions[stubName].realizations.size(), *f);
} }
mod->types.push_back(py);
if (c.realizations.size() != 1)
compilationError(fmt::format("cannot pythonize generic class '{}'", cn));
auto &r = c.realizations.begin()->second;
py.type = realizeType(r->type);
for (auto &[mn, mt] : r->fields) {
py.members.push_back(ir::PyMember{mn, "",
mt->is("int") ? ir::PyMember::Type::LONGLONG
: mt->is("float")
? ir::PyMember::Type::DOUBLE
: ir::PyMember::Type::OBJECT,
true});
LOG("[py] {}: {}.{} => {}", "member", cn, mn, py.members.back().type);
}
pyModule->types.push_back(py);
} }
#undef N #undef N
return mod;
} }
} // namespace codon::ast } // namespace codon::ast

View File

@ -308,7 +308,8 @@ public:
static std::vector<ExprPtr> mergeC3(std::vector<std::vector<ExprPtr>> &); static std::vector<ExprPtr> mergeC3(std::vector<std::vector<ExprPtr>> &);
std::shared_ptr<ir::PyModule> getPythonModule(); std::shared_ptr<ir::PyModule> pyModule = nullptr;
void populatePythonModule();
}; };
} // namespace codon::ast } // namespace codon::ast

View File

@ -56,7 +56,7 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) {
} }
TranslateVisitor(cache->codegenCtx).transform(stmts); TranslateVisitor(cache->codegenCtx).transform(stmts);
auto _ = cache->getPythonModule(); cache->populatePythonModule();
return main; return main;
} }