1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Add support for partial functions with *args/**kwargs; Fix partial method dispatch

This commit is contained in:
Ibrahim Numanagić 2021-12-16 13:23:25 -08:00
parent 3d6090322d
commit acc96aa6eb
5 changed files with 158 additions and 42 deletions

View File

@ -275,7 +275,7 @@ struct KeywordStarExpr : public Expr {
struct TupleExpr : public Expr {
std::vector<ExprPtr> items;
explicit TupleExpr(std::vector<ExprPtr> items);
explicit TupleExpr(std::vector<ExprPtr> items = {});
TupleExpr(const TupleExpr &expr);
std::string toString() const override;

View File

@ -954,7 +954,8 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
auto subs = substitutions[ai];
if (ctx->cache->classes[ctx->bases.back().name]
.methods[ctx->cache->reverseIdentifierLookup[f->name]].empty())
.methods[ctx->cache->reverseIdentifierLookup[f->name]]
.empty())
generateDispatch(ctx->cache->reverseIdentifierLookup[f->name]);
auto newName = ctx->generateCanonicalName(
ctx->cache->reverseIdentifierLookup[f->name], true);
@ -1765,10 +1766,11 @@ std::vector<StmtPtr> SimplifyVisitor::getClassMethods(const StmtPtr &s) {
void SimplifyVisitor::generateDispatch(const std::string &name) {
transform(N<FunctionStmt>(
name + ".dispatch", nullptr, std::vector<Param>{Param("*args")},
N<SuiteStmt>(
N<ReturnStmt>(N<CallExpr>(N<DotExpr>(N<IdExpr>(ctx->bases.back().name), name),
N<StarExpr>(N<IdExpr>("args")))))));
name + ".dispatch", nullptr,
std::vector<Param>{Param("*args"), Param("**kwargs")},
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
N<DotExpr>(N<IdExpr>(ctx->bases.back().name), name),
N<StarExpr>(N<IdExpr>("args")), N<KeywordStarExpr>(N<IdExpr>("kwargs")))))));
}
} // namespace ast

View File

@ -195,15 +195,17 @@ int TypeContext::reorderNamedArgs(types::FuncType *func,
int starArgIndex = -1, kwstarArgIndex = -1;
for (int i = 0; i < func->ast->args.size(); i++) {
if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "**"))
// if (!known.empty() && known[i] && !partial)
// continue;
if (startswith(func->ast->args[i].name, "**"))
kwstarArgIndex = i, score -= 2;
else if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "*"))
else if (startswith(func->ast->args[i].name, "*"))
starArgIndex = i, score -= 2;
}
seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex],
"partial *args");
seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex],
"partial **kwargs");
// seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex],
// "partial *args");
// seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex],
// "partial **kwargs");
// 1. Assign positional arguments to slots
// Each slot contains a list of arg's indices

View File

@ -952,7 +952,8 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
if (bestMethod->ast->attributes.has(Attr::Property))
methodArgs.pop_back();
ExprPtr e = N<CallExpr>(N<IdExpr>(bestMethod->ast->name), methodArgs);
return transform(e, false, allowVoidExpr);
auto ex = transform(e, false, allowVoidExpr);
return ex;
}
}
@ -1094,6 +1095,18 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
bool isPartial = false;
int ellipsisStage = -1;
auto newMask = std::vector<char>(calleeFn->ast->args.size(), 1);
auto getPartialArg = [&](int pi) {
auto id = transform(N<IdExpr>(partialVar));
ExprPtr it = N<IntExpr>(pi);
// Manual call to transformStaticTupleIndex needed because otherwise
// IndexExpr routes this to InstantiateExpr.
auto ex = transformStaticTupleIndex(callee.get(), id, it);
seqassert(ex, "partial indexing failed");
return ex;
};
ExprPtr partialStarArgs = nullptr;
ExprPtr partialKwstarArgs = nullptr;
if (expr->ordered)
args = expr->args;
else
@ -1110,17 +1123,38 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
: expr->args[slots[si][0]].value);
typeArgCount += typeArgs.back() != nullptr;
newMask[si] = slots[si].empty() ? 0 : 1;
} else if (si == starArgIndex && !(partial && slots[si].empty())) {
} else if (si == starArgIndex) {
std::vector<ExprPtr> extra;
if (!known.empty())
extra.push_back(N<StarExpr>(getPartialArg(-2)));
for (auto &e : slots[si]) {
extra.push_back(expr->args[e].value);
if (extra.back()->getEllipsis())
ellipsisStage = args.size();
}
args.push_back({"", transform(N<TupleExpr>(extra))});
} else if (si == kwstarArgIndex && !(partial && slots[si].empty())) {
auto e = transform(N<TupleExpr>(extra));
if (partial) {
partialStarArgs = e;
args.push_back({"", transform(N<EllipsisExpr>())});
newMask[si] = 0;
} else {
args.push_back({"", e});
}
} else if (si == kwstarArgIndex) {
std::vector<std::string> names;
std::vector<CallExpr::Arg> values;
if (!known.empty()) {
auto e = getPartialArg(-1);
auto t = e->getType()->getRecord();
seqassert(t && startswith(t->name, "KwTuple"), "{} not a kwtuple",
e->toString());
auto &ff = ctx->cache->classes[t->name].fields;
for (int i = 0; i < t->getRecord()->args.size(); i++) {
names.emplace_back(ff[i].name);
values.emplace_back(
CallExpr::Arg{"", transform(N<DotExpr>(clone(e), ff[i].name))});
}
}
for (auto &e : slots[si]) {
names.emplace_back(expr->args[e].name);
values.emplace_back(CallExpr::Arg{"", expr->args[e].value});
@ -1128,16 +1162,17 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
ellipsisStage = args.size();
}
auto kwName = generateTupleStub(names.size(), "KwTuple", names);
args.push_back({"", transform(N<CallExpr>(N<IdExpr>(kwName), values))});
auto e = transform(N<CallExpr>(N<IdExpr>(kwName), values));
if (partial) {
partialKwstarArgs = e;
args.push_back({"", transform(N<EllipsisExpr>())});
newMask[si] = 0;
} else {
args.push_back({"", e});
}
} else if (slots[si].empty()) {
if (!known.empty() && known[si]) {
// Manual call to transformStaticTupleIndex needed because otherwise
// IndexExpr routes this to InstantiateExpr.
auto id = transform(N<IdExpr>(partialVar));
ExprPtr it = N<IntExpr>(pi++);
auto ex = transformStaticTupleIndex(callee.get(), id, it);
seqassert(ex, "partial indexing failed");
args.push_back({"", ex});
args.push_back({"", getPartialArg(pi++)});
} else if (partial) {
args.push_back({"", transform(N<EllipsisExpr>())});
newMask[si] = 0;
@ -1165,6 +1200,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
if (isPartial) {
deactivateUnbounds(expr->args.back().value->getType().get());
expr->args.pop_back();
if (!partialStarArgs)
partialStarArgs = transform(N<TupleExpr>());
if (!partialKwstarArgs) {
auto kwName = generateTupleStub(0, "KwTuple", {});
partialKwstarArgs = transform(N<CallExpr>(N<IdExpr>(kwName)));
}
}
// Typecheck given arguments with the expected (signature) types.
@ -1257,11 +1298,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
deactivateUnbounds(pt->func.get());
calleeFn->generics[si + 1].type = calleeFn->args[si + 1] = replacements[si];
}
if (auto rt = realize(calleeFn)) {
unify(rt, std::static_pointer_cast<Type>(calleeFn));
expr->expr = transform(expr->expr);
if (!isPartial) {
if (auto rt = realize(calleeFn)) {
unify(rt, std::static_pointer_cast<Type>(calleeFn));
expr->expr = transform(expr->expr);
}
}
expr->done &= expr->expr->done;
// Emit the final call.
@ -1274,6 +1316,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
for (auto &r : args)
if (!r.value->getEllipsis())
newArgs.push_back(r.value);
newArgs.push_back(partialStarArgs);
newArgs.push_back(partialKwstarArgs);
std::string var = ctx->cache->getTemporaryVar("partial");
ExprPtr call = nullptr;
@ -1535,7 +1579,8 @@ std::string TypecheckVisitor::generatePartialStub(const std::vector<char> &mask,
tupleSize++;
auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name);
if (!ctx->find(typeName))
generateTupleStub(tupleSize, typeName, {}, false);
// 2 for .starArgs and .kwstarArgs (empty tuples if fn does not have them)
generateTupleStub(tupleSize + 2, typeName, {}, false);
return typeName;
}
@ -1601,9 +1646,12 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
auto partialTypeName = generatePartialStub(mask, fn.get());
deactivateUnbounds(fn.get());
std::string var = ctx->cache->getTemporaryVar("partial");
ExprPtr call = N<StmtExpr>(
N<AssignStmt>(N<IdExpr>(var), N<CallExpr>(N<IdExpr>(partialTypeName))),
N<IdExpr>(var));
auto kwName = generateTupleStub(0, "KwTuple", {});
ExprPtr call =
N<StmtExpr>(N<AssignStmt>(N<IdExpr>(var),
N<CallExpr>(N<IdExpr>(partialTypeName), N<TupleExpr>(),
N<CallExpr>(N<IdExpr>(kwName)))),
N<IdExpr>(var));
call = transform(call, false, allowVoidExpr);
seqassert(call->type->getRecord() &&
startswith(call->type->getRecord()->name, partialTypeName) &&

View File

@ -447,16 +447,7 @@ q = p(zh=43, ...)
q(1) #: 1 () (zh: 43)
r = q(5, 38, ...)
r() #: 5 (38) (zh: 43)
#%% call_partial_star_error,barebones
def foo(x, *args, **kwargs):
print x, args, kwargs
p = foo(...)
p(1, z=5)
q = p(zh=43, ...)
q(1)
r = q(5, 38, ...)
r(1, a=1) #! too many arguments for foo[T1,T2,T3] (expected maximum 3, got 2)
r(1, a=1) #: 5 (38, 1) (zh: 43, a: 1)
#%% call_kwargs,barebones
def kwhatever(**kwargs):
@ -504,6 +495,79 @@ foo(*(1,2)) #: (1, 2) ()
foo(3, f) #: (3, (x: 6, y: True)) ()
foo(k = 3, **f) #: () (k: 3, x: 6, y: True)
#%% call_partial_args_kwargs,barebones
def foo(*args):
print(args)
a = foo(1, 2, ...)
b = a(3, 4, ...)
c = b(5, ...)
c('zooooo')
#: (1, 2, 3, 4, 5, 'zooooo')
def fox(*args, **kwargs):
print(args, kwargs)
xa = fox(1, 2, x=5, ...)
xb = xa(3, 4, q=6, ...)
xc = xb(5, ...)
xd = xc(z=5.1, ...)
xd('zooooo', w='lele')
#: (1, 2, 3, 4, 5, 'zooooo') (x: 5, q: 6, z: 5.1, w: 'lele')
class Foo:
i: int
def __str__(self):
return f'#{self.i}'
def foo(self, a):
return f'{self}:generic'
def foo(self, a: float):
return f'{self}:float'
def foo(self, a: int):
return f'{self}:int'
f = Foo(4)
def pacman(x, f):
print f(x, '5')
print f(x, 2.1)
print f(x, 4)
pacman(f, Foo.foo)
#: #4:generic
#: #4:float
#: #4:int
def macman(f):
print f('5')
print f(2.1)
print f(4)
macman(f.foo)
#: #4:generic
#: #4:float
#: #4:int
class Fox:
i: int
def __str__(self):
return f'#{self.i}'
def foo(self, a, b):
return f'{self}:generic b={b}'
def foo(self, a: float, c):
return f'{self}:float, c={c}'
def foo(self, a: int):
return f'{self}:int'
def foo(self, a: int, z, q):
return f'{self}:int z={z} q={q}'
ff = Fox(5)
def maxman(f):
print f('5', b=1)
print f(2.1, 3)
print f(4)
print f(5, 1, q=3)
maxman(ff.foo)
#: #5:generic b=1
#: #5:float, c=3
#: #5:int
#: #5:int z=1 q=3
#%% call_static,barebones
print isinstance(1, int), isinstance(2.2, float), isinstance(3, bool)
#: True True False