1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Add support for super()

This commit is contained in:
Ibrahim Numanagić 2022-01-07 18:26:14 -08:00
parent cb0a6ea443
commit 5672cebe1c
5 changed files with 63 additions and 0 deletions

View File

@ -134,6 +134,8 @@ struct Cache : public std::enable_shared_from_this<Cache> {
/// ClassRealization instance.
std::unordered_map<std::string, std::shared_ptr<ClassRealization>> realizations;
std::vector<std::pair<std::string, int>> parentClasses;
Class() : ast(nullptr), originalAst(nullptr) {}
};
/// Class lookup table that maps a canonical class identifier to the corresponding

View File

@ -788,6 +788,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
std::vector<std::unordered_map<std::string, ExprPtr>> substitutions;
std::vector<int> argSubstitutions;
std::unordered_set<std::string> seenMembers;
std::vector<int> baseASTsFields;
for (auto &baseClass : stmt->baseClasses) {
std::string bcName;
std::vector<ExprPtr> subs;
@ -828,6 +829,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
if (!extension)
ctx->cache->classes[canonicalName].fields.push_back({a.name, nullptr});
}
baseASTsFields.push_back(args.size());
}
// Add generics, if any, to the context.
@ -909,6 +911,9 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
ctx->moduleName.module);
ctx->cache->classes[canonicalName].ast =
N<ClassStmt>(canonicalName, args, N<SuiteStmt>(), attr);
for (int i = 0; i < baseASTs.size(); i++)
ctx->cache->classes[canonicalName].parentClasses.push_back(
{baseASTs[i]->name, baseASTsFields[i]});
std::vector<StmtPtr> fns;
ExprPtr codeType = ctx->bases.back().ast->clone();
std::vector<std::string> magics{};

View File

@ -301,6 +301,8 @@ private:
const std::vector<types::FuncTypePtr> &methods,
const std::vector<CallExpr::Arg> &args);
ExprPtr transformSuper(const CallExpr *expr);
private:
types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b,
bool undoOnSuccess = false);

View File

@ -1052,6 +1052,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
ExprPtr e = N<CallExpr>(N<IdExpr>(m[0]->ast->name), expr->args);
return transform(e, false, true);
}
if (expr->expr->isId("super"))
return transformSuper(expr);
bool isPartial = !expr->args.empty() && expr->args.back().value->getEllipsis() &&
!expr->args.back().value->getEllipsis()->isPipeArg &&
@ -1946,5 +1948,51 @@ types::FuncTypePtr TypecheckVisitor::findDispatch(const std::string &fn) {
return typ;
}
ExprPtr TypecheckVisitor::transformSuper(const CallExpr *expr) {
// For now, we just support casting to the _FIRST_ overload (i.e. empty super())
if (!expr->args.empty())
error("super does not take arguments");
if (ctx->bases.empty())
error("no parent classes available");
auto fptyp = ctx->bases.back().type->getFunc();
if (!fptyp || fptyp->ast->hasAttr(Attr::Method))
error("no parent classes available");
ClassTypePtr typ = fptyp->args[1]->getClass();
auto &cands = ctx->cache->classes[typ->name].parentClasses;
if (cands.empty())
error("no parent classes available");
if (typ->getRecord())
error("cannot use super on tuple types");
// find parent typ
// unify top N args with parent typ args
// realize & do bitcast
// call bitcast() . method
auto name = cands[0].first;
int fields = cands[0].second;
auto val = ctx->find(name);
seqassert(val, "cannot find '{}'", name);
auto ftyp = ctx->instantiate(expr, val->type)->getClass();
for (int i = 0; i < fields; i++) {
auto t = ctx->cache->classes[typ->name].fields[i].type;
t = ctx->instantiate(expr, t, typ.get());
auto ft = ctx->cache->classes[name].fields[i].type;
ft = ctx->instantiate(expr, ft, ftyp.get());
unify(t, ft);
}
ExprPtr typExpr = N<IdExpr>(name);
typExpr->setType(ftyp);
auto self = fptyp->ast->args[0].name;
ExprPtr e = transform(
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "to_class_ptr"),
N<CallExpr>(N<DotExpr>(N<IdExpr>(self), "__raw__")), typExpr));
return e;
}
} // namespace ast
} // namespace codon

View File

@ -125,6 +125,12 @@ class __internal__:
def opt_ref_invert[T](what: Optional[T]) -> T:
ret i8* %what
@pure
@llvm
def to_class_ptr[T](ptr: Ptr[byte]) -> T:
%0 = bitcast i8* %ptr to {=T}
ret {=T} %0
def raw_type_str(p: Ptr[byte], name: str) -> str:
pstr = p.__repr__()
# '<[name] at [pstr]>'