diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 7adcb4c8..63f7161f 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -133,6 +133,9 @@ struct Cache : public std::enable_shared_from_this { /// Realization lookup table that maps a realized class name to the corresponding /// ClassRealization instance. std::unordered_map> realizations; + /// List of inherited class. We also keep the number of fields each of inherited + /// class. + std::vector> parentClasses; Class() : ast(nullptr), originalAst(nullptr) {} }; diff --git a/codon/parser/visitors/simplify/simplify.cpp b/codon/parser/visitors/simplify/simplify.cpp index 0b1343c9..6c149a94 100644 --- a/codon/parser/visitors/simplify/simplify.cpp +++ b/codon/parser/visitors/simplify/simplify.cpp @@ -88,9 +88,10 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil } // Reserve the following static identifiers. for (auto name : {"staticlen", "compile_error", "isinstance", "hasattr", "type", - "TypeVar", "Callable", "argv", "super"}) + "TypeVar", "Callable", "argv", "super", "superf"}) stdlib->generateCanonicalName(name); stdlib->add(SimplifyItem::Var, "super", "super", true); + stdlib->add(SimplifyItem::Var, "superf", "superf", true); // This code must be placed in a preamble (these are not POD types but are // referenced by the various preamble Function.N and Tuple.N stubs) diff --git a/codon/parser/visitors/simplify/simplify_stmt.cpp b/codon/parser/visitors/simplify/simplify_stmt.cpp index 1d0a8823..823301e8 100644 --- a/codon/parser/visitors/simplify/simplify_stmt.cpp +++ b/codon/parser/visitors/simplify/simplify_stmt.cpp @@ -550,6 +550,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { typeAst = ctx->bases[ctx->bases.size() - 2].ast; ctx->bases.back().selfName = name; attr.set(".changedSelf"); + attr.set(Attr::Method); } if (attr.has(Attr::C)) { @@ -796,6 +797,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { std::vector> substitutions; std::vector argSubstitutions; std::unordered_set seenMembers; + std::vector baseASTsFields; for (auto &baseClass : stmt->baseClasses) { std::string bcName; std::vector subs; @@ -836,6 +838,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { if (!extension) ctx->cache->classes[canonicalName].fields.push_back({a.name, nullptr}); } + baseASTsFields.push_back(args.size()); } // Add generics, if any, to the context. @@ -956,6 +959,9 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { ctx->moduleName.module); ctx->cache->classes[canonicalName].ast = N(canonicalName, args, N(), attr); + for (int i = 0; i < baseASTs.size(); i++) + ctx->cache->classes[canonicalName].parentClasses.push_back( + {baseASTs[i]->name, baseASTsFields[i]}); std::vector fns; ExprPtr codeType = ctx->bases.back().ast->clone(); std::vector magics{}; diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index 92777c51..91e0aa74 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -301,6 +301,10 @@ private: const std::vector &methods, const std::vector &args); + ExprPtr transformSuper(const CallExpr *expr); + std::vector getSuperTypes(const types::ClassTypePtr &cls); + + private: types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b, bool undoOnSuccess = false); diff --git a/codon/parser/visitors/typecheck/typecheck_expr.cpp b/codon/parser/visitors/typecheck/typecheck_expr.cpp index 3007302c..811f28a6 100644 --- a/codon/parser/visitors/typecheck/typecheck_expr.cpp +++ b/codon/parser/visitors/typecheck/typecheck_expr.cpp @@ -1038,20 +1038,22 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in seenNames.insert(i.name); } - if (expr->expr->isId("super")) { + if (expr->expr->isId("superf")) { if (ctx->bases.back().supers.empty()) - error("no matching super methods are available"); + error("no matching superf methods are available"); auto parentCls = ctx->bases.back().type->getFunc()->funcParent; auto m = findMatchingMethods(parentCls ? CAST(parentCls, types::ClassType) : nullptr, ctx->bases.back().supers, expr->args); if (m.empty()) - error("no matching super methods are available"); + error("no matching superf methods are available"); // LOG("found {} <- {}", ctx->bases.back().type->getFunc()->toString(), // m[0]->toString()); ExprPtr e = N(N(m[0]->ast->name), expr->args); return transform(e, false, true); } + if (expr->expr->isId("super")) + return transformSuper(expr); bool isPartial = !expr->args.empty() && expr->args.back().value->getEllipsis() && !expr->args.back().value->getEllipsis()->isPipeArg && @@ -1421,8 +1423,15 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) expr->args[1].value = transformType(expr->args[1].value, /*disableActivation*/ true); auto t = expr->args[1].value->type; - auto unifyOK = typ->unify(t.get(), nullptr) >= 0; - return {true, transform(N(unifyOK))}; + auto hierarchy = getSuperTypes(typ->getClass()); + + for (auto &tx: hierarchy) { + auto unifyOK = tx->unify(t.get(), nullptr) >= 0; + if (unifyOK) { + return {true, transform(N(true))}; + } + } + return {true, transform(N(false))}; } } } else if (val == "staticlen") { @@ -1946,5 +1955,91 @@ types::FuncTypePtr TypecheckVisitor::findDispatch(const std::string &fn) { return typ; } +ExprPtr TypecheckVisitor::transformSuper(const CallExpr *expr) { + // For now, we just support casting to the _FIRST_ overload (i.e. empty super()) + if (!expr->args.empty()) + error("super does not take arguments"); + + if (ctx->bases.empty() || !ctx->bases.back().type) + error("no parent classes available"); + auto fptyp = ctx->bases.back().type->getFunc(); + if (!fptyp || !fptyp->ast->hasAttr(Attr::Method)) + error("no parent classes available"); + if (fptyp->args.size() < 2) + error("no parent classes available"); + ClassTypePtr typ = fptyp->args[1]->getClass(); + auto &cands = ctx->cache->classes[typ->name].parentClasses; + if (cands.empty()) + error("no parent classes available"); + // if (typ->getRecord()) + // error("cannot use super on tuple types"); + + // find parent typ + // unify top N args with parent typ args + // realize & do bitcast + // call bitcast() . method + + auto name = cands[0].first; + int fields = cands[0].second; + auto val = ctx->find(name); + seqassert(val, "cannot find '{}'", name); + auto ftyp = ctx->instantiate(expr, val->type)->getClass(); + + if (typ->getRecord()) { + std::vector members; + for (int i = 0; i < fields; i++) + members.push_back(N(N(fptyp->ast->args[0].name), + ctx->cache->classes[typ->name].fields[i].name)); + ExprPtr e = transform( + N(N(format(TYPE_TUPLE "{}", members.size())), members)); + unify(e->type, ftyp); + e->type = ftyp; + return e; + } else { + for (int i = 0; i < fields; i++) { + auto t = ctx->cache->classes[typ->name].fields[i].type; + t = ctx->instantiate(expr, t, typ.get()); + + auto ft = ctx->cache->classes[name].fields[i].type; + ft = ctx->instantiate(expr, ft, ftyp.get()); + unify(t, ft); + } + + ExprPtr typExpr = N(name); + typExpr->setType(ftyp); + auto self = fptyp->ast->args[0].name; + ExprPtr e = transform( + N(N(N("__internal__"), "to_class_ptr"), + N(N(N(self), "__raw__")), typExpr)); + return e; + } +} + +std::vector TypecheckVisitor::getSuperTypes(const ClassTypePtr &cls) { + std::vector result; + if (!cls) + return result; + result.push_back(cls); + int start = 0; + for (auto &cand: ctx->cache->classes[cls->name].parentClasses) { + auto name = cand.first; + int fields = cand.second; + auto val = ctx->find(name); + seqassert(val, "cannot find '{}'", name); + auto ftyp = ctx->instantiate(nullptr, val->type)->getClass(); + for (int i = start; i < fields; i++) { + auto t = ctx->cache->classes[cls->name].fields[i].type; + t = ctx->instantiate(nullptr, t, cls.get()); + auto ft = ctx->cache->classes[name].fields[i].type; + ft = ctx->instantiate(nullptr, ft, ftyp.get()); + unify(t, ft); + } + start += fields; + for (auto &t: getSuperTypes(ftyp)) + result.push_back(t); + } + return result; +} + } // namespace ast } // namespace codon diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index f07c7e50..cfe19211 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -125,6 +125,24 @@ class __internal__: def opt_ref_invert[T](what: Optional[T]) -> T: ret i8* %what + @pure + @llvm + def to_class_ptr[T](ptr: Ptr[byte]) -> T: + %0 = bitcast i8* %ptr to {=T} + ret {=T} %0 + + @pure + def _tuple_offsetof(x, field: Static[int]): + @llvm + def _llvm_offsetof(T: type, idx: Static[int], TE: type) -> int: + %a = alloca {=T} + %b = getelementptr inbounds {=T}, {=T}* %a, i64 0, i32 {=idx} + %base = ptrtoint {=T}* %a to i64 + %elem = ptrtoint {=TE}* %b to i64 + %offset = sub i64 %elem, %base + ret i64 %offset + return _llvm_offsetof(type(x), field, type(x[field])) + def raw_type_str(p: Ptr[byte], name: str) -> str: pstr = p.__repr__() # '<[name] at [pstr]>' diff --git a/test/parser/typecheck_expr.codon b/test/parser/typecheck_expr.codon index ce02e25b..f28f1b87 100644 --- a/test/parser/typecheck_expr.codon +++ b/test/parser/typecheck_expr.codon @@ -600,6 +600,30 @@ print hasattr(int, "__getitem__") print hasattr([1, 2], "__getitem__", str) #: False +#%% isinstance_inheritance,barebones +class AX[T]: + a: T + def __init__(self, a: T): + self.a = a +class Side: + def __init__(self): + pass +class BX[T,U](AX[T], Side): + b: U + def __init__(self, a: T, b: U): + super().__init__(a) + self.b = b +class CX[T,U](BX[T,U]): + c: int + def __init__(self, a: T, b: U): + super().__init__(a, b) + self.c = 1 +c = CX('a', False) +print isinstance(c, CX), isinstance(c, BX), isinstance(c, AX), isinstance(c, Side) +#: True True True True +print isinstance(c, BX[str, bool]), isinstance(c, BX[str, str]), isinstance(c, AX[int]) +#: True False False + #%% staticlen_err,barebones print staticlen([1, 2]) #! List[int] is not a tuple type @@ -679,3 +703,83 @@ def foo(x: Callable[[1,2], 3]): pass #! unexpected static type #%% static_unify_2,barebones def foo(x: List[1]): pass #! cannot unify T and 1 + +#%% super,barebones +class A[T]: + a: T + def __init__(self, t: T): + self.a = t + def foo(self): + return f'A:{self.a}' +class B(A[str]): + b: int + def __init__(self): + super().__init__('s') + self.b = 6 + def baz(self): + return f'{super().foo()}::{self.b}' +b = B() +print b.foo() #: A:s +print b.baz() #: A:s::6 + +class AX[T]: + a: T + def __init__(self, a: T): + self.a = a + def foo(self): + return f'[AX:{self.a}]' +class BX[T,U](AX[T]): + b: U + def __init__(self, a: T, b: U): + print super().__class__ + super().__init__(a) + self.b = b + def foo(self): + return f'[BX:{super().foo()}:{self.b}]' +class CX[T,U](BX[T,U]): + c: int + def __init__(self, a: T, b: U): + print super().__class__ + super().__init__(a, b) + self.c = 1 + def foo(self): + return f'CX:{super().foo()}:{self.c}' +c = CX('a', False) +print c.__class__, c.foo() +#: BX[str,bool] +#: AX[str] +#: CX[str,bool] CX:[BX:[AX:a]:False]:1 + + +#%% super_tuple,barebones +@tuple +class A[T]: + a: T + x: int + def __new__(a: T) -> A[T]: + return (a, 1) + def foo(self): + return f'A:{self.a}' +@tuple +class B(A[str]): + b: int + def __new__() -> B: + return (*(A('s')), 6) + def baz(self): + return f'{super().foo()}::{self.b}' + +b = B() +print b.foo() #: A:s +print b.baz() #: A:s::6 + + +#%% super_error,barebones +class A: + def __init__(self): + super().__init__() +a = A() +#! no parent classes available +#! while realizing A.__init__:1 (arguments A.__init__:1[A]) + +#%% super_error_2,barebones +super().foo(1) #! no parent classes available diff --git a/test/parser/typecheck_stmt.codon b/test/parser/typecheck_stmt.codon index a3099c5c..1baa0fcd 100644 --- a/test/parser/typecheck_stmt.codon +++ b/test/parser/typecheck_stmt.codon @@ -299,19 +299,19 @@ def foo2(x): foo2(1) #: 2 foo2('s') #: 1 -#%% super,barebones +#%% superf,barebones class Foo: def foo(a): - # super(a) + # superf(a) print 'foo-1', a def foo(a: int): - super(a) + superf(a) print 'foo-2', a def foo(a: str): - super(a) + superf(a) print 'foo-3', a def foo(a): - super(a) + superf(a) print 'foo-4', a Foo.foo(1) #: foo-1 1 @@ -324,21 +324,21 @@ class Bear: @extend class Bear: def woof(x): - return super(x) + f' bear w--f {x}' + return superf(x) + f' bear w--f {x}' print Bear.woof('!') #: bear woof ! bear w--f ! class PolarBear(Bear): def woof(): - return 'polar ' + super('@') + return 'polar ' + superf('@') print PolarBear.woof() #: polar bear woof @ bear w--f @ -#%% super_error,barebones +#%% superf_error,barebones class Foo: def foo(a): - super(a) + superf(a) print 'foo-1', a Foo.foo(1) -#! no matching super methods are available +#! no matching superf methods are available #! while realizing Foo.foo:0