1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Fix polymorphism

This commit is contained in:
Ibrahim Numanagić 2024-03-18 16:15:36 -07:00
parent f4fe8ec18f
commit e737536b38
23 changed files with 204 additions and 123 deletions

View File

@ -57,6 +57,10 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) {
} }
TranslateVisitor(cache->codegenCtx).transform(stmts); TranslateVisitor(cache->codegenCtx).transform(stmts);
for (auto &[_, f]: cache->functions)
TranslateVisitor(cache->codegenCtx).transform(f.ast);
cache->populatePythonModule(); cache->populatePythonModule();
return main; return main;
} }
@ -174,7 +178,7 @@ void TranslateVisitor::visit(StringExpr *expr) {
void TranslateVisitor::visit(IdExpr *expr) { void TranslateVisitor::visit(IdExpr *expr) {
auto val = ctx->find(expr->value); auto val = ctx->find(expr->value);
seqassert(val, "cannot find '{}'", expr->value); seqassert(val, "cannot find '{}'", expr->value);
if (expr->value == "__vtable_size__") { if (expr->value == "__vtable_size__.0") {
// LOG("[] __vtable_size__={}", ctx->cache->classRealizationCnt + 2); // LOG("[] __vtable_size__={}", ctx->cache->classRealizationCnt + 2);
result = make<ir::IntConst>(expr, ctx->cache->classRealizationCnt + 2, result = make<ir::IntConst>(expr, ctx->cache->classRealizationCnt + 2,
getType(expr->getType())); getType(expr->getType()));
@ -438,7 +442,6 @@ void TranslateVisitor::visit(AssignStmt *stmt) {
auto isGlobal = in(ctx->cache->globals, var); auto isGlobal = in(ctx->cache->globals, var);
ir::Var *v = nullptr; ir::Var *v = nullptr;
if (!stmt->lhs->type->isInstantiated() || (stmt->lhs->type->is("type"))) { if (!stmt->lhs->type->isInstantiated() || (stmt->lhs->type->is("type"))) {
// LOG("{} {}", getSrcInfo(), stmt->toString(0)); // LOG("{} {}", getSrcInfo(), stmt->toString(0));
return; // type aliases/fn aliases etc return; // type aliases/fn aliases etc
@ -697,9 +700,8 @@ void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt
} else { } else {
seqassert(ss[i]->getExpr()->expr->getType(), "invalid LLVM type argument: {}", seqassert(ss[i]->getExpr()->expr->getType(), "invalid LLVM type argument: {}",
ss[i]->getExpr()->toString()); ss[i]->getExpr()->toString());
literals.emplace_back(getType( literals.emplace_back(
ctx->cache->typeCtx->getType( getType(ctx->cache->typeCtx->getType(ss[i]->getExpr()->expr->getType())));
ss[i]->getExpr()->expr->getType())));
} }
} }
bool isDeclare = true; bool isDeclare = true;

View File

@ -357,10 +357,11 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
} }
// Special case: cls.__id__ // Special case: cls.__id__
if (expr->expr->type->is("type") && expr->member == "__id__") { if (expr->expr->type->is("type") && expr->member == "__id__") {
if (auto c = realize(expr->expr->type)) if (auto c = realize(getType(expr->expr))) {
return transform(N<IntExpr>(ctx->cache->classes[c->getClass()->name] return transform(N<IntExpr>(ctx->cache->classes[c->getClass()->name]
.realizations[c->getClass()->realizedName()] .realizations[c->getClass()->realizedName()]
->id)); ->id));
}
return nullptr; return nullptr;
} }
@ -618,7 +619,9 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr,
// If overload is ambiguous, route through a dispatch function // If overload is ambiguous, route through a dispatch function
std::string name; std::string name;
if (auto dot = expr->getDot()) { if (auto dot = expr->getDot()) {
name = ctx->cache->getMethod(getType(dot->expr)->getClass(), dot->member); auto methods = ctx->findMethod(getType(dot->expr)->getClass()->name, dot->member, false);
seqassert(!methods.empty(), "unknown method");
name = ctx->cache->functions[methods.back()->ast->name].rootName;
} else { } else {
name = expr->getId()->value; name = expr->getId()->value;
} }

View File

@ -209,7 +209,7 @@ StmtPtr TypecheckVisitor::transformUpdate(AssignStmt *stmt) {
void TypecheckVisitor::visit(AssignMemberStmt *stmt) { void TypecheckVisitor::visit(AssignMemberStmt *stmt) {
transform(stmt->lhs); transform(stmt->lhs);
if (auto lhsClass = stmt->lhs->getType()->getClass()) { if (auto lhsClass = getType(stmt->lhs)->getClass()) {
auto member = ctx->findMember(lhsClass->name, stmt->member); auto member = ctx->findMember(lhsClass->name, stmt->member);
if (!member && stmt->lhs->type->is("type")) { if (!member && stmt->lhs->type->is("type")) {

View File

@ -471,9 +471,8 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
expr->args.pop_back(); expr->args.pop_back();
if (!part.args) if (!part.args)
part.args = transform(N<TupleExpr>()); // use () part.args = transform(N<TupleExpr>()); // use ()
if (!part.kwArgs) { if (!part.kwArgs)
part.kwArgs = transform(N<CallExpr>(N<IdExpr>("NamedTuple"))); // use NamedTuple() part.kwArgs = transform(N<CallExpr>(N<IdExpr>("NamedTuple"))); // use NamedTuple()
}
} }
// Unify function type generics with the provided generics // Unify function type generics with the provided generics
@ -805,7 +804,8 @@ ExprPtr TypecheckVisitor::transformSuper() {
self->type = typ; self->type = typ;
auto typExpr = N<IdExpr>(superTyp->name); auto typExpr = N<IdExpr>(superTyp->name);
typExpr->setType(superTyp); typExpr->setType(ctx->instantiateGeneric(ctx->getType("type"), {superTyp}));
// LOG("-> {:c} : {:c} {:c}", typ, vCands[1], typExpr->type);
return transform(N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "class_super"), return transform(N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "class_super"),
self, typExpr, N<IntExpr>(1))); self, typExpr, N<IntExpr>(1)));
} }
@ -819,7 +819,6 @@ ExprPtr TypecheckVisitor::transformSuper() {
members.push_back(N<DotExpr>(N<IdExpr>(funcTyp->ast->args[0].name), field.name)); members.push_back(N<DotExpr>(N<IdExpr>(funcTyp->ast->args[0].name), field.name));
ExprPtr e = ExprPtr e =
transform(N<CallExpr>(N<IdExpr>(generateTuple(members.size())), members)); transform(N<CallExpr>(N<IdExpr>(generateTuple(members.size())), members));
auto ft = getClassFieldTypes(superTyp); auto ft = getClassFieldTypes(superTyp);
for (size_t i = 0; i < ft.size(); i++) for (size_t i = 0; i < ft.size(); i++)
unify( unify(

View File

@ -122,7 +122,8 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l); generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
} }
if (auto st = getStaticGeneric(a.type.get())) { if (auto st = getStaticGeneric(a.type.get())) {
if (st > 3) transform(a.type); // error check if (st > 3)
transform(a.type); // error check
generic->isStatic = st; generic->isStatic = st;
auto val = ctx->addVar(genName, varName, generic); auto val = ctx->addVar(genName, varName, generic);
val->generic = true; val->generic = true;
@ -197,8 +198,9 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
// : ctx->generateCanonicalName(a.name); // : ctx->generateCanonicalName(a.name);
args.emplace_back(varName, transformType(clean_clone(a.type)), args.emplace_back(varName, transformType(clean_clone(a.type)),
transform(clone(a.defaultValue), true)); transform(clone(a.defaultValue), true));
ctx->cache->classes[canonicalName].fields.emplace_back(Cache::Class::ClassField{ ctx->cache->classes[canonicalName].fields.emplace_back(
varName, types::TypePtr(nullptr), canonicalName}); Cache::Class::ClassField{varName, types::TypePtr(nullptr),
canonicalName});
} }
} }
} }
@ -248,7 +250,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
for (auto &b : staticBaseASTs) for (auto &b : staticBaseASTs)
ctx->cache->classes[canonicalName].staticParentClasses.emplace_back(b->name); ctx->cache->classes[canonicalName].staticParentClasses.emplace_back(b->name);
ctx->cache->classes[canonicalName].ast->validate(); ctx->cache->classes[canonicalName].ast->validate();
ctx->cache->classes[canonicalName].module = ctx->getModule(); ctx->cache->classes[canonicalName].module = ctx->moduleName.path;
// Codegen default magic methods // Codegen default magic methods
// __new__ must be the first // __new__ must be the first
@ -260,11 +262,28 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
for (auto &base : staticBaseASTs) { for (auto &base : staticBaseASTs) {
for (auto &mm : ctx->cache->classes[base->name].methods) for (auto &mm : ctx->cache->classes[base->name].methods)
for (auto &mf : ctx->cache->overloads[mm.second]) { for (auto &mf : ctx->cache->overloads[mm.second]) {
auto f = ctx->cache->functions[mf].origAst; const auto &fp = ctx->cache->functions[mf];
auto f = fp.origAst;
if (f && !f->attributes.has("autogenerated")) { if (f && !f->attributes.has("autogenerated")) {
ctx->addBlock(); ctx->addBlock();
addClassGenerics(base); addClassGenerics(base);
fnStmts.push_back(transform(clean_clone(f))); // since functions can come from other modules
// make sure to transform them in their respective module
// however makle sure to add/pop generics :/
auto cf = clean_clone(f);
if (!ctx->isStdlibLoading && fp.module != ctx->moduleName.path) {
auto ictx = ctx->cache->imports[fp.module].ctx;
TypeContext::BaseGuard br(ictx.get(), canonicalName);
ictx->getBase()->type = typ;
ictx->addBlock();
auto tv = TypecheckVisitor(ictx);
tv.addClassGenerics(typ, true);
cf = std::dynamic_pointer_cast<FunctionStmt>(tv.transform(cf));
ictx->popBlock();
} else {
cf = std::dynamic_pointer_cast<FunctionStmt>(transform(cf));
}
fnStmts.push_back(cf);
ctx->popBlock(); ctx->popBlock();
} }
} }
@ -328,13 +347,12 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
// Debug information // Debug information
// LOG("[class] {} -> {:c} / {}", canonicalName, typ, // LOG("[class] {} -> {:c} / {}", canonicalName, typ,
// ctx->cache->classes[canonicalName].fields.size()); // ctx->cache->classes[canonicalName].fields.size());
// if (auto r = typ->getRecord())
// for (auto &tx: r->args)
// LOG(" ... {:c}", tx);
// for (auto &m : ctx->cache->classes[canonicalName].fields) // for (auto &m : ctx->cache->classes[canonicalName].fields)
// LOG(" - member: {}: {:D}", m.name, m.type); // LOG(" - member: {}: {:c}", m.name, m.type);
// for (auto &m : ctx->cache->classes[canonicalName].methods) // for (auto &m : ctx->cache->classes[canonicalName].methods)
// LOG(" - method: {}: {}", m.first, m.second); // LOG(" - method: {}: {}", m.first, m.second);
// for (auto &m : ctx->cache->classes[canonicalName].mro)
// LOG(" - mro: {:c}", m);
// LOG(""); // LOG("");
// ctx->dump(); // ctx->dump();
} catch (const exc::ParserException &) { } catch (const exc::ParserException &) {
@ -395,7 +413,10 @@ TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
name = clsTyp->name; name = clsTyp->name;
asts.push_back(clsTyp); asts.push_back(clsTyp);
Cache::Class *cachedCls = in(ctx->cache->classes, name); Cache::Class *cachedCls = in(ctx->cache->classes, name);
mro.push_back(cachedCls->mro); auto rootMro = cachedCls->mro;
for (auto &t : rootMro)
t = ctx->instantiate(t, clsTyp)->getClass();
mro.push_back(rootMro);
// Sanity checks // Sanity checks
if (attr.has(Attr::Tuple) && typeAst) if (attr.has(Attr::Tuple) && typeAst)
@ -438,9 +459,7 @@ TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
transform(clean_clone(a.defaultValue))); transform(clean_clone(a.defaultValue)));
ctx->cache->classes[canonicalName].fields.emplace_back(Cache::Class::ClassField{ ctx->cache->classes[canonicalName].fields.emplace_back(Cache::Class::ClassField{
name, getType(args.back().type), name, getType(args.back().type),
ctx->cache->classes[ast->name].fields[ai].baseClass ctx->cache->classes[ast->name].fields[ai].baseClass});
}
);
ai++; ai++;
} }
} }
@ -455,7 +474,8 @@ TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
if (ctx->cache->classes[canonicalName].mro.empty()) { if (ctx->cache->classes[canonicalName].mro.empty()) {
E(Error::CLASS_BAD_MRO, getSrcInfo()); E(Error::CLASS_BAD_MRO, getSrcInfo());
} else if (ctx->cache->classes[canonicalName].mro.size() > 1) { } else if (ctx->cache->classes[canonicalName].mro.size() > 1) {
// LOG("[mro] {} -> {}", canonicalName, ctx->cache->classes[canonicalName].mro); // for (auto &t: ctx->cache->classes[canonicalName].mro)
// LOG("[mro] {} -> {:c}", canonicalName, t);
} }
} }
return asts; return asts;
@ -753,11 +773,19 @@ int TypecheckVisitor::generateKwId(const std::vector<std::string> &names) {
} }
} }
void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp) { void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp,
bool instantiate) {
auto addGen = [&](auto g) { auto addGen = [&](auto g) {
auto t = g.type; auto t = g.type;
if (instantiate)
if (auto l = g.type->getLink())
if (l->kind == LinkType::Generic) {
auto lx = std::make_shared<LinkType>(*l);
lx->kind = LinkType::Unbound;
t = lx;
}
if (t->getClass() && !t->getStatic() && !t->is("type")) if (t->getClass() && !t->getStatic() && !t->is("type"))
t = ctx->instantiateGeneric(ctx->getType("type"), {t}); t = ctx->instantiateGeneric(ctx->getType("type"), {t});
ctx->addVar(ctx->cache->rev(g.name), g.name, t)->generic = true; ctx->addVar(ctx->cache->rev(g.name), g.name, t)->generic = true;
}; };
for (auto &g : clsTyp->hiddenGenerics) for (auto &g : clsTyp->hiddenGenerics)

View File

@ -138,6 +138,10 @@ std::string TypeContext::getModule() const {
return base; return base;
} }
std::string TypeContext::getModulePath() const {
return moduleName.path;
}
void TypeContext::dump() { dump(0); } void TypeContext::dump() { dump(0); }
bool TypeContext::isCanonicalName(const std::string &name) const { bool TypeContext::isCanonicalName(const std::string &name) const {

View File

@ -205,6 +205,8 @@ public:
std::string getBaseName() const; std::string getBaseName() const;
/// Return the current module. /// Return the current module.
std::string getModule() const; std::string getModule() const;
/// Return the current module path.
std::string getModulePath() const;
/// Pretty-print the current context state. /// Pretty-print the current context state.
void dump() override; void dump() override;

View File

@ -49,6 +49,8 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) {
stmt->expr = partializeFunction(stmt->expr->type->getFunc()); stmt->expr = partializeFunction(stmt->expr->type->getFunc());
} }
if (!ctx->getBase()->returnType->isStaticType() && stmt->expr->type->getStatic())
stmt->expr->type = stmt->expr->type->getStatic()->getNonStaticType();
unify(ctx->getBase()->returnType, stmt->expr->type); unify(ctx->getBase()->returnType, stmt->expr->type);
} else { } else {
// Just set the expr for the translation stage. However, do not unify the return // Just set the expr for the translation stage. However, do not unify the return
@ -397,7 +399,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
// Make function AST and cache it for later realization // Make function AST and cache it for later realization
auto f = N<FunctionStmt>(canonicalName, ret, args, suite, stmt->attributes); auto f = N<FunctionStmt>(canonicalName, ret, args, suite, stmt->attributes);
ctx->cache->functions[canonicalName].module = ctx->getModule(); ctx->cache->functions[canonicalName].module = ctx->moduleName.path;
ctx->cache->functions[canonicalName].ast = f; ctx->cache->functions[canonicalName].ast = f;
ctx->cache->functions[canonicalName].origAst = stmt_clone; ctx->cache->functions[canonicalName].origAst = stmt_clone;
ctx->cache->functions[canonicalName].isToplevel = ctx->cache->functions[canonicalName].isToplevel =

View File

@ -77,7 +77,7 @@ void TypecheckVisitor::visit(ImportStmt *stmt) {
transform(N<AssignStmt>( transform(N<AssignStmt>(
N<IdExpr>(name), N<IdExpr>(name),
N<CallExpr>(N<IdExpr>("Import"), N<StringExpr>(file->path), N<CallExpr>(N<IdExpr>("Import"), N<StringExpr>(file->path),
N<StringExpr>(file->path), N<StringExpr>(file->module))))); N<StringExpr>(file->module), N<StringExpr>(file->path)))));
} else if (stmt->what->isId("*")) { } else if (stmt->what->isId("*")) {
// Case: from foo import * // Case: from foo import *
seqassert(stmt->as.empty(), "renamed star-import"); seqassert(stmt->as.empty(), "renamed star-import");
@ -202,7 +202,7 @@ StmtPtr TypecheckVisitor::transformCVarImport(const std::string &name, const Exp
auto canonical = ctx->generateCanonicalName(name); auto canonical = ctx->generateCanonicalName(name);
auto typ = transformType(clone(type)); auto typ = transformType(clone(type));
auto val = ctx->addVar(altName.empty() ? name : altName, canonical, auto val = ctx->addVar(altName.empty() ? name : altName, canonical,
std::make_shared<types::LinkType>(typ->type->getClass())); std::make_shared<types::LinkType>(getType(typ)->getClass()));
auto s = N<AssignStmt>(N<IdExpr>(canonical), nullptr, typ); auto s = N<AssignStmt>(N<IdExpr>(canonical), nullptr, typ);
s->lhs->setAttr(ExprAttr::ExternVar); s->lhs->setAttr(ExprAttr::ExternVar);
s->lhs->setType(val->type); s->lhs->setType(val->type);

View File

@ -394,8 +394,9 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
// Use NoneType as the return type when the return type is not specified and // Use NoneType as the return type when the return type is not specified and
// function has no return statement // function has no return statement
if (!ast->ret && type->getRetType()->getUnbound()) if (!ast->ret && type->getRetType()->getUnbound()) {
unify(type->getRetType(), ctx->getType("NoneType")); unify(type->getRetType(), ctx->getType("NoneType"));
}
// LOG("-> {} {}", key, ret->toString(2)); // LOG("-> {} {}", key, ret->toString(2));
} }
// Realize the return type // Realize the return type
@ -583,8 +584,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
for (auto &[_, real] : cls.realizations) { for (auto &[_, real] : cls.realizations) {
auto &vtable = real->vtables[baseCls]; auto &vtable = real->vtables[baseCls];
auto ct = auto ct = ctx->instantiate(ctx->getType(clsName), cp->getClass())->getClass();
ctx->instantiate(ctx->forceFind(clsName)->type, cp->getClass())->getClass();
std::vector<types::TypePtr> args = fp->getArgTypes(); std::vector<types::TypePtr> args = fp->getArgTypes();
args[0] = ct; args[0] = ct;
auto m = findBestMethod(ct, fnName, args); auto m = findBestMethod(ct, fnName, args);
@ -605,7 +605,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
// Thunk name: _thunk.<BASE>.<FN>.<ARGS> // Thunk name: _thunk.<BASE>.<FN>.<ARGS>
auto thunkName = auto thunkName =
format("_thunk.{}.{}.{}", baseCls, m->ast->name, fmt::join(ns, ".")); format("_thunk.{}.{}.{}", baseCls, m->ast->name, fmt::join(ns, "."));
if (in(ctx->cache->functions, thunkName)) if (in(ctx->cache->functions, thunkName+":0"))
continue; continue;
// Thunk contents: // Thunk contents:
@ -614,27 +614,24 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
// __internal__.class_base_to_derived(self, <BASE>, <DERIVED>), // __internal__.class_base_to_derived(self, <BASE>, <DERIVED>),
// <ARGS...>) // <ARGS...>)
std::vector<Param> fnArgs; std::vector<Param> fnArgs;
fnArgs.emplace_back(fp->ast->args[0].name, N<IdExpr>(cp->realizedName()), fnArgs.emplace_back("self", N<IdExpr>(cp->realizedName()), nullptr);
nullptr);
for (size_t i = 1; i < args.size(); i++) for (size_t i = 1; i < args.size(); i++)
fnArgs.emplace_back(fp->ast->args[i].name, N<IdExpr>(args[i]->realizedName()), fnArgs.emplace_back(ctx->cache->rev(fp->ast->args[i].name),
nullptr); N<IdExpr>(args[i]->realizedName()), nullptr);
std::vector<ExprPtr> callArgs; std::vector<ExprPtr> callArgs;
callArgs.emplace_back( callArgs.emplace_back(
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "class_base_to_derived"), N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "class_base_to_derived"),
N<IdExpr>(fp->ast->args[0].name), N<IdExpr>(cp->realizedName()), N<IdExpr>("self"), N<IdExpr>(cp->realizedName()),
N<IdExpr>(real->type->realizedName()))); N<IdExpr>(real->type->realizedName())));
for (size_t i = 1; i < args.size(); i++) for (size_t i = 1; i < args.size(); i++)
callArgs.emplace_back(N<IdExpr>(fp->ast->args[i].name)); callArgs.emplace_back(N<IdExpr>(ctx->cache->rev(fp->ast->args[i].name)));
auto thunkAst = N<FunctionStmt>( auto thunkAst = N<FunctionStmt>(
thunkName, nullptr, fnArgs, thunkName, nullptr, fnArgs,
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(N<IdExpr>(m->ast->name), callArgs))), N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(N<IdExpr>(m->ast->name), callArgs))),
Attr({"std.internal.attributes.inline", Attr::ForceRealize})); Attr({"std.internal.attributes.inline"}));
auto &thunkFn = ctx->cache->functions[thunkAst->name]; thunkAst = std::dynamic_pointer_cast<FunctionStmt>(transform(thunkAst));
thunkFn.ast = clone(thunkAst);
transform(thunkAst); auto &thunkFn = ctx->cache->functions[thunkAst->name];
prependStmts->push_back(thunkAst);
auto ti = ctx->instantiate(thunkFn.type)->getFunc(); auto ti = ctx->instantiate(thunkFn.type)->getFunc();
auto tm = realizeFunc(ti.get(), true); auto tm = realizeFunc(ti.get(), true);
seqassert(tm, "bad thunk {}", thunkFn.type); seqassert(tm, "bad thunk {}", thunkFn.type);
@ -651,8 +648,11 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
auto realizedName = t->ClassType::realizedName(); auto realizedName = t->ClassType::realizedName();
if (!in(ctx->cache->classes[t->name].realizations, realizedName)) if (!in(ctx->cache->classes[t->name].realizations, realizedName))
realize(t->getClass()); realize(t->getClass());
if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir) if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir) {
if (ctx->cache->classes[t->name].rtti)
ir::cast<ir::types::RefType>(l)->setPolymorphic();
return l; return l;
}
auto forceFindIRType = [&](const TypePtr &tt) { auto forceFindIRType = [&](const TypePtr &tt) {
auto t = tt->getClass(); auto t = tt->getClass();

View File

@ -289,7 +289,7 @@ TypecheckVisitor::transformStaticLoopCall(
if (vars.size() != 1) if (vars.size() != 1)
error("expected one item"); error("expected one item");
for (auto &a : args) { for (auto &a : args) {
stmt->rhs = a.value; stmt->rhs = transform(clean_clone(a.value));
if (auto st = stmt->rhs->type->getStatic()) { if (auto st = stmt->rhs->type->getStatic()) {
stmt->type = N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>(st->name)); stmt->type = N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>(st->name));
} else { } else {

View File

@ -543,7 +543,12 @@ void ScopingVisitor::visit(GlobalStmt *stmt) {
} }
void ScopingVisitor::visit(FunctionStmt *stmt) { void ScopingVisitor::visit(FunctionStmt *stmt) {
visitName(stmt->name, true, stmt->shared_from_this(), stmt->getSrcInfo()); bool isOverload = false;
for (auto &d: stmt->decorators)
if (d->isId("overload"))
isOverload = true;
if (!isOverload)
visitName(stmt->name, true, stmt->shared_from_this(), stmt->getSrcInfo());
auto c = std::make_shared<ScopingVisitor::Context>(); auto c = std::make_shared<ScopingVisitor::Context>();
c->cache = ctx->cache; c->cache = ctx->cache;

View File

@ -56,25 +56,29 @@ void TypecheckVisitor::visit(UnaryExpr *expr) {
/// @c transformBinaryInplaceMagic for details. /// @c transformBinaryInplaceMagic for details.
/// Also evaluate static expressions. See @c evaluateStaticBinary for details. /// Also evaluate static expressions. See @c evaluateStaticBinary for details.
void TypecheckVisitor::visit(BinaryExpr *expr) { void TypecheckVisitor::visit(BinaryExpr *expr) {
// Transform lexpr and rexpr. Ignore Nones for now transform(expr->lexpr, true);
if (!(startswith(expr->op, "is") && expr->lexpr->getNone())) transform(expr->rexpr, true);
transform(expr->lexpr, true);
if (!(startswith(expr->op, "is") && expr->rexpr->getNone()))
transform(expr->rexpr, true);
static std::unordered_map<int, std::unordered_set<std::string>> staticOps = { static std::unordered_map<int, std::unordered_set<std::string>> staticOps = {
{1, {1,
{"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//", "%", "&", {"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//", "%", "&",
"|", "^"}}, "|", "^"}},
{2, {"==", "!=", "+"}}, {2, {"==", "!=", "+"}},
{3, {3, {"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}};
{"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}}; if (expr->lexpr->type->isStaticType() && expr->rexpr->type->isStaticType()) {
if (expr->lexpr->type->isStaticType() && auto l = expr->lexpr->type->isStaticType();
expr->lexpr->type->isStaticType() == expr->rexpr->type->isStaticType() && auto r = expr->rexpr->type->isStaticType();
in(staticOps[expr->lexpr->type->isStaticType()], expr->op)) { bool isStatic = l == r && in(staticOps[l], expr->op);
// Handle static expressions if (!isStatic && ((l == 1 && r == 3) || (r == 1 && l == 3)) &&
resultExpr = evaluateStaticBinary(expr); in(staticOps[1], expr->op))
} else if (auto e = transformBinarySimple(expr)) { isStatic = true;
if (isStatic) {
resultExpr = evaluateStaticBinary(expr);
return;
}
}
if (auto e = transformBinarySimple(expr)) {
// Case: simple binary expressions // Case: simple binary expressions
resultExpr = e; resultExpr = e;
} else if (expr->lexpr->getType()->getUnbound() || } else if (expr->lexpr->getType()->getUnbound() ||
@ -264,7 +268,8 @@ void TypecheckVisitor::visit(PipeExpr *expr) {
void TypecheckVisitor::visit(IndexExpr *expr) { void TypecheckVisitor::visit(IndexExpr *expr) {
if (expr->expr->isId("Static")) { if (expr->expr->isId("Static")) {
// Special case: static types. Ensure that static is supported // Special case: static types. Ensure that static is supported
if (!expr->index->isId("int") && !expr->index->isId("str") && !expr->index->isId("bool")) if (!expr->index->isId("int") && !expr->index->isId("str") &&
!expr->index->isId("bool"))
E(Error::BAD_STATIC_TYPE, expr->index); E(Error::BAD_STATIC_TYPE, expr->index);
auto typ = ctx->getUnbound(); auto typ = ctx->getUnbound();
typ->isStatic = getStaticGeneric(expr); typ->isStatic = getStaticGeneric(expr);
@ -364,19 +369,16 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
unify(expr->type, typ); unify(expr->type, typ);
} else { } else {
for (size_t i = 0; i < expr->typeParams.size(); i++) { for (size_t i = 0; i < expr->typeParams.size(); i++) {
// transform(expr->typeParams[i]);
transformType(expr->typeParams[i]); transformType(expr->typeParams[i]);
auto t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(), auto t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(),
getType(expr->typeParams[i])); getType(expr->typeParams[i]));
// if (expr->typeParams[i]->type->isStaticType() && if (expr->typeParams[i]->type->isStaticType() !=
// generics[i].type->isStaticType()) { generics[i].type->isStaticType()) {
// t = ctx->instantiate(expr->typeParams[i]->type); if (expr->typeParams[i]->getNone()) // `None` -> `NoneType`
// } else { transformType(expr->typeParams[i]);
// if (expr->typeParams[i]->getNone()) // `None` -> `NoneType` if (!expr->typeParams[i]->type->is("type"))
// transformType(expr->typeParams[i]); E(Error::EXPECTED_TYPE, expr->typeParams[i], "type");
// if (!expr->typeParams[i]->type->is("type")) }
// E(Error::EXPECTED_TYPE, expr->typeParams[i], "type");
// }
if (isUnion) if (isUnion)
typ->getUnion()->addType(t); typ->getUnion()->addType(t);
else else
@ -454,7 +456,7 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
value = !bool(value); value = !bool(value);
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
if (expr->op == "!") if (expr->op == "!")
return transform(N<IntExpr>(bool(value))); return transform(N<BoolExpr>(value));
else else
return transform(N<IntExpr>(value)); return transform(N<IntExpr>(value));
} else { } else {
@ -469,9 +471,10 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
/// Division and modulus implementations. /// Division and modulus implementations.
std::pair<int64_t, int64_t> divMod(const std::shared_ptr<TypeContext> &ctx, int64_t a, std::pair<int64_t, int64_t> divMod(const std::shared_ptr<TypeContext> &ctx, int64_t a,
int64_t b) { int64_t b) {
if (!b) if (!b) {
E(Error::STATIC_DIV_ZERO, ctx->getSrcInfo()); E(Error::STATIC_DIV_ZERO, ctx->getSrcInfo());
if (ctx->cache->pythonCompat) { return {0, 0};
} else if (ctx->cache->pythonCompat) {
// Use Python implementation. // Use Python implementation.
int64_t d = a / b; int64_t d = a / b;
int64_t m = a - d * b; int64_t m = a - d * b;
@ -511,7 +514,7 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
expr->rexpr->type->getStrStatic()->value; expr->rexpr->type->getStrStatic()->value;
bool value = expr->op == "==" ? eq : !eq; bool value = expr->op == "==" ? eq : !eq;
LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value); LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value);
return transform(N<IntExpr>(value)); return transform(N<BoolExpr>(value));
} else { } else {
// Cannot be evaluated yet: just set the type // Cannot be evaluated yet: just set the type
expr->type->getUnbound()->isStatic = 1; expr->type->getUnbound()->isStatic = 1;
@ -522,8 +525,12 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
// Case: static integers // Case: static integers
if (expr->lexpr->type->getStatic() && expr->rexpr->type->getStatic()) { if (expr->lexpr->type->getStatic() && expr->rexpr->type->getStatic()) {
int64_t lvalue = expr->lexpr->type->getIntStatic() ? expr->lexpr->type->getIntStatic()->value : expr->lexpr->type->getBoolStatic()->value; int64_t lvalue = expr->lexpr->type->getIntStatic()
int64_t rvalue = expr->rexpr->type->getIntStatic() ? expr->rexpr->type->getIntStatic()->value : expr->rexpr->type->getBoolStatic()->value; ? expr->lexpr->type->getIntStatic()->value
: expr->lexpr->type->getBoolStatic()->value;
int64_t rvalue = expr->rexpr->type->getIntStatic()
? expr->rexpr->type->getIntStatic()->value
: expr->rexpr->type->getBoolStatic()->value;
if (expr->op == "<") if (expr->op == "<")
lvalue = lvalue < rvalue; lvalue = lvalue < rvalue;
else if (expr->op == "<=") else if (expr->op == "<=")
@ -596,7 +603,7 @@ ExprPtr TypecheckVisitor::transformBinarySimple(BinaryExpr *expr) {
return transform(N<CallExpr>(N<DotExpr>(expr->rexpr, "__contains__"), expr->lexpr)); return transform(N<CallExpr>(N<DotExpr>(expr->rexpr, "__contains__"), expr->lexpr));
} else if (expr->op == "is") { } else if (expr->op == "is") {
if (expr->lexpr->getNone() && expr->rexpr->getNone()) if (expr->lexpr->getNone() && expr->rexpr->getNone())
return transform(N<IntExpr>(1)); return transform(N<BoolExpr>(true));
else if (expr->lexpr->getNone()) else if (expr->lexpr->getNone())
return transform(N<BinaryExpr>(expr->rexpr, "is", expr->lexpr)); return transform(N<BinaryExpr>(expr->rexpr, "is", expr->lexpr));
} else if (expr->op == "is not") { } else if (expr->op == "is not") {
@ -613,17 +620,17 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
// Case: `is None` expressions // Case: `is None` expressions
if (expr->rexpr->getNone()) { if (expr->rexpr->getNone()) {
if (expr->lexpr->getType()->is("NoneType")) if (expr->lexpr->getType()->is("NoneType"))
return transform(N<IntExpr>(1)); return transform(N<BoolExpr>(true));
if (!expr->lexpr->getType()->is(TYPE_OPTIONAL)) { if (!expr->lexpr->getType()->is(TYPE_OPTIONAL)) {
// lhs is not optional: `return False` // lhs is not optional: `return False`
return transform(N<IntExpr>(0)); return transform(N<BoolExpr>(false));
} else { } else {
// Special case: Optional[Optional[... Optional[NoneType]]...] == NoneType // Special case: Optional[Optional[... Optional[NoneType]]...] == NoneType
auto g = expr->lexpr->getType()->getClass(); auto g = expr->lexpr->getType()->getClass();
for (; g->generics[0].type->is("Optional"); g = g->generics[0].type->getClass()) for (; g->generics[0].type->is("Optional"); g = g->generics[0].type->getClass())
; ;
if (g->generics[0].type->is("NoneType")) if (g->generics[0].type->is("NoneType"))
return transform(N<IntExpr>(1)); return transform(N<BoolExpr>(true));
// lhs is optional: `return lhs.__has__().__invert__()` // lhs is optional: `return lhs.__has__().__invert__()`
return transform(N<CallExpr>( return transform(N<CallExpr>(
@ -640,7 +647,7 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
return nullptr; return nullptr;
} }
if (expr->lexpr->type->is("type") && expr->rexpr->type->is("type")) if (expr->lexpr->type->is("type") && expr->rexpr->type->is("type"))
return transform(N<IntExpr>(lc->realizedName() == rc->realizedName())); return transform(N<BoolExpr>(lc->realizedName() == rc->realizedName()));
if (!lc->getClass()->isRecord() && !rc->getClass()->isRecord()) { if (!lc->getClass()->isRecord() && !rc->getClass()->isRecord()) {
// Both reference types: `return lhs.__raw__() == rhs.__raw__()` // Both reference types: `return lhs.__raw__() == rhs.__raw__()`
return transform( return transform(
@ -659,7 +666,7 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
} }
if (lc->realizedName() != rc->realizedName()) { if (lc->realizedName() != rc->realizedName()) {
// tuple names do not match: `return False` // tuple names do not match: `return False`
return transform(N<IntExpr>(0)); return transform(N<BoolExpr>(false));
} }
// Same tuple types: `return lhs == rhs` // Same tuple types: `return lhs == rhs`
return transform(N<BinaryExpr>(expr->lexpr, "==", expr->rexpr)); return transform(N<BinaryExpr>(expr->lexpr, "==", expr->rexpr));

View File

@ -579,6 +579,20 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
} else { } else {
expr = p; expr = p;
} }
} else if (expectedClass && expectedClass->name == "Function" && exprClass &&
exprClass->getPartial() &&
exprClass->generics[2].type->getClass()->generics.size() == 1 &&
exprClass->generics[2]
.type->getClass()
->generics[0]
.type->getClass()
->generics.empty() &&
exprClass->generics[3]
.type->getClass()
->generics[0]
.type->getClass()
->generics.empty()) {
expr = transform(N<IdExpr>(exprClass->getPartialFunc()->ast->name));
} else if (allowUnwrap && exprClass && expr->type->getUnion() && expectedClass && } else if (allowUnwrap && exprClass && expr->type->getUnion() && expectedClass &&
!expectedClass->getUnion()) { !expectedClass->getUnion()) {
// Extract union types via __internal__.get_union // Extract union types via __internal__.get_union
@ -696,7 +710,8 @@ types::TypePtr TypecheckVisitor::getType(const ExprPtr &e) {
return t; return t;
} }
std::vector<types::TypePtr> TypecheckVisitor::getClassFieldTypes(const types::ClassTypePtr &cls) { std::vector<types::TypePtr>
TypecheckVisitor::getClassFieldTypes(const types::ClassTypePtr &cls) {
std::vector<types::TypePtr> result; std::vector<types::TypePtr> result;
ctx->addBlock(); ctx->addBlock();
addClassGenerics(cls); addClassGenerics(cls);

View File

@ -247,7 +247,7 @@ private: // Node typechecking rules
bool); bool);
std::string generateTuple(size_t); std::string generateTuple(size_t);
int generateKwId(const std::vector<std::string> & = {}); int generateKwId(const std::vector<std::string> & = {});
void addClassGenerics(const types::ClassTypePtr &); void addClassGenerics(const types::ClassTypePtr &, bool instantiate = false);
/* The rest (typecheck.cpp) */ /* The rest (typecheck.cpp) */
void visit(SuiteStmt *) override; void visit(SuiteStmt *) override;

View File

@ -15,7 +15,11 @@ from internal.types.float import *
from internal.types.byte import * from internal.types.byte import *
from internal.types.generator import * from internal.types.generator import *
from internal.types.optional import * from internal.types.optional import *
import internal.c_stubs as _C
from internal.format import *
from internal.internal import * from internal.internal import *
from internal.types.slice import * from internal.types.slice import *
from internal.types.range import * from internal.types.range import *
from internal.types.complex import * from internal.types.complex import *
@ -28,26 +32,21 @@ from internal.types.collections.set import *
from internal.types.collections.dict import * from internal.types.collections.dict import *
from internal.types.collections.tuple import * from internal.types.collections.tuple import *
# Extended core library
import internal.c_stubs as _C
from internal.format import *
from internal.builtin import * from internal.builtin import *
from internal.builtin import _jit_display from internal.builtin import _jit_display
from internal.str import * from internal.str import *
from internal.sort import sorted from internal.sort import sorted
# # from openmp import Ident as __OMPIdent, for_par from openmp import Ident as __OMPIdent, for_par
# # from gpu import _gpu_loop_outline_template from gpu import _gpu_loop_outline_template
# from internal.file import File, gzFile, open, gzopen from internal.file import File, gzFile, open, gzopen
# from pickle import pickle, unpickle from pickle import pickle, unpickle
# from internal.dlopen import dlsym as _dlsym from internal.dlopen import dlsym as _dlsym
# import internal.python import internal.python
# from internal.python import PyError from internal.python import PyError
# # if __py_numerics__: if __py_numerics__:
# # import internal.pynumerics import internal.pynumerics
# # if __py_extension__: if __py_extension__:
# # internal.python.ensure_initialized() internal.python.ensure_initialized()

View File

@ -174,7 +174,7 @@ class Import:
P: Static[str] P: Static[str]
@llvm @llvm
def __new__(P: Static[str], name: str, path: str) -> Import[P]: def __new__(path: str, name: str, P: Static[str]) -> Import[P]:
%0 = insertvalue { {=str}, {=str} } undef, {=str} %path, 0 %0 = insertvalue { {=str}, {=str} } undef, {=str} %path, 0
%1 = insertvalue { {=str}, {=str} } %0, {=str} %name, 1 %1 = insertvalue { {=str}, {=str} } %0, {=str} %name, 1
ret { {=str}, {=str} } %1 ret { {=str}, {=str} } %1

View File

@ -68,6 +68,10 @@ class __internal__:
__vtables__ = Ptr[Ptr[cobj]](alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj]))) __vtables__ = Ptr[Ptr[cobj]](alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj])))
__internal__.class_populate_vtables() __internal__.class_populate_vtables()
# def print(a):
# from C import seq_print(str)
# seq_print(a.__repr__())
def class_populate_vtables() -> None: def class_populate_vtables() -> None:
""" """
Populate content of vtables. Compiler generated. Populate content of vtables. Compiler generated.
@ -91,7 +95,8 @@ class __internal__:
def class_set_rtti_vtable(id: int, sz: int, T: type): def class_set_rtti_vtable(id: int, sz: int, T: type):
if not __has_rtti__(T): if not __has_rtti__(T):
compile_error("class is not polymorphic") compile_error("class is not polymorphic")
__vtables__[id] = Ptr[cobj](sz + 1) p = alloc_atomic_uncollectable((sz + 1) * sizeof(cobj))
__vtables__[id] = Ptr[cobj](p)
__internal__.class_set_typeinfo(__vtables__[id], id) __internal__.class_set_typeinfo(__vtables__[id], id)
def class_set_rtti_vtable_fn(id: int, fid: int, f: cobj, T: type): def class_set_rtti_vtable_fn(id: int, fid: int, f: cobj, T: type):

View File

@ -192,8 +192,9 @@ Jar = Ptr[byte]
@extend @extend
class NoneType: class NoneType:
@llvm
def __new__() -> NoneType: def __new__() -> NoneType:
return () ret {} {}
def __eq__(self, other: NoneType): def __eq__(self, other: NoneType):
return True return True

View File

@ -466,3 +466,15 @@ def test_mandelbrot():
return (MAX, N, pixels, scale(N, -2, 0.4)) return (MAX, N, pixels, scale(N, -2, 0.4))
k(pixels, grid=(N*N)//1024, block=1024) k(pixels, grid=(N*N)//1024, block=1024)
test_mandelbrot() #: 0 1024 (10, 2, [0, 0], 0.4) test_mandelbrot() #: 0 1024 (10, 2, [0, 0], 0.4)
#%% id_shadow_overload_call,barebones
def foo():
def bar():
return -1
def xo():
return bar()
@overload # w/o this this fails because xo cannot capture bar
def bar(a):
return a
bar(1)
foo()

View File

@ -92,12 +92,6 @@ class F[T: Static[float]]:
pass pass
#! expected 'int' or 'str' (only integers and strings can be static) #! expected 'int' or 'str' (only integers and strings can be static)
#%% class_err_10,barebones
def foo[T]():
class A:
x: T
#! name 'T' cannot be captured
#%% class_err_11,barebones #%% class_err_11,barebones
def foo(x): def foo(x):
class A: class A:

View File

@ -103,7 +103,7 @@ for i in range(10):
#%% for_error,barebones #%% for_error,barebones
for i in 1: for i in 1:
pass pass
#! 'int' object has no attribute '__iter__' #! '1' object has no attribute '__iter__'
#%% for_void,barebones #%% for_void,barebones
def foo(): yield def foo(): yield

View File

@ -3,7 +3,7 @@
a, b = False, 1 a, b = False, 1
print not a, not b, ~b, +b, -b, -(+(-b)) #: True False -2 1 -1 1 print not a, not b, ~b, +b, -b, -(+(-b)) #: True False -2 1 -1 1
#%% binary,barebones #%% binary_simple,barebones
x, y = 1, 0 x, y = 1, 0
c = [1, 2, 3] c = [1, 2, 3]
@ -356,16 +356,19 @@ print Foo[int, 3, 4](), Foo[int, 5, 4]()
#%% static_int,barebones #%% static_int,barebones
def foo(n: Static[int]): def foo(n: Static[int]):
print n print n
@overload
def foo(n: Static[bool]):
print n
a: Static[int] = 5 a: Static[int] = 5
foo(a < 1) #: 0 foo(a < 1) #: False
foo(a <= 1) #: 0 foo(a <= 1) #: False
foo(a > 1) #: 1 foo(a > 1) #: True
foo(a >= 1) #: 1 foo(a >= 1) #: True
foo(a == 1) #: 0 foo(a == 1) #: False
foo(a != 1) #: 1 foo(a != 1) #: True
foo(a and 1) #: 1 foo(a and 1) #: True
foo(a or 1) #: 1 foo(a or 1) #: True
foo(a + 1) #: 6 foo(a + 1) #: 6
foo(a - 1) #: 4 foo(a - 1) #: 4
foo(a * 1) #: 5 foo(a * 1) #: 5