mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
Add support for partial functions with *args/**kwargs; Fix partial method dispatch
This commit is contained in:
parent
3d6090322d
commit
acc96aa6eb
@ -275,7 +275,7 @@ struct KeywordStarExpr : public Expr {
|
||||
struct TupleExpr : public Expr {
|
||||
std::vector<ExprPtr> items;
|
||||
|
||||
explicit TupleExpr(std::vector<ExprPtr> items);
|
||||
explicit TupleExpr(std::vector<ExprPtr> items = {});
|
||||
TupleExpr(const TupleExpr &expr);
|
||||
|
||||
std::string toString() const override;
|
||||
|
@ -954,7 +954,8 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
|
||||
|
||||
auto subs = substitutions[ai];
|
||||
if (ctx->cache->classes[ctx->bases.back().name]
|
||||
.methods[ctx->cache->reverseIdentifierLookup[f->name]].empty())
|
||||
.methods[ctx->cache->reverseIdentifierLookup[f->name]]
|
||||
.empty())
|
||||
generateDispatch(ctx->cache->reverseIdentifierLookup[f->name]);
|
||||
auto newName = ctx->generateCanonicalName(
|
||||
ctx->cache->reverseIdentifierLookup[f->name], true);
|
||||
@ -1765,10 +1766,11 @@ std::vector<StmtPtr> SimplifyVisitor::getClassMethods(const StmtPtr &s) {
|
||||
|
||||
void SimplifyVisitor::generateDispatch(const std::string &name) {
|
||||
transform(N<FunctionStmt>(
|
||||
name + ".dispatch", nullptr, std::vector<Param>{Param("*args")},
|
||||
N<SuiteStmt>(
|
||||
N<ReturnStmt>(N<CallExpr>(N<DotExpr>(N<IdExpr>(ctx->bases.back().name), name),
|
||||
N<StarExpr>(N<IdExpr>("args")))))));
|
||||
name + ".dispatch", nullptr,
|
||||
std::vector<Param>{Param("*args"), Param("**kwargs")},
|
||||
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
|
||||
N<DotExpr>(N<IdExpr>(ctx->bases.back().name), name),
|
||||
N<StarExpr>(N<IdExpr>("args")), N<KeywordStarExpr>(N<IdExpr>("kwargs")))))));
|
||||
}
|
||||
|
||||
} // namespace ast
|
||||
|
@ -195,15 +195,17 @@ int TypeContext::reorderNamedArgs(types::FuncType *func,
|
||||
|
||||
int starArgIndex = -1, kwstarArgIndex = -1;
|
||||
for (int i = 0; i < func->ast->args.size(); i++) {
|
||||
if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "**"))
|
||||
// if (!known.empty() && known[i] && !partial)
|
||||
// continue;
|
||||
if (startswith(func->ast->args[i].name, "**"))
|
||||
kwstarArgIndex = i, score -= 2;
|
||||
else if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "*"))
|
||||
else if (startswith(func->ast->args[i].name, "*"))
|
||||
starArgIndex = i, score -= 2;
|
||||
}
|
||||
seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex],
|
||||
"partial *args");
|
||||
seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex],
|
||||
"partial **kwargs");
|
||||
// seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex],
|
||||
// "partial *args");
|
||||
// seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex],
|
||||
// "partial **kwargs");
|
||||
|
||||
// 1. Assign positional arguments to slots
|
||||
// Each slot contains a list of arg's indices
|
||||
|
@ -952,7 +952,8 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
|
||||
if (bestMethod->ast->attributes.has(Attr::Property))
|
||||
methodArgs.pop_back();
|
||||
ExprPtr e = N<CallExpr>(N<IdExpr>(bestMethod->ast->name), methodArgs);
|
||||
return transform(e, false, allowVoidExpr);
|
||||
auto ex = transform(e, false, allowVoidExpr);
|
||||
return ex;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1094,6 +1095,18 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
|
||||
bool isPartial = false;
|
||||
int ellipsisStage = -1;
|
||||
auto newMask = std::vector<char>(calleeFn->ast->args.size(), 1);
|
||||
auto getPartialArg = [&](int pi) {
|
||||
auto id = transform(N<IdExpr>(partialVar));
|
||||
ExprPtr it = N<IntExpr>(pi);
|
||||
// Manual call to transformStaticTupleIndex needed because otherwise
|
||||
// IndexExpr routes this to InstantiateExpr.
|
||||
auto ex = transformStaticTupleIndex(callee.get(), id, it);
|
||||
seqassert(ex, "partial indexing failed");
|
||||
return ex;
|
||||
};
|
||||
|
||||
ExprPtr partialStarArgs = nullptr;
|
||||
ExprPtr partialKwstarArgs = nullptr;
|
||||
if (expr->ordered)
|
||||
args = expr->args;
|
||||
else
|
||||
@ -1110,17 +1123,38 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
|
||||
: expr->args[slots[si][0]].value);
|
||||
typeArgCount += typeArgs.back() != nullptr;
|
||||
newMask[si] = slots[si].empty() ? 0 : 1;
|
||||
} else if (si == starArgIndex && !(partial && slots[si].empty())) {
|
||||
} else if (si == starArgIndex) {
|
||||
std::vector<ExprPtr> extra;
|
||||
if (!known.empty())
|
||||
extra.push_back(N<StarExpr>(getPartialArg(-2)));
|
||||
for (auto &e : slots[si]) {
|
||||
extra.push_back(expr->args[e].value);
|
||||
if (extra.back()->getEllipsis())
|
||||
ellipsisStage = args.size();
|
||||
}
|
||||
args.push_back({"", transform(N<TupleExpr>(extra))});
|
||||
} else if (si == kwstarArgIndex && !(partial && slots[si].empty())) {
|
||||
auto e = transform(N<TupleExpr>(extra));
|
||||
if (partial) {
|
||||
partialStarArgs = e;
|
||||
args.push_back({"", transform(N<EllipsisExpr>())});
|
||||
newMask[si] = 0;
|
||||
} else {
|
||||
args.push_back({"", e});
|
||||
}
|
||||
} else if (si == kwstarArgIndex) {
|
||||
std::vector<std::string> names;
|
||||
std::vector<CallExpr::Arg> values;
|
||||
if (!known.empty()) {
|
||||
auto e = getPartialArg(-1);
|
||||
auto t = e->getType()->getRecord();
|
||||
seqassert(t && startswith(t->name, "KwTuple"), "{} not a kwtuple",
|
||||
e->toString());
|
||||
auto &ff = ctx->cache->classes[t->name].fields;
|
||||
for (int i = 0; i < t->getRecord()->args.size(); i++) {
|
||||
names.emplace_back(ff[i].name);
|
||||
values.emplace_back(
|
||||
CallExpr::Arg{"", transform(N<DotExpr>(clone(e), ff[i].name))});
|
||||
}
|
||||
}
|
||||
for (auto &e : slots[si]) {
|
||||
names.emplace_back(expr->args[e].name);
|
||||
values.emplace_back(CallExpr::Arg{"", expr->args[e].value});
|
||||
@ -1128,16 +1162,17 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
|
||||
ellipsisStage = args.size();
|
||||
}
|
||||
auto kwName = generateTupleStub(names.size(), "KwTuple", names);
|
||||
args.push_back({"", transform(N<CallExpr>(N<IdExpr>(kwName), values))});
|
||||
auto e = transform(N<CallExpr>(N<IdExpr>(kwName), values));
|
||||
if (partial) {
|
||||
partialKwstarArgs = e;
|
||||
args.push_back({"", transform(N<EllipsisExpr>())});
|
||||
newMask[si] = 0;
|
||||
} else {
|
||||
args.push_back({"", e});
|
||||
}
|
||||
} else if (slots[si].empty()) {
|
||||
if (!known.empty() && known[si]) {
|
||||
// Manual call to transformStaticTupleIndex needed because otherwise
|
||||
// IndexExpr routes this to InstantiateExpr.
|
||||
auto id = transform(N<IdExpr>(partialVar));
|
||||
ExprPtr it = N<IntExpr>(pi++);
|
||||
auto ex = transformStaticTupleIndex(callee.get(), id, it);
|
||||
seqassert(ex, "partial indexing failed");
|
||||
args.push_back({"", ex});
|
||||
args.push_back({"", getPartialArg(pi++)});
|
||||
} else if (partial) {
|
||||
args.push_back({"", transform(N<EllipsisExpr>())});
|
||||
newMask[si] = 0;
|
||||
@ -1165,6 +1200,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
|
||||
if (isPartial) {
|
||||
deactivateUnbounds(expr->args.back().value->getType().get());
|
||||
expr->args.pop_back();
|
||||
if (!partialStarArgs)
|
||||
partialStarArgs = transform(N<TupleExpr>());
|
||||
if (!partialKwstarArgs) {
|
||||
auto kwName = generateTupleStub(0, "KwTuple", {});
|
||||
partialKwstarArgs = transform(N<CallExpr>(N<IdExpr>(kwName)));
|
||||
}
|
||||
}
|
||||
|
||||
// Typecheck given arguments with the expected (signature) types.
|
||||
@ -1257,11 +1298,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
|
||||
deactivateUnbounds(pt->func.get());
|
||||
calleeFn->generics[si + 1].type = calleeFn->args[si + 1] = replacements[si];
|
||||
}
|
||||
if (auto rt = realize(calleeFn)) {
|
||||
unify(rt, std::static_pointer_cast<Type>(calleeFn));
|
||||
expr->expr = transform(expr->expr);
|
||||
if (!isPartial) {
|
||||
if (auto rt = realize(calleeFn)) {
|
||||
unify(rt, std::static_pointer_cast<Type>(calleeFn));
|
||||
expr->expr = transform(expr->expr);
|
||||
}
|
||||
}
|
||||
|
||||
expr->done &= expr->expr->done;
|
||||
|
||||
// Emit the final call.
|
||||
@ -1274,6 +1316,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
|
||||
for (auto &r : args)
|
||||
if (!r.value->getEllipsis())
|
||||
newArgs.push_back(r.value);
|
||||
newArgs.push_back(partialStarArgs);
|
||||
newArgs.push_back(partialKwstarArgs);
|
||||
|
||||
std::string var = ctx->cache->getTemporaryVar("partial");
|
||||
ExprPtr call = nullptr;
|
||||
@ -1535,7 +1579,8 @@ std::string TypecheckVisitor::generatePartialStub(const std::vector<char> &mask,
|
||||
tupleSize++;
|
||||
auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name);
|
||||
if (!ctx->find(typeName))
|
||||
generateTupleStub(tupleSize, typeName, {}, false);
|
||||
// 2 for .starArgs and .kwstarArgs (empty tuples if fn does not have them)
|
||||
generateTupleStub(tupleSize + 2, typeName, {}, false);
|
||||
return typeName;
|
||||
}
|
||||
|
||||
@ -1601,9 +1646,12 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
|
||||
auto partialTypeName = generatePartialStub(mask, fn.get());
|
||||
deactivateUnbounds(fn.get());
|
||||
std::string var = ctx->cache->getTemporaryVar("partial");
|
||||
ExprPtr call = N<StmtExpr>(
|
||||
N<AssignStmt>(N<IdExpr>(var), N<CallExpr>(N<IdExpr>(partialTypeName))),
|
||||
N<IdExpr>(var));
|
||||
auto kwName = generateTupleStub(0, "KwTuple", {});
|
||||
ExprPtr call =
|
||||
N<StmtExpr>(N<AssignStmt>(N<IdExpr>(var),
|
||||
N<CallExpr>(N<IdExpr>(partialTypeName), N<TupleExpr>(),
|
||||
N<CallExpr>(N<IdExpr>(kwName)))),
|
||||
N<IdExpr>(var));
|
||||
call = transform(call, false, allowVoidExpr);
|
||||
seqassert(call->type->getRecord() &&
|
||||
startswith(call->type->getRecord()->name, partialTypeName) &&
|
||||
|
@ -447,16 +447,7 @@ q = p(zh=43, ...)
|
||||
q(1) #: 1 () (zh: 43)
|
||||
r = q(5, 38, ...)
|
||||
r() #: 5 (38) (zh: 43)
|
||||
|
||||
#%% call_partial_star_error,barebones
|
||||
def foo(x, *args, **kwargs):
|
||||
print x, args, kwargs
|
||||
p = foo(...)
|
||||
p(1, z=5)
|
||||
q = p(zh=43, ...)
|
||||
q(1)
|
||||
r = q(5, 38, ...)
|
||||
r(1, a=1) #! too many arguments for foo[T1,T2,T3] (expected maximum 3, got 2)
|
||||
r(1, a=1) #: 5 (38, 1) (zh: 43, a: 1)
|
||||
|
||||
#%% call_kwargs,barebones
|
||||
def kwhatever(**kwargs):
|
||||
@ -504,6 +495,79 @@ foo(*(1,2)) #: (1, 2) ()
|
||||
foo(3, f) #: (3, (x: 6, y: True)) ()
|
||||
foo(k = 3, **f) #: () (k: 3, x: 6, y: True)
|
||||
|
||||
#%% call_partial_args_kwargs,barebones
|
||||
def foo(*args):
|
||||
print(args)
|
||||
a = foo(1, 2, ...)
|
||||
b = a(3, 4, ...)
|
||||
c = b(5, ...)
|
||||
c('zooooo')
|
||||
#: (1, 2, 3, 4, 5, 'zooooo')
|
||||
|
||||
def fox(*args, **kwargs):
|
||||
print(args, kwargs)
|
||||
xa = fox(1, 2, x=5, ...)
|
||||
xb = xa(3, 4, q=6, ...)
|
||||
xc = xb(5, ...)
|
||||
xd = xc(z=5.1, ...)
|
||||
xd('zooooo', w='lele')
|
||||
#: (1, 2, 3, 4, 5, 'zooooo') (x: 5, q: 6, z: 5.1, w: 'lele')
|
||||
|
||||
class Foo:
|
||||
i: int
|
||||
def __str__(self):
|
||||
return f'#{self.i}'
|
||||
def foo(self, a):
|
||||
return f'{self}:generic'
|
||||
def foo(self, a: float):
|
||||
return f'{self}:float'
|
||||
def foo(self, a: int):
|
||||
return f'{self}:int'
|
||||
f = Foo(4)
|
||||
|
||||
def pacman(x, f):
|
||||
print f(x, '5')
|
||||
print f(x, 2.1)
|
||||
print f(x, 4)
|
||||
pacman(f, Foo.foo)
|
||||
#: #4:generic
|
||||
#: #4:float
|
||||
#: #4:int
|
||||
|
||||
def macman(f):
|
||||
print f('5')
|
||||
print f(2.1)
|
||||
print f(4)
|
||||
macman(f.foo)
|
||||
#: #4:generic
|
||||
#: #4:float
|
||||
#: #4:int
|
||||
|
||||
class Fox:
|
||||
i: int
|
||||
def __str__(self):
|
||||
return f'#{self.i}'
|
||||
def foo(self, a, b):
|
||||
return f'{self}:generic b={b}'
|
||||
def foo(self, a: float, c):
|
||||
return f'{self}:float, c={c}'
|
||||
def foo(self, a: int):
|
||||
return f'{self}:int'
|
||||
def foo(self, a: int, z, q):
|
||||
return f'{self}:int z={z} q={q}'
|
||||
ff = Fox(5)
|
||||
def maxman(f):
|
||||
print f('5', b=1)
|
||||
print f(2.1, 3)
|
||||
print f(4)
|
||||
print f(5, 1, q=3)
|
||||
maxman(ff.foo)
|
||||
#: #5:generic b=1
|
||||
#: #5:float, c=3
|
||||
#: #5:int
|
||||
#: #5:int z=1 q=3
|
||||
|
||||
|
||||
#%% call_static,barebones
|
||||
print isinstance(1, int), isinstance(2.2, float), isinstance(3, bool)
|
||||
#: True True False
|
||||
|
Loading…
x
Reference in New Issue
Block a user