Merge remote-tracking branch 'origin/cancall_new' into develop

pull/561/head
Ibrahim Numanagić 2024-04-03 09:42:42 -07:00
commit e7bb5c1609
5 changed files with 41 additions and 27 deletions

View File

@ -1033,25 +1033,15 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
callArgs.back().value->setType(a.second); callArgs.back().value->setType(a.second);
} }
auto fn = expr->args[0].value->type->getFunc(); if (auto fn = expr->args[0].value->type->getFunc()) {
if (!fn) {
bool canCompile = true;
// Special case: not a function, just try compiling it!
auto ocache = *(ctx->cache);
auto octx = *ctx;
try {
transform(N<CallExpr>(clone(expr->args[0].value),
N<StarExpr>(clone(expr->args[1].value)),
N<KeywordStarExpr>(clone(expr->args[2].value))));
} catch (const exc::ParserException &e) {
// LOG("{}", e.what());
canCompile = false;
*ctx = octx;
*(ctx->cache) = ocache;
}
return {true, transform(N<BoolExpr>(canCompile))};
}
return {true, transform(N<BoolExpr>(canCall(fn, callArgs) >= 0))}; return {true, transform(N<BoolExpr>(canCall(fn, callArgs) >= 0))};
} else if (auto pt = expr->args[0].value->type->getPartial()) {
return {true, transform(N<BoolExpr>(canCall(pt->func, callArgs, pt) >= 0))};
} else {
compilationWarning("cannot use fn_can_call on non-functions", getSrcInfo().file,
getSrcInfo().line, getSrcInfo().col);
return {true, transform(N<BoolExpr>(false))};
}
} else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) { } else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) {
expr->staticValue.type = StaticValue::INT; expr->staticValue.type = StaticValue::INT;
auto fn = ctx->extractFunction(expr->args[0].value->type); auto fn = ctx->extractFunction(expr->args[0].value->type);

View File

@ -26,6 +26,7 @@ StmtPtr TypecheckVisitor::apply(Cache *cache, const StmtPtr &stmts) {
auto so = clone(stmts); auto so = clone(stmts);
auto s = v.inferTypes(so, true); auto s = v.inferTypes(so, true);
if (!s) { if (!s) {
// LOG("{}", so->toString(2));
v.error("cannot typecheck the program"); v.error("cannot typecheck the program");
} }
if (s->getSuite()) if (s->getSuite())
@ -251,13 +252,21 @@ public:
/// Check if a function can be called with the given arguments. /// Check if a function can be called with the given arguments.
/// See @c reorderNamedArgs for details. /// See @c reorderNamedArgs for details.
int TypecheckVisitor::canCall(const types::FuncTypePtr &fn, int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
const std::vector<CallExpr::Arg> &args) { const std::vector<CallExpr::Arg> &args,
std::shared_ptr<types::PartialType> part) {
auto getPartialArg = [&](size_t pi) -> types::TypePtr {
if (pi < part->args.size())
return part->args[pi];
else
return nullptr;
};
std::vector<std::pair<types::TypePtr, size_t>> reordered; std::vector<std::pair<types::TypePtr, size_t>> reordered;
auto niGenerics = fn->ast->getNonInferrableGenerics(); auto niGenerics = fn->ast->getNonInferrableGenerics();
auto score = ctx->reorderNamedArgs( auto score = ctx->reorderNamedArgs(
fn.get(), args, fn.get(), args,
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) { [&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
for (int si = 0, gi = 0; si < slots.size(); si++) { for (int si = 0, gi = 0, pi = 0; si < slots.size(); si++) {
if (fn->ast->args[si].status == Param::Generic) { if (fn->ast->args[si].status == Param::Generic) {
if (slots[si].empty()) { if (slots[si].empty()) {
// is this "real" type? // is this "real" type?
@ -275,15 +284,21 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
} }
gi++; gi++;
} else if (si == s || si == k || slots[si].size() != 1) { } else if (si == s || si == k || slots[si].size() != 1) {
// Partials
if (slots[si].empty() && part && part->known[si]) {
reordered.emplace_back(getPartialArg(pi++), 0);
} else {
// Ignore *args, *kwargs and default arguments // Ignore *args, *kwargs and default arguments
reordered.emplace_back(nullptr, 0); reordered.emplace_back(nullptr, 0);
}
} else { } else {
reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]); reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]);
} }
} }
return 0; return 0;
}, },
[](error::Error, const SrcInfo &, const std::string &) { return -1; }); [](error::Error, const SrcInfo &, const std::string &) { return -1; },
part ? part->known : std::vector<char>{});
int ai = 0, mai = 0, gi = 0, real_gi = 0; int ai = 0, mai = 0, gi = 0, real_gi = 0;
for (; score != -1 && ai < reordered.size(); ai++) { for (; score != -1 && ai < reordered.size(); ai++) {
auto expectTyp = fn->ast->args[ai].status == Param::Normal auto expectTyp = fn->ast->args[ai].status == Param::Normal
@ -341,6 +356,8 @@ TypecheckVisitor::findMatchingMethods(const types::ClassTypePtr &typ,
continue; // avoid overloads that have not been seen yet continue; // avoid overloads that have not been seen yet
auto method = ctx->instantiate(mi, typ)->getFunc(); auto method = ctx->instantiate(mi, typ)->getFunc();
int score = canCall(method, args); int score = canCall(method, args);
// LOG("{}: {} {} :: {} :: {}", getSrcInfo(), method->debugString(2), args, score,
// method->ast->getSrcInfo());
if (score != -1) { if (score != -1) {
results.push_back(mi); results.push_back(mi);
} }

View File

@ -220,7 +220,8 @@ private:
types::FuncTypePtr types::FuncTypePtr
findBestMethod(const types::ClassTypePtr &typ, const std::string &member, findBestMethod(const types::ClassTypePtr &typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args); const std::vector<std::pair<std::string, types::TypePtr>> &args);
int canCall(const types::FuncTypePtr &, const std::vector<CallExpr::Arg> &); int canCall(const types::FuncTypePtr &, const std::vector<CallExpr::Arg> &,
std::shared_ptr<types::PartialType> = nullptr);
std::vector<types::FuncTypePtr> std::vector<types::FuncTypePtr>
findMatchingMethods(const types::ClassTypePtr &typ, findMatchingMethods(const types::ClassTypePtr &typ,
const std::vector<types::FuncTypePtr> &methods, const std::vector<types::FuncTypePtr> &methods,

View File

@ -208,11 +208,14 @@ class __internal__:
def _union_call_helper(union, args, kwargs) -> Union: def _union_call_helper(union, args, kwargs) -> Union:
for tag, T in vars_types(union, with_index=1): for tag, T in vars_types(union, with_index=1):
if hasattr(T, '__call__'): if fn_can_call(T, *args, **kwargs):
if fn_can_call(__internal__.union_get_data(union, T), *args, **kwargs): if __internal__.union_get_tag(union) == tag:
return __internal__.union_get_data(union, T)(*args, **kwargs)
elif hasattr(T, '__call__'):
if fn_can_call(T.__call__, *args, **kwargs):
if __internal__.union_get_tag(union) == tag: if __internal__.union_get_tag(union) == tag:
return __internal__.union_get_data(union, T).__call__(*args, **kwargs) return __internal__.union_get_data(union, T).__call__(*args, **kwargs)
raise TypeError("cannot call union") raise TypeError("cannot call union " + union.__class__.__name__)
def union_call(union, args, kwargs): def union_call(union, args, kwargs):
t = __internal__._union_call_helper(union, args, kwargs) t = __internal__._union_call_helper(union, args, kwargs)

View File

@ -15,6 +15,8 @@ def fn_arg_has_type(F, i: Static[int]):
def fn_arg_get_type(F, i: Static[int]): def fn_arg_get_type(F, i: Static[int]):
pass pass
@no_type_wrap
@no_argument_wrap
def fn_can_call(F, *args, **kwargs): def fn_can_call(F, *args, **kwargs):
pass pass
@ -28,6 +30,7 @@ def fn_get_default(F, i: Static[int]):
pass pass
@no_type_wrap @no_type_wrap
@no_argument_wrap
def static_print(*args): def static_print(*args):
pass pass