diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 9219f67f..d54ff29f 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -57,6 +57,10 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) { } TranslateVisitor(cache->codegenCtx).transform(stmts); + + for (auto &[_, f]: cache->functions) + TranslateVisitor(cache->codegenCtx).transform(f.ast); + cache->populatePythonModule(); return main; } @@ -174,7 +178,7 @@ void TranslateVisitor::visit(StringExpr *expr) { void TranslateVisitor::visit(IdExpr *expr) { auto val = ctx->find(expr->value); seqassert(val, "cannot find '{}'", expr->value); - if (expr->value == "__vtable_size__") { + if (expr->value == "__vtable_size__.0") { // LOG("[] __vtable_size__={}", ctx->cache->classRealizationCnt + 2); result = make(expr, ctx->cache->classRealizationCnt + 2, getType(expr->getType())); @@ -438,7 +442,6 @@ void TranslateVisitor::visit(AssignStmt *stmt) { auto isGlobal = in(ctx->cache->globals, var); ir::Var *v = nullptr; - if (!stmt->lhs->type->isInstantiated() || (stmt->lhs->type->is("type"))) { // LOG("{} {}", getSrcInfo(), stmt->toString(0)); return; // type aliases/fn aliases etc @@ -697,9 +700,8 @@ void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt } else { seqassert(ss[i]->getExpr()->expr->getType(), "invalid LLVM type argument: {}", ss[i]->getExpr()->toString()); - literals.emplace_back(getType( - ctx->cache->typeCtx->getType( - ss[i]->getExpr()->expr->getType()))); + literals.emplace_back( + getType(ctx->cache->typeCtx->getType(ss[i]->getExpr()->expr->getType()))); } } bool isDeclare = true; diff --git a/codon/parser/visitors/typecheck/access.cpp b/codon/parser/visitors/typecheck/access.cpp index c3002d84..10b70dc2 100644 --- a/codon/parser/visitors/typecheck/access.cpp +++ b/codon/parser/visitors/typecheck/access.cpp @@ -357,10 +357,11 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, } // Special case: cls.__id__ if (expr->expr->type->is("type") && expr->member == "__id__") { - if (auto c = realize(expr->expr->type)) + if (auto c = realize(getType(expr->expr))) { return transform(N(ctx->cache->classes[c->getClass()->name] .realizations[c->getClass()->realizedName()] ->id)); + } return nullptr; } @@ -618,7 +619,9 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr, // If overload is ambiguous, route through a dispatch function std::string name; if (auto dot = expr->getDot()) { - name = ctx->cache->getMethod(getType(dot->expr)->getClass(), dot->member); + auto methods = ctx->findMethod(getType(dot->expr)->getClass()->name, dot->member, false); + seqassert(!methods.empty(), "unknown method"); + name = ctx->cache->functions[methods.back()->ast->name].rootName; } else { name = expr->getId()->value; } diff --git a/codon/parser/visitors/typecheck/assign.cpp b/codon/parser/visitors/typecheck/assign.cpp index d3ab4cdc..8f54924c 100644 --- a/codon/parser/visitors/typecheck/assign.cpp +++ b/codon/parser/visitors/typecheck/assign.cpp @@ -209,7 +209,7 @@ StmtPtr TypecheckVisitor::transformUpdate(AssignStmt *stmt) { void TypecheckVisitor::visit(AssignMemberStmt *stmt) { transform(stmt->lhs); - if (auto lhsClass = stmt->lhs->getType()->getClass()) { + if (auto lhsClass = getType(stmt->lhs)->getClass()) { auto member = ctx->findMember(lhsClass->name, stmt->member); if (!member && stmt->lhs->type->is("type")) { diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index fd33fe26..fca02e2b 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -471,9 +471,8 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e expr->args.pop_back(); if (!part.args) part.args = transform(N()); // use () - if (!part.kwArgs) { + if (!part.kwArgs) part.kwArgs = transform(N(N("NamedTuple"))); // use NamedTuple() - } } // Unify function type generics with the provided generics @@ -805,7 +804,8 @@ ExprPtr TypecheckVisitor::transformSuper() { self->type = typ; auto typExpr = N(superTyp->name); - typExpr->setType(superTyp); + typExpr->setType(ctx->instantiateGeneric(ctx->getType("type"), {superTyp})); + // LOG("-> {:c} : {:c} {:c}", typ, vCands[1], typExpr->type); return transform(N(N(N("__internal__"), "class_super"), self, typExpr, N(1))); } @@ -819,7 +819,6 @@ ExprPtr TypecheckVisitor::transformSuper() { members.push_back(N(N(funcTyp->ast->args[0].name), field.name)); ExprPtr e = transform(N(N(generateTuple(members.size())), members)); - auto ft = getClassFieldTypes(superTyp); for (size_t i = 0; i < ft.size(); i++) unify( diff --git a/codon/parser/visitors/typecheck/class.cpp b/codon/parser/visitors/typecheck/class.cpp index 13c5798c..47e79b21 100644 --- a/codon/parser/visitors/typecheck/class.cpp +++ b/codon/parser/visitors/typecheck/class.cpp @@ -122,7 +122,8 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { generic->getLink()->trait = std::make_shared(l); } if (auto st = getStaticGeneric(a.type.get())) { - if (st > 3) transform(a.type); // error check + if (st > 3) + transform(a.type); // error check generic->isStatic = st; auto val = ctx->addVar(genName, varName, generic); val->generic = true; @@ -197,8 +198,9 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { // : ctx->generateCanonicalName(a.name); args.emplace_back(varName, transformType(clean_clone(a.type)), transform(clone(a.defaultValue), true)); - ctx->cache->classes[canonicalName].fields.emplace_back(Cache::Class::ClassField{ - varName, types::TypePtr(nullptr), canonicalName}); + ctx->cache->classes[canonicalName].fields.emplace_back( + Cache::Class::ClassField{varName, types::TypePtr(nullptr), + canonicalName}); } } } @@ -248,7 +250,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { for (auto &b : staticBaseASTs) ctx->cache->classes[canonicalName].staticParentClasses.emplace_back(b->name); ctx->cache->classes[canonicalName].ast->validate(); - ctx->cache->classes[canonicalName].module = ctx->getModule(); + ctx->cache->classes[canonicalName].module = ctx->moduleName.path; // Codegen default magic methods // __new__ must be the first @@ -260,11 +262,28 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { for (auto &base : staticBaseASTs) { for (auto &mm : ctx->cache->classes[base->name].methods) for (auto &mf : ctx->cache->overloads[mm.second]) { - auto f = ctx->cache->functions[mf].origAst; + const auto &fp = ctx->cache->functions[mf]; + auto f = fp.origAst; if (f && !f->attributes.has("autogenerated")) { ctx->addBlock(); addClassGenerics(base); - fnStmts.push_back(transform(clean_clone(f))); + // since functions can come from other modules + // make sure to transform them in their respective module + // however makle sure to add/pop generics :/ + auto cf = clean_clone(f); + if (!ctx->isStdlibLoading && fp.module != ctx->moduleName.path) { + auto ictx = ctx->cache->imports[fp.module].ctx; + TypeContext::BaseGuard br(ictx.get(), canonicalName); + ictx->getBase()->type = typ; + ictx->addBlock(); + auto tv = TypecheckVisitor(ictx); + tv.addClassGenerics(typ, true); + cf = std::dynamic_pointer_cast(tv.transform(cf)); + ictx->popBlock(); + } else { + cf = std::dynamic_pointer_cast(transform(cf)); + } + fnStmts.push_back(cf); ctx->popBlock(); } } @@ -328,13 +347,12 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { // Debug information // LOG("[class] {} -> {:c} / {}", canonicalName, typ, // ctx->cache->classes[canonicalName].fields.size()); - // if (auto r = typ->getRecord()) - // for (auto &tx: r->args) - // LOG(" ... {:c}", tx); // for (auto &m : ctx->cache->classes[canonicalName].fields) - // LOG(" - member: {}: {:D}", m.name, m.type); + // LOG(" - member: {}: {:c}", m.name, m.type); // for (auto &m : ctx->cache->classes[canonicalName].methods) // LOG(" - method: {}: {}", m.first, m.second); + // for (auto &m : ctx->cache->classes[canonicalName].mro) + // LOG(" - mro: {:c}", m); // LOG(""); // ctx->dump(); } catch (const exc::ParserException &) { @@ -395,7 +413,10 @@ TypecheckVisitor::parseBaseClasses(std::vector &baseClasses, name = clsTyp->name; asts.push_back(clsTyp); Cache::Class *cachedCls = in(ctx->cache->classes, name); - mro.push_back(cachedCls->mro); + auto rootMro = cachedCls->mro; + for (auto &t : rootMro) + t = ctx->instantiate(t, clsTyp)->getClass(); + mro.push_back(rootMro); // Sanity checks if (attr.has(Attr::Tuple) && typeAst) @@ -438,9 +459,7 @@ TypecheckVisitor::parseBaseClasses(std::vector &baseClasses, transform(clean_clone(a.defaultValue))); ctx->cache->classes[canonicalName].fields.emplace_back(Cache::Class::ClassField{ name, getType(args.back().type), - ctx->cache->classes[ast->name].fields[ai].baseClass - } - ); + ctx->cache->classes[ast->name].fields[ai].baseClass}); ai++; } } @@ -455,7 +474,8 @@ TypecheckVisitor::parseBaseClasses(std::vector &baseClasses, if (ctx->cache->classes[canonicalName].mro.empty()) { E(Error::CLASS_BAD_MRO, getSrcInfo()); } else if (ctx->cache->classes[canonicalName].mro.size() > 1) { - // LOG("[mro] {} -> {}", canonicalName, ctx->cache->classes[canonicalName].mro); + // for (auto &t: ctx->cache->classes[canonicalName].mro) + // LOG("[mro] {} -> {:c}", canonicalName, t); } } return asts; @@ -753,11 +773,19 @@ int TypecheckVisitor::generateKwId(const std::vector &names) { } } -void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp) { +void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp, + bool instantiate) { auto addGen = [&](auto g) { auto t = g.type; + if (instantiate) + if (auto l = g.type->getLink()) + if (l->kind == LinkType::Generic) { + auto lx = std::make_shared(*l); + lx->kind = LinkType::Unbound; + t = lx; + } if (t->getClass() && !t->getStatic() && !t->is("type")) - t = ctx->instantiateGeneric(ctx->getType("type"), {t}); + t = ctx->instantiateGeneric(ctx->getType("type"), {t}); ctx->addVar(ctx->cache->rev(g.name), g.name, t)->generic = true; }; for (auto &g : clsTyp->hiddenGenerics) diff --git a/codon/parser/visitors/typecheck/ctx.cpp b/codon/parser/visitors/typecheck/ctx.cpp index 52de747a..e40cda19 100644 --- a/codon/parser/visitors/typecheck/ctx.cpp +++ b/codon/parser/visitors/typecheck/ctx.cpp @@ -138,6 +138,10 @@ std::string TypeContext::getModule() const { return base; } +std::string TypeContext::getModulePath() const { + return moduleName.path; +} + void TypeContext::dump() { dump(0); } bool TypeContext::isCanonicalName(const std::string &name) const { diff --git a/codon/parser/visitors/typecheck/ctx.h b/codon/parser/visitors/typecheck/ctx.h index 39629d27..9dd82af7 100644 --- a/codon/parser/visitors/typecheck/ctx.h +++ b/codon/parser/visitors/typecheck/ctx.h @@ -205,6 +205,8 @@ public: std::string getBaseName() const; /// Return the current module. std::string getModule() const; + /// Return the current module path. + std::string getModulePath() const; /// Pretty-print the current context state. void dump() override; diff --git a/codon/parser/visitors/typecheck/function.cpp b/codon/parser/visitors/typecheck/function.cpp index 319d4695..d4c13990 100644 --- a/codon/parser/visitors/typecheck/function.cpp +++ b/codon/parser/visitors/typecheck/function.cpp @@ -49,6 +49,8 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) { stmt->expr = partializeFunction(stmt->expr->type->getFunc()); } + if (!ctx->getBase()->returnType->isStaticType() && stmt->expr->type->getStatic()) + stmt->expr->type = stmt->expr->type->getStatic()->getNonStaticType(); unify(ctx->getBase()->returnType, stmt->expr->type); } else { // Just set the expr for the translation stage. However, do not unify the return @@ -397,7 +399,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { // Make function AST and cache it for later realization auto f = N(canonicalName, ret, args, suite, stmt->attributes); - ctx->cache->functions[canonicalName].module = ctx->getModule(); + ctx->cache->functions[canonicalName].module = ctx->moduleName.path; ctx->cache->functions[canonicalName].ast = f; ctx->cache->functions[canonicalName].origAst = stmt_clone; ctx->cache->functions[canonicalName].isToplevel = diff --git a/codon/parser/visitors/typecheck/import.cpp b/codon/parser/visitors/typecheck/import.cpp index ccd37150..46e13bde 100644 --- a/codon/parser/visitors/typecheck/import.cpp +++ b/codon/parser/visitors/typecheck/import.cpp @@ -77,7 +77,7 @@ void TypecheckVisitor::visit(ImportStmt *stmt) { transform(N( N(name), N(N("Import"), N(file->path), - N(file->path), N(file->module))))); + N(file->module), N(file->path))))); } else if (stmt->what->isId("*")) { // Case: from foo import * seqassert(stmt->as.empty(), "renamed star-import"); @@ -202,7 +202,7 @@ StmtPtr TypecheckVisitor::transformCVarImport(const std::string &name, const Exp auto canonical = ctx->generateCanonicalName(name); auto typ = transformType(clone(type)); auto val = ctx->addVar(altName.empty() ? name : altName, canonical, - std::make_shared(typ->type->getClass())); + std::make_shared(getType(typ)->getClass())); auto s = N(N(canonical), nullptr, typ); s->lhs->setAttr(ExprAttr::ExternVar); s->lhs->setType(val->type); diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index 42447e58..d1278a99 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -394,8 +394,9 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) // Use NoneType as the return type when the return type is not specified and // function has no return statement - if (!ast->ret && type->getRetType()->getUnbound()) + if (!ast->ret && type->getRetType()->getUnbound()) { unify(type->getRetType(), ctx->getType("NoneType")); + } // LOG("-> {} {}", key, ret->toString(2)); } // Realize the return type @@ -583,8 +584,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType for (auto &[_, real] : cls.realizations) { auto &vtable = real->vtables[baseCls]; - auto ct = - ctx->instantiate(ctx->forceFind(clsName)->type, cp->getClass())->getClass(); + auto ct = ctx->instantiate(ctx->getType(clsName), cp->getClass())->getClass(); std::vector args = fp->getArgTypes(); args[0] = ct; auto m = findBestMethod(ct, fnName, args); @@ -605,7 +605,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType // Thunk name: _thunk... auto thunkName = format("_thunk.{}.{}.{}", baseCls, m->ast->name, fmt::join(ns, ".")); - if (in(ctx->cache->functions, thunkName)) + if (in(ctx->cache->functions, thunkName+":0")) continue; // Thunk contents: @@ -614,27 +614,24 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType // __internal__.class_base_to_derived(self, , ), // ) std::vector fnArgs; - fnArgs.emplace_back(fp->ast->args[0].name, N(cp->realizedName()), - nullptr); + fnArgs.emplace_back("self", N(cp->realizedName()), nullptr); for (size_t i = 1; i < args.size(); i++) - fnArgs.emplace_back(fp->ast->args[i].name, N(args[i]->realizedName()), - nullptr); + fnArgs.emplace_back(ctx->cache->rev(fp->ast->args[i].name), + N(args[i]->realizedName()), nullptr); std::vector callArgs; callArgs.emplace_back( N(N(N("__internal__"), "class_base_to_derived"), - N(fp->ast->args[0].name), N(cp->realizedName()), + N("self"), N(cp->realizedName()), N(real->type->realizedName()))); for (size_t i = 1; i < args.size(); i++) - callArgs.emplace_back(N(fp->ast->args[i].name)); + callArgs.emplace_back(N(ctx->cache->rev(fp->ast->args[i].name))); auto thunkAst = N( thunkName, nullptr, fnArgs, N(N(N(N(m->ast->name), callArgs))), - Attr({"std.internal.attributes.inline", Attr::ForceRealize})); - auto &thunkFn = ctx->cache->functions[thunkAst->name]; - thunkFn.ast = clone(thunkAst); + Attr({"std.internal.attributes.inline"})); + thunkAst = std::dynamic_pointer_cast(transform(thunkAst)); - transform(thunkAst); - prependStmts->push_back(thunkAst); + auto &thunkFn = ctx->cache->functions[thunkAst->name]; auto ti = ctx->instantiate(thunkFn.type)->getFunc(); auto tm = realizeFunc(ti.get(), true); seqassert(tm, "bad thunk {}", thunkFn.type); @@ -651,8 +648,11 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { auto realizedName = t->ClassType::realizedName(); if (!in(ctx->cache->classes[t->name].realizations, realizedName)) realize(t->getClass()); - if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir) + if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir) { + if (ctx->cache->classes[t->name].rtti) + ir::cast(l)->setPolymorphic(); return l; + } auto forceFindIRType = [&](const TypePtr &tt) { auto t = tt->getClass(); diff --git a/codon/parser/visitors/typecheck/loops.cpp b/codon/parser/visitors/typecheck/loops.cpp index 41d5d3a5..5f4891ed 100644 --- a/codon/parser/visitors/typecheck/loops.cpp +++ b/codon/parser/visitors/typecheck/loops.cpp @@ -289,7 +289,7 @@ TypecheckVisitor::transformStaticLoopCall( if (vars.size() != 1) error("expected one item"); for (auto &a : args) { - stmt->rhs = a.value; + stmt->rhs = transform(clean_clone(a.value)); if (auto st = stmt->rhs->type->getStatic()) { stmt->type = N(N("Static"), N(st->name)); } else { diff --git a/codon/parser/visitors/typecheck/names.cpp b/codon/parser/visitors/typecheck/names.cpp index 7a08fb2d..3383c358 100644 --- a/codon/parser/visitors/typecheck/names.cpp +++ b/codon/parser/visitors/typecheck/names.cpp @@ -543,7 +543,12 @@ void ScopingVisitor::visit(GlobalStmt *stmt) { } void ScopingVisitor::visit(FunctionStmt *stmt) { - visitName(stmt->name, true, stmt->shared_from_this(), stmt->getSrcInfo()); + bool isOverload = false; + for (auto &d: stmt->decorators) + if (d->isId("overload")) + isOverload = true; + if (!isOverload) + visitName(stmt->name, true, stmt->shared_from_this(), stmt->getSrcInfo()); auto c = std::make_shared(); c->cache = ctx->cache; diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index 54573d43..60f68074 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -56,25 +56,29 @@ void TypecheckVisitor::visit(UnaryExpr *expr) { /// @c transformBinaryInplaceMagic for details. /// Also evaluate static expressions. See @c evaluateStaticBinary for details. void TypecheckVisitor::visit(BinaryExpr *expr) { - // Transform lexpr and rexpr. Ignore Nones for now - if (!(startswith(expr->op, "is") && expr->lexpr->getNone())) - transform(expr->lexpr, true); - if (!(startswith(expr->op, "is") && expr->rexpr->getNone())) - transform(expr->rexpr, true); + transform(expr->lexpr, true); + transform(expr->rexpr, true); static std::unordered_map> staticOps = { {1, {"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//", "%", "&", "|", "^"}}, {2, {"==", "!=", "+"}}, - {3, - {"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}}; - if (expr->lexpr->type->isStaticType() && - expr->lexpr->type->isStaticType() == expr->rexpr->type->isStaticType() && - in(staticOps[expr->lexpr->type->isStaticType()], expr->op)) { - // Handle static expressions - resultExpr = evaluateStaticBinary(expr); - } else if (auto e = transformBinarySimple(expr)) { + {3, {"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}}; + if (expr->lexpr->type->isStaticType() && expr->rexpr->type->isStaticType()) { + auto l = expr->lexpr->type->isStaticType(); + auto r = expr->rexpr->type->isStaticType(); + bool isStatic = l == r && in(staticOps[l], expr->op); + if (!isStatic && ((l == 1 && r == 3) || (r == 1 && l == 3)) && + in(staticOps[1], expr->op)) + isStatic = true; + if (isStatic) { + resultExpr = evaluateStaticBinary(expr); + return; + } + } + + if (auto e = transformBinarySimple(expr)) { // Case: simple binary expressions resultExpr = e; } else if (expr->lexpr->getType()->getUnbound() || @@ -264,7 +268,8 @@ void TypecheckVisitor::visit(PipeExpr *expr) { void TypecheckVisitor::visit(IndexExpr *expr) { if (expr->expr->isId("Static")) { // Special case: static types. Ensure that static is supported - if (!expr->index->isId("int") && !expr->index->isId("str") && !expr->index->isId("bool")) + if (!expr->index->isId("int") && !expr->index->isId("str") && + !expr->index->isId("bool")) E(Error::BAD_STATIC_TYPE, expr->index); auto typ = ctx->getUnbound(); typ->isStatic = getStaticGeneric(expr); @@ -364,19 +369,16 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) { unify(expr->type, typ); } else { for (size_t i = 0; i < expr->typeParams.size(); i++) { - // transform(expr->typeParams[i]); transformType(expr->typeParams[i]); auto t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(), getType(expr->typeParams[i])); - // if (expr->typeParams[i]->type->isStaticType() && - // generics[i].type->isStaticType()) { - // t = ctx->instantiate(expr->typeParams[i]->type); - // } else { - // if (expr->typeParams[i]->getNone()) // `None` -> `NoneType` - // transformType(expr->typeParams[i]); - // if (!expr->typeParams[i]->type->is("type")) - // E(Error::EXPECTED_TYPE, expr->typeParams[i], "type"); - // } + if (expr->typeParams[i]->type->isStaticType() != + generics[i].type->isStaticType()) { + if (expr->typeParams[i]->getNone()) // `None` -> `NoneType` + transformType(expr->typeParams[i]); + if (!expr->typeParams[i]->type->is("type")) + E(Error::EXPECTED_TYPE, expr->typeParams[i], "type"); + } if (isUnion) typ->getUnion()->addType(t); else @@ -454,7 +456,7 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) { value = !bool(value); LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); if (expr->op == "!") - return transform(N(bool(value))); + return transform(N(value)); else return transform(N(value)); } else { @@ -469,9 +471,10 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) { /// Division and modulus implementations. std::pair divMod(const std::shared_ptr &ctx, int64_t a, int64_t b) { - if (!b) + if (!b) { E(Error::STATIC_DIV_ZERO, ctx->getSrcInfo()); - if (ctx->cache->pythonCompat) { + return {0, 0}; + } else if (ctx->cache->pythonCompat) { // Use Python implementation. int64_t d = a / b; int64_t m = a - d * b; @@ -511,7 +514,7 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) { expr->rexpr->type->getStrStatic()->value; bool value = expr->op == "==" ? eq : !eq; LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value); - return transform(N(value)); + return transform(N(value)); } else { // Cannot be evaluated yet: just set the type expr->type->getUnbound()->isStatic = 1; @@ -522,8 +525,12 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) { // Case: static integers if (expr->lexpr->type->getStatic() && expr->rexpr->type->getStatic()) { - int64_t lvalue = expr->lexpr->type->getIntStatic() ? expr->lexpr->type->getIntStatic()->value : expr->lexpr->type->getBoolStatic()->value; - int64_t rvalue = expr->rexpr->type->getIntStatic() ? expr->rexpr->type->getIntStatic()->value : expr->rexpr->type->getBoolStatic()->value; + int64_t lvalue = expr->lexpr->type->getIntStatic() + ? expr->lexpr->type->getIntStatic()->value + : expr->lexpr->type->getBoolStatic()->value; + int64_t rvalue = expr->rexpr->type->getIntStatic() + ? expr->rexpr->type->getIntStatic()->value + : expr->rexpr->type->getBoolStatic()->value; if (expr->op == "<") lvalue = lvalue < rvalue; else if (expr->op == "<=") @@ -596,7 +603,7 @@ ExprPtr TypecheckVisitor::transformBinarySimple(BinaryExpr *expr) { return transform(N(N(expr->rexpr, "__contains__"), expr->lexpr)); } else if (expr->op == "is") { if (expr->lexpr->getNone() && expr->rexpr->getNone()) - return transform(N(1)); + return transform(N(true)); else if (expr->lexpr->getNone()) return transform(N(expr->rexpr, "is", expr->lexpr)); } else if (expr->op == "is not") { @@ -613,17 +620,17 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) { // Case: `is None` expressions if (expr->rexpr->getNone()) { if (expr->lexpr->getType()->is("NoneType")) - return transform(N(1)); + return transform(N(true)); if (!expr->lexpr->getType()->is(TYPE_OPTIONAL)) { // lhs is not optional: `return False` - return transform(N(0)); + return transform(N(false)); } else { // Special case: Optional[Optional[... Optional[NoneType]]...] == NoneType auto g = expr->lexpr->getType()->getClass(); for (; g->generics[0].type->is("Optional"); g = g->generics[0].type->getClass()) ; if (g->generics[0].type->is("NoneType")) - return transform(N(1)); + return transform(N(true)); // lhs is optional: `return lhs.__has__().__invert__()` return transform(N( @@ -640,7 +647,7 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) { return nullptr; } if (expr->lexpr->type->is("type") && expr->rexpr->type->is("type")) - return transform(N(lc->realizedName() == rc->realizedName())); + return transform(N(lc->realizedName() == rc->realizedName())); if (!lc->getClass()->isRecord() && !rc->getClass()->isRecord()) { // Both reference types: `return lhs.__raw__() == rhs.__raw__()` return transform( @@ -659,7 +666,7 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) { } if (lc->realizedName() != rc->realizedName()) { // tuple names do not match: `return False` - return transform(N(0)); + return transform(N(false)); } // Same tuple types: `return lhs == rhs` return transform(N(expr->lexpr, "==", expr->rexpr)); diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 2e116e47..1b2b8013 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -579,6 +579,20 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType, } else { expr = p; } + } else if (expectedClass && expectedClass->name == "Function" && exprClass && + exprClass->getPartial() && + exprClass->generics[2].type->getClass()->generics.size() == 1 && + exprClass->generics[2] + .type->getClass() + ->generics[0] + .type->getClass() + ->generics.empty() && + exprClass->generics[3] + .type->getClass() + ->generics[0] + .type->getClass() + ->generics.empty()) { + expr = transform(N(exprClass->getPartialFunc()->ast->name)); } else if (allowUnwrap && exprClass && expr->type->getUnion() && expectedClass && !expectedClass->getUnion()) { // Extract union types via __internal__.get_union @@ -696,7 +710,8 @@ types::TypePtr TypecheckVisitor::getType(const ExprPtr &e) { return t; } -std::vector TypecheckVisitor::getClassFieldTypes(const types::ClassTypePtr &cls) { +std::vector +TypecheckVisitor::getClassFieldTypes(const types::ClassTypePtr &cls) { std::vector result; ctx->addBlock(); addClassGenerics(cls); diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index a6fa8853..726bacd7 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -247,7 +247,7 @@ private: // Node typechecking rules bool); std::string generateTuple(size_t); int generateKwId(const std::vector & = {}); - void addClassGenerics(const types::ClassTypePtr &); + void addClassGenerics(const types::ClassTypePtr &, bool instantiate = false); /* The rest (typecheck.cpp) */ void visit(SuiteStmt *) override; diff --git a/stdlib/internal/__init__.codon b/stdlib/internal/__init__.codon index 83bbd992..75f9dcfb 100644 --- a/stdlib/internal/__init__.codon +++ b/stdlib/internal/__init__.codon @@ -15,7 +15,11 @@ from internal.types.float import * from internal.types.byte import * from internal.types.generator import * from internal.types.optional import * + +import internal.c_stubs as _C +from internal.format import * from internal.internal import * + from internal.types.slice import * from internal.types.range import * from internal.types.complex import * @@ -28,26 +32,21 @@ from internal.types.collections.set import * from internal.types.collections.dict import * from internal.types.collections.tuple import * -# Extended core library - -import internal.c_stubs as _C -from internal.format import * from internal.builtin import * - from internal.builtin import _jit_display from internal.str import * from internal.sort import sorted -# # from openmp import Ident as __OMPIdent, for_par -# # from gpu import _gpu_loop_outline_template -# from internal.file import File, gzFile, open, gzopen -# from pickle import pickle, unpickle -# from internal.dlopen import dlsym as _dlsym -# import internal.python -# from internal.python import PyError +from openmp import Ident as __OMPIdent, for_par +from gpu import _gpu_loop_outline_template +from internal.file import File, gzFile, open, gzopen +from pickle import pickle, unpickle +from internal.dlopen import dlsym as _dlsym +import internal.python +from internal.python import PyError -# # if __py_numerics__: -# # import internal.pynumerics -# # if __py_extension__: -# # internal.python.ensure_initialized() +if __py_numerics__: + import internal.pynumerics +if __py_extension__: + internal.python.ensure_initialized() diff --git a/stdlib/internal/core.codon b/stdlib/internal/core.codon index 296be42f..757e0dad 100644 --- a/stdlib/internal/core.codon +++ b/stdlib/internal/core.codon @@ -174,7 +174,7 @@ class Import: P: Static[str] @llvm - def __new__(P: Static[str], name: str, path: str) -> Import[P]: + def __new__(path: str, name: str, P: Static[str]) -> Import[P]: %0 = insertvalue { {=str}, {=str} } undef, {=str} %path, 0 %1 = insertvalue { {=str}, {=str} } %0, {=str} %name, 1 ret { {=str}, {=str} } %1 diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index 8a6a03d7..715e1c51 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -68,6 +68,10 @@ class __internal__: __vtables__ = Ptr[Ptr[cobj]](alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj]))) __internal__.class_populate_vtables() + # def print(a): + # from C import seq_print(str) + # seq_print(a.__repr__()) + def class_populate_vtables() -> None: """ Populate content of vtables. Compiler generated. @@ -91,7 +95,8 @@ class __internal__: def class_set_rtti_vtable(id: int, sz: int, T: type): if not __has_rtti__(T): compile_error("class is not polymorphic") - __vtables__[id] = Ptr[cobj](sz + 1) + p = alloc_atomic_uncollectable((sz + 1) * sizeof(cobj)) + __vtables__[id] = Ptr[cobj](p) __internal__.class_set_typeinfo(__vtables__[id], id) def class_set_rtti_vtable_fn(id: int, fid: int, f: cobj, T: type): diff --git a/stdlib/internal/types/ptr.codon b/stdlib/internal/types/ptr.codon index c19c9ae2..c136327d 100644 --- a/stdlib/internal/types/ptr.codon +++ b/stdlib/internal/types/ptr.codon @@ -192,8 +192,9 @@ Jar = Ptr[byte] @extend class NoneType: + @llvm def __new__() -> NoneType: - return () + ret {} {} def __eq__(self, other: NoneType): return True diff --git a/test/parser/typecheck/test_access.codon b/test/parser/typecheck/test_access.codon index d4a533c4..5e7bb3bb 100644 --- a/test/parser/typecheck/test_access.codon +++ b/test/parser/typecheck/test_access.codon @@ -466,3 +466,15 @@ def test_mandelbrot(): return (MAX, N, pixels, scale(N, -2, 0.4)) k(pixels, grid=(N*N)//1024, block=1024) test_mandelbrot() #: 0 1024 (10, 2, [0, 0], 0.4) + +#%% id_shadow_overload_call,barebones +def foo(): + def bar(): + return -1 + def xo(): + return bar() + @overload # w/o this this fails because xo cannot capture bar + def bar(a): + return a + bar(1) +foo() diff --git a/test/parser/typecheck/test_class.codon b/test/parser/typecheck/test_class.codon index 87dc7e1f..0f7059e1 100644 --- a/test/parser/typecheck/test_class.codon +++ b/test/parser/typecheck/test_class.codon @@ -92,12 +92,6 @@ class F[T: Static[float]]: pass #! expected 'int' or 'str' (only integers and strings can be static) -#%% class_err_10,barebones -def foo[T](): - class A: - x: T -#! name 'T' cannot be captured - #%% class_err_11,barebones def foo(x): class A: diff --git a/test/parser/typecheck/test_loops.codon b/test/parser/typecheck/test_loops.codon index 502a60c6..3e0ca99e 100644 --- a/test/parser/typecheck/test_loops.codon +++ b/test/parser/typecheck/test_loops.codon @@ -103,7 +103,7 @@ for i in range(10): #%% for_error,barebones for i in 1: pass -#! 'int' object has no attribute '__iter__' +#! '1' object has no attribute '__iter__' #%% for_void,barebones def foo(): yield diff --git a/test/parser/typecheck/test_op.codon b/test/parser/typecheck/test_op.codon index 14052b61..97ad5bcd 100644 --- a/test/parser/typecheck/test_op.codon +++ b/test/parser/typecheck/test_op.codon @@ -3,7 +3,7 @@ a, b = False, 1 print not a, not b, ~b, +b, -b, -(+(-b)) #: True False -2 1 -1 1 -#%% binary,barebones +#%% binary_simple,barebones x, y = 1, 0 c = [1, 2, 3] @@ -356,16 +356,19 @@ print Foo[int, 3, 4](), Foo[int, 5, 4]() #%% static_int,barebones def foo(n: Static[int]): print n +@overload +def foo(n: Static[bool]): + print n a: Static[int] = 5 -foo(a < 1) #: 0 -foo(a <= 1) #: 0 -foo(a > 1) #: 1 -foo(a >= 1) #: 1 -foo(a == 1) #: 0 -foo(a != 1) #: 1 -foo(a and 1) #: 1 -foo(a or 1) #: 1 +foo(a < 1) #: False +foo(a <= 1) #: False +foo(a > 1) #: True +foo(a >= 1) #: True +foo(a == 1) #: False +foo(a != 1) #: True +foo(a and 1) #: True +foo(a or 1) #: True foo(a + 1) #: 6 foo(a - 1) #: 4 foo(a * 1) #: 5