Fix traits

typecheck-v2
Ibrahim Numanagić 2024-04-01 17:26:34 -07:00
parent 547c744b53
commit 07ffc62511
12 changed files with 107 additions and 68 deletions

View File

@ -25,6 +25,15 @@ int ClassType::unify(Type *typ, Unification *us) {
auto t64 = std::make_shared<IntStaticType>(cache, 64);
return generics[0].type->unify(t64.get(), us);
}
if (name == "unrealized_type" && tc->name == name) {
// instantiate + unify!
std::unordered_map<int, types::TypePtr> genericCache;
auto l = generics[0].type->instantiate(0, &(cache->unboundCount), &genericCache);
genericCache.clear();
auto r =
tc->generics[0].type->instantiate(0, &(cache->unboundCount), &genericCache);
return l->unify(r.get(), us);
}
// Check names.
if (name != tc->name)
return -1;

View File

@ -49,9 +49,9 @@ int LinkType::unify(Type *typ, Unification *undo) {
// Identical unbound types get a score of 1
if (id == t->id)
return 1;
// Generics must have matching IDs
if (kind != Unbound)
return -1;
// Generics must have matching IDs unless we are doing non-destructive unification
if (kind == Generic)
return undo ? -1 : 1;
// Always merge a newer type into the older type (e.g. keep the types with
// lower IDs around).
if (id < t->id)
@ -176,8 +176,8 @@ std::string LinkType::debugString(char mode) const {
}
return (genericName.empty() ? (mode ? "?" : "<unknown type>") : genericName);
}
if (mode == 2)
return ">" + type->debugString(mode);
// if (mode == 2)
// return ">" + type->debugString(mode);
return type->debugString(mode);
}

View File

@ -120,7 +120,8 @@ int CallableTrait::unify(Type *typ, Unification *us) {
return -1;
}
if (kwStar < trInArgs->generics.size()) {
TypePtr tt = cache->typeCtx->getType(TypecheckVisitor(cache->typeCtx).generateTuple(0));
TypePtr tt =
cache->typeCtx->getType(TypecheckVisitor(cache->typeCtx).generateTuple(0));
size_t id = 0;
if (auto tp = tr->getPartial()) {
auto ts = tp->generics[2].type->getClass();
@ -145,6 +146,7 @@ int CallableTrait::unify(Type *typ, Unification *us) {
if (args[1]->unify(pf->getRetType().get(), us) == -1)
return -1;
}
// LOG("- {} vs {}: ok", debugString(2), typ->debugString(2));
return 1;
} else if (auto tl = typ->getLink()) {
if (tl->kind == LinkType::Link)
@ -191,7 +193,10 @@ std::string CallableTrait::debugString(char mode) const {
TypeTrait::TypeTrait(TypePtr typ) : Trait(typ), type(std::move(typ)) {}
int TypeTrait::unify(Type *typ, Unification *us) { return typ->unify(type.get(), us); }
int TypeTrait::unify(Type *typ, Unification *us) {
if (typ->getClass()) // does not make sense otherwise and results in infinite cycles
return typ->unify(type.get(), us);
}
TypePtr TypeTrait::generalize(int atLevel) {
auto c = std::make_shared<TypeTrait>(type->generalize(atLevel));

View File

@ -153,7 +153,7 @@ void UnionType::seal() {
std::vector<TypePtr> typeSet(pendingTypes.begin(), pendingTypes.begin() + i);
auto name = tv.generateTuple(typeSet.size());
auto t = cache->typeCtx->instantiateGeneric(
cache->typeCtx->forceFind(name)->type->getClass(), typeSet);
cache->typeCtx->getType(name)->getClass(), typeSet);
Unification us;
generics[0].type->unify(t.get(), &us);
}

View File

@ -90,6 +90,7 @@ void TypecheckVisitor::visit(CallExpr *expr) {
if (auto f = expr->expr->type->getFunc())
addFunctionGenerics(f.get());
auto a = transformCallArgs(expr->args);
ctx->popBlock();
if (!a)
return;
@ -162,7 +163,6 @@ void TypecheckVisitor::visit(CallExpr *expr) {
newArgs.push_back(part.args);
auto partialCall = generatePartialCall(part.known, calleeFn->getFunc().get(),
N<TupleExpr>(newArgs), part.kwArgs);
std::string var = ctx->cache->getTemporaryVar("part");
ExprPtr call = nullptr;
if (!part.var.empty()) {
@ -541,40 +541,40 @@ bool TypecheckVisitor::typecheckCallArgs(const FuncTypePtr &calleeFn,
if (calleeFn->ast->args[i].status == Param::Generic)
continue;
if (startswith(calleeFn->ast->args[i].name, "*") && calleeFn->ast->args[i].type &&
args[si].value->getCall()) {
if (startswith(calleeFn->ast->args[i].name, "*") && calleeFn->ast->args[i].type) {
// Special case: `*args: type` and `**kwargs: type`
auto typ = ctx->getType(transform(clone(calleeFn->ast->args[i].type))->type);
auto callExpr = args[si].value;
if (startswith(calleeFn->ast->args[i].name, "**"))
callExpr = args[si].value->getCall()->args[0].value;
for (auto &ca : callExpr->getCall()->args) {
if (wrapExpr(ca.value, typ, calleeFn)) {
unify(ca.value->type, typ);
if (args[si].value->getCall()) {
auto typ = ctx->getType(transform(clone(calleeFn->ast->args[i].type))->type);
auto callExpr = args[si].value;
if (startswith(calleeFn->ast->args[i].name, "**"))
callExpr = args[si].value->getCall()->args[0].value;
for (auto &ca : callExpr->getCall()->args) {
if (wrapExpr(ca.value, typ, calleeFn)) {
unify(ca.value->type, typ);
} else {
wrappingDone = false;
}
}
auto name = callExpr->type->getClass()->name;
auto tup = transform(N<CallExpr>(N<IdExpr>(name), callExpr->getCall()->args));
if (startswith(calleeFn->ast->args[i].name, "**")) {
args[si].value =
transform(N<CallExpr>(N<DotExpr>(N<IdExpr>("NamedTuple"), "__new__"), tup,
N<IntExpr>(args[si]
.value->type->getClass()
->generics[0]
.type->getIntStatic()
->value)));
} else {
wrappingDone = false;
args[si].value = tup;
}
}
auto name = callExpr->type->getClass()->name;
auto tup = transform(N<CallExpr>(N<IdExpr>(name), callExpr->getCall()->args));
if (startswith(calleeFn->ast->args[i].name, "**")) {
args[si].value =
transform(N<CallExpr>(N<DotExpr>(N<IdExpr>("NamedTuple"), "__new__"), tup,
N<IntExpr>(args[si]
.value->type->getClass()
->generics[0]
.type->getIntStatic()
->value)));
} else {
args[si].value = tup;
}
replacements.push_back(args[si].value->type);
// else this is empty and is a partial call; leave it for later
} else {
if (calleeFn->ast->args[i].type && !calleeFn->getArgTypes()[si]->canRealize()) {
auto t = ctx->instantiate(ctx->getType(calleeFn->ast->args[i].type->type)->generalize(0));
// calleeFn->ast->args[i].type->
// ctx->getType(transform(clean_clone(calleeFn->ast->args[i].type))->type);
auto t = ctx->instantiate(
ctx->getType(calleeFn->ast->args[i].type->type)->generalize(0));
unify(calleeFn->getArgTypes()[si], t);
}
if (wrapExpr(args[si].value, calleeFn->getArgTypes()[si], calleeFn)) {

View File

@ -35,7 +35,10 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
ctx->generateCanonicalName(name, !stmt->attributes.has(Attr::Internal),
/* noSuffix*/ stmt->attributes.has(Attr::Internal));
typ = std::make_shared<types::ClassType>(ctx->cache, canonicalName, name);
if (canonicalName == "Union")
typ = std::make_shared<types::UnionType>(ctx->cache);
else
typ = std::make_shared<types::ClassType>(ctx->cache, canonicalName, name);
if (stmt->isRecord())
typ->isTuple = true;
// if (stmt->isRecord() && stmt->hasAttr("__notuple__"))
@ -112,15 +115,6 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
auto defType = transformType(clone(a.defaultValue));
generic->defaultType = getType(defType);
}
if (auto ti = CAST(a.type, InstantiateExpr)) {
// Parse TraitVar
seqassert(ti->typeExpr->isId(TYPE_TYPEVAR), "not a TypeVar instantiation");
auto l = transformType(ti->typeParams[0])->type;
if (l->getLink() && l->getLink()->trait)
generic->getLink()->trait = l->getLink()->trait;
else
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
}
if (auto st = getStaticGeneric(a.type.get())) {
if (st > 3)
transform(a.type); // error check
@ -128,10 +122,22 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
auto val = ctx->addVar(genName, varName, generic);
val->generic = true;
} else {
if (a.type->getIndex()) { // Parse TraitVar
transform(a.type);
auto ti = a.type->getInstantiate();
seqassert(ti && ti->typeExpr->isId(TYPE_TYPEVAR),
"not a TypeVar instantiation: {}", a.type);
auto l = getType(ti->typeParams[0]);
if (l->getLink() && l->getLink()->trait)
generic->getLink()->trait = l->getLink()->trait;
else
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
}
ctx->addType(genName, varName, generic)->generic = true;
}
ClassType::Generic g(varName, genName, generic->generalize(ctx->typecheckLevel),
typId, generic->isStatic);
if (a.status == Param::Generic) {
typ->generics.push_back(g);
} else {
@ -787,6 +793,8 @@ void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp,
if (t->getClass() && !t->getStatic() && !t->is("type"))
t = ctx->instantiateGeneric(ctx->getType("type"), {t});
ctx->addVar(ctx->cache->rev(g.name), g.name, t)->generic = true;
// LOG("=[g]=> {}: {} {:c} {}", clsTyp, g.name, t,
// t->getLink() && t->getLink()->trait ? "OK" : "-");
};
for (auto &g : clsTyp->hiddenGenerics)
addGen(g);

View File

@ -206,13 +206,13 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
kw = stmt->args.back();
stmt->args.pop_back();
}
std::array<const char*, 4> op {"", "int", "str", "bool"};
std::array<const char *, 4> op{"", "int", "str", "bool"};
for (auto &[c, v] : captures) {
if (v->isType())
stmt->args.emplace_back(c, N<IdExpr>("type"));
else if (auto si = v->isStatic())
stmt->args.emplace_back(c, N<IndexExpr>(N<IdExpr>("Static"),
N<IdExpr>(op[si])));
stmt->args.emplace_back(c,
N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>(op[si])));
else
stmt->args.emplace_back(c);
partialArgs.emplace_back(c, N<IdExpr>(v->canonicalName));
@ -279,6 +279,15 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
generic->defaultType = getType(defType);
}
} else {
if (auto ti = CAST(a.type, InstantiateExpr)) {
// Parse TraitVar
seqassert(ti->typeExpr->isId(TYPE_TYPEVAR), "not a TypeVar instantiation");
auto l = transformType(ti->typeParams[0])->type;
if (l->getLink() && l->getLink()->trait)
generic->getLink()->trait = l->getLink()->trait;
else
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
}
auto val = ctx->addType(varName, name, generic);
val->generic = true;
if (a.defaultValue) {
@ -327,7 +336,10 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
// Parse arguments to the context. Needs to be done after adding generics
// to support cases like `foo(a: T, T: type)`
for (auto &a : args) {
// if (a.status == Param::Normal || a.type->is ) // todo)) makes typevar work! need to check why...
a.type = transformType(a.type, false);
// if (a.type && a.type->type->getLink() && a.type->type->getLink()->trait)
// LOG("-> {:c}", a.type->type->getLink()->trait);
a.defaultValue = transform(a.defaultValue, true);
}
@ -413,7 +425,6 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
f->setDone();
// Construct the type
// g = ctx->instantiateGeneric(ctx->getType("type"), {g});
auto funcTyp = std::make_shared<types::FuncType>(
baseType, ctx->cache->functions[canonicalName].ast.get(), explicits);
funcTyp->setSrcInfo(getSrcInfo());
@ -423,6 +434,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
funcTyp = std::static_pointer_cast<types::FuncType>(
funcTyp->generalize(ctx->typecheckLevel));
ctx->cache->functions[canonicalName].type = funcTyp;
// LOG("-> {:c}", funcTyp);
ctx->addFunc(stmt->name, rootName, funcTyp);
ctx->addFunc(canonicalName, canonicalName, funcTyp);

View File

@ -194,7 +194,6 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
if (!type || !type->canRealize())
return nullptr;
// type->_rn = type->ClassType::realizedName();
// Check if the type fields are all initialized
// (sometimes that's not the case: e.g., `class X: x: List[X]`)
for (auto &field : ctx->cache->classes[type->name].fields) {
@ -202,6 +201,10 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
return nullptr;
}
// generalize generics to ensure that they do not get unified later!
if (type->is("unrealized_type"))
type->generics[0].type = type->generics[0].type->generalize(0);
// Check if the type was already realized
auto rn = type->ClassType::realizedName();
if (auto r = in(ctx->cache->classes[type->name].realizations, rn)) {
@ -218,9 +221,7 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
}
// Realize generics
if (type->is("unrealized_type"))
type->generics[0].type->generalize(ctx->typecheckLevel);
else
if (!type->is("unrealized_type"))
for (auto &e : realized->generics) {
if (!realize(e.type))
return nullptr;

View File

@ -372,7 +372,7 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
transformType(expr->typeParams[i]);
auto t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(),
getType(expr->typeParams[i]));
if (expr->typeParams[i]->type->isStaticType() !=
if (isUnion || expr->typeParams[i]->type->isStaticType() !=
generics[i].type->isStaticType()) {
if (expr->typeParams[i]->getNone()) // `None` -> `NoneType`
transformType(expr->typeParams[i]);

View File

@ -62,7 +62,7 @@ StmtPtr TypecheckVisitor::apply(
ScopingVisitor::apply(cache, suite);
auto n = tv.inferTypes(suite, true);
if (!n) {
LOG("[error=>] {}", suite->toString(2));
// LOG("[error=>] {}", suite->toString(2));
tv.error("cannot typecheck the program");
}
@ -238,6 +238,10 @@ ExprPtr TypecheckVisitor::transformType(ExprPtr &expr, bool allowTypeOf) {
!expr->type->getUnbound()->genericName.empty()) {
// generic!
expr->setType(ctx->instantiate(expr->getType()));
} else if (expr->type->getUnbound() &&
expr->type->getUnbound()->trait) {
// generic (is type)!
expr->setType(ctx->instantiate(expr->getType()));
} else {
E(Error::EXPECTED_TYPE, expr, "type");
}

View File

@ -362,17 +362,17 @@ class defaultdict(Static[Dict[K,V]]):
V: type
S: TypeVar[Callable[[], V]]
# def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V]):
# super().__init__()
# self.default_factory = lambda: VV()
def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V]):
super().__init__()
self.default_factory = lambda: VV()
def __init__(self, f: S):
super().__init__()
self.default_factory = f
# def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V], other: Dict[K, V]):
# super().__init__(other)
# self.default_factory = lambda: VV()
def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V], other: Dict[K, V]):
super().__init__(other)
self.default_factory = lambda: VV()
def __init__(self, f: S, other: Dict[K, V]):
super().__init__(other)

View File

@ -191,12 +191,12 @@ class Dict:
self._vals[x] = op(dflt if ret != 0 else self._vals[x], other)
def update(self, other):
if isinstance(other, Dict[K, V]):
for k, v in other.items():
self[k] = v
else:
for k, v in other:
self[k] = v
for k, v in other:
self[k] = v
def update(self, other: Dict[K, V]):
for k, v in other.items():
self[k] = v
def pop(self, key: K) -> V:
x = self._kh_get(key)