pull/185/head
Ibrahim Numanagić 2023-02-05 11:51:57 -08:00
parent b74601244d
commit 49d9097e94
5 changed files with 62 additions and 25 deletions

View File

@ -783,16 +783,31 @@ ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) {
.type->getStatic()
->evaluate()
.getString();
std::vector<TypePtr> args{typ};
std::vector<std::pair<std::string, TypePtr>> args{{"", typ}};
if (expr->expr->isId("hasattr:0")) {
// Case: the first hasattr overload allows passing argument types via *args
auto tup = expr->args[1].value->getTuple();
seqassert(tup, "not a tuple");
for (auto &a : tup->items) {
transformType(a);
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.push_back(a->getType());
args.push_back({"", a->getType()});
}
auto kwtup = expr->args[2].value->origExpr->getCall();
seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(),
"expected call: {}", expr->args[2].value->origExpr);
auto kw = expr->args[2].value->origExpr->getCall();
auto kwCls =
in(ctx->cache->classes, expr->args[2].value->getType()->getClass()->name);
seqassert(kwCls, "cannot find {}",
expr->args[2].value->getType()->getClass()->name);
for (size_t i = 0; i < kw->args.size(); i++) {
auto &a = kw->args[i].value;
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.push_back({kwCls->fields[i].name, a->getType()});
}
}

View File

@ -382,9 +382,8 @@ StmtPtr TypecheckVisitor::prepareVTables() {
// def class_init_vtables():
// return __internal__.class_make_n_vtables(<NUM_REALIZATIONS> + 1)
auto &initAllVT = ctx->cache->functions[rep];
auto suite = N<SuiteStmt>(
N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.class_make_n_vtables:0"),
N<IdExpr>("__vtable_size__"))));
auto suite = N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
N<IdExpr>("__internal__.class_make_n_vtables:0"), N<IdExpr>("__vtable_size__"))));
initAllVT.ast->suite = suite;
auto typ = initAllVT.realizations.begin()->second->type;
LOG_REALIZE("[poly] {} : {}", typ, *suite);
@ -416,10 +415,9 @@ StmtPtr TypecheckVisitor::prepareVTables() {
std::vector<ExprPtr>{NT<IdExpr>("cobj")}),
N<IntExpr>(vtSz + 2)))));
// __internal__.class_set_typeinfo(p[real.ID], real.ID)
suite->stmts.push_back(N<ExprStmt>(
N<CallExpr>(N<IdExpr>("__internal__.class_set_typeinfo:0"),
N<IndexExpr>(N<IdExpr>(var), N<IntExpr>(real->id)),
N<IntExpr>(real->id))));
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<IdExpr>("__internal__.class_set_typeinfo:0"),
N<IndexExpr>(N<IdExpr>(var), N<IntExpr>(real->id)), N<IntExpr>(real->id))));
vtSz = 0;
for (auto &[base, vtable] : real->vtables) {
if (!vtable.ir) {
@ -842,10 +840,9 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
auto unionType = type->funcParent->getUnion();
seqassert(unionType, "expected union, got {}", type->funcParent);
StmtPtr suite = N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("__internal__.new_union:0"),
N<IdExpr>(type->ast->args[0].name),
N<IdExpr>(unionType->realizedTypeName())));
StmtPtr suite = N<ReturnStmt>(N<CallExpr>(
N<IdExpr>("__internal__.new_union:0"), N<IdExpr>(type->ast->args[0].name),
N<IdExpr>(unionType->realizedTypeName())));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.new_union:0")) {
// Special case: __internal__.new_union
@ -885,7 +882,6 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
}
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<IdExpr>("compile_error"), N<StringExpr>("invalid union constructor"))));
LOG("-> {}", suite->toString(2));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union:0")) {
// Special case: __internal__.get_union
@ -932,21 +928,29 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
auto suite = N<SuiteStmt>();
int tag = 0;
for (auto &t : unionTypes) {
auto callee =
N<DotExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"),
N<IdExpr>(selfVar), NT<IdExpr>(t->realizedName())),
fnName);
auto args = N<StarExpr>(N<IdExpr>(ast->args[2].name.substr(1)));
auto kwargs = N<KeywordStarExpr>(N<IdExpr>(ast->args[3].name.substr(2)));
std::vector<CallExpr::Arg> callArgs;
ExprPtr check =
N<CallExpr>(N<IdExpr>("hasattr"), callee->clone(), N<StringExpr>(fnName),
args->clone(), kwargs->clone());
suite->stmts.push_back(N<IfStmt>(
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"),
N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag)),
N<ReturnStmt>(N<CallExpr>(
N<DotExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"),
N<IdExpr>(selfVar), NT<IdExpr>(t->realizedName())),
fnName),
N<StarExpr>(N<IdExpr>(ast->args[2].name.substr(1))),
N<KeywordStarExpr>(N<IdExpr>(ast->args[3].name.substr(2)))))));
N<BinaryExpr>(
check, "&&",
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"),
N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag))),
N<ReturnStmt>(N<CallExpr>(callee, args, kwargs))));
tag++;
}
suite->stmts.push_back(
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
N<StringExpr>("invalid union call"))));
suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>()));
unify(type->getRetType(), ctx->instantiate(ctx->getType("Union")));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union_first:0")) {

View File

@ -200,6 +200,21 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(const ClassTypePtr &typ,
return m.empty() ? nullptr : m[0];
}
/// Select the best method indicated of an object that matches the given argument
/// types. See @c findMatchingMethods for details.
types::FuncTypePtr TypecheckVisitor::findBestMethod(
const ClassTypePtr &typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args) {
std::vector<CallExpr::Arg> callArgs;
for (auto &[n, a] : args) {
callArgs.push_back({n, std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a);
}
auto methods = ctx->findMethod(typ->name, member, false);
auto m = findMatchingMethods(typ, methods, callArgs);
return m.empty() ? nullptr : m[0];
}
/// Select the best method among the provided methods given the list of arguments.
/// See @c reorderNamedArgs for details.
std::vector<types::FuncTypePtr>

View File

@ -210,6 +210,9 @@ private:
types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ,
const std::string &member,
const std::vector<ExprPtr> &args);
types::FuncTypePtr
findBestMethod(const types::ClassTypePtr &typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args);
std::vector<types::FuncTypePtr>
findMatchingMethods(const types::ClassTypePtr &typ,
const std::vector<types::FuncTypePtr> &methods,

View File

@ -155,7 +155,7 @@ def isinstance(obj, what):
def overload():
pass
def hasattr(obj, attr: Static[str], *args):
def hasattr(obj, attr: Static[str], *args, **kwargs):
"""Special handling"""
pass