diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index ab5ced9b..8fc69217 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -7,6 +7,7 @@ #include #include "codon/cir/pyextension.h" +#include "codon/cir/util/irtools.h" #include "codon/parser/common.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/simplify/simplify.h" @@ -283,8 +284,6 @@ void Cache::populatePythonModule() { N(N(N(canonicalName), args), "__to_py__")))), Attr({Attr::ForceRealize})); functions[node->name].ast = node; - int oldAge = typeCtx->age; - typeCtx->age = 99999; auto tv = TypecheckVisitor(typeCtx); auto tnode = tv.transform(node); seqassertn(tnode, "blah"); @@ -297,7 +296,6 @@ void Cache::populatePythonModule() { for (auto &fn : pr) TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone()); auto f = functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir; - typeCtx->age = oldAge; return f; }; @@ -307,11 +305,64 @@ void Cache::populatePythonModule() { pyModule = std::make_shared(); using namespace ast; + int oldAge = typeCtx->age; + typeCtx->age = 99999; + // def wrapper(self: cobj, arg: cobj) -> cobj // def wrapper(self: cobj, args: Ptr[cobj], nargs: int) -> cobj for (const auto &[cn, c] : classes) if (c.module.empty() && startswith(cn, "Pyx")) { ir::PyType py{rev(cn), c.ast->getDocstr()}; + + auto tc = typeCtx->forceFind(cn)->type; + if (!tc->canRealize()) + compilationError(fmt::format("cannot realize '{}' for Python export", rev(cn))); + tc = TypecheckVisitor(typeCtx).realize(tc); + seqassertn(tc, "cannot realize '{}'", cn); + + // fix to_py / from_py + if (auto ofnn = in(c.methods, "__to_py__")) { + auto fnn = overloads[*ofnn].begin()->name; // default first overload! + auto &fna = functions[fnn].ast; + fna->getFunction()->suite = N(N( + N("__internal__.to_py:0"), N(fna->args[0].name))); + } else { + compilationError(fmt::format("class '{}' has no __to_py__"), rev(cn)); + } + if (auto ofnn = in(c.methods, "__from_py__")) { + auto fnn = overloads[*ofnn].begin()->name; // default first overload! + auto &fna = functions[fnn].ast; + fna->getFunction()->suite = + N(N(N("__internal__.from_py:0"), + N(fna->args[0].name), N(cn))); + } else { + compilationError(fmt::format("class '{}' has no __from_py__"), rev(cn)); + } + for (auto &n : std::vector{"__from_py__", "__to_py__"}) { + auto fnn = overloads[*in(c.methods, n)].begin()->name; + ir::Func *oldIR = nullptr; + if (!functions[fnn].realizations.empty()) + oldIR = functions[fnn].realizations.begin()->second->ir; + functions[fnn].realizations.clear(); + auto tf = TypecheckVisitor(typeCtx).realize(functions[fnn].type); + seqassertn(tf, "cannot re-realize '{}'", fnn); + if (oldIR) { + std::vector args; + for (auto it = oldIR->arg_begin(); it != oldIR->arg_end(); ++it) { + args.push_back(module->Nr(*it)); + } + ir::cast(oldIR)->setBody(ir::util::series( + ir::util::call(functions[fnn].realizations.begin()->second->ir, args))); + } + } + + for (auto &[rn, r] : functions["__internal__.py_type:0"].realizations) { + if (r->type->funcGenerics[0].type->unify(tc.get(), nullptr) >= 0) { + py.typePtrHook = r->ir; + break; + } + } + for (const auto &[n, ofnn] : c.methods) { auto fnn = overloads[ofnn].back().name; // last overload auto &fna = functions[fnn].ast; @@ -452,6 +503,8 @@ void Cache::populatePythonModule() { ir::PyFunction::Type::TOPLEVEL, int(f.ast->args.size())}); } + + typeCtx->age = oldAge; } } // namespace codon::ast diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index 36fd0802..586129ae 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -1,6 +1,9 @@ # Copyright (C) 2022-2023 Exaloop Inc. -from internal.gc import free, register_finalizer, seq_alloc, seq_alloc_atomic, seq_gc_add_roots +from internal.gc import ( + free, register_finalizer, seq_alloc, + seq_alloc_atomic, seq_gc_add_roots, alloc_uncollectable +) @pure @C @@ -448,3 +451,34 @@ class Function: __vtables__ = __internal__.class_init_vtables() def _____(): __vtables__ # make it global! + + +@tuple +class PyObject: + refcnt: int + pytype: Ptr[byte] + + +@tuple +class PyWrapper[T]: + head: PyObject + data: T + + +@extend +class __internal__: + def py_type(T: type) -> Ptr[byte]: + return Ptr[byte]() + + def to_py(o) -> Ptr[byte]: + pytype = __internal__.py_type(type(o)) + obj = Ptr[PyWrapper[type(o)]](alloc_uncollectable(sizeof(type(o)))) + obj[0] = PyWrapper(PyObject(1, pytype), o) + return obj.as_byte() + + def from_py(o: Ptr[byte], T: type) -> T: + obj = Ptr[PyWrapper[T]](o)[0] + pytype = __internal__.py_type(T) + if obj.head.pytype != pytype: + raise TypeError("Python object has incompatible type") + return obj.data