1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Python compat fixes

This commit is contained in:
Ibrahim Numanagić 2023-03-03 20:36:44 -08:00
parent 12d21ff5eb
commit 3ab03b9c3b
4 changed files with 95 additions and 54 deletions

View File

@ -268,8 +268,12 @@ void Cache::populatePythonModule() {
}; };
const std::string pyWrap = "std.internal.python._PyWrap"; const std::string pyWrap = "std.internal.python._PyWrap";
for (const auto &[cn, c] : classes) for (const auto &[cn, c] : classes) {
if (c.module.empty() && startswith(cn, "Pyx")) { if (c.module.empty()) {
if (!in(c.methods, "__to_py__") || !in(c.methods, "__from_py__"))
continue;
LOG("[py] Cythonizing {}", cn);
ir::PyType py{rev(cn), c.ast->getDocstr()}; ir::PyType py{rev(cn), c.ast->getDocstr()};
auto tc = typeCtx->forceFind(cn)->type; auto tc = typeCtx->forceFind(cn)->type;
@ -284,8 +288,6 @@ void Cache::populatePythonModule() {
auto &fna = functions[fnn].ast; auto &fna = functions[fnn].ast;
fna->getFunction()->suite = N<ReturnStmt>(N<CallExpr>( fna->getFunction()->suite = N<ReturnStmt>(N<CallExpr>(
N<IdExpr>(pyWrap + ".wrap_to_py:0"), N<IdExpr>(fna->args[0].name))); N<IdExpr>(pyWrap + ".wrap_to_py:0"), N<IdExpr>(fna->args[0].name)));
} else {
compilationError(fmt::format("class '{}' has no __to_py__"), rev(cn));
} }
if (auto ofnn = in(c.methods, "__from_py__")) { if (auto ofnn = in(c.methods, "__from_py__")) {
auto fnn = overloads[*ofnn].begin()->name; // default first overload! auto fnn = overloads[*ofnn].begin()->name; // default first overload!
@ -293,8 +295,6 @@ void Cache::populatePythonModule() {
fna->getFunction()->suite = fna->getFunction()->suite =
N<ReturnStmt>(N<CallExpr>(N<IdExpr>(pyWrap + ".wrap_from_py:0"), N<ReturnStmt>(N<CallExpr>(N<IdExpr>(pyWrap + ".wrap_from_py:0"),
N<IdExpr>(fna->args[0].name), N<IdExpr>(cn))); N<IdExpr>(fna->args[0].name), N<IdExpr>(cn)));
} else {
compilationError(fmt::format("class '{}' has no __from_py__"), rev(cn));
} }
for (auto &n : std::vector<std::string>{"__from_py__", "__to_py__"}) { for (auto &n : std::vector<std::string>{"__from_py__", "__to_py__"}) {
auto fnn = overloads[*in(c.methods, n)].begin()->name; auto fnn = overloads[*in(c.methods, n)].begin()->name;
@ -327,12 +327,11 @@ void Cache::populatePythonModule() {
if (overloads[ofnn].size() == 1 && if (overloads[ofnn].size() == 1 &&
functions[canonicalName].ast->hasAttr("autogenerated")) functions[canonicalName].ast->hasAttr("autogenerated"))
continue; continue;
auto fna = functions[canonicalName].ast; auto fna = functions[canonicalName].ast;
bool isMethod = fna->hasAttr(Attr::Method); bool isMethod = fna->hasAttr(Attr::Method);
std::string call = pyWrap + ".wrap_single"; std::string call = pyWrap + ".wrap_multiple";
if (fna->args.size() - isMethod > 1) if (isMethod)
call = pyWrap + ".wrap_multiple"; call += "_method";
bool isMagic = false; bool isMagic = false;
if (startswith(n, "__") && endswith(n, "__")) { if (startswith(n, "__") && endswith(n, "__")) {
if (auto i = in(classes[pyWrap].methods, if (auto i = in(classes[pyWrap].methods,
@ -347,12 +346,12 @@ void Cache::populatePythonModule() {
auto generics = std::vector<types::TypePtr>{tc}; auto generics = std::vector<types::TypePtr>{tc};
if (!isMagic) { if (!isMagic) {
generics.push_back(std::make_shared<types::StaticType>(this, n)); generics.push_back(std::make_shared<types::StaticType>(this, n));
generics.push_back(std::make_shared<types::StaticType>(this, isMethod));
} }
auto f = realizeIR(functions[fnName].type, generics); auto f = realizeIR(functions[fnName].type, generics);
if (!f) if (!f)
continue; continue;
LOG("[py] {} -> {}", n, call);
if (n == "__repr__") { if (n == "__repr__") {
py.repr = f; py.repr = f;
} else if (n == "__add__") { } else if (n == "__add__") {
@ -450,7 +449,9 @@ void Cache::populatePythonModule() {
ir::PyFunction{n, fna->getDocstr(), f, ir::PyFunction{n, fna->getDocstr(), f,
fna->hasAttr(Attr::Method) ? ir::PyFunction::Type::METHOD fna->hasAttr(Attr::Method) ? ir::PyFunction::Type::METHOD
: ir::PyFunction::Type::CLASS, : ir::PyFunction::Type::CLASS,
int(fna->args.size()) - fna->hasAttr(Attr::Method)}); // always use FASTCALL for now; works even for 0- or 1- arg methods
2
});
} }
} }
@ -480,9 +481,10 @@ void Cache::populatePythonModule() {
} }
pyModule->types.push_back(py); pyModule->types.push_back(py);
} }
}
// Handle __iternext__ wrappers // Handle __iternext__ wrappers
auto cin = "std.internal.python._PyWrap.IterWrap"; auto cin = "_PyWrap.IterWrap";
for (auto &[cn, cr] : classes[cin].realizations) { for (auto &[cn, cr] : classes[cin].realizations) {
LOG("[py] iterfn: {}", cn); LOG("[py] iterfn: {}", cn);
ir::PyType py{cn, ""}; ir::PyType py{cn, ""};
@ -519,7 +521,6 @@ void Cache::populatePythonModule() {
call = pyWrap + ".wrap_multiple"; call = pyWrap + ".wrap_multiple";
auto fnName = call + ":0"; auto fnName = call + ":0";
seqassertn(in(functions, fnName), "bad name"); seqassertn(in(functions, fnName), "bad name");
LOG("<- {}", typeCtx->forceFind(".toplevel")->type);
auto generics = std::vector<types::TypePtr>{ auto generics = std::vector<types::TypePtr>{
typeCtx->forceFind(".toplevel")->type, typeCtx->forceFind(".toplevel")->type,
std::make_shared<types::StaticType>(this, rev(f.ast->name))}; std::make_shared<types::StaticType>(this, rev(f.ast->name))};
@ -530,6 +531,10 @@ void Cache::populatePythonModule() {
} }
} }
// Handle pending realizations!
auto pr = pendingRealizations; // copy it as it might be modified
for (auto &fn : pr)
TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone());
typeCtx->age = oldAge; typeCtx->age = oldAge;
} }

View File

@ -840,7 +840,7 @@ ExprPtr TypecheckVisitor::transformSetAttr(CallExpr *expr) {
return transform(N<StmtExpr>(N<AssignMemberStmt>(expr->args[0].value, return transform(N<StmtExpr>(N<AssignMemberStmt>(expr->args[0].value,
staticTyp->evaluate().getString(), staticTyp->evaluate().getString(),
expr->args[1].value), expr->args[1].value),
N<NoneExpr>())); N<CallExpr>(N<IdExpr>("NoneType"))));
} }
/// Raise a compiler error. /// Raise a compiler error.
@ -872,6 +872,7 @@ ExprPtr TypecheckVisitor::transformTupleFn(CallExpr *expr) {
ExprPtr TypecheckVisitor::transformTypeFn(CallExpr *expr) { ExprPtr TypecheckVisitor::transformTypeFn(CallExpr *expr) {
expr->markType(); expr->markType();
transform(expr->args[0].value); transform(expr->args[0].value);
unify(expr->type, expr->args[0].value->getType()); unify(expr->type, expr->args[0].value->getType());
if (!realize(expr->type)) if (!realize(expr->type))
@ -961,9 +962,13 @@ ExprPtr TypecheckVisitor::transformInternalStaticFn(CallExpr *expr) {
if (!fn) if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
std::vector<ExprPtr> v; std::vector<ExprPtr> v;
for (size_t i = 0; i < fn->ast->args.size(); i++) for (size_t i = 0; i < fn->ast->args.size(); i++) {
auto n = fn->ast->args[i].name;
trimStars(n);
n = ctx->cache->rev(n);
v.push_back(N<TupleExpr>(std::vector<ExprPtr>{ v.push_back(N<TupleExpr>(std::vector<ExprPtr>{
N<IntExpr>(i), N<StringExpr>(ctx->cache->rev(fn->ast->args[i].name))})); N<IntExpr>(i), N<StringExpr>(n)}));
}
return transform(N<TupleExpr>(v)); return transform(N<TupleExpr>(v));
} else { } else {
return nullptr; return nullptr;

View File

@ -224,7 +224,9 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) {
auto name = ctx->getStaticString(generics[1]); auto name = ctx->getStaticString(generics[1]);
seqassert(name, "bad static string"); seqassert(name, "bad static string");
if (auto n = in(ctx->cache->classes[typ->name].methods, *name)) { if (auto n = in(ctx->cache->classes[typ->name].methods, *name)) {
for (auto &method : ctx->cache->overloads[*n]) { auto &mt = ctx->cache->overloads[*n];
for (int mti = int(mt.size()) - 1; mti >= 0; mti--) {
auto &method = mt[mti];
if (endswith(method.name, ":dispatch") || if (endswith(method.name, ":dispatch") ||
!ctx->cache->functions[method.name].type) !ctx->cache->functions[method.name].type)
continue; continue;

View File

@ -1517,6 +1517,8 @@ def _____(): __pyenv__ # make it global!
import internal.static as _S import internal.static as _S
class _PyWrap: class _PyWrap:
def _wrap_arg(arg: cobj):
return pyobj(arg, steal=True)
def _wrap(args, T: type, F: Static[str], map): def _wrap(args, T: type, F: Static[str], map):
for fn in _S.fn_overloads(T, F): for fn in _S.fn_overloads(T, F):
if _S.fn_can_call(fn, *args): if _S.fn_can_call(fn, *args):
@ -1527,8 +1529,9 @@ class _PyWrap:
raise PyError("cannot dispatch " + F) raise PyError("cannot dispatch " + F)
def _wrap_unary(obj: cobj, T: type, F: Static[str]) -> cobj: def _wrap_unary(obj: cobj, T: type, F: Static[str]) -> cobj:
# print(f'[c] unary: {T.__class__.__name__} {F}')
return _PyWrap._wrap( return _PyWrap._wrap(
(pyobj(obj), ), T=T, F=F, (_PyWrap._wrap_arg(obj), ), T=T, F=F,
map=lambda f, a: f(*a).__to_py__() map=lambda f, a: f(*a).__to_py__()
) )
def wrap_magic_abs(obj: cobj, T: type): def wrap_magic_abs(obj: cobj, T: type):
@ -1552,7 +1555,7 @@ class _PyWrap:
def _wrap_hash(obj: cobj, T: type, F: Static[str]) -> i64: def _wrap_hash(obj: cobj, T: type, F: Static[str]) -> i64:
return _PyWrap._wrap( return _PyWrap._wrap(
(pyobj(obj), ), T=T, F=F, (_PyWrap._wrap_arg(obj), ), T=T, F=F,
map=lambda f, a: f(*a) map=lambda f, a: f(*a)
) )
def wrap_magic_len(obj: cobj, T: type): def wrap_magic_len(obj: cobj, T: type):
@ -1562,34 +1565,42 @@ class _PyWrap:
def wrap_magic_bool(obj: cobj, T: type) -> i32: def wrap_magic_bool(obj: cobj, T: type) -> i32:
return _PyWrap._wrap( return _PyWrap._wrap(
(pyobj(obj), ), T=T, F="__bool__", (_PyWrap._wrap_arg(obj), ), T=T, F="__bool__",
map=lambda f, a: i32(f(*a)) map=lambda f, a: i32(f(*a))
) )
def wrap_magic_del(obj: cobj, T: type): def wrap_magic_del(obj: cobj, T: type):
_PyWrap._wrap( _PyWrap._wrap(
(pyobj(obj), ), T=T, F="__del__", (_PyWrap._wrap_arg(obj), ), T=T, F="__del__",
map=lambda f, a: f(*a) map=lambda f, a: f(*a)
) )
def wrap_magic_contains(obj: cobj, arg: cobj, T: type) -> i32: def wrap_magic_contains(obj: cobj, arg: cobj, T: type) -> i32:
return _PyWrap._wrap( return _PyWrap._wrap(
(pyobj(obj), pyobj(arg)), T=T, F="__contains__", (_PyWrap._wrap_arg(obj), _PyWrap._wrap_arg(arg)), T=T, F="__contains__",
map=lambda f, a: i32(f(*a)) map=lambda f, a: i32(f(*a))
) )
def wrap_magic_init(obj: cobj, _args: cobj, _kwds: cobj, T: type) -> i32: def wrap_magic_init(obj: cobj, _args: cobj, _kwds: cobj, T: type) -> i32:
args = pyobj(_args) # print(f'[c] init: {T.__class__.__name__}')
kwds = pyobj(_kwds)
args = _PyWrap._wrap_arg(_args)
kwds = _PyWrap._wrap_arg(_kwds) if _kwds != cobj() else None
# print(f'[c] args: {args}')
# print(f'[c] kwargs: {kwds}')
for fn in _S.fn_overloads(T, "__init__"): for fn in _S.fn_overloads(T, "__init__"):
try: try:
ai = -1 ai = -1
# TODO: default values do not work # TODO: default values do not work; same for *args/**kwargs
a = tuple( a = tuple(
kwds[n] if n in kwds else args[(ai := ai + 1)] _PyWrap._wrap_arg(obj) if i == 0 else
for _, n in _S.fn_args(fn) (kwds[n] if kwds and n in kwds else args[(ai := ai + 1)])
for i, n in _S.fn_args(fn)
) )
a = (pyobj(obj), *a) if ai + 1 != args.__len__():
continue
if _S.fn_can_call(fn, *a): if _S.fn_can_call(fn, *a):
fn(*a) fn(*a)
return i32(0) return i32(0)
@ -1598,16 +1609,19 @@ class _PyWrap:
return i32(-1) return i32(-1)
def wrap_magic_call(obj: cobj, _args: cobj, _kwds: cobj, T: type) -> cobj: def wrap_magic_call(obj: cobj, _args: cobj, _kwds: cobj, T: type) -> cobj:
args = pyobj(_args) args = _PyWrap._wrap_arg(_args)
kwds = pyobj(_kwds) kwds = _PyWrap._wrap_arg(_kwds) if _kwds != cobj() else None
for fn in _S.fn_overloads(T, "__call__"): for fn in _S.fn_overloads(T, "__call__"):
try: try:
ai = -1 ai = -1
a = tuple( # TODO: default values do not work # TODO: default values do not work; same for *args/**kwargs
kwds[n] if n in kwds else args[(ai := ai + 1)] a = tuple(
for _, n in _S.fn_args(fn) _PyWrap._wrap_arg(obj) if i == 0 else
(kwds[n] if kwds and n in kwds else args[(ai := ai + 1)])
for i, n in _S.fn_args(fn)
) )
a = (pyobj(obj), *a) if ai + 1 != args.__len__():
continue
if _S.fn_can_call(fn, *a): if _S.fn_can_call(fn, *a):
return fn(*a).__to_py__() return fn(*a).__to_py__()
except PyError: except PyError:
@ -1616,7 +1630,7 @@ class _PyWrap:
def _wrap_cmp(obj: cobj, other: cobj, T: type, F: Static[str]) -> cobj: def _wrap_cmp(obj: cobj, other: cobj, T: type, F: Static[str]) -> cobj:
return _PyWrap._wrap( return _PyWrap._wrap(
(pyobj(obj), pyobj(other)), T=T, F=F, (_PyWrap._wrap_arg(obj), _PyWrap._wrap_arg(other)), T=T, F=F,
map=lambda f, a: f(*a).__to_py__() map=lambda f, a: f(*a).__to_py__()
) )
def wrap_magic_lt(obj: cobj, other: cobj, T: type): def wrap_magic_lt(obj: cobj, other: cobj, T: type):
@ -1652,14 +1666,14 @@ class _PyWrap:
if val == cobj(): if val == cobj():
try: try:
if hasattr(T, "__delitem__"): if hasattr(T, "__delitem__"):
T.__delitem__(pyobj(obj), pyobj(idx)) T.__delitem__(_PyWrap._wrap_arg(obj), _PyWrap._wrap_arg(idx))
return 0 return 0
except PyError: except PyError:
pass pass
return -1 return -1
try: try:
_PyWrap._wrap( _PyWrap._wrap(
(pyobj(obj), pyobj(idx), pyobj(val)), T=T, F="__setitem__", (_PyWrap._wrap_arg(obj), _PyWrap._wrap_arg(idx), _PyWrap._wrap_arg(val)), T=T, F="__setitem__",
map=lambda f, a: f(*a).__to_py__() map=lambda f, a: f(*a).__to_py__()
) )
return 0 return 0
@ -1696,32 +1710,46 @@ class _PyWrap:
return _PyWrap.wrap_from_py(obj, _PyWrap.IterWrap[T]) return _PyWrap.wrap_from_py(obj, _PyWrap.IterWrap[T])
def wrap_magic_iter(obj: cobj, T: type) -> cobj: def wrap_magic_iter(obj: cobj, T: type) -> cobj:
# print('[c] iter')
return _PyWrap.IterWrap._init(obj, T) return _PyWrap.IterWrap._init(obj, T)
def wrap_single(obj: cobj, arg: cobj, T: type, F: Static[str], method: Static[int]): def wrap_multiple_method(obj: cobj, args: Ptr[cobj], nargs: int, T: type, F: Static[str]):
a = (pyobj(obj), pyobj(arg)) if method else (pyobj(arg),) # print(f'[c] method: {T.__class__.__name__} {F} {obj} {args} {nargs}')
return _PyWrap._wrap( def _err() -> pyobj:
a, T=T, F=F,
map=lambda f, a: f(*a).__to_py__()
)
def wrap_multiple(obj: cobj, args: Ptr[cobj], nargs: i32, T: type, F: Static[str], method: Static[int]):
def _err():
raise PyError("argument mismatch") raise PyError("argument mismatch")
return pyobj()
a = (pyobj(obj), ) if method else ()
for fn in _S.fn_overloads(T, F): for fn in _S.fn_overloads(T, F):
try: try:
ai = -1 ai = -1
an = ( an = tuple(
pyobj(args[i]) if i < nargs else _err() _PyWrap._wrap_arg(obj) if i == 0 else
(_PyWrap._wrap_arg(args[i]) if i < nargs else _err())
for i, _ in _S.fn_args(fn)
)
if len(an) != nargs + 1:
_err()
if _S.fn_can_call(fn, *an):
return fn(*an).__to_py__()
except PyError:
pass
PyError("cannot dispatch " + F)
def wrap_multiple(obj: cobj, args: Ptr[cobj], nargs: int, T: type, F: Static[str]):
# print(f'[c] nonmethod: {T.__class__.__name__} {F} {obj} {args} {nargs}')
def _err() -> pyobj:
raise PyError("argument mismatch")
for fn in _S.fn_overloads(T, F):
try:
ai = -1
an = tuple(
_PyWrap._wrap_arg(args[i]) if i < nargs else _err()
for i, _ in _S.fn_args(fn) for i, _ in _S.fn_args(fn)
) )
if len(an) != nargs: if len(an) != nargs:
_err() _err()
if _S.fn_can_call(fn, (*a, *an)): if _S.fn_can_call(fn, *an):
return fn(*a, *an).__to_py__() return fn(*an).__to_py__()
except PyError: except PyError:
pass pass
PyError("cannot dispatch " + F) PyError("cannot dispatch " + F)
@ -1731,7 +1759,8 @@ class _PyWrap:
def wrap_set(obj: cobj, what: cobj, closure: cobj, T: type, S: Static[str]) -> i32: def wrap_set(obj: cobj, what: cobj, closure: cobj, T: type, S: Static[str]) -> i32:
try: try:
t = T.__from_py__(obj) t = T.__from_py__(obj)
setattr(t, S, type(getattr(t, S)).__from_py__(what)) val = type(getattr(t, S)).__from_py__(what)
setattr(t, S, val)
return i32(0) return i32(0)
except PyError: except PyError:
return i32(-1) return i32(-1)