From 5672cebe1c70527b6b2a7d22ff9def5514fe77f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ibrahim=20Numanagic=CC=81?= Date: Fri, 7 Jan 2022 18:26:14 -0800 Subject: [PATCH] Add support for super() --- codon/parser/cache.h | 2 + .../visitors/simplify/simplify_stmt.cpp | 5 ++ codon/parser/visitors/typecheck/typecheck.h | 2 + .../visitors/typecheck/typecheck_expr.cpp | 48 +++++++++++++++++++ stdlib/internal/internal.codon | 6 +++ 5 files changed, 63 insertions(+) diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 7adcb4c8..bef3629d 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -134,6 +134,8 @@ struct Cache : public std::enable_shared_from_this { /// ClassRealization instance. std::unordered_map> realizations; + std::vector> parentClasses; + Class() : ast(nullptr), originalAst(nullptr) {} }; /// Class lookup table that maps a canonical class identifier to the corresponding diff --git a/codon/parser/visitors/simplify/simplify_stmt.cpp b/codon/parser/visitors/simplify/simplify_stmt.cpp index 9664e99b..b7cb13ec 100644 --- a/codon/parser/visitors/simplify/simplify_stmt.cpp +++ b/codon/parser/visitors/simplify/simplify_stmt.cpp @@ -788,6 +788,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { std::vector> substitutions; std::vector argSubstitutions; std::unordered_set seenMembers; + std::vector baseASTsFields; for (auto &baseClass : stmt->baseClasses) { std::string bcName; std::vector subs; @@ -828,6 +829,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { if (!extension) ctx->cache->classes[canonicalName].fields.push_back({a.name, nullptr}); } + baseASTsFields.push_back(args.size()); } // Add generics, if any, to the context. @@ -909,6 +911,9 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { ctx->moduleName.module); ctx->cache->classes[canonicalName].ast = N(canonicalName, args, N(), attr); + for (int i = 0; i < baseASTs.size(); i++) + ctx->cache->classes[canonicalName].parentClasses.push_back( + {baseASTs[i]->name, baseASTsFields[i]}); std::vector fns; ExprPtr codeType = ctx->bases.back().ast->clone(); std::vector magics{}; diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index 92777c51..a8bfca92 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -301,6 +301,8 @@ private: const std::vector &methods, const std::vector &args); + ExprPtr transformSuper(const CallExpr *expr); + private: types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b, bool undoOnSuccess = false); diff --git a/codon/parser/visitors/typecheck/typecheck_expr.cpp b/codon/parser/visitors/typecheck/typecheck_expr.cpp index bc547469..0213aaa9 100644 --- a/codon/parser/visitors/typecheck/typecheck_expr.cpp +++ b/codon/parser/visitors/typecheck/typecheck_expr.cpp @@ -1052,6 +1052,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in ExprPtr e = N(N(m[0]->ast->name), expr->args); return transform(e, false, true); } + if (expr->expr->isId("super")) + return transformSuper(expr); bool isPartial = !expr->args.empty() && expr->args.back().value->getEllipsis() && !expr->args.back().value->getEllipsis()->isPipeArg && @@ -1946,5 +1948,51 @@ types::FuncTypePtr TypecheckVisitor::findDispatch(const std::string &fn) { return typ; } +ExprPtr TypecheckVisitor::transformSuper(const CallExpr *expr) { + // For now, we just support casting to the _FIRST_ overload (i.e. empty super()) + if (!expr->args.empty()) + error("super does not take arguments"); + + if (ctx->bases.empty()) + error("no parent classes available"); + auto fptyp = ctx->bases.back().type->getFunc(); + if (!fptyp || fptyp->ast->hasAttr(Attr::Method)) + error("no parent classes available"); + ClassTypePtr typ = fptyp->args[1]->getClass(); + auto &cands = ctx->cache->classes[typ->name].parentClasses; + if (cands.empty()) + error("no parent classes available"); + if (typ->getRecord()) + error("cannot use super on tuple types"); + + // find parent typ + // unify top N args with parent typ args + // realize & do bitcast + // call bitcast() . method + + auto name = cands[0].first; + int fields = cands[0].second; + auto val = ctx->find(name); + seqassert(val, "cannot find '{}'", name); + auto ftyp = ctx->instantiate(expr, val->type)->getClass(); + + for (int i = 0; i < fields; i++) { + auto t = ctx->cache->classes[typ->name].fields[i].type; + t = ctx->instantiate(expr, t, typ.get()); + + auto ft = ctx->cache->classes[name].fields[i].type; + ft = ctx->instantiate(expr, ft, ftyp.get()); + unify(t, ft); + } + + ExprPtr typExpr = N(name); + typExpr->setType(ftyp); + auto self = fptyp->ast->args[0].name; + ExprPtr e = transform( + N(N(N("__internal__"), "to_class_ptr"), + N(N(N(self), "__raw__")), typExpr)); + return e; +} + } // namespace ast } // namespace codon diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index f07c7e50..ba66d154 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -125,6 +125,12 @@ class __internal__: def opt_ref_invert[T](what: Optional[T]) -> T: ret i8* %what + @pure + @llvm + def to_class_ptr[T](ptr: Ptr[byte]) -> T: + %0 = bitcast i8* %ptr to {=T} + ret {=T} %0 + def raw_type_str(p: Ptr[byte], name: str) -> str: pstr = p.__repr__() # '<[name] at [pstr]>'