Fix wrap_multiple

pull/335/head
Ibrahim Numanagić 2023-03-18 13:17:31 -07:00
parent 8d57bc164a
commit 2736bda827
2 changed files with 22 additions and 15 deletions

View File

@ -902,9 +902,10 @@ ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) {
auto &args = expr->args[0].value->getCall()->args; auto &args = expr->args[0].value->getCall()->args;
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
realize(args[i].value->type); realize(args[i].value->type);
fmt::print(stderr, "[static_print] {}: {} := {}\n", getSrcInfo(), fmt::print(stderr, "[static_print] {}: {} := {}{}\n", getSrcInfo(),
FormatVisitor::apply(args[i].value), FormatVisitor::apply(args[i].value),
args[i].value->type ? args[i].value->type->debugString(1) : "-"); args[i].value->type ? args[i].value->type->debugString(1) : "-",
args[i].value->isStatic() ? " [static]" : "");
} }
return nullptr; return nullptr;
} }
@ -937,7 +938,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
return {true, transform(N<BoolExpr>(canCall(fn, callArgs) >= 0))}; return {true, transform(N<BoolExpr>(canCall(fn, callArgs) >= 0))};
} else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) { } else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) {
expr->staticValue.type = StaticValue::INT; expr->staticValue.type = StaticValue::INT;
auto fn = expr->args[0].value->type->getFunc(); auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn) if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
@ -946,7 +947,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
return {true, transform(N<BoolExpr>(*idx >= 0 && *idx < args.size() && return {true, transform(N<BoolExpr>(*idx >= 0 && *idx < args.size() &&
args[*idx]->canRealize()))}; args[*idx]->canRealize()))};
} else if (expr->expr->isId("std.internal.static.fn_arg_get_type")) { } else if (expr->expr->isId("std.internal.static.fn_arg_get_type")) {
auto fn = expr->args[0].value->type->getFunc(); auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn) if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
@ -956,7 +957,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
error("argument does not have type"); error("argument does not have type");
return {true, transform(NT<IdExpr>(args[*idx]->realizedName()))}; return {true, transform(NT<IdExpr>(args[*idx]->realizedName()))};
} else if (expr->expr->isId("std.internal.static.fn_args")) { } else if (expr->expr->isId("std.internal.static.fn_args")) {
auto fn = expr->args[0].value->type->getFunc(); auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn) if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
std::vector<ExprPtr> v; std::vector<ExprPtr> v;
@ -969,7 +970,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
return {true, transform(N<TupleExpr>(v))}; return {true, transform(N<TupleExpr>(v))};
} else if (expr->expr->isId("std.internal.static.fn_has_default")) { } else if (expr->expr->isId("std.internal.static.fn_has_default")) {
expr->staticValue.type = StaticValue::INT; expr->staticValue.type = StaticValue::INT;
auto fn = expr->args[0].value->type->getFunc(); auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn) if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
@ -979,7 +980,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
error("argument out of bounds"); error("argument out of bounds");
return {true, transform(N<IntExpr>(args[*idx].defaultValue != nullptr))}; return {true, transform(N<IntExpr>(args[*idx].defaultValue != nullptr))};
} else if (expr->expr->isId("std.internal.static.fn_get_default")) { } else if (expr->expr->isId("std.internal.static.fn_get_default")) {
auto fn = expr->args[0].value->type->getFunc(); auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn) if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
@ -993,7 +994,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
if (!typ) if (!typ)
return {true, nullptr}; return {true, nullptr};
auto fn = expr->args[0].value->type->getFunc(); auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn) if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); error("expected a function, got '{}'", expr->args[0].value->type->prettyString());

View File

@ -1689,7 +1689,6 @@ class _PyWrap:
pargs[i + M] = args[i] pargs[i + M] = args[i]
for i in range(nargs, nargs + nkw): for i in range(nargs, nargs + nkw):
kw = kwds[i - nargs] kw = kwds[i - nargs]
__static_print__(kw)
o = args[i] o = args[i]
found = False found = False
@ -1708,13 +1707,20 @@ class _PyWrap:
if not found: if not found:
raise TypeError(F + "() got an unexpected keyword argument '" + kw + "'") raise TypeError(F + "() got an unexpected keyword argument '" + kw + "'")
def e(k): def _err(k, T: type = NoneType) -> T:
raise TypeError(F + "() missing required positional argument: '" + k + "'") raise TypeError(F + "() missing required positional argument: '" + k + "'")
ta = tuple( def _get_arg(F, p, k, i: Static[int]):
(pyobj(pargs[i], steal=True) if pargs[i] != cobj() else if _S.fn_arg_has_type(F, i):
(_S.fn_get_default(fn, i) if _S.fn_has_default(fn, i) else e(k))) return _S.fn_arg_get_type(F, i).__from_py__(p[i]) if p[i] != cobj() else (
for i, k in staticenumerate(_S.fn_args(fn)) _S.fn_get_default(F, i) if _S.fn_has_default(F, i)
) else _err(k, _S.fn_arg_get_type(F, i)) # need to appease type checker
)
else:
return pyobj(p[i], steal=True) if p[i] != cobj() else (
_S.fn_get_default(F, i) if _S.fn_has_default(F, i) else _err(k)
)
ta = tuple(_get_arg(fn, pargs, k, i) for i, k in staticenumerate(_S.fn_args(fn)))
__static_print__(ta)
if _S.fn_can_call(fn, *ta): if _S.fn_can_call(fn, *ta):
try: try:
tn = _S.fn_wrap_call_args(fn, *ta) tn = _S.fn_wrap_call_args(fn, *ta)