pyextension.h support

pull/335/head
Ibrahim Numanagić 2023-02-09 16:48:49 -08:00
parent d6aa7e5142
commit 14ea7127c7
3 changed files with 111 additions and 92 deletions

View File

@ -230,10 +230,11 @@ std::vector<ExprPtr> Cache::mergeC3(std::vector<std::vector<ExprPtr>> &seqs) {
return result;
}
std::shared_ptr<ir::PyModule> Cache::getPythonModule() {
LOG("====== module generation =======");
void Cache::populatePythonModule() {
LOG("[py] ====== module generation =======");
auto mod = std::make_shared<ir::PyModule>();
if (!pyModule)
pyModule = std::make_shared<ir::PyModule>();
#define N std::make_shared
using namespace ast;
@ -243,46 +244,48 @@ std::shared_ptr<ir::PyModule> Cache::getPythonModule() {
if (c.module.empty() && startswith(cn, "Pyx")) {
ir::PyType py{rev(cn), c.ast->getDocstr()};
for (const auto &[n, ofnn] : c.methods) {
// get the last
auto fnn = overloads[ofnn].back().name;
LOG("{} : {} -> {}", cn, n, fnn);
// get the types
auto fnn = overloads[ofnn].back().name; // last overload
auto &fna = functions[fnn].ast;
// std::vector<TypeExpr> t;
if (fna->hasAttr("autogenerated"))
continue;
std::vector<Param> params;
std::vector<ExprPtr> args;
auto sctx = imports[MAIN_IMPORT].ctx;
if (true) { // assume these are methods!
bool isMethod = fna->hasAttr(Attr::Method);
LOG("[py] {}: {}.{} => {}", isMethod ? "method" : "classm", cn, n, fnn);
params = {Param{sctx->generateCanonicalName("self"), N<IdExpr>("cobj")},
Param{sctx->generateCanonicalName("args"), N<IdExpr>("cobj")}};
if (fna->args.size() > 2) {
params.back().type =
N<InstantiateExpr>(N<IdExpr>("Ptr"), params.back().type);
params.back().type = N<InstantiateExpr>(N<IdExpr>("Ptr"), params.back().type);
params.push_back(
Param{sctx->generateCanonicalName("nargs"), N<IdExpr>("int")});
}
ExprPtr po = N<IdExpr>(params[0].name);
po = N<CallExpr>(N<DotExpr>(N<IdExpr>(cn), "__from_py__"), po);
if (isMethod)
args.push_back(po);
if (fna->args.size() > 2) {
for (size_t ai = 1; ai < fna->args.size(); ai++) {
ExprPtr po = N<IndexExpr>(N<IdExpr>(params[1].name), N<IntExpr>(ai - 1));
if (fna->args[ai].type)
if (fna->args.size() > 1 + isMethod) {
for (size_t ai = isMethod; ai < fna->args.size(); ai++) {
ExprPtr po =
N<IndexExpr>(N<IdExpr>(params[1].name), N<IntExpr>(ai - isMethod));
if (fna->args[ai].type) {
po = N<CallExpr>(N<DotExpr>(fna->args[ai].type->clone(), "__from_py__"),
po);
} else {
po = N<CallExpr>(N<IdExpr>("pyobj"), po);
}
args.push_back(po);
}
} else if (fna->args.size() == 2) {
} else if (fna->args.size() == 1 + isMethod) {
ExprPtr po = N<IdExpr>(params[1].name);
if (fna->args[1].type)
po = N<CallExpr>(N<DotExpr>(fna->args[1].type->clone(), "__from_py__"),
po);
args.push_back(po);
if (fna->args[1].type) {
po = N<CallExpr>(N<DotExpr>(fna->args[1].type->clone(), "__from_py__"), po);
} else {
po = N<CallExpr>(N<IdExpr>("pyobj"), po);
}
args.push_back(po);
}
auto stubName = sctx->generateCanonicalName(fmt::format("_py.{}.{}", cn, n));
auto node =
@ -297,119 +300,134 @@ std::shared_ptr<ir::PyModule> Cache::getPythonModule() {
auto tnode = tv.transform(node);
auto rtv = tv.realize(typeCtx->forceFind(stubName)->type);
seqassertn(rtv, "realization of {} failed", stubName);
TranslateVisitor(codegenCtx).transform(tnode);
auto pr = pendingRealizations; // copy it as it might be modified
for (auto &fn : pr)
TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone());
auto f =
functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir;
typeCtx->age = oldAge;
if (n == "__repr__")
if (n == "__repr__") {
py.repr = f;
else if (n == "__add__")
} else if (n == "__add__") {
py.add = f;
else if (n == "__iadd__")
} else if (n == "__iadd__") {
py.iadd = f;
else if (n == "__sub__")
} else if (n == "__sub__") {
py.sub = f;
else if (n == "__isub__")
} else if (n == "__isub__") {
py.isub = f;
else if (n == "__mul__")
} else if (n == "__mul__") {
py.mul = f;
else if (n == "__imul__")
} else if (n == "__imul__") {
py.imul = f;
else if (n == "__mod__")
} else if (n == "__mod__") {
py.mod = f;
else if (n == "__imod__")
} else if (n == "__imod__") {
py.imod = f;
else if (n == "__divmod__")
} else if (n == "__divmod__") {
py.divmod = f;
else if (n == "__pow__")
} else if (n == "__pow__") {
py.pow = f;
else if (n == "__ipow__")
} else if (n == "__ipow__") {
py.ipow = f;
else if (n == "__neg__")
} else if (n == "__neg__") {
py.neg = f;
else if (n == "__pos__")
} else if (n == "__pos__") {
py.pos = f;
else if (n == "__abs__")
} else if (n == "__abs__") {
py.abs = f;
else if (n == "__bool__")
} else if (n == "__bool__") {
py.bool_ = f;
else if (n == "__invert__")
} else if (n == "__invert__") {
py.invert = f;
else if (n == "__lshift__")
} else if (n == "__lshift__") {
py.lshift = f;
else if (n == "__ilshift__")
} else if (n == "__ilshift__") {
py.ilshift = f;
else if (n == "__rshift__")
} else if (n == "__rshift__") {
py.rshift = f;
else if (n == "__irshift__")
} else if (n == "__irshift__") {
py.irshift = f;
else if (n == "__and__")
} else if (n == "__and__") {
py.and_ = f;
else if (n == "__iand__")
} else if (n == "__iand__") {
py.iand = f;
else if (n == "__xor__")
} else if (n == "__xor__") {
py.xor_ = f;
else if (n == "__ixor__")
} else if (n == "__ixor__") {
py.ixor = f;
else if (n == "__or__")
} else if (n == "__or__") {
py.or_ = f;
else if (n == "__ior__")
} else if (n == "__ior__") {
py.ior = f;
else if (n == "__int__")
} else if (n == "__int__") {
py.int_ = f;
else if (n == "__float__")
} else if (n == "__float__") {
py.float_ = f;
else if (n == "__floordiv__")
} else if (n == "__floordiv__") {
py.floordiv = f;
else if (n == "__ifloordiv__")
} else if (n == "__ifloordiv__") {
py.ifloordiv = f;
else if (n == "__truediv__")
} else if (n == "__truediv__") {
py.truediv = f;
else if (n == "__itruediv__")
} else if (n == "__itruediv__") {
py.itruediv = f;
else if (n == "__index__")
} else if (n == "__index__") {
py.index = f;
else if (n == "__matmul__")
} else if (n == "__matmul__") {
py.matmul = f;
else if (n == "__imatmul__")
} else if (n == "__imatmul__") {
py.imatmul = f;
else if (n == "__len__")
} else if (n == "__len__") {
py.len = f;
else if (n == "__getitem__")
} else if (n == "__getitem__") {
py.getitem = f;
else if (n == "__setitem__")
} else if (n == "__setitem__") {
py.setitem = f;
else if (n == "__contains__")
} else if (n == "__contains__") {
py.contains = f;
else if (n == "__hash__")
} else if (n == "__hash__") {
py.hash = f;
else if (n == "__call__")
} else if (n == "__call__") {
py.call = f;
else if (n == "__str__")
} else if (n == "__str__") {
py.str = f;
else if (n == "__cmp__")
} else if (n == "__cmp__") {
py.cmp = f;
else if (n == "__iter__")
} else if (n == "__iter__") {
py.iter = f;
else if (n == "__del__")
} else if (n == "__del__") {
py.del = f;
else if (n == "__new__")
} else if (n == "__new__") {
py.new_ = f;
else if (n == "__init__")
} else if (n == "__init__") {
py.init = f;
else
py.methods.push_back(
ir::PyFunction{n, fna->getDocstr(), f, ir::PyFunction::Type::METHOD});
LOG(">| [{}] {}", functions[stubName].realizations.size(), *f);
} else {
py.methods.push_back(ir::PyFunction{n, fna->getDocstr(), f,
isMethod ? ir::PyFunction::Type::METHOD
: ir::PyFunction::Type::CLASS});
}
mod->types.push_back(py);
// LOG(">| [{}] {}", functions[stubName].realizations.size(), *f);
}
if (c.realizations.size() != 1)
compilationError(fmt::format("cannot pythonize generic class '{}'", cn));
auto &r = c.realizations.begin()->second;
py.type = realizeType(r->type);
for (auto &[mn, mt] : r->fields) {
py.members.push_back(ir::PyMember{mn, "",
mt->is("int") ? ir::PyMember::Type::LONGLONG
: mt->is("float")
? ir::PyMember::Type::DOUBLE
: ir::PyMember::Type::OBJECT,
true});
LOG("[py] {}: {}.{} => {}", "member", cn, mn, py.members.back().type);
}
pyModule->types.push_back(py);
}
#undef N
return mod;
}
} // namespace codon::ast

View File

@ -308,7 +308,8 @@ public:
static std::vector<ExprPtr> mergeC3(std::vector<std::vector<ExprPtr>> &);
std::shared_ptr<ir::PyModule> getPythonModule();
std::shared_ptr<ir::PyModule> pyModule = nullptr;
void populatePythonModule();
};
} // namespace codon::ast

View File

@ -56,7 +56,7 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) {
}
TranslateVisitor(cache->codegenCtx).transform(stmts);
auto _ = cache->getPythonModule();
cache->populatePythonModule();
return main;
}