From 66f9e4fd2391dea105da0040726ca6bfd562e38e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ibrahim=20Numanagic=CC=81?= Date: Tue, 14 Mar 2023 22:32:34 -0700 Subject: [PATCH] Fix Python calls; add staticenumerate --- codon/parser/cache.cpp | 4 +- codon/parser/visitors/simplify/call.cpp | 8 +++ codon/parser/visitors/typecheck/call.cpp | 23 ++++++- codon/parser/visitors/typecheck/loops.cpp | 13 ++-- stdlib/internal/builtin.codon | 5 ++ stdlib/internal/python.codon | 84 +++++++++++++---------- stdlib/internal/static.codon | 6 ++ 7 files changed, 95 insertions(+), 48 deletions(-) diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index 540d8cf3..de86edff 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -330,8 +330,6 @@ void Cache::populatePythonModule() { auto fna = functions[canonicalName].ast; bool isMethod = fna->hasAttr(Attr::Method); std::string call = pyWrap + ".wrap_multiple"; - if (isMethod) - call += "_method"; bool isMagic = false; if (startswith(n, "__") && endswith(n, "__")) { if (auto i = in(classes[pyWrap].methods, @@ -346,6 +344,7 @@ void Cache::populatePythonModule() { auto generics = std::vector{tc}; if (!isMagic) { generics.push_back(std::make_shared(this, n)); + generics.push_back(std::make_shared(this, (int)isMethod)); } auto f = realizeIR(functions[fnName].type, generics); if (!f) @@ -451,6 +450,7 @@ void Cache::populatePythonModule() { : ir::PyFunction::Type::CLASS, // always use FASTCALL for now; works even for 0- or 1- arg methods 2}); + py.methods.back().keywords = true; } } diff --git a/codon/parser/visitors/simplify/call.cpp b/codon/parser/visitors/simplify/call.cpp index 50a9ba03..4158575a 100644 --- a/codon/parser/visitors/simplify/call.cpp +++ b/codon/parser/visitors/simplify/call.cpp @@ -78,6 +78,9 @@ SimplifyVisitor::transformTupleGenerator(const std::vector &args) E(Error::CALL_TUPLE_COMPREHENSION, args[0].value); auto var = clone(g->loops[0].vars); auto ex = clone(g->expr); + + ctx->enterConditionalBlock(); + ctx->getBase()->loops.push_back({"", ctx->scope.blocks, {}}); if (auto i = var->getId()) { ctx->addVar(i->value, ctx->generateCanonicalName(i->value), var->getSrcInfo()); var = transform(var); @@ -89,6 +92,11 @@ SimplifyVisitor::transformTupleGenerator(const std::vector &args) auto head = transform(N(clone(g->loops[0].vars), clone(var))); ex = N(head, transform(ex)); } + ctx->leaveConditionalBlock(); + // Dominate loop variables + for (auto &var : ctx->getBase()->getLoop()->seenVars) + ctx->findDominatingBinding(var); + ctx->getBase()->loops.pop_back(); return N( GeneratorExpr::Generator, ex, std::vector{{var, transform(g->loops[0].gen), {}}}); diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index 24d39af9..1e8f13be 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -966,9 +966,30 @@ ExprPtr TypecheckVisitor::transformInternalStaticFn(CallExpr *expr) { auto n = fn->ast->args[i].name; trimStars(n); n = ctx->cache->rev(n); - v.push_back(N(std::vector{N(i), N(n)})); + v.push_back(N(n)); } return transform(N(v)); + } else if (expr->expr->isId("std.internal.static.fn_has_default")) { + expr->staticValue.type = StaticValue::INT; + auto fn = expr->args[0].value->type->getFunc(); + if (!fn) + error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); + auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); + seqassert(idx, "expected a static integer"); + auto &args = fn->ast->args; + if (*idx < 0 || *idx >= args.size()) + error("argument out of bounds"); + return transform(N(args[*idx].defaultValue != nullptr)); + } else if (expr->expr->isId("std.internal.static.fn_get_default")) { + auto fn = expr->args[0].value->type->getFunc(); + if (!fn) + error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); + auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); + seqassert(idx, "expected a static integer"); + auto &args = fn->ast->args; + if (*idx < 0 || *idx >= args.size()) + error("argument out of bounds"); + return transform(args[*idx].defaultValue); } else { return nullptr; } diff --git a/codon/parser/visitors/typecheck/loops.cpp b/codon/parser/visitors/typecheck/loops.cpp index 3e505b2f..375e319c 100644 --- a/codon/parser/visitors/typecheck/loops.cpp +++ b/codon/parser/visitors/typecheck/loops.cpp @@ -237,7 +237,7 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { } else { error("bad call to fn_overloads"); } - } else if (iter && startswith(iter->value, "std.internal.static.fn_args")) { + } else if (iter && startswith(iter->value, "std.internal.builtin.staticenumerate")) { auto &suiteVec = stmt->suite->getSuite()->stmts; int validI = 0; for (; validI < suiteVec.size(); validI++) { @@ -252,17 +252,15 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { if (auto fna = ctx->getFunctionArgs(iter->type)) { auto [generics, args] = *fna; - auto typ = ctx->extractFunction(args[0]); + auto typ = args[0]->getRecord(); if (!typ) - error("fn_args needs a function"); - for (size_t i = 0; i < typ->ast->args.size(); i++) { + error("fn_args needs a tuple"); + for (size_t i = 0; i < typ->args.size(); i++) { suiteVec[0]->getAssign()->rhs = N(i); suiteVec[0]->getAssign()->type = NT(NT("Static"), NT("int")); suiteVec[1]->getAssign()->rhs = - N(ctx->cache->rev(typ->ast->args[i].name)); - suiteVec[1]->getAssign()->type = - NT(NT("Static"), NT("str")); + N(stmt->iter->getCall()->args[0].value->clone(), N(i)); block->stmts.push_back(fn("", nullptr)); } } else { @@ -282,7 +280,6 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { transform(N(N(N(loopVar), N(true)), N(N(loopVar), block))); ctx->blockLevel--; - // LOG("-> {} :: {}", getSrcInfo(), loop->toString(2)); return loop; } diff --git a/stdlib/internal/builtin.codon b/stdlib/internal/builtin.codon index be18e3cf..7567ce65 100644 --- a/stdlib/internal/builtin.codon +++ b/stdlib/internal/builtin.codon @@ -221,6 +221,11 @@ def enumerate(x, start: int = 0): yield (i, a) i += 1 +def staticenumerate(tup): + i = -1 + return tuple(((i := i + 1), t) for t in tup) + i + def echo(x): """ Print and return argument diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon index b2bb0cd6..ea58fea9 100644 --- a/stdlib/internal/python.codon +++ b/stdlib/internal/python.codon @@ -1443,6 +1443,17 @@ def _____(): __pyenv__ # make it global! import internal.static as _S + + +class _PyWrapError(Static[PyError]): + def __init__(self, message: str, pytype: pyobj = pyobj(cobj(), steal=True)): + super().__init__("_PyWrapError", message) + self.pytype = pytype + + def __init__(self, e: PyError): + self.__init__("_PyWrapError", e.message, e.pytype) + + class _PyWrap: def _wrap_arg(arg: cobj): return pyobj(arg, steal=True) @@ -1451,7 +1462,7 @@ class _PyWrap: if _S.fn_can_call(fn, *args): try: return map(fn, args) - except PyError: + except PyError as e: pass raise PyError("cannot dispatch " + F) @@ -1524,7 +1535,7 @@ class _PyWrap: a = tuple( _PyWrap._wrap_arg(obj) if i == 0 else (kwds[n] if kwds and n in kwds else args[(ai := ai + 1)]) - for i, n in _S.fn_args(fn) + for i, n in staticenumerate(_S.fn_args(fn)) ) if ai + 1 != args.__len__(): continue @@ -1545,7 +1556,7 @@ class _PyWrap: a = tuple( _PyWrap._wrap_arg(obj) if i == 0 else (kwds[n] if kwds and n in kwds else args[(ai := ai + 1)]) - for i, n in _S.fn_args(fn) + for i, n in staticenumerate(_S.fn_args(fn)) ) if ai + 1 != args.__len__(): continue @@ -1640,43 +1651,40 @@ class _PyWrap: # print('[c] iter') return _PyWrap.IterWrap._init(obj, T) - def wrap_multiple_method(obj: cobj, args: Ptr[cobj], nargs: int, T: type, F: Static[str]): - # print(f'[c] method: {T.__class__.__name__} {F} {obj} {args} {nargs}') - def _err() -> pyobj: - raise PyError("argument mismatch") - + def wrap_multiple( + obj: cobj, _args: cobj, nargs: int, _kwds: cobj, T: type, F: Static[str], + M: Static[int] = 1 + ): + args = _PyWrap._wrap_arg(_args) + kwds = _PyWrap._wrap_arg(_kwds) if _kwds != cobj() else None for fn in _S.fn_overloads(T, F): try: - ai = -1 - an = tuple( - _PyWrap._wrap_arg(obj) if i == 0 else - (_PyWrap._wrap_arg(args[i]) if i < nargs else _err()) - for i, _ in _S.fn_args(fn) - ) - if len(an) != nargs + 1: - _err() - if _S.fn_can_call(fn, *an): - return fn(*an).__to_py__() - except PyError: - pass - raise PyError("cannot dispatch " + F) + a = [pyobj(cobj(), steal=True) for _ in _S.fn_args(fn)] + ai, ki = 0, 0 + for i, n in staticenumerate(_S.fn_args(fn)): + if M == 1 and i == 0: + print('le method') + a[i] = _PyWrap._wrap_arg(obj) + continue + if ai < nargs: + a[i] = args[ai] + ai += 1 + elif kwds and n in kwds: + a[i] = kwds[n] + ki += 1 + else: + if _S.fn_has_default(fn, i): + a[i] = _S.fn_get_default(fn, i) + else: + raise PyError("missing argument") + if ai < nargs: + raise PyError("too many *args") + if kwds and ki < kwds.__len__(): + raise PyError("too many **kwargs") + ta = tuple(unwrap(a[i]) for i, _ in staticenumerate(_S.fn_args(fn))) - def wrap_multiple(obj: cobj, args: Ptr[cobj], nargs: int, T: type, F: Static[str]): - # print(f'[c] nonmethod: {T.__class__.__name__} {F} {obj} {args} {nargs}') - def _err() -> pyobj: - raise PyError("argument mismatch") - - for fn in _S.fn_overloads(T, F): - try: - ai = -1 - an = tuple( - _PyWrap._wrap_arg(args[i]) if i < nargs else _err() - for i, _ in _S.fn_args(fn) - ) - if len(an) != nargs: - _err() - if _S.fn_can_call(fn, *an): - return fn(*an).__to_py__() + if _S.fn_can_call(fn, *ta): + return fn(*ta).__to_py__() except PyError: pass raise PyError("cannot dispatch " + F) @@ -1707,3 +1715,5 @@ class _PyWrap: if obj.head.pytype != pytype: raise TypeError("Python object has incompatible type") return obj.data + + diff --git a/stdlib/internal/static.codon b/stdlib/internal/static.codon index 4b75ff12..bf043597 100644 --- a/stdlib/internal/static.codon +++ b/stdlib/internal/static.codon @@ -15,5 +15,11 @@ def fn_arg_get_type(F, i: Static[int]): def fn_can_call(F, *args, **kwargs): pass +def fn_has_default(F, i: Static[int]): + pass + +def fn_get_default(F, i: Static[int]): + pass + def class_args(T: type): pass