Add support for typePtrHook and new to/from_py hooks

pull/335/head
Ibrahim Numanagić 2023-02-12 12:20:33 -08:00
parent 19f91909a8
commit 3b0d277e30
2 changed files with 91 additions and 4 deletions

View File

@ -7,6 +7,7 @@
#include <vector>
#include "codon/cir/pyextension.h"
#include "codon/cir/util/irtools.h"
#include "codon/parser/common.h"
#include "codon/parser/peg/peg.h"
#include "codon/parser/visitors/simplify/simplify.h"
@ -283,8 +284,6 @@ void Cache::populatePythonModule() {
N<DotExpr>(N<CallExpr>(N<IdExpr>(canonicalName), args), "__to_py__")))),
Attr({Attr::ForceRealize}));
functions[node->name].ast = node;
int oldAge = typeCtx->age;
typeCtx->age = 99999;
auto tv = TypecheckVisitor(typeCtx);
auto tnode = tv.transform(node);
seqassertn(tnode, "blah");
@ -297,7 +296,6 @@ void Cache::populatePythonModule() {
for (auto &fn : pr)
TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone());
auto f = functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir;
typeCtx->age = oldAge;
return f;
};
@ -307,11 +305,64 @@ void Cache::populatePythonModule() {
pyModule = std::make_shared<ir::PyModule>();
using namespace ast;
int oldAge = typeCtx->age;
typeCtx->age = 99999;
// def wrapper(self: cobj, arg: cobj) -> cobj
// def wrapper(self: cobj, args: Ptr[cobj], nargs: int) -> cobj
for (const auto &[cn, c] : classes)
if (c.module.empty() && startswith(cn, "Pyx")) {
ir::PyType py{rev(cn), c.ast->getDocstr()};
auto tc = typeCtx->forceFind(cn)->type;
if (!tc->canRealize())
compilationError(fmt::format("cannot realize '{}' for Python export", rev(cn)));
tc = TypecheckVisitor(typeCtx).realize(tc);
seqassertn(tc, "cannot realize '{}'", cn);
// fix to_py / from_py
if (auto ofnn = in(c.methods, "__to_py__")) {
auto fnn = overloads[*ofnn].begin()->name; // default first overload!
auto &fna = functions[fnn].ast;
fna->getFunction()->suite = N<ReturnStmt>(N<CallExpr>(
N<IdExpr>("__internal__.to_py:0"), N<IdExpr>(fna->args[0].name)));
} else {
compilationError(fmt::format("class '{}' has no __to_py__"), rev(cn));
}
if (auto ofnn = in(c.methods, "__from_py__")) {
auto fnn = overloads[*ofnn].begin()->name; // default first overload!
auto &fna = functions[fnn].ast;
fna->getFunction()->suite =
N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.from_py:0"),
N<IdExpr>(fna->args[0].name), N<IdExpr>(cn)));
} else {
compilationError(fmt::format("class '{}' has no __from_py__"), rev(cn));
}
for (auto &n : std::vector<std::string>{"__from_py__", "__to_py__"}) {
auto fnn = overloads[*in(c.methods, n)].begin()->name;
ir::Func *oldIR = nullptr;
if (!functions[fnn].realizations.empty())
oldIR = functions[fnn].realizations.begin()->second->ir;
functions[fnn].realizations.clear();
auto tf = TypecheckVisitor(typeCtx).realize(functions[fnn].type);
seqassertn(tf, "cannot re-realize '{}'", fnn);
if (oldIR) {
std::vector<ir::Value *> args;
for (auto it = oldIR->arg_begin(); it != oldIR->arg_end(); ++it) {
args.push_back(module->Nr<ir::VarValue>(*it));
}
ir::cast<ir::BodiedFunc>(oldIR)->setBody(ir::util::series(
ir::util::call(functions[fnn].realizations.begin()->second->ir, args)));
}
}
for (auto &[rn, r] : functions["__internal__.py_type:0"].realizations) {
if (r->type->funcGenerics[0].type->unify(tc.get(), nullptr) >= 0) {
py.typePtrHook = r->ir;
break;
}
}
for (const auto &[n, ofnn] : c.methods) {
auto fnn = overloads[ofnn].back().name; // last overload
auto &fna = functions[fnn].ast;
@ -452,6 +503,8 @@ void Cache::populatePythonModule() {
ir::PyFunction::Type::TOPLEVEL,
int(f.ast->args.size())});
}
typeCtx->age = oldAge;
}
} // namespace codon::ast

View File

@ -1,6 +1,9 @@
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
from internal.gc import free, register_finalizer, seq_alloc, seq_alloc_atomic, seq_gc_add_roots
from internal.gc import (
free, register_finalizer, seq_alloc,
seq_alloc_atomic, seq_gc_add_roots, alloc_uncollectable
)
@pure
@C
@ -448,3 +451,34 @@ class Function:
__vtables__ = __internal__.class_init_vtables()
def _____(): __vtables__ # make it global!
@tuple
class PyObject:
refcnt: int
pytype: Ptr[byte]
@tuple
class PyWrapper[T]:
head: PyObject
data: T
@extend
class __internal__:
def py_type(T: type) -> Ptr[byte]:
return Ptr[byte]()
def to_py(o) -> Ptr[byte]:
pytype = __internal__.py_type(type(o))
obj = Ptr[PyWrapper[type(o)]](alloc_uncollectable(sizeof(type(o))))
obj[0] = PyWrapper(PyObject(1, pytype), o)
return obj.as_byte()
def from_py(o: Ptr[byte], T: type) -> T:
obj = Ptr[PyWrapper[T]](o)[0]
pytype = __internal__.py_type(T)
if obj.head.pytype != pytype:
raise TypeError("Python object has incompatible type")
return obj.data