Fix wrap_multiple

pull/335/head
Ibrahim Numanagić 2023-03-18 13:17:31 -07:00
parent 8d57bc164a
commit 2736bda827
2 changed files with 22 additions and 15 deletions

View File

@ -902,9 +902,10 @@ ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) {
auto &args = expr->args[0].value->getCall()->args;
for (size_t i = 0; i < args.size(); i++) {
realize(args[i].value->type);
fmt::print(stderr, "[static_print] {}: {} := {}\n", getSrcInfo(),
fmt::print(stderr, "[static_print] {}: {} := {}{}\n", getSrcInfo(),
FormatVisitor::apply(args[i].value),
args[i].value->type ? args[i].value->type->debugString(1) : "-");
args[i].value->type ? args[i].value->type->debugString(1) : "-",
args[i].value->isStatic() ? " [static]" : "");
}
return nullptr;
}
@ -937,7 +938,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
return {true, transform(N<BoolExpr>(canCall(fn, callArgs) >= 0))};
} else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) {
expr->staticValue.type = StaticValue::INT;
auto fn = expr->args[0].value->type->getFunc();
auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
@ -946,7 +947,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
return {true, transform(N<BoolExpr>(*idx >= 0 && *idx < args.size() &&
args[*idx]->canRealize()))};
} else if (expr->expr->isId("std.internal.static.fn_arg_get_type")) {
auto fn = expr->args[0].value->type->getFunc();
auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
@ -956,7 +957,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
error("argument does not have type");
return {true, transform(NT<IdExpr>(args[*idx]->realizedName()))};
} else if (expr->expr->isId("std.internal.static.fn_args")) {
auto fn = expr->args[0].value->type->getFunc();
auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
std::vector<ExprPtr> v;
@ -969,7 +970,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
return {true, transform(N<TupleExpr>(v))};
} else if (expr->expr->isId("std.internal.static.fn_has_default")) {
expr->staticValue.type = StaticValue::INT;
auto fn = expr->args[0].value->type->getFunc();
auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
@ -979,7 +980,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
error("argument out of bounds");
return {true, transform(N<IntExpr>(args[*idx].defaultValue != nullptr))};
} else if (expr->expr->isId("std.internal.static.fn_get_default")) {
auto fn = expr->args[0].value->type->getFunc();
auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
@ -993,7 +994,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
if (!typ)
return {true, nullptr};
auto fn = expr->args[0].value->type->getFunc();
auto fn = ctx->extractFunction(expr->args[0].value->type);
if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());

View File

@ -1689,7 +1689,6 @@ class _PyWrap:
pargs[i + M] = args[i]
for i in range(nargs, nargs + nkw):
kw = kwds[i - nargs]
__static_print__(kw)
o = args[i]
found = False
@ -1708,13 +1707,20 @@ class _PyWrap:
if not found:
raise TypeError(F + "() got an unexpected keyword argument '" + kw + "'")
def e(k):
def _err(k, T: type = NoneType) -> T:
raise TypeError(F + "() missing required positional argument: '" + k + "'")
ta = tuple(
(pyobj(pargs[i], steal=True) if pargs[i] != cobj() else
(_S.fn_get_default(fn, i) if _S.fn_has_default(fn, i) else e(k)))
for i, k in staticenumerate(_S.fn_args(fn))
)
def _get_arg(F, p, k, i: Static[int]):
if _S.fn_arg_has_type(F, i):
return _S.fn_arg_get_type(F, i).__from_py__(p[i]) if p[i] != cobj() else (
_S.fn_get_default(F, i) if _S.fn_has_default(F, i)
else _err(k, _S.fn_arg_get_type(F, i)) # need to appease type checker
)
else:
return pyobj(p[i], steal=True) if p[i] != cobj() else (
_S.fn_get_default(F, i) if _S.fn_has_default(F, i) else _err(k)
)
ta = tuple(_get_arg(fn, pargs, k, i) for i, k in staticenumerate(_S.fn_args(fn)))
__static_print__(ta)
if _S.fn_can_call(fn, *ta):
try:
tn = _S.fn_wrap_call_args(fn, *ta)