From fa7278e616940d7c4978592ef9fee96b37986a1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ibrahim=20Numanagic=CC=81?= Date: Tue, 28 Dec 2021 20:58:20 -0800 Subject: [PATCH] Support for overloaded functions [wip; base logic done] --- codon/parser/cache.cpp | 3 + codon/parser/cache.h | 30 +++-- codon/parser/peg/peg.cpp | 2 - .../parser/visitors/simplify/simplify_ctx.cpp | 22 ++-- codon/parser/visitors/simplify/simplify_ctx.h | 9 +- .../visitors/simplify/simplify_stmt.cpp | 120 +++++++++--------- codon/parser/visitors/translate/translate.cpp | 3 +- codon/parser/visitors/typecheck/typecheck.h | 2 + .../visitors/typecheck/typecheck_ctx.cpp | 19 +-- .../visitors/typecheck/typecheck_expr.cpp | 80 +++++++++--- .../visitors/typecheck/typecheck_stmt.cpp | 14 +- stdlib/internal/str.codon | 1 + test/parser/simplify_stmt.codon | 2 +- test/parser/typecheck_expr.codon | 12 +- test/parser/typecheck_stmt.codon | 12 +- test/parser/types.codon | 75 ++++++----- 16 files changed, 241 insertions(+), 165 deletions(-) diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index 4d774c1b..e94ca102 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -50,6 +50,9 @@ types::ClassTypePtr Cache::findClass(const std::string &name) const { types::FuncTypePtr Cache::findFunction(const std::string &name) const { auto f = typeCtx->find(name); + if (f && f->type && f->kind == TypecheckItem::Func) + return f->type->getFunc(); + f = typeCtx->find(name + ":0"); if (f && f->type && f->kind == TypecheckItem::Func) return f->type->getFunc(); return nullptr; diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 16db95cc..7adcb4c8 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -106,19 +106,9 @@ struct Cache : public std::enable_shared_from_this { /// Non-simplified AST. Used for base class instantiation. std::shared_ptr originalAst; - /// A class function method. - struct ClassMethod { - /// Canonical name of a method (e.g. __init__.1). - std::string name; - /// A corresponding generic function type. - types::FuncTypePtr type; - /// Method age (how many class extension were seen before a method definition). - /// Used to prevent the usage of a method before it was defined in the code. - int age; - }; - /// Class method lookup table. Each name points to a list of ClassMethod instances - /// that share the same method name (a list because methods can be overloaded). - std::unordered_map> methods; + /// Class method lookup table. Each non-canonical name points + /// to a root function name of a corresponding method. + std::unordered_map methods; /// A class field (member). struct ClassField { @@ -177,6 +167,20 @@ struct Cache : public std::enable_shared_from_this { /// corresponding Function instance. std::unordered_map functions; + + struct Overload { + /// Canonical name of an overload (e.g. Foo.__init__.1). + std::string name; + /// Overload age (how many class extension were seen before a method definition). + /// Used to prevent the usage of an overload before it was defined in the code. + /// TODO: I have no recollection of how this was supposed to work. Most likely + /// it does not work at all... + int age; + }; + /// Maps a "root" name of each function to the list of names of the function + /// overloads. + std::unordered_map> overloads; + /// Pointer to the later contexts needed for IR API access. std::shared_ptr typeCtx; std::shared_ptr codegenCtx; diff --git a/codon/parser/peg/peg.cpp b/codon/parser/peg/peg.cpp index f8b8dcc9..80ef9f49 100644 --- a/codon/parser/peg/peg.cpp +++ b/codon/parser/peg/peg.cpp @@ -52,8 +52,6 @@ std::shared_ptr initParser() { template T parseCode(Cache *cache, const std::string &file, std::string code, int line_offset, int col_offset, const std::string &rule) { - TIME("peg"); - // Initialize if (!grammar) grammar = initParser(); diff --git a/codon/parser/visitors/simplify/simplify_ctx.cpp b/codon/parser/visitors/simplify/simplify_ctx.cpp index 3f87dec2..517518db 100644 --- a/codon/parser/visitors/simplify/simplify_ctx.cpp +++ b/codon/parser/visitors/simplify/simplify_ctx.cpp @@ -14,8 +14,9 @@ namespace codon { namespace ast { SimplifyItem::SimplifyItem(Kind k, std::string base, std::string canonicalName, - bool global) - : kind(k), base(move(base)), canonicalName(move(canonicalName)), global(global) {} + bool global, std::string moduleName) + : kind(k), base(move(base)), canonicalName(move(canonicalName)), global(global), + moduleName(move(moduleName)) {} SimplifyContext::SimplifyContext(std::string filename, Cache *cache) : Context(move(filename)), cache(move(cache)), @@ -31,6 +32,7 @@ std::shared_ptr SimplifyContext::add(SimplifyItem::Kind kind, bool global) { seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); auto t = std::make_shared(kind, getBase(), canonicalName, global); + t->moduleName = getModule(); Context::add(name, t); Context::add(canonicalName, t); return t; @@ -60,6 +62,14 @@ std::string SimplifyContext::getBase() const { return bases.back().name; } +std::string SimplifyContext::getModule() const { + std::string base = moduleName.status == ImportFile::STDLIB ? "std." : ""; + base += moduleName.module; + if (startswith(base, "__main__")) + base = base.substr(8); + return base; +} + std::string SimplifyContext::generateCanonicalName(const std::string &name, bool includeBase, bool zeroId) const { @@ -67,12 +77,8 @@ std::string SimplifyContext::generateCanonicalName(const std::string &name, bool alreadyGenerated = name.find('.') != std::string::npos; if (includeBase && !alreadyGenerated) { std::string base = getBase(); - if (base.empty()) { - base = moduleName.status == ImportFile::STDLIB ? "std." : ""; - base += moduleName.module; - if (startswith(base, "__main__")) - base = base.substr(8); - } + if (base.empty()) + base = getModule(); newName = (base.empty() ? "" : (base + ".")) + newName; } auto num = cache->identifierCount[newName]++; diff --git a/codon/parser/visitors/simplify/simplify_ctx.h b/codon/parser/visitors/simplify/simplify_ctx.h index 11a628a1..468ded2a 100644 --- a/codon/parser/visitors/simplify/simplify_ctx.h +++ b/codon/parser/visitors/simplify/simplify_ctx.h @@ -32,13 +32,16 @@ struct SimplifyItem { bool global; /// Non-empty string if a variable is import variable std::string importPath; + /// Full module name + std::string moduleName; public: - SimplifyItem(Kind k, std::string base, std::string canonicalName, - bool global = false); + SimplifyItem(Kind k, std::string base, std::string canonicalName, bool global = false, + std::string moduleName = ""); /// Convenience getters. std::string getBase() const { return base; } + std::string getModule() const { return moduleName; } bool isGlobal() const { return global; } bool isVar() const { return kind == Var; } bool isFunc() const { return kind == Func; } @@ -107,6 +110,8 @@ public: /// Return a canonical name of the top-most base, or an empty string if this is a /// top-level base. std::string getBase() const; + /// Return the current module. + std::string getModule() const; /// Return the current base nesting level (note: bases, not blocks). int getLevel() const { return bases.size(); } /// Pretty-print the current context state. diff --git a/codon/parser/visitors/simplify/simplify_stmt.cpp b/codon/parser/visitors/simplify/simplify_stmt.cpp index 3c4b9386..7e9f2c3b 100644 --- a/codon/parser/visitors/simplify/simplify_stmt.cpp +++ b/codon/parser/visitors/simplify/simplify_stmt.cpp @@ -437,28 +437,30 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { if (stmt->decorators.size() != 1) error("__attribute__ cannot be mixed with other decorators"); attr.isAttribute = true; - } else if (d->isId(Attr::LLVM)) + } else if (d->isId(Attr::LLVM)) { attr.set(Attr::LLVM); - else if (d->isId(Attr::Python)) + } else if (d->isId(Attr::Python)) { attr.set(Attr::Python); - else if (d->isId(Attr::Internal)) + } else if (d->isId(Attr::Internal)) { attr.set(Attr::Internal); - else if (d->isId(Attr::Atomic)) + } else if (d->isId(Attr::Atomic)) { attr.set(Attr::Atomic); - else if (d->isId(Attr::Property)) + } else if (d->isId(Attr::Property)) { attr.set(Attr::Property); - else if (d->isId(Attr::ForceRealize)) + } else if (d->isId(Attr::ForceRealize)) { attr.set(Attr::ForceRealize); - else { + } else { // Let's check if this is a attribute auto dt = transform(clone(d)); if (dt && dt->getId()) { auto ci = ctx->find(dt->getId()->value); if (ci && ci->kind == SimplifyItem::Func) { - if (ctx->cache->functions[ci->canonicalName].ast->attributes.isAttribute) { - attr.set(ci->canonicalName); - continue; - } + if (ctx->cache->overloads[ci->canonicalName].size() == 1) + if (ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name] + .ast->attributes.isAttribute) { + attr.set(ci->canonicalName); + continue; + } } } decorators.emplace_back(clone(d)); @@ -473,19 +475,22 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { } bool isClassMember = ctx->inClass(); - if (isClassMember && !endswith(stmt->name, ".dispatch") && - ctx->cache->classes[ctx->bases.back().name].methods[stmt->name].empty()) { - generateDispatch(stmt->name); - } - auto func_name = stmt->name; - if (endswith(stmt->name, ".dispatch")) - func_name = func_name.substr(0, func_name.size() - 9); - auto canonicalName = ctx->generateCanonicalName( - func_name, true, isClassMember && !endswith(stmt->name, ".dispatch")); - if (endswith(stmt->name, ".dispatch")) { - canonicalName += ".dispatch"; - ctx->cache->reverseIdentifierLookup[canonicalName] = func_name; + std::string rootName; + if (isClassMember) { + auto &m = ctx->cache->classes[ctx->bases.back().name].methods; + auto i = m.find(stmt->name); + if (i != m.end()) + rootName = i->second; + } else if (auto c = ctx->find(stmt->name)) { + if (c->isFunc() && c->getModule() == ctx->getModule() && + c->getBase() == ctx->getBase()) + rootName = c->canonicalName; } + if (rootName.empty()) + rootName = ctx->generateCanonicalName(stmt->name, true); + auto canonicalName = + format("{}:{}", rootName, ctx->cache->overloads[rootName].size()); + ctx->cache->reverseIdentifierLookup[canonicalName] = stmt->name; bool isEnclosedFunc = ctx->inFunction(); if (attr.has(Attr::ForceRealize) && (ctx->getLevel() || isClassMember)) @@ -495,7 +500,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { ctx->bases = std::vector(); if (!isClassMember) // Class members are added to class' method table - ctx->add(SimplifyItem::Func, func_name, canonicalName, ctx->isToplevel()); + ctx->add(SimplifyItem::Func, stmt->name, rootName, ctx->isToplevel()); if (isClassMember) ctx->bases.push_back(oldBases[0]); ctx->bases.emplace_back(SimplifyContext::Base{canonicalName}); // Add new base... @@ -614,8 +619,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { // ... set the enclosing class name... attr.parentClass = ctx->bases.back().name; // ... add the method to class' method list ... - ctx->cache->classes[ctx->bases.back().name].methods[func_name].push_back( - {canonicalName, nullptr, ctx->cache->age}); + ctx->cache->classes[ctx->bases.back().name].methods[stmt->name] = rootName; // ... and if the function references outer class variable (by definition a // generic), mark it as not static as it needs fully instantiated class to be // realized. For example, in class A[T]: def foo(): pass, A.foo() can be realized @@ -624,6 +628,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { if (isMethod) attr.set(Attr::Method); } + ctx->cache->overloads[rootName].push_back({canonicalName, ctx->cache->age}); std::vector partialArgs; if (!captures.empty()) { @@ -649,21 +654,21 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { ExprPtr finalExpr; if (!captures.empty()) - finalExpr = N(N(func_name), partialArgs); + finalExpr = N(N(stmt->name), partialArgs); if (isClassMember && decorators.size()) error("decorators cannot be applied to class methods"); for (int j = int(decorators.size()) - 1; j >= 0; j--) { if (auto c = const_cast(decorators[j]->getCall())) { c->args.emplace(c->args.begin(), - CallExpr::Arg{"", finalExpr ? finalExpr : N(func_name)}); + CallExpr::Arg{"", finalExpr ? finalExpr : N(stmt->name)}); finalExpr = N(c->expr, c->args); } else { finalExpr = - N(decorators[j], finalExpr ? finalExpr : N(func_name)); + N(decorators[j], finalExpr ? finalExpr : N(stmt->name)); } } if (finalExpr) - resultStmt = transform(N(N(func_name), finalExpr)); + resultStmt = transform(N(N(stmt->name), finalExpr)); } void SimplifyVisitor::visit(ClassStmt *stmt) { @@ -744,6 +749,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { ClassStmt *originalAST = nullptr; auto classItem = std::make_shared(SimplifyItem::Type, "", "", ctx->isToplevel()); + classItem->moduleName = ctx->getModule(); if (!extension) { classItem->canonicalName = canonicalName = ctx->generateCanonicalName(name, !attr.has(Attr::Internal)); @@ -949,22 +955,29 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { for (int ai = 0; ai < baseASTs.size(); ai++) { // FUNCS for (auto &mm : ctx->cache->classes[baseASTs[ai]->name].methods) - for (auto &mf : mm.second) { + for (auto &mf : ctx->cache->overloads[mm.second]) { auto f = ctx->cache->functions[mf.name].ast; if (f->attributes.has("autogenerated")) continue; auto subs = substitutions[ai]; - if (ctx->cache->classes[ctx->bases.back().name] - .methods[ctx->cache->reverseIdentifierLookup[f->name]] - .empty()) - generateDispatch(ctx->cache->reverseIdentifierLookup[f->name]); - auto newName = ctx->generateCanonicalName( - ctx->cache->reverseIdentifierLookup[f->name], true); + + std::string rootName; + auto &mts = ctx->cache->classes[ctx->bases.back().name].methods; + auto it = mts.find(ctx->cache->reverseIdentifierLookup[f->name]); + if (it != mts.end()) + rootName = it->second; + else + rootName = ctx->generateCanonicalName( + ctx->cache->reverseIdentifierLookup[f->name], true); + auto newCanonicalName = + format("{}:{}", rootName, ctx->cache->overloads[rootName].size()); + ctx->cache->reverseIdentifierLookup[newCanonicalName] = + ctx->cache->reverseIdentifierLookup[f->name]; auto nf = std::dynamic_pointer_cast( replace(std::static_pointer_cast(f), subs)); - subs[nf->name] = N(newName); - nf->name = newName; + subs[nf->name] = N(newCanonicalName); + nf->name = newCanonicalName; suite->stmts.push_back(nf); nf->attributes.parentClass = ctx->bases.back().name; @@ -972,10 +985,10 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { if (nf->attributes.has(".changedSelf")) // replace self type with new class nf->args[0].type = transformType(ctx->bases.back().ast); preamble->functions.push_back(clone(nf)); - ctx->cache->functions[newName].ast = nf; + ctx->cache->overloads[rootName].push_back({newCanonicalName, ctx->cache->age}); + ctx->cache->functions[newCanonicalName].ast = nf; ctx->cache->classes[ctx->bases.back().name] - .methods[ctx->cache->reverseIdentifierLookup[f->name]] - .push_back({newName, nullptr, ctx->cache->age}); + .methods[ctx->cache->reverseIdentifierLookup[f->name]] = rootName; } } for (auto sp : getClassMethods(stmt->suite)) @@ -1248,8 +1261,10 @@ StmtPtr SimplifyVisitor::transformCImport(const std::string &name, auto f = N(name, ret ? ret->clone() : N("void"), fnArgs, nullptr, attr); StmtPtr tf = transform(f); // Already in the preamble - if (!altName.empty()) + if (!altName.empty()) { ctx->add(altName, ctx->find(name)); + ctx->remove(name); + } return tf; } @@ -1392,10 +1407,11 @@ void SimplifyVisitor::transformNewImport(const ImportFile &file) { stmts[0] = N(); // Add a def import(): ... manually to the cache and to the preamble (it won't be // transformed here!). - ctx->cache->functions[importVar].ast = - N(importVar, nullptr, std::vector{}, N(stmts), - Attr({Attr::ForceRealize})); - preamble->functions.push_back(ctx->cache->functions[importVar].ast->clone()); + ctx->cache->overloads[importVar].push_back({importVar, ctx->cache->age}); + ctx->cache->functions[importVar + ":0"].ast = + N(importVar + ":0", nullptr, std::vector{}, + N(stmts), Attr({Attr::ForceRealize})); + preamble->functions.push_back(ctx->cache->functions[importVar + ":0"].ast->clone()); ; } } @@ -1768,15 +1784,5 @@ std::vector SimplifyVisitor::getClassMethods(const StmtPtr &s) { return v; } -void SimplifyVisitor::generateDispatch(const std::string &name) { - transform(N( - name + ".dispatch", nullptr, - std::vector{Param("*args"), Param("**kwargs")}, - N(N(N( - N(N(ctx->bases.back().name), name), - N(N("args")), N(N("kwargs"))))), - Attr({"autogenerated"}))); -} - } // namespace ast } // namespace codon diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 492211d6..7e0ba953 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -296,7 +296,8 @@ void TranslateVisitor::visit(ForStmt *stmt) { auto c = stmt->decorator->getCall(); seqassert(c, "for par is not a call: {}", stmt->decorator->toString()); auto fc = c->expr->getType()->getFunc(); - seqassert(fc && fc->ast->name == "std.openmp.for_par", "for par is not a function"); + seqassert(fc && fc->ast->name == "std.openmp.for_par:0", + "for par is not a function"); auto schedule = fc->funcGenerics[0].type->getStatic()->expr->staticValue.getString(); bool ordered = fc->funcGenerics[1].type->getStatic()->expr->staticValue.getInt(); diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index 97b77ac0..f8013fc6 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -313,6 +313,8 @@ private: int64_t translateIndex(int64_t idx, int64_t len, bool clamp = false); int64_t sliceAdjustIndices(int64_t length, int64_t *start, int64_t *stop, int64_t step); + types::FuncTypePtr findDispatch(const std::string &fn); + std::string getRootName(const std::string &name); friend struct Cache; }; diff --git a/codon/parser/visitors/typecheck/typecheck_ctx.cpp b/codon/parser/visitors/typecheck/typecheck_ctx.cpp index 2e5ba2f7..aac6eb51 100644 --- a/codon/parser/visitors/typecheck/typecheck_ctx.cpp +++ b/codon/parser/visitors/typecheck/typecheck_ctx.cpp @@ -142,22 +142,23 @@ std::vector TypeContext::findMethod(const std::string &typeN auto m = cache->classes.find(typeName); if (m != cache->classes.end()) { auto t = m->second.methods.find(method); - if (t != m->second.methods.end() && !t->second.empty()) { - seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"), - "first method '{}' is not dispatch", t->second[0].name); + if (t != m->second.methods.end()) { + auto mt = cache->overloads[t->second]; std::unordered_set signatureLoci; std::vector vv; - for (int mti = int(t->second.size()) - 1; mti > 0; mti--) { - auto &mt = t->second[mti]; - if (mt.age <= age) { + for (int mti = int(mt.size()) - 1; mti >= 0; mti--) { + auto &m = mt[mti]; + if (endswith(m.name, ":dispatch")) + continue; + if (m.age <= age) { if (hideShadowed) { - auto sig = cache->functions[mt.name].ast->signature(); + auto sig = cache->functions[m.name].ast->signature(); if (!in(signatureLoci, sig)) { signatureLoci.insert(sig); - vv.emplace_back(mt.type); + vv.emplace_back(cache->functions[m.name].type); } } else { - vv.emplace_back(mt.type); + vv.emplace_back(cache->functions[m.name].type); } } } diff --git a/codon/parser/visitors/typecheck/typecheck_expr.cpp b/codon/parser/visitors/typecheck/typecheck_expr.cpp index 0a390879..78d4171c 100644 --- a/codon/parser/visitors/typecheck/typecheck_expr.cpp +++ b/codon/parser/visitors/typecheck/typecheck_expr.cpp @@ -105,6 +105,9 @@ void TypecheckVisitor::visit(IdExpr *expr) { return; } auto val = ctx->find(expr->value); + if (!val) { + val = ctx->find(expr->value + ":0"); // is it function?! + } seqassert(val, "cannot find IdExpr '{}' ({})", expr->value, expr->getSrcInfo()); auto t = ctx->instantiate(expr, val->type); @@ -725,14 +728,17 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic, ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &expr, ExprPtr &index) { - if (!tuple->getRecord() || - in(std::set{"Ptr", "pyobj", "str", "Array"}, tuple->name)) + if (!tuple->getRecord()) + return nullptr; + if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_PARTIAL)) + // in(std::set{"Ptr", "pyobj", "str", "Array"}, tuple->name)) // Ptr, pyobj and str are internal types and have only one overloaded __getitem__ return nullptr; - if (ctx->cache->classes[tuple->name].methods["__getitem__"].size() != 2) - // n.b.: there is dispatch as well - // TODO: be smarter! there might be a compatible getitem? - return nullptr; + // if (in(ctx->cache->classes[tuple->name].methods, "__getitem__")) { + // ctx->cache->overloads[ctx->cache->classes[tuple->name].methods["__getitem__"]] + // .size() != 1) + // return nullptr; + // } // Extract a static integer value from a compatible expression. auto getInt = [&](int64_t *o, const ExprPtr &e) { @@ -919,9 +925,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, } else if (methods.size() > 1) { auto m = ctx->cache->classes.find(typ->name); auto t = m->second.methods.find(expr->member); - seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"), - "first method is not dispatch"); - bestMethod = t->second[0].type; + bestMethod = findDispatch(t->second); } else { bestMethod = methods[0]; } @@ -1036,7 +1040,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in ctx->bases.back().supers, expr->args); if (m.empty()) error("no matching super methods are available"); - // LOG("found {} <- {}", ctx->bases.back().type->getFunc()->toString(), m[0]->toString()); + // LOG("found {} <- {}", ctx->bases.back().type->getFunc()->toString(), + // m[0]->toString()); ExprPtr e = N(N(m[0]->ast->name), expr->args); return transform(e, false, true); } @@ -1700,7 +1705,7 @@ TypecheckVisitor::findBestMethod(const Expr *expr, const std::string &member, std::vector TypecheckVisitor::findSuperMethods(const types::FuncTypePtr &func) { if (func->ast->attributes.parentClass.empty() || - endswith(func->ast->name, ".dispatch")) + endswith(func->ast->name, ":dispatch")) return {}; auto p = ctx->find(func->ast->attributes.parentClass)->type; if (!p || !p->getClass()) @@ -1711,14 +1716,13 @@ TypecheckVisitor::findSuperMethods(const types::FuncTypePtr &func) { std::vector result; if (m != ctx->cache->classes.end()) { auto t = m->second.methods.find(methodName); - if (t != m->second.methods.end() && !t->second.empty()) { - seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"), - "first method '{}' is not dispatch", t->second[0].name); - for (int mti = 1; mti < t->second.size(); mti++) { - auto &mt = t->second[mti]; - if (mt.type->ast->name == func->ast->name) + if (t != m->second.methods.end()) { + for (auto &m : ctx->cache->overloads[t->second]) { + if (endswith(m.name, ":dispatch")) + continue; + if (m.name == func->ast->name) break; - result.emplace_back(mt.type); + result.emplace_back(ctx->cache->functions[m.name].type); } } } @@ -1860,5 +1864,45 @@ int64_t TypecheckVisitor::sliceAdjustIndices(int64_t length, int64_t *start, return 0; } +types::FuncTypePtr TypecheckVisitor::findDispatch(const std::string &fn) { + for (auto &m : ctx->cache->overloads[fn]) + if (endswith(ctx->cache->functions[m.name].ast->name, ":dispatch")) + return ctx->cache->functions[m.name].type; + + // Generate dispatch and return it! + auto name = fn + ":dispatch"; + + ExprPtr root; + auto a = ctx->cache->functions[ctx->cache->overloads[fn][0].name].ast; + if (!a->attributes.parentClass.empty()) + root = N(N(a->attributes.parentClass), + ctx->cache->reverseIdentifierLookup[fn]); + else + root = N(fn); + root = N(root, N(N("args")), + N(N("kwargs"))); + auto ast = N( + name, nullptr, std::vector{Param("*args"), Param("**kwargs")}, + N(N( + N(N("isinstance"), root->clone(), N("void")), + N(root->clone()), N(root))), + Attr({"autogenerated"})); + ctx->cache->reverseIdentifierLookup[name] = ctx->cache->reverseIdentifierLookup[fn]; + + auto baseType = + ctx->instantiate(N(name).get(), ctx->find(generateFunctionStub(2))->type, + nullptr, false) + ->getRecord(); + auto typ = std::make_shared(baseType, ast.get()); + typ = std::static_pointer_cast(typ->generalize(ctx->typecheckLevel)); + ctx->add(TypecheckItem::Func, name, typ); + + ctx->cache->overloads[fn].insert(ctx->cache->overloads[fn].begin(), {name, 0}); + ctx->cache->functions[name].ast = ast; + ctx->cache->functions[name].type = typ; + prependStmts->push_back(ast); + return typ; +} + } // namespace ast } // namespace codon diff --git a/codon/parser/visitors/typecheck/typecheck_stmt.cpp b/codon/parser/visitors/typecheck/typecheck_stmt.cpp index edbe2ad2..42ece717 100644 --- a/codon/parser/visitors/typecheck/typecheck_stmt.cpp +++ b/codon/parser/visitors/typecheck/typecheck_stmt.cpp @@ -474,12 +474,12 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { typ = std::static_pointer_cast(typ->generalize(ctx->typecheckLevel)); // Check if this is a class method; if so, update the class method lookup table. if (isClassMember) { - auto &methods = ctx->cache->classes[attr.parentClass] - .methods[ctx->cache->reverseIdentifierLookup[stmt->name]]; + auto m = ctx->cache->classes[attr.parentClass] + .methods[ctx->cache->reverseIdentifierLookup[stmt->name]]; bool found = false; - for (auto &i : methods) + for (auto &i : ctx->cache->overloads[m]) if (i.name == stmt->name) { - i.type = typ; + ctx->cache->functions[i.name].type = typ; found = true; break; } @@ -570,5 +570,11 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { stmt->done = true; } +std::string TypecheckVisitor::getRootName(const std::string &name) { + auto p = name.rfind(':'); + seqassert(p != std::string::npos, ": not found in {}", name); + return name.substr(0, p); +} + } // namespace ast } // namespace codon diff --git a/stdlib/internal/str.codon b/stdlib/internal/str.codon index 67a8dd67..ad062b5a 100644 --- a/stdlib/internal/str.codon +++ b/stdlib/internal/str.codon @@ -44,6 +44,7 @@ class str: if c == '\n': d = "\\n" elif c == '\r': d = "\\r" elif c == '\t': d = "\\t" + elif c == '\a': d = "\\a" elif c == '\\': d = "\\\\" elif c == q: d = qe else: diff --git a/test/parser/simplify_stmt.codon b/test/parser/simplify_stmt.codon index 6de479b8..d849b7dd 100644 --- a/test/parser/simplify_stmt.codon +++ b/test/parser/simplify_stmt.codon @@ -538,7 +538,7 @@ def foo() -> int: a{=a} foo() #! not a type or static expression -#! while realizing foo (arguments foo) +#! while realizing foo:0 (arguments foo:0) #%% function_llvm_err_4,barebones a = 5 diff --git a/test/parser/typecheck_expr.codon b/test/parser/typecheck_expr.codon index 1403319e..ce02e25b 100644 --- a/test/parser/typecheck_expr.codon +++ b/test/parser/typecheck_expr.codon @@ -374,7 +374,7 @@ def foo(i, j, k): return i + j + k print foo(1.1, 2.2, 3.3) #: 6.6 p = foo(6, ...) -print p.__class__ #: foo[int,...,...] +print p.__class__ #: foo:0[int,...,...] print p(2, 1) #: 9 print p(k=3, j=6) #: 15 q = p(k=1, ...) @@ -390,11 +390,11 @@ print 42 |> add_two #: 44 def moo(a, b, c=3): print a, b, c m = moo(b=2, ...) -print m.__class__ #: moo[...,int,...] +print m.__class__ #: moo:0[...,int,...] m('s', 1.1) #: s 2 1.1 # # n = m(c=2.2, ...) -print n.__class__ #: moo[...,int,float] +print n.__class__ #: moo:0[...,int,float] n('x') #: x 2 2.2 print n('y').__class__ #: void @@ -403,11 +403,11 @@ def ff(a, b, c): print ff(1.1, 2, True).__class__ #: Tuple[float,int,bool] print ff(1.1, ...)(2, True).__class__ #: Tuple[float,int,bool] y = ff(1.1, ...)(c=True, ...) -print y.__class__ #: ff[float,...,bool] +print y.__class__ #: ff:0[float,...,bool] print ff(1.1, ...)(2, ...)(True).__class__ #: Tuple[float,int,bool] print y('hei').__class__ #: Tuple[float,str,bool] z = ff(1.1, ...)(c='s', ...) -print z.__class__ #: ff[float,...,str] +print z.__class__ #: ff:0[float,...,str] #%% call_arguments_partial,barebones def doo[R, T](a: Callable[[T], R], b: Generator[T], c: Optional[T], d: T): @@ -432,7 +432,7 @@ l = [1] def adder(a, b): return a+b doo(b=l, d=Optional(5), c=l[0], a=adder(b=4, ...)) #: int int -#: adder[.. Generator[int] +#: adder:0[ Generator[int] #: 5 #: 1 Optional[int] #: 5 int diff --git a/test/parser/typecheck_stmt.codon b/test/parser/typecheck_stmt.codon index ace31779..a3099c5c 100644 --- a/test/parser/typecheck_stmt.codon +++ b/test/parser/typecheck_stmt.codon @@ -252,7 +252,7 @@ try: except MyError: print "my" except OSError as o: - print "os", o._hdr[0], len(o._hdr[1]), o._hdr[3][-20:], o._hdr[4] + print "os", o._hdr.typename, len(o._hdr.msg), o._hdr.file[-20:], o._hdr.line #: os std.internal.types.error.OSError 9 typecheck_stmt.codon 249 finally: print "whoa" #: whoa @@ -263,7 +263,7 @@ def foo(): try: foo() except MyError as e: - print e._hdr[0], e._hdr[1] #: MyError foo! + print e._hdr.typename, e._hdr.msg #: MyError foo! #%% throw_error,barebones raise 'hello' #! cannot throw non-exception (first object member must be of type ExcHeader) @@ -291,13 +291,13 @@ def foo(x): print len(x) foo(5) #: 4 -def foo(x): +def foo2(x): if isinstance(x, int): print x+1 return print len(x) -foo(1) #: 2 -foo('s') #: 1 +foo2(1) #: 2 +foo2('s') #: 1 #%% super,barebones class Foo: @@ -341,4 +341,4 @@ class Foo: print 'foo-1', a Foo.foo(1) #! no matching super methods are available -#! while realizing Foo.foo.2 +#! while realizing Foo.foo:0 diff --git a/test/parser/types.codon b/test/parser/types.codon index 764a38a9..5302905a 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -199,10 +199,10 @@ def f[T](x: T) -> T: print f(1.2).__class__ #: float print f('s').__class__ #: str -def f[T](x: T): - return f(x - 1, T) if x else 1 -print f(1) #: 1 -print f(1.1).__class__ #: int +def f2[T](x: T): + return f2(x - 1, T) if x else 1 +print f2(1) #: 1 +print f2(1.1).__class__ #: int #%% recursive_error,barebones @@ -215,7 +215,7 @@ def rec3(x, y): #- ('a, 'b) -> 'b return y rec3(1, 's') #! cannot unify str and int -#! while realizing rec3 (arguments rec3[int,str]) +#! while realizing rec3:0 (arguments rec3:0[int,str]) #%% instantiate_function_2,barebones def fx[T](x: T) -> T: @@ -447,13 +447,13 @@ def f(x): return g(x) print f(5), f('s') #: 5 s -def f[U](x: U, y): +def f2[U](x: U, y): def g[T, U](x: T, y: U): return (x, y) return g(y, x) x, y = 1, 'haha' -print f(x, y).__class__ #: Tuple[str,int] -print f('aa', 1.1, U=str).__class__ #: Tuple[float,str] +print f2(x, y).__class__ #: Tuple[str,int] +print f2('aa', 1.1, U=str).__class__ #: Tuple[float,str] #%% nested_fn_generic_error,barebones def f[U](x: U, y): # ('u, 'a) -> tuple['a, 'u] @@ -464,7 +464,7 @@ print f(1.1, 1, int).__class__ #! cannot unify float and int #%% fn_realization,barebones def ff[T](x: T, y: tuple[T]): - print ff(T=str,...).__class__ #: ff[str,Tuple[str],str] + print ff(T=str,...).__class__ #: ff:0[str,Tuple[str],str] return x x = ff(1, (1,)) print x, x.__class__ #: 1 int @@ -474,7 +474,7 @@ def fg[T](x:T): def g[T](y): z = T() return z - print fg(T=str,...).__class__ #: fg[str,str] + print fg(T=str,...).__class__ #: fg:0[str,str] print g(1, T).__class__ #: int fg(1) print fg(1).__class__ #: void @@ -515,7 +515,7 @@ class A[T]: def foo[W](t: V, u: V, v: V, w: W): return (t, u, v, w) -print A.B.C[bool].foo(W=str, ...).__class__ #: A.B.C.foo.2[bool,bool,bool,str,str] +print A.B.C[bool].foo(W=str, ...).__class__ #: A.B.C.foo:0[bool,bool,bool,str,str] print A.B.C.foo(1,1,1,True) #: (1, 1, 1, True) print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x') print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x') @@ -734,10 +734,10 @@ def test(name, sort, key): def foo(l, f): return [f(i) for i in l] test('hi', foo, lambda x: x+1) #: hi [2, 3, 4, 5] -# TODO -# def foof(l: List[int], x, f: Callable[[int], int]): -# return [f(i)+x for i in l] -# test('qsort', foof(..., 3, ...)) + +def foof(l: List[int], x, f: Callable[[int], int]): + return [f(i)+x for i in l] +test('qsort', foof(x=3, ...), lambda x: x+1) #: qsort [5, 6, 7, 8] #%% class_fn_access,barebones class X[T]: @@ -745,8 +745,7 @@ class X[T]: return (x+x, y+y) y = X[X[int]]() print y.__class__ #: X[X[int]] -print X[float].foo(U=int, ...).__class__ #: X.foo.2[X[float],float,int,int] -# print y.foo.1[float].__class__ +print X[float].foo(U=int, ...).__class__ #: X.foo:0[X[float],float,int,int] print X[int]().foo(1, 's') #: (2, 'ss') #%% class_partial_access,barebones @@ -754,7 +753,7 @@ class X[T]: def foo[U](self, x, y: U): return (x+x, y+y) y = X[X[int]]() -print y.foo(U=float,...).__class__ #: X.foo.2[X[X[int]],...,...] +print y.foo(U=float,...).__class__ #: X.foo:0[X[X[int]],...,...] print y.foo(1, 2.2, float) #: (2, 4.4) #%% forward,barebones @@ -765,10 +764,10 @@ def bar[T](x): print x, T.__class__ foo(bar, 1) #: 1 int -#: bar[...] +#: bar:0[...] foo(bar(...), 's') #: s str -#: bar[...] +#: bar:0[...] z = bar z('s', int) #: s int @@ -786,8 +785,8 @@ def foo(f, x): def bar[T](x): print x, T.__class__ foo(bar(T=int,...), 1) -#! too many arguments for bar[T1,int] (expected maximum 2, got 2) -#! while realizing foo (arguments foo[bar[...],int]) +#! too many arguments for bar:0[T1,int] (expected maximum 2, got 2) +#! while realizing foo:0 (arguments foo:0[bar:0[...],int]) #%% sort_partial def foo(x, y): @@ -806,16 +805,16 @@ def frec(x, y): return grec(x, y) if bl(y) else 2 print frec(1, 2).__class__, frec('s', 1).__class__ #! expression with void type -#! while realizing frec (arguments frec[int,int]) +#! while realizing frec:0 (arguments frec:0[int,int]) #%% return_fn,barebones def retfn(a): def inner(b, *args, **kwargs): print a, b, args, kwargs - print inner.__class__ #: retfn.inner[...,...,int,...] + print inner.__class__ #: retfn:0.inner:0[...,...,int,...] return inner(15, ...) f = retfn(1) -print f.__class__ #: retfn.inner[int,...,int,...] +print f.__class__ #: retfn:0.inner:0[int,...,int,...] f(2,3,foo='bar') #: 1 15 (2, 3) (foo: 'bar') #%% decorator_manual,barebones @@ -823,7 +822,7 @@ def foo(x, *args, **kwargs): print x, args, kwargs return 1 def dec(fn, a): - print 'decorating', fn.__class__ #: decorating foo[...,...,...] + print 'decorating', fn.__class__ #: decorating foo:0[...,...,...] def inner(*args, **kwargs): print 'decorator', args, kwargs #: decorator (5.5, 's') (z: True) return fn(a, *args, **kwargs) @@ -846,7 +845,7 @@ def dec(fn, a): return inner ff = dec(foo, 10) print ff(5.5, 's', z=True) -#: decorating foo[...,...,...] +#: decorating foo:0[...,...,...] #: decorator (5.5, 's') (z: True) #: 10 (5.5, 's') (z: True) #: 1 @@ -856,7 +855,7 @@ def zoo(e, b, *args): return f'zoo: {e}, {b}, {args}' print zoo(2, 3) print zoo('s', 3) -#: decorating zoo[...,...,...] +#: decorating zoo:0[...,...,...] #: decorator (2, 3) () #: zoo: 5, 2, (3) #: decorator ('s', 3) () @@ -869,9 +868,9 @@ def mydecorator(func): print("after") return inner @mydecorator -def foo(): +def foo2(): print("foo") -foo() +foo2() #: before #: foo #: after @@ -891,7 +890,7 @@ def factorial(num): return n factorial(10) #: 3628800 -#: time needed for factorial[...] is 3628799 +#: time needed for factorial:0[...] is 3628799 def dx1(func): def inner(): @@ -921,9 +920,9 @@ def dy2(func): return inner @dy1 @dy2 -def num(a, b): +def num2(a, b): return a+b -print(num(10, 20)) #: 3600 +print(num2(10, 20)) #: 3600 #%% hetero_iter,barebones e = (1, 2, 3, 'foo', 5, 'bar', 6) @@ -970,14 +969,14 @@ def tee(iterable, n=2): return list(gen(d) for d in deques) it = [1,2,3,4] a, b = tee(it) #! cannot typecheck the program -#! while realizing tee (arguments tee[List[int],int]) +#! while realizing tee:0 (arguments tee:0[List[int],int]) #%% new_syntax,barebones def foo[T,U](x: type, y, z: Static[int] = 10): print T.__class__, U.__class__, x.__class__, y.__class__, Int[z+1].__class__ return List[x]() -print foo(T=int,U=str,...).__class__ #: foo[T1,x,z,int,str] -print foo(T=int,U=str,z=5,x=bool,...).__class__ #: foo[T1,bool,5,int,str] +print foo(T=int,U=str,...).__class__ #: foo:0[T1,x,z,int,str] +print foo(T=int,U=str,z=5,x=bool,...).__class__ #: foo:0[T1,bool,5,int,str] print foo(float,3,T=int,U=str,z=5).__class__ #: List[float] foo(float,1,10,str,int) #: str int float int Int[11] @@ -993,11 +992,11 @@ print Foo[5,int,float,6].__class__ #: Foo[5,int,float,6] print Foo(1.1, 10i32, [False], 10u66).__class__ #: Foo[66,bool,float,32] -def foo[N: Static[int]](): +def foo2[N: Static[int]](): print Int[N].__class__, N x: Static[int] = 5 y: Static[int] = 105 - x * 2 -foo(y-x) #: Int[90] 90 +foo2(y-x) #: Int[90] 90 if 1.1+2.2 > 0: x: Static[int] = 88