From 07ffc62511c0f11f3787e4fe648c110981005db1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ibrahim=20Numanagic=CC=81?= Date: Mon, 1 Apr 2024 17:26:34 -0700 Subject: [PATCH] Fix traits --- codon/parser/ast/types/class.cpp | 9 +++ codon/parser/ast/types/link.cpp | 10 ++-- codon/parser/ast/types/traits.cpp | 9 ++- codon/parser/ast/types/union.cpp | 2 +- codon/parser/visitors/typecheck/call.cpp | 56 +++++++++---------- codon/parser/visitors/typecheck/class.cpp | 28 ++++++---- codon/parser/visitors/typecheck/function.cpp | 20 +++++-- codon/parser/visitors/typecheck/infer.cpp | 9 +-- codon/parser/visitors/typecheck/op.cpp | 2 +- codon/parser/visitors/typecheck/typecheck.cpp | 6 +- stdlib/collections.codon | 12 ++-- stdlib/internal/types/collections/dict.codon | 12 ++-- 12 files changed, 107 insertions(+), 68 deletions(-) diff --git a/codon/parser/ast/types/class.cpp b/codon/parser/ast/types/class.cpp index 8fbfeb7c..d2141bf1 100644 --- a/codon/parser/ast/types/class.cpp +++ b/codon/parser/ast/types/class.cpp @@ -25,6 +25,15 @@ int ClassType::unify(Type *typ, Unification *us) { auto t64 = std::make_shared(cache, 64); return generics[0].type->unify(t64.get(), us); } + if (name == "unrealized_type" && tc->name == name) { + // instantiate + unify! + std::unordered_map genericCache; + auto l = generics[0].type->instantiate(0, &(cache->unboundCount), &genericCache); + genericCache.clear(); + auto r = + tc->generics[0].type->instantiate(0, &(cache->unboundCount), &genericCache); + return l->unify(r.get(), us); + } // Check names. if (name != tc->name) return -1; diff --git a/codon/parser/ast/types/link.cpp b/codon/parser/ast/types/link.cpp index 2865292a..fdd6430d 100644 --- a/codon/parser/ast/types/link.cpp +++ b/codon/parser/ast/types/link.cpp @@ -49,9 +49,9 @@ int LinkType::unify(Type *typ, Unification *undo) { // Identical unbound types get a score of 1 if (id == t->id) return 1; - // Generics must have matching IDs - if (kind != Unbound) - return -1; + // Generics must have matching IDs unless we are doing non-destructive unification + if (kind == Generic) + return undo ? -1 : 1; // Always merge a newer type into the older type (e.g. keep the types with // lower IDs around). if (id < t->id) @@ -176,8 +176,8 @@ std::string LinkType::debugString(char mode) const { } return (genericName.empty() ? (mode ? "?" : "") : genericName); } - if (mode == 2) - return ">" + type->debugString(mode); + // if (mode == 2) + // return ">" + type->debugString(mode); return type->debugString(mode); } diff --git a/codon/parser/ast/types/traits.cpp b/codon/parser/ast/types/traits.cpp index ae22c7a7..0ddd9060 100644 --- a/codon/parser/ast/types/traits.cpp +++ b/codon/parser/ast/types/traits.cpp @@ -120,7 +120,8 @@ int CallableTrait::unify(Type *typ, Unification *us) { return -1; } if (kwStar < trInArgs->generics.size()) { - TypePtr tt = cache->typeCtx->getType(TypecheckVisitor(cache->typeCtx).generateTuple(0)); + TypePtr tt = + cache->typeCtx->getType(TypecheckVisitor(cache->typeCtx).generateTuple(0)); size_t id = 0; if (auto tp = tr->getPartial()) { auto ts = tp->generics[2].type->getClass(); @@ -145,6 +146,7 @@ int CallableTrait::unify(Type *typ, Unification *us) { if (args[1]->unify(pf->getRetType().get(), us) == -1) return -1; } + // LOG("- {} vs {}: ok", debugString(2), typ->debugString(2)); return 1; } else if (auto tl = typ->getLink()) { if (tl->kind == LinkType::Link) @@ -191,7 +193,10 @@ std::string CallableTrait::debugString(char mode) const { TypeTrait::TypeTrait(TypePtr typ) : Trait(typ), type(std::move(typ)) {} -int TypeTrait::unify(Type *typ, Unification *us) { return typ->unify(type.get(), us); } +int TypeTrait::unify(Type *typ, Unification *us) { + if (typ->getClass()) // does not make sense otherwise and results in infinite cycles + return typ->unify(type.get(), us); +} TypePtr TypeTrait::generalize(int atLevel) { auto c = std::make_shared(type->generalize(atLevel)); diff --git a/codon/parser/ast/types/union.cpp b/codon/parser/ast/types/union.cpp index 37bce02c..c1f42932 100644 --- a/codon/parser/ast/types/union.cpp +++ b/codon/parser/ast/types/union.cpp @@ -153,7 +153,7 @@ void UnionType::seal() { std::vector typeSet(pendingTypes.begin(), pendingTypes.begin() + i); auto name = tv.generateTuple(typeSet.size()); auto t = cache->typeCtx->instantiateGeneric( - cache->typeCtx->forceFind(name)->type->getClass(), typeSet); + cache->typeCtx->getType(name)->getClass(), typeSet); Unification us; generics[0].type->unify(t.get(), &us); } diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index 5451bc8a..ce5b92f8 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -90,6 +90,7 @@ void TypecheckVisitor::visit(CallExpr *expr) { if (auto f = expr->expr->type->getFunc()) addFunctionGenerics(f.get()); auto a = transformCallArgs(expr->args); + ctx->popBlock(); if (!a) return; @@ -162,7 +163,6 @@ void TypecheckVisitor::visit(CallExpr *expr) { newArgs.push_back(part.args); auto partialCall = generatePartialCall(part.known, calleeFn->getFunc().get(), N(newArgs), part.kwArgs); - std::string var = ctx->cache->getTemporaryVar("part"); ExprPtr call = nullptr; if (!part.var.empty()) { @@ -541,40 +541,40 @@ bool TypecheckVisitor::typecheckCallArgs(const FuncTypePtr &calleeFn, if (calleeFn->ast->args[i].status == Param::Generic) continue; - if (startswith(calleeFn->ast->args[i].name, "*") && calleeFn->ast->args[i].type && - args[si].value->getCall()) { + if (startswith(calleeFn->ast->args[i].name, "*") && calleeFn->ast->args[i].type) { // Special case: `*args: type` and `**kwargs: type` - auto typ = ctx->getType(transform(clone(calleeFn->ast->args[i].type))->type); - auto callExpr = args[si].value; - if (startswith(calleeFn->ast->args[i].name, "**")) - callExpr = args[si].value->getCall()->args[0].value; - for (auto &ca : callExpr->getCall()->args) { - if (wrapExpr(ca.value, typ, calleeFn)) { - unify(ca.value->type, typ); + if (args[si].value->getCall()) { + auto typ = ctx->getType(transform(clone(calleeFn->ast->args[i].type))->type); + auto callExpr = args[si].value; + if (startswith(calleeFn->ast->args[i].name, "**")) + callExpr = args[si].value->getCall()->args[0].value; + for (auto &ca : callExpr->getCall()->args) { + if (wrapExpr(ca.value, typ, calleeFn)) { + unify(ca.value->type, typ); + } else { + wrappingDone = false; + } + } + auto name = callExpr->type->getClass()->name; + auto tup = transform(N(N(name), callExpr->getCall()->args)); + if (startswith(calleeFn->ast->args[i].name, "**")) { + args[si].value = + transform(N(N(N("NamedTuple"), "__new__"), tup, + N(args[si] + .value->type->getClass() + ->generics[0] + .type->getIntStatic() + ->value))); } else { - wrappingDone = false; + args[si].value = tup; } } - auto name = callExpr->type->getClass()->name; - auto tup = transform(N(N(name), callExpr->getCall()->args)); - - if (startswith(calleeFn->ast->args[i].name, "**")) { - args[si].value = - transform(N(N(N("NamedTuple"), "__new__"), tup, - N(args[si] - .value->type->getClass() - ->generics[0] - .type->getIntStatic() - ->value))); - } else { - args[si].value = tup; - } replacements.push_back(args[si].value->type); + // else this is empty and is a partial call; leave it for later } else { if (calleeFn->ast->args[i].type && !calleeFn->getArgTypes()[si]->canRealize()) { - auto t = ctx->instantiate(ctx->getType(calleeFn->ast->args[i].type->type)->generalize(0)); - // calleeFn->ast->args[i].type-> - // ctx->getType(transform(clean_clone(calleeFn->ast->args[i].type))->type); + auto t = ctx->instantiate( + ctx->getType(calleeFn->ast->args[i].type->type)->generalize(0)); unify(calleeFn->getArgTypes()[si], t); } if (wrapExpr(args[si].value, calleeFn->getArgTypes()[si], calleeFn)) { diff --git a/codon/parser/visitors/typecheck/class.cpp b/codon/parser/visitors/typecheck/class.cpp index 8c955869..ef8e06c5 100644 --- a/codon/parser/visitors/typecheck/class.cpp +++ b/codon/parser/visitors/typecheck/class.cpp @@ -35,7 +35,10 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { ctx->generateCanonicalName(name, !stmt->attributes.has(Attr::Internal), /* noSuffix*/ stmt->attributes.has(Attr::Internal)); - typ = std::make_shared(ctx->cache, canonicalName, name); + if (canonicalName == "Union") + typ = std::make_shared(ctx->cache); + else + typ = std::make_shared(ctx->cache, canonicalName, name); if (stmt->isRecord()) typ->isTuple = true; // if (stmt->isRecord() && stmt->hasAttr("__notuple__")) @@ -112,15 +115,6 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { auto defType = transformType(clone(a.defaultValue)); generic->defaultType = getType(defType); } - if (auto ti = CAST(a.type, InstantiateExpr)) { - // Parse TraitVar - seqassert(ti->typeExpr->isId(TYPE_TYPEVAR), "not a TypeVar instantiation"); - auto l = transformType(ti->typeParams[0])->type; - if (l->getLink() && l->getLink()->trait) - generic->getLink()->trait = l->getLink()->trait; - else - generic->getLink()->trait = std::make_shared(l); - } if (auto st = getStaticGeneric(a.type.get())) { if (st > 3) transform(a.type); // error check @@ -128,10 +122,22 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { auto val = ctx->addVar(genName, varName, generic); val->generic = true; } else { + if (a.type->getIndex()) { // Parse TraitVar + transform(a.type); + auto ti = a.type->getInstantiate(); + seqassert(ti && ti->typeExpr->isId(TYPE_TYPEVAR), + "not a TypeVar instantiation: {}", a.type); + auto l = getType(ti->typeParams[0]); + if (l->getLink() && l->getLink()->trait) + generic->getLink()->trait = l->getLink()->trait; + else + generic->getLink()->trait = std::make_shared(l); + } ctx->addType(genName, varName, generic)->generic = true; } ClassType::Generic g(varName, genName, generic->generalize(ctx->typecheckLevel), typId, generic->isStatic); + if (a.status == Param::Generic) { typ->generics.push_back(g); } else { @@ -787,6 +793,8 @@ void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp, if (t->getClass() && !t->getStatic() && !t->is("type")) t = ctx->instantiateGeneric(ctx->getType("type"), {t}); ctx->addVar(ctx->cache->rev(g.name), g.name, t)->generic = true; + // LOG("=[g]=> {}: {} {:c} {}", clsTyp, g.name, t, + // t->getLink() && t->getLink()->trait ? "OK" : "-"); }; for (auto &g : clsTyp->hiddenGenerics) addGen(g); diff --git a/codon/parser/visitors/typecheck/function.cpp b/codon/parser/visitors/typecheck/function.cpp index 35a5c493..604881f3 100644 --- a/codon/parser/visitors/typecheck/function.cpp +++ b/codon/parser/visitors/typecheck/function.cpp @@ -206,13 +206,13 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { kw = stmt->args.back(); stmt->args.pop_back(); } - std::array op {"", "int", "str", "bool"}; + std::array op{"", "int", "str", "bool"}; for (auto &[c, v] : captures) { if (v->isType()) stmt->args.emplace_back(c, N("type")); else if (auto si = v->isStatic()) - stmt->args.emplace_back(c, N(N("Static"), - N(op[si]))); + stmt->args.emplace_back(c, + N(N("Static"), N(op[si]))); else stmt->args.emplace_back(c); partialArgs.emplace_back(c, N(v->canonicalName)); @@ -279,6 +279,15 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { generic->defaultType = getType(defType); } } else { + if (auto ti = CAST(a.type, InstantiateExpr)) { + // Parse TraitVar + seqassert(ti->typeExpr->isId(TYPE_TYPEVAR), "not a TypeVar instantiation"); + auto l = transformType(ti->typeParams[0])->type; + if (l->getLink() && l->getLink()->trait) + generic->getLink()->trait = l->getLink()->trait; + else + generic->getLink()->trait = std::make_shared(l); + } auto val = ctx->addType(varName, name, generic); val->generic = true; if (a.defaultValue) { @@ -327,7 +336,10 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { // Parse arguments to the context. Needs to be done after adding generics // to support cases like `foo(a: T, T: type)` for (auto &a : args) { + // if (a.status == Param::Normal || a.type->is ) // todo)) makes typevar work! need to check why... a.type = transformType(a.type, false); + // if (a.type && a.type->type->getLink() && a.type->type->getLink()->trait) + // LOG("-> {:c}", a.type->type->getLink()->trait); a.defaultValue = transform(a.defaultValue, true); } @@ -413,7 +425,6 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { f->setDone(); // Construct the type - // g = ctx->instantiateGeneric(ctx->getType("type"), {g}); auto funcTyp = std::make_shared( baseType, ctx->cache->functions[canonicalName].ast.get(), explicits); funcTyp->setSrcInfo(getSrcInfo()); @@ -423,6 +434,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { funcTyp = std::static_pointer_cast( funcTyp->generalize(ctx->typecheckLevel)); ctx->cache->functions[canonicalName].type = funcTyp; + // LOG("-> {:c}", funcTyp); ctx->addFunc(stmt->name, rootName, funcTyp); ctx->addFunc(canonicalName, canonicalName, funcTyp); diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index 0ef004ad..e36698b2 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -194,7 +194,6 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) { if (!type || !type->canRealize()) return nullptr; // type->_rn = type->ClassType::realizedName(); - // Check if the type fields are all initialized // (sometimes that's not the case: e.g., `class X: x: List[X]`) for (auto &field : ctx->cache->classes[type->name].fields) { @@ -202,6 +201,10 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) { return nullptr; } + // generalize generics to ensure that they do not get unified later! + if (type->is("unrealized_type")) + type->generics[0].type = type->generics[0].type->generalize(0); + // Check if the type was already realized auto rn = type->ClassType::realizedName(); if (auto r = in(ctx->cache->classes[type->name].realizations, rn)) { @@ -218,9 +221,7 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) { } // Realize generics - if (type->is("unrealized_type")) - type->generics[0].type->generalize(ctx->typecheckLevel); - else + if (!type->is("unrealized_type")) for (auto &e : realized->generics) { if (!realize(e.type)) return nullptr; diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index 60f68074..ec84fb54 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -372,7 +372,7 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) { transformType(expr->typeParams[i]); auto t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(), getType(expr->typeParams[i])); - if (expr->typeParams[i]->type->isStaticType() != + if (isUnion || expr->typeParams[i]->type->isStaticType() != generics[i].type->isStaticType()) { if (expr->typeParams[i]->getNone()) // `None` -> `NoneType` transformType(expr->typeParams[i]); diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index ddbbf113..fc0bd96f 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -62,7 +62,7 @@ StmtPtr TypecheckVisitor::apply( ScopingVisitor::apply(cache, suite); auto n = tv.inferTypes(suite, true); if (!n) { - LOG("[error=>] {}", suite->toString(2)); + // LOG("[error=>] {}", suite->toString(2)); tv.error("cannot typecheck the program"); } @@ -238,6 +238,10 @@ ExprPtr TypecheckVisitor::transformType(ExprPtr &expr, bool allowTypeOf) { !expr->type->getUnbound()->genericName.empty()) { // generic! expr->setType(ctx->instantiate(expr->getType())); + } else if (expr->type->getUnbound() && + expr->type->getUnbound()->trait) { + // generic (is type)! + expr->setType(ctx->instantiate(expr->getType())); } else { E(Error::EXPECTED_TYPE, expr, "type"); } diff --git a/stdlib/collections.codon b/stdlib/collections.codon index 3133ca93..e716b19a 100644 --- a/stdlib/collections.codon +++ b/stdlib/collections.codon @@ -362,17 +362,17 @@ class defaultdict(Static[Dict[K,V]]): V: type S: TypeVar[Callable[[], V]] - # def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V]): - # super().__init__() - # self.default_factory = lambda: VV() + def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V]): + super().__init__() + self.default_factory = lambda: VV() def __init__(self, f: S): super().__init__() self.default_factory = f - # def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V], other: Dict[K, V]): - # super().__init__(other) - # self.default_factory = lambda: VV() + def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V], other: Dict[K, V]): + super().__init__(other) + self.default_factory = lambda: VV() def __init__(self, f: S, other: Dict[K, V]): super().__init__(other) diff --git a/stdlib/internal/types/collections/dict.codon b/stdlib/internal/types/collections/dict.codon index 12f22e86..18a1b622 100644 --- a/stdlib/internal/types/collections/dict.codon +++ b/stdlib/internal/types/collections/dict.codon @@ -191,12 +191,12 @@ class Dict: self._vals[x] = op(dflt if ret != 0 else self._vals[x], other) def update(self, other): - if isinstance(other, Dict[K, V]): - for k, v in other.items(): - self[k] = v - else: - for k, v in other: - self[k] = v + for k, v in other: + self[k] = v + + def update(self, other: Dict[K, V]): + for k, v in other.items(): + self[k] = v def pop(self, key: K) -> V: x = self._kh_get(key)