From acc96aa6eb99d5500c48af67c62f43c7a1f53141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ibrahim=20Numanagic=CC=81?= Date: Thu, 16 Dec 2021 13:23:25 -0800 Subject: [PATCH] Add support for partial functions with *args/**kwargs; Fix partial method dispatch --- codon/parser/ast/expr.h | 2 +- .../visitors/simplify/simplify_stmt.cpp | 12 +-- .../visitors/typecheck/typecheck_ctx.cpp | 14 +-- .../visitors/typecheck/typecheck_expr.cpp | 88 ++++++++++++++----- test/parser/typecheck_expr.codon | 84 +++++++++++++++--- 5 files changed, 158 insertions(+), 42 deletions(-) diff --git a/codon/parser/ast/expr.h b/codon/parser/ast/expr.h index e1c3be71..df5716c4 100644 --- a/codon/parser/ast/expr.h +++ b/codon/parser/ast/expr.h @@ -275,7 +275,7 @@ struct KeywordStarExpr : public Expr { struct TupleExpr : public Expr { std::vector items; - explicit TupleExpr(std::vector items); + explicit TupleExpr(std::vector items = {}); TupleExpr(const TupleExpr &expr); std::string toString() const override; diff --git a/codon/parser/visitors/simplify/simplify_stmt.cpp b/codon/parser/visitors/simplify/simplify_stmt.cpp index 20dd5282..f736e081 100644 --- a/codon/parser/visitors/simplify/simplify_stmt.cpp +++ b/codon/parser/visitors/simplify/simplify_stmt.cpp @@ -954,7 +954,8 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { auto subs = substitutions[ai]; if (ctx->cache->classes[ctx->bases.back().name] - .methods[ctx->cache->reverseIdentifierLookup[f->name]].empty()) + .methods[ctx->cache->reverseIdentifierLookup[f->name]] + .empty()) generateDispatch(ctx->cache->reverseIdentifierLookup[f->name]); auto newName = ctx->generateCanonicalName( ctx->cache->reverseIdentifierLookup[f->name], true); @@ -1765,10 +1766,11 @@ std::vector SimplifyVisitor::getClassMethods(const StmtPtr &s) { void SimplifyVisitor::generateDispatch(const std::string &name) { transform(N( - name + ".dispatch", nullptr, std::vector{Param("*args")}, - N( - N(N(N(N(ctx->bases.back().name), name), - N(N("args"))))))); + name + ".dispatch", nullptr, + std::vector{Param("*args"), Param("**kwargs")}, + N(N(N( + N(N(ctx->bases.back().name), name), + N(N("args")), N(N("kwargs"))))))); } } // namespace ast diff --git a/codon/parser/visitors/typecheck/typecheck_ctx.cpp b/codon/parser/visitors/typecheck/typecheck_ctx.cpp index e463618e..a84b02d9 100644 --- a/codon/parser/visitors/typecheck/typecheck_ctx.cpp +++ b/codon/parser/visitors/typecheck/typecheck_ctx.cpp @@ -195,15 +195,17 @@ int TypeContext::reorderNamedArgs(types::FuncType *func, int starArgIndex = -1, kwstarArgIndex = -1; for (int i = 0; i < func->ast->args.size(); i++) { - if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "**")) + // if (!known.empty() && known[i] && !partial) + // continue; + if (startswith(func->ast->args[i].name, "**")) kwstarArgIndex = i, score -= 2; - else if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "*")) + else if (startswith(func->ast->args[i].name, "*")) starArgIndex = i, score -= 2; } - seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex], - "partial *args"); - seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex], - "partial **kwargs"); + // seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex], + // "partial *args"); + // seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex], + // "partial **kwargs"); // 1. Assign positional arguments to slots // Each slot contains a list of arg's indices diff --git a/codon/parser/visitors/typecheck/typecheck_expr.cpp b/codon/parser/visitors/typecheck/typecheck_expr.cpp index 09f63c99..8be7976c 100644 --- a/codon/parser/visitors/typecheck/typecheck_expr.cpp +++ b/codon/parser/visitors/typecheck/typecheck_expr.cpp @@ -952,7 +952,8 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, if (bestMethod->ast->attributes.has(Attr::Property)) methodArgs.pop_back(); ExprPtr e = N(N(bestMethod->ast->name), methodArgs); - return transform(e, false, allowVoidExpr); + auto ex = transform(e, false, allowVoidExpr); + return ex; } } @@ -1094,6 +1095,18 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in bool isPartial = false; int ellipsisStage = -1; auto newMask = std::vector(calleeFn->ast->args.size(), 1); + auto getPartialArg = [&](int pi) { + auto id = transform(N(partialVar)); + ExprPtr it = N(pi); + // Manual call to transformStaticTupleIndex needed because otherwise + // IndexExpr routes this to InstantiateExpr. + auto ex = transformStaticTupleIndex(callee.get(), id, it); + seqassert(ex, "partial indexing failed"); + return ex; + }; + + ExprPtr partialStarArgs = nullptr; + ExprPtr partialKwstarArgs = nullptr; if (expr->ordered) args = expr->args; else @@ -1110,17 +1123,38 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in : expr->args[slots[si][0]].value); typeArgCount += typeArgs.back() != nullptr; newMask[si] = slots[si].empty() ? 0 : 1; - } else if (si == starArgIndex && !(partial && slots[si].empty())) { + } else if (si == starArgIndex) { std::vector extra; + if (!known.empty()) + extra.push_back(N(getPartialArg(-2))); for (auto &e : slots[si]) { extra.push_back(expr->args[e].value); if (extra.back()->getEllipsis()) ellipsisStage = args.size(); } - args.push_back({"", transform(N(extra))}); - } else if (si == kwstarArgIndex && !(partial && slots[si].empty())) { + auto e = transform(N(extra)); + if (partial) { + partialStarArgs = e; + args.push_back({"", transform(N())}); + newMask[si] = 0; + } else { + args.push_back({"", e}); + } + } else if (si == kwstarArgIndex) { std::vector names; std::vector values; + if (!known.empty()) { + auto e = getPartialArg(-1); + auto t = e->getType()->getRecord(); + seqassert(t && startswith(t->name, "KwTuple"), "{} not a kwtuple", + e->toString()); + auto &ff = ctx->cache->classes[t->name].fields; + for (int i = 0; i < t->getRecord()->args.size(); i++) { + names.emplace_back(ff[i].name); + values.emplace_back( + CallExpr::Arg{"", transform(N(clone(e), ff[i].name))}); + } + } for (auto &e : slots[si]) { names.emplace_back(expr->args[e].name); values.emplace_back(CallExpr::Arg{"", expr->args[e].value}); @@ -1128,16 +1162,17 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in ellipsisStage = args.size(); } auto kwName = generateTupleStub(names.size(), "KwTuple", names); - args.push_back({"", transform(N(N(kwName), values))}); + auto e = transform(N(N(kwName), values)); + if (partial) { + partialKwstarArgs = e; + args.push_back({"", transform(N())}); + newMask[si] = 0; + } else { + args.push_back({"", e}); + } } else if (slots[si].empty()) { if (!known.empty() && known[si]) { - // Manual call to transformStaticTupleIndex needed because otherwise - // IndexExpr routes this to InstantiateExpr. - auto id = transform(N(partialVar)); - ExprPtr it = N(pi++); - auto ex = transformStaticTupleIndex(callee.get(), id, it); - seqassert(ex, "partial indexing failed"); - args.push_back({"", ex}); + args.push_back({"", getPartialArg(pi++)}); } else if (partial) { args.push_back({"", transform(N())}); newMask[si] = 0; @@ -1165,6 +1200,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in if (isPartial) { deactivateUnbounds(expr->args.back().value->getType().get()); expr->args.pop_back(); + if (!partialStarArgs) + partialStarArgs = transform(N()); + if (!partialKwstarArgs) { + auto kwName = generateTupleStub(0, "KwTuple", {}); + partialKwstarArgs = transform(N(N(kwName))); + } } // Typecheck given arguments with the expected (signature) types. @@ -1257,11 +1298,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in deactivateUnbounds(pt->func.get()); calleeFn->generics[si + 1].type = calleeFn->args[si + 1] = replacements[si]; } - if (auto rt = realize(calleeFn)) { - unify(rt, std::static_pointer_cast(calleeFn)); - expr->expr = transform(expr->expr); + if (!isPartial) { + if (auto rt = realize(calleeFn)) { + unify(rt, std::static_pointer_cast(calleeFn)); + expr->expr = transform(expr->expr); + } } - expr->done &= expr->expr->done; // Emit the final call. @@ -1274,6 +1316,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in for (auto &r : args) if (!r.value->getEllipsis()) newArgs.push_back(r.value); + newArgs.push_back(partialStarArgs); + newArgs.push_back(partialKwstarArgs); std::string var = ctx->cache->getTemporaryVar("partial"); ExprPtr call = nullptr; @@ -1535,7 +1579,8 @@ std::string TypecheckVisitor::generatePartialStub(const std::vector &mask, tupleSize++; auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name); if (!ctx->find(typeName)) - generateTupleStub(tupleSize, typeName, {}, false); + // 2 for .starArgs and .kwstarArgs (empty tuples if fn does not have them) + generateTupleStub(tupleSize + 2, typeName, {}, false); return typeName; } @@ -1601,9 +1646,12 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) { auto partialTypeName = generatePartialStub(mask, fn.get()); deactivateUnbounds(fn.get()); std::string var = ctx->cache->getTemporaryVar("partial"); - ExprPtr call = N( - N(N(var), N(N(partialTypeName))), - N(var)); + auto kwName = generateTupleStub(0, "KwTuple", {}); + ExprPtr call = + N(N(N(var), + N(N(partialTypeName), N(), + N(N(kwName)))), + N(var)); call = transform(call, false, allowVoidExpr); seqassert(call->type->getRecord() && startswith(call->type->getRecord()->name, partialTypeName) && diff --git a/test/parser/typecheck_expr.codon b/test/parser/typecheck_expr.codon index faeb6497..1403319e 100644 --- a/test/parser/typecheck_expr.codon +++ b/test/parser/typecheck_expr.codon @@ -447,16 +447,7 @@ q = p(zh=43, ...) q(1) #: 1 () (zh: 43) r = q(5, 38, ...) r() #: 5 (38) (zh: 43) - -#%% call_partial_star_error,barebones -def foo(x, *args, **kwargs): - print x, args, kwargs -p = foo(...) -p(1, z=5) -q = p(zh=43, ...) -q(1) -r = q(5, 38, ...) -r(1, a=1) #! too many arguments for foo[T1,T2,T3] (expected maximum 3, got 2) +r(1, a=1) #: 5 (38, 1) (zh: 43, a: 1) #%% call_kwargs,barebones def kwhatever(**kwargs): @@ -504,6 +495,79 @@ foo(*(1,2)) #: (1, 2) () foo(3, f) #: (3, (x: 6, y: True)) () foo(k = 3, **f) #: () (k: 3, x: 6, y: True) +#%% call_partial_args_kwargs,barebones +def foo(*args): + print(args) +a = foo(1, 2, ...) +b = a(3, 4, ...) +c = b(5, ...) +c('zooooo') +#: (1, 2, 3, 4, 5, 'zooooo') + +def fox(*args, **kwargs): + print(args, kwargs) +xa = fox(1, 2, x=5, ...) +xb = xa(3, 4, q=6, ...) +xc = xb(5, ...) +xd = xc(z=5.1, ...) +xd('zooooo', w='lele') +#: (1, 2, 3, 4, 5, 'zooooo') (x: 5, q: 6, z: 5.1, w: 'lele') + +class Foo: + i: int + def __str__(self): + return f'#{self.i}' + def foo(self, a): + return f'{self}:generic' + def foo(self, a: float): + return f'{self}:float' + def foo(self, a: int): + return f'{self}:int' +f = Foo(4) + +def pacman(x, f): + print f(x, '5') + print f(x, 2.1) + print f(x, 4) +pacman(f, Foo.foo) +#: #4:generic +#: #4:float +#: #4:int + +def macman(f): + print f('5') + print f(2.1) + print f(4) +macman(f.foo) +#: #4:generic +#: #4:float +#: #4:int + +class Fox: + i: int + def __str__(self): + return f'#{self.i}' + def foo(self, a, b): + return f'{self}:generic b={b}' + def foo(self, a: float, c): + return f'{self}:float, c={c}' + def foo(self, a: int): + return f'{self}:int' + def foo(self, a: int, z, q): + return f'{self}:int z={z} q={q}' +ff = Fox(5) +def maxman(f): + print f('5', b=1) + print f(2.1, 3) + print f(4) + print f(5, 1, q=3) +maxman(ff.foo) +#: #5:generic b=1 +#: #5:float, c=3 +#: #5:int +#: #5:int z=1 q=3 + + #%% call_static,barebones print isinstance(1, int), isinstance(2.2, float), isinstance(3, bool) #: True True False