From cc634d1940c7962e2aa5cd097bbf55edc5cc0194 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ibrahim=20Numanagi=C4=87?= Date: Tue, 11 Jan 2022 17:39:15 -0800 Subject: [PATCH] Improved logic for handling overloaded functions (#10) * Backport seq-lang/seq@develop fixes * Backport seq-lang/seq@develop fixes * Select the last matching overload by default (remove scoring logic); Add dispatch stubs for partial overload support * Select the last matching overload by default [wip] * Fix various bugs and update tests * Add support for partial functions with *args/**kwargs; Fix partial method dispatch * Update .gitignore * Fix grammar to allow variable names that have reserved word as a prefix * Add support for super() call * Add super() tests; Allow static inheritance to inherit @extend methods * Support for overloaded functions [wip; base logic done] * Support for overloaded functions * Update .gitignore * Fix partial dots * Rename function overload 'super' to 'superf' * Add support for super() * Add tests for super() * Add tuple_offsetof * Add tuple support for super() * Add isinstance support for inherited classes; Fix review issues Co-authored-by: A. R. Shajii --- .gitignore | 2 + codon/compiler/compiler.cpp | 13 + codon/parser/ast/expr.h | 2 +- codon/parser/cache.cpp | 10 +- codon/parser/cache.h | 38 +- codon/parser/peg/grammar.peg | 10 +- codon/parser/peg/peg.cpp | 2 - codon/parser/visitors/simplify/simplify.cpp | 4 +- codon/parser/visitors/simplify/simplify.h | 3 + .../parser/visitors/simplify/simplify_ctx.cpp | 33 +- codon/parser/visitors/simplify/simplify_ctx.h | 13 +- .../visitors/simplify/simplify_expr.cpp | 16 +- .../visitors/simplify/simplify_stmt.cpp | 107 ++-- codon/parser/visitors/translate/translate.cpp | 3 +- codon/parser/visitors/typecheck/typecheck.cpp | 13 +- codon/parser/visitors/typecheck/typecheck.h | 28 +- .../visitors/typecheck/typecheck_ctx.cpp | 162 ++---- .../parser/visitors/typecheck/typecheck_ctx.h | 16 +- .../visitors/typecheck/typecheck_expr.cpp | 501 +++++++++++++++--- .../visitors/typecheck/typecheck_infer.cpp | 10 +- .../visitors/typecheck/typecheck_stmt.cpp | 24 +- codon/sir/module.cpp | 6 +- stdlib/collections.codon | 2 + stdlib/internal/internal.codon | 18 + stdlib/internal/sort.codon | 8 +- stdlib/internal/types/array.codon | 11 + stdlib/internal/types/bool.codon | 2 +- stdlib/internal/types/collections/list.codon | 24 +- stdlib/internal/types/complex.codon | 88 ++- stdlib/internal/types/float.codon | 123 +++-- stdlib/internal/types/int.codon | 180 +++++-- stdlib/internal/types/intn.codon | 80 +++ stdlib/internal/types/optional.codon | 7 + stdlib/internal/types/ptr.codon | 44 +- stdlib/internal/types/str.codon | 21 +- stdlib/statistics.codon | 16 +- test/parser/simplify_stmt.codon | 21 +- test/parser/typecheck_expr.codon | 219 +++++++- test/parser/typecheck_stmt.codon | 66 ++- test/parser/types.codon | 105 ++-- test/stdlib/datetime_test.codon | 8 +- test/transform/folding.codon | 218 +++++--- 42 files changed, 1611 insertions(+), 666 deletions(-) diff --git a/.gitignore b/.gitignore index 84b29947..3bc55219 100644 --- a/.gitignore +++ b/.gitignore @@ -54,3 +54,5 @@ Thumbs.db extra/jupyter/share/jupyter/kernels/codon/kernel.json scratch.* +_* +.ipynb_checkpoints diff --git a/codon/compiler/compiler.cpp b/codon/compiler/compiler.cpp index 55e8153f..2c920274 100644 --- a/codon/compiler/compiler.cpp +++ b/codon/compiler/compiler.cpp @@ -68,11 +68,24 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code, auto transformed = ast::SimplifyVisitor::apply(cache.get(), std::move(codeStmt), abspath, defines, (testFlags > 1)); t2.log(); + if (codon::getLogger().flags & codon::Logger::FLAG_USER) { + auto fo = fopen("_dump_simplify.sexp", "w"); + fmt::print(fo, "{}\n", transformed->toString(0)); + fclose(fo); + } Timer t3("typecheck"); auto typechecked = ast::TypecheckVisitor::apply(cache.get(), std::move(transformed)); t3.log(); + if (codon::getLogger().flags & codon::Logger::FLAG_USER) { + auto fo = fopen("_dump_typecheck.sexp", "w"); + fmt::print(fo, "{}\n", typechecked->toString(0)); + for (auto &f : cache->functions) + for (auto &r : f.second.realizations) + fmt::print(fo, "{}\n", r.second->ast->toString(0)); + fclose(fo); + } Timer t4("translate"); ast::TranslateVisitor::apply(cache.get(), std::move(typechecked)); diff --git a/codon/parser/ast/expr.h b/codon/parser/ast/expr.h index e1c3be71..df5716c4 100644 --- a/codon/parser/ast/expr.h +++ b/codon/parser/ast/expr.h @@ -275,7 +275,7 @@ struct KeywordStarExpr : public Expr { struct TupleExpr : public Expr { std::vector items; - explicit TupleExpr(std::vector items); + explicit TupleExpr(std::vector items = {}); TupleExpr(const TupleExpr &expr); std::string toString() const override; diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index dedaaf9c..e94ca102 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -50,20 +50,22 @@ types::ClassTypePtr Cache::findClass(const std::string &name) const { types::FuncTypePtr Cache::findFunction(const std::string &name) const { auto f = typeCtx->find(name); + if (f && f->type && f->kind == TypecheckItem::Func) + return f->type->getFunc(); + f = typeCtx->find(name + ":0"); if (f && f->type && f->kind == TypecheckItem::Func) return f->type->getFunc(); return nullptr; } -types::FuncTypePtr -Cache::findMethod(types::ClassType *typ, const std::string &member, - const std::vector> &args) { +types::FuncTypePtr Cache::findMethod(types::ClassType *typ, const std::string &member, + const std::vector &args) { auto e = std::make_shared(typ->name); e->type = typ->getClass(); seqassert(e->type, "not a class"); int oldAge = typeCtx->age; typeCtx->age = 99999; - auto f = typeCtx->findBestMethod(e.get(), member, args); + auto f = TypecheckVisitor(typeCtx).findBestMethod(e.get(), member, args); typeCtx->age = oldAge; return f; } diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 2c32c8bb..63f7161f 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -106,19 +106,9 @@ struct Cache : public std::enable_shared_from_this { /// Non-simplified AST. Used for base class instantiation. std::shared_ptr originalAst; - /// A class function method. - struct ClassMethod { - /// Canonical name of a method (e.g. __init__.1). - std::string name; - /// A corresponding generic function type. - types::FuncTypePtr type; - /// Method age (how many class extension were seen before a method definition). - /// Used to prevent the usage of a method before it was defined in the code. - int age; - }; - /// Class method lookup table. Each name points to a list of ClassMethod instances - /// that share the same method name (a list because methods can be overloaded). - std::unordered_map> methods; + /// Class method lookup table. Each non-canonical name points + /// to a root function name of a corresponding method. + std::unordered_map methods; /// A class field (member). struct ClassField { @@ -143,6 +133,9 @@ struct Cache : public std::enable_shared_from_this { /// Realization lookup table that maps a realized class name to the corresponding /// ClassRealization instance. std::unordered_map> realizations; + /// List of inherited class. We also keep the number of fields each of inherited + /// class. + std::vector> parentClasses; Class() : ast(nullptr), originalAst(nullptr) {} }; @@ -177,6 +170,20 @@ struct Cache : public std::enable_shared_from_this { /// corresponding Function instance. std::unordered_map functions; + + struct Overload { + /// Canonical name of an overload (e.g. Foo.__init__.1). + std::string name; + /// Overload age (how many class extension were seen before a method definition). + /// Used to prevent the usage of an overload before it was defined in the code. + /// TODO: I have no recollection of how this was supposed to work. Most likely + /// it does not work at all... + int age; + }; + /// Maps a "root" name of each function to the list of names of the function + /// overloads. + std::unordered_map> overloads; + /// Pointer to the later contexts needed for IR API access. std::shared_ptr typeCtx; std::shared_ptr codegenCtx; @@ -223,9 +230,8 @@ public: types::FuncTypePtr findFunction(const std::string &name) const; /// Find the class method in a given class type that best matches the given arguments. /// Returns an _uninstantiated_ type. - types::FuncTypePtr - findMethod(types::ClassType *typ, const std::string &member, - const std::vector> &args); + types::FuncTypePtr findMethod(types::ClassType *typ, const std::string &member, + const std::vector &args); /// Given a class type and the matching generic vector, instantiate the type and /// realize it. diff --git a/codon/parser/peg/grammar.peg b/codon/parser/peg/grammar.peg index ad292a2b..92f22344 100644 --- a/codon/parser/peg/grammar.peg +++ b/codon/parser/peg/grammar.peg @@ -61,12 +61,12 @@ small_stmt <- / 'break' &(SPACE / ';' / EOL) { return any(ast(LOC)); } / 'continue' &(SPACE / ';' / EOL) { return any(ast(LOC)); } / global_stmt - / yield_stmt + / yield_stmt &(SPACE / ';' / EOL) / assert_stmt / del_stmt - / return_stmt - / raise_stmt - / print_stmt + / return_stmt &(SPACE / ';' / EOL) + / raise_stmt &(SPACE / ';' / EOL) + / print_stmt / import_stmt / expressions &(_ ';' / _ EOL) { return any(ast(LOC, ac_expr(V0))); } / NAME SPACE expressions { @@ -253,7 +253,7 @@ with_stmt <- 'with' SPACE (with_parens_item / with_item) _ ':' _ suite { with_parens_item <- '(' _ tlist(',', as_item) _ ')' { return VS; } with_item <- list(',', as_item) { return VS; } as_item <- - / expression SPACE 'as' SPACE star_target &(_ (',' / ')' / ':')) { + / expression SPACE 'as' SPACE id &(_ (',' / ')' / ':')) { return pair(ac_expr(V0), ac_expr(V1)); } / expression { return pair(ac_expr(V0), (ExprPtr)nullptr); } diff --git a/codon/parser/peg/peg.cpp b/codon/parser/peg/peg.cpp index f8b8dcc9..80ef9f49 100644 --- a/codon/parser/peg/peg.cpp +++ b/codon/parser/peg/peg.cpp @@ -52,8 +52,6 @@ std::shared_ptr initParser() { template T parseCode(Cache *cache, const std::string &file, std::string code, int line_offset, int col_offset, const std::string &rule) { - TIME("peg"); - // Initialize if (!grammar) grammar = initParser(); diff --git a/codon/parser/visitors/simplify/simplify.cpp b/codon/parser/visitors/simplify/simplify.cpp index ef171ba4..6c149a94 100644 --- a/codon/parser/visitors/simplify/simplify.cpp +++ b/codon/parser/visitors/simplify/simplify.cpp @@ -88,8 +88,10 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil } // Reserve the following static identifiers. for (auto name : {"staticlen", "compile_error", "isinstance", "hasattr", "type", - "TypeVar", "Callable", "argv"}) + "TypeVar", "Callable", "argv", "super", "superf"}) stdlib->generateCanonicalName(name); + stdlib->add(SimplifyItem::Var, "super", "super", true); + stdlib->add(SimplifyItem::Var, "superf", "superf", true); // This code must be placed in a preamble (these are not POD types but are // referenced by the various preamble Function.N and Tuple.N stubs) diff --git a/codon/parser/visitors/simplify/simplify.h b/codon/parser/visitors/simplify/simplify.h index df4a8317..52362f5d 100644 --- a/codon/parser/visitors/simplify/simplify.h +++ b/codon/parser/visitors/simplify/simplify.h @@ -488,6 +488,9 @@ private: // suite recursively, and assumes that each statement is either a function or a // doc-string. std::vector getClassMethods(const StmtPtr &s); + + // Generate dispatch method for partial overloaded calls. + void generateDispatch(const std::string &name); }; } // namespace ast diff --git a/codon/parser/visitors/simplify/simplify_ctx.cpp b/codon/parser/visitors/simplify/simplify_ctx.cpp index 711eeb9c..517518db 100644 --- a/codon/parser/visitors/simplify/simplify_ctx.cpp +++ b/codon/parser/visitors/simplify/simplify_ctx.cpp @@ -14,8 +14,9 @@ namespace codon { namespace ast { SimplifyItem::SimplifyItem(Kind k, std::string base, std::string canonicalName, - bool global) - : kind(k), base(move(base)), canonicalName(move(canonicalName)), global(global) {} + bool global, std::string moduleName) + : kind(k), base(move(base)), canonicalName(move(canonicalName)), global(global), + moduleName(move(moduleName)) {} SimplifyContext::SimplifyContext(std::string filename, Cache *cache) : Context(move(filename)), cache(move(cache)), @@ -31,6 +32,7 @@ std::shared_ptr SimplifyContext::add(SimplifyItem::Kind kind, bool global) { seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); auto t = std::make_shared(kind, getBase(), canonicalName, global); + t->moduleName = getModule(); Context::add(name, t); Context::add(canonicalName, t); return t; @@ -60,22 +62,29 @@ std::string SimplifyContext::getBase() const { return bases.back().name; } +std::string SimplifyContext::getModule() const { + std::string base = moduleName.status == ImportFile::STDLIB ? "std." : ""; + base += moduleName.module; + if (startswith(base, "__main__")) + base = base.substr(8); + return base; +} + std::string SimplifyContext::generateCanonicalName(const std::string &name, - bool includeBase) const { + bool includeBase, + bool zeroId) const { std::string newName = name; - if (includeBase && name.find('.') == std::string::npos) { + bool alreadyGenerated = name.find('.') != std::string::npos; + if (includeBase && !alreadyGenerated) { std::string base = getBase(); - if (base.empty()) { - base = moduleName.status == ImportFile::STDLIB ? "std." : ""; - base += moduleName.module; - if (startswith(base, "__main__")) - base = base.substr(8); - } + if (base.empty()) + base = getModule(); newName = (base.empty() ? "" : (base + ".")) + newName; } auto num = cache->identifierCount[newName]++; - newName = num ? format("{}.{}", newName, num) : newName; - if (newName != name) + if (num) + newName = format("{}.{}", newName, num); + if (name != newName && !zeroId) cache->identifierCount[newName]++; cache->reverseIdentifierLookup[newName] = name; return newName; diff --git a/codon/parser/visitors/simplify/simplify_ctx.h b/codon/parser/visitors/simplify/simplify_ctx.h index ed5ad79a..468ded2a 100644 --- a/codon/parser/visitors/simplify/simplify_ctx.h +++ b/codon/parser/visitors/simplify/simplify_ctx.h @@ -32,13 +32,16 @@ struct SimplifyItem { bool global; /// Non-empty string if a variable is import variable std::string importPath; + /// Full module name + std::string moduleName; public: - SimplifyItem(Kind k, std::string base, std::string canonicalName, - bool global = false); + SimplifyItem(Kind k, std::string base, std::string canonicalName, bool global = false, + std::string moduleName = ""); /// Convenience getters. std::string getBase() const { return base; } + std::string getModule() const { return moduleName; } bool isGlobal() const { return global; } bool isVar() const { return kind == Var; } bool isFunc() const { return kind == Func; } @@ -107,14 +110,16 @@ public: /// Return a canonical name of the top-most base, or an empty string if this is a /// top-level base. std::string getBase() const; + /// Return the current module. + std::string getModule() const; /// Return the current base nesting level (note: bases, not blocks). int getLevel() const { return bases.size(); } /// Pretty-print the current context state. void dump() override { dump(0); } /// Generate a unique identifier (name) for a given string. - std::string generateCanonicalName(const std::string &name, - bool includeBase = false) const; + std::string generateCanonicalName(const std::string &name, bool includeBase = false, + bool zeroId = false) const; bool inFunction() const { return getLevel() && !bases.back().isType(); } bool inClass() const { return getLevel() && bases.back().isType(); } diff --git a/codon/parser/visitors/simplify/simplify_expr.cpp b/codon/parser/visitors/simplify/simplify_expr.cpp index 053e384b..5f5299d5 100644 --- a/codon/parser/visitors/simplify/simplify_expr.cpp +++ b/codon/parser/visitors/simplify/simplify_expr.cpp @@ -385,12 +385,22 @@ void SimplifyVisitor::visit(IndexExpr *expr) { } // IndexExpr[i1, ..., iN] is internally stored as IndexExpr[TupleExpr[i1, ..., iN]] // for N > 1, so make sure to check that case. + + std::vector it; if (auto t = index->getTuple()) for (auto &i : t->items) - it.push_back(transform(i, true)); + it.push_back(i); else - it.push_back(transform(index, true)); + it.push_back(index); + for (auto &i: it) { + if (auto es = i->getStar()) + i = N(transform(es->what)); + else if (auto ek = CAST(i, KeywordStarExpr)) + i = N(transform(ek->what)); + else + i = transform(i, true); + } if (e->isType()) { resultExpr = N(e, it); resultExpr->markType(); @@ -617,7 +627,7 @@ void SimplifyVisitor::visit(DotExpr *expr) { auto s = join(chain, ".", importEnd, i + 1); val = fctx->find(s); // Make sure that we access only global imported variables. - if (val && (importName.empty() || val->isGlobal())) { + if (val && (importName.empty() || val->isType() || val->isGlobal())) { itemName = val->canonicalName; itemEnd = i + 1; if (!importName.empty()) diff --git a/codon/parser/visitors/simplify/simplify_stmt.cpp b/codon/parser/visitors/simplify/simplify_stmt.cpp index a3158edc..178fc12a 100644 --- a/codon/parser/visitors/simplify/simplify_stmt.cpp +++ b/codon/parser/visitors/simplify/simplify_stmt.cpp @@ -437,28 +437,30 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { if (stmt->decorators.size() != 1) error("__attribute__ cannot be mixed with other decorators"); attr.isAttribute = true; - } else if (d->isId(Attr::LLVM)) + } else if (d->isId(Attr::LLVM)) { attr.set(Attr::LLVM); - else if (d->isId(Attr::Python)) + } else if (d->isId(Attr::Python)) { attr.set(Attr::Python); - else if (d->isId(Attr::Internal)) + } else if (d->isId(Attr::Internal)) { attr.set(Attr::Internal); - else if (d->isId(Attr::Atomic)) + } else if (d->isId(Attr::Atomic)) { attr.set(Attr::Atomic); - else if (d->isId(Attr::Property)) + } else if (d->isId(Attr::Property)) { attr.set(Attr::Property); - else if (d->isId(Attr::ForceRealize)) + } else if (d->isId(Attr::ForceRealize)) { attr.set(Attr::ForceRealize); - else { + } else { // Let's check if this is a attribute auto dt = transform(clone(d)); if (dt && dt->getId()) { auto ci = ctx->find(dt->getId()->value); if (ci && ci->kind == SimplifyItem::Func) { - if (ctx->cache->functions[ci->canonicalName].ast->attributes.isAttribute) { - attr.set(ci->canonicalName); - continue; - } + if (ctx->cache->overloads[ci->canonicalName].size() == 1) + if (ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name] + .ast->attributes.isAttribute) { + attr.set(ci->canonicalName); + continue; + } } } decorators.emplace_back(clone(d)); @@ -472,8 +474,23 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { return; } - auto canonicalName = ctx->generateCanonicalName(stmt->name, true); bool isClassMember = ctx->inClass(); + std::string rootName; + if (isClassMember) { + auto &m = ctx->cache->classes[ctx->bases.back().name].methods; + auto i = m.find(stmt->name); + if (i != m.end()) + rootName = i->second; + } else if (auto c = ctx->find(stmt->name)) { + if (c->isFunc() && c->getModule() == ctx->getModule() && + c->getBase() == ctx->getBase()) + rootName = c->canonicalName; + } + if (rootName.empty()) + rootName = ctx->generateCanonicalName(stmt->name, true); + auto canonicalName = + format("{}:{}", rootName, ctx->cache->overloads[rootName].size()); + ctx->cache->reverseIdentifierLookup[canonicalName] = stmt->name; bool isEnclosedFunc = ctx->inFunction(); if (attr.has(Attr::ForceRealize) && (ctx->getLevel() || isClassMember)) @@ -483,7 +500,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { ctx->bases = std::vector(); if (!isClassMember) // Class members are added to class' method table - ctx->add(SimplifyItem::Func, stmt->name, canonicalName, ctx->isToplevel()); + ctx->add(SimplifyItem::Func, stmt->name, rootName, ctx->isToplevel()); if (isClassMember) ctx->bases.push_back(oldBases[0]); ctx->bases.emplace_back(SimplifyContext::Base{canonicalName}); // Add new base... @@ -527,6 +544,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { if (!typeAst && isClassMember && ia == 0 && a.name == "self") { typeAst = ctx->bases[ctx->bases.size() - 2].ast; attr.set(".changedSelf"); + attr.set(Attr::Method); } if (attr.has(Attr::C)) { @@ -602,8 +620,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { // ... set the enclosing class name... attr.parentClass = ctx->bases.back().name; // ... add the method to class' method list ... - ctx->cache->classes[ctx->bases.back().name].methods[stmt->name].push_back( - {canonicalName, nullptr, ctx->cache->age}); + ctx->cache->classes[ctx->bases.back().name].methods[stmt->name] = rootName; // ... and if the function references outer class variable (by definition a // generic), mark it as not static as it needs fully instantiated class to be // realized. For example, in class A[T]: def foo(): pass, A.foo() can be realized @@ -612,6 +629,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { if (isMethod) attr.set(Attr::Method); } + ctx->cache->overloads[rootName].push_back({canonicalName, ctx->cache->age}); std::vector partialArgs; if (!captures.empty()) { @@ -732,6 +750,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { ClassStmt *originalAST = nullptr; auto classItem = std::make_shared(SimplifyItem::Type, "", "", ctx->isToplevel()); + classItem->moduleName = ctx->getModule(); if (!extension) { classItem->canonicalName = canonicalName = ctx->generateCanonicalName(name, !attr.has(Attr::Internal)); @@ -770,6 +789,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { std::vector> substitutions; std::vector argSubstitutions; std::unordered_set seenMembers; + std::vector baseASTsFields; for (auto &baseClass : stmt->baseClasses) { std::string bcName; std::vector subs; @@ -810,6 +830,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { if (!extension) ctx->cache->classes[canonicalName].fields.push_back({a.name, nullptr}); } + baseASTsFields.push_back(args.size()); } // Add generics, if any, to the context. @@ -891,6 +912,9 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { ctx->moduleName.module); ctx->cache->classes[canonicalName].ast = N(canonicalName, args, N(), attr); + for (int i = 0; i < baseASTs.size(); i++) + ctx->cache->classes[canonicalName].parentClasses.push_back( + {baseASTs[i]->name, baseASTsFields[i]}); std::vector fns; ExprPtr codeType = ctx->bases.back().ast->clone(); std::vector magics{}; @@ -934,29 +958,45 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { suite->stmts.push_back(preamble->functions.back()); } } - for (int ai = 0; ai < baseASTs.size(); ai++) - for (auto sp : getClassMethods(baseASTs[ai]->suite)) - if (auto f = sp->getFunction()) { + for (int ai = 0; ai < baseASTs.size(); ai++) { + // FUNCS + for (auto &mm : ctx->cache->classes[baseASTs[ai]->name].methods) + for (auto &mf : ctx->cache->overloads[mm.second]) { + auto f = ctx->cache->functions[mf.name].ast; if (f->attributes.has("autogenerated")) continue; + auto subs = substitutions[ai]; - auto newName = ctx->generateCanonicalName( - ctx->cache->reverseIdentifierLookup[f->name], true); - auto nf = std::dynamic_pointer_cast(replace(sp, subs)); - subs[nf->name] = N(newName); - nf->name = newName; + + std::string rootName; + auto &mts = ctx->cache->classes[ctx->bases.back().name].methods; + auto it = mts.find(ctx->cache->reverseIdentifierLookup[f->name]); + if (it != mts.end()) + rootName = it->second; + else + rootName = ctx->generateCanonicalName( + ctx->cache->reverseIdentifierLookup[f->name], true); + auto newCanonicalName = + format("{}:{}", rootName, ctx->cache->overloads[rootName].size()); + ctx->cache->reverseIdentifierLookup[newCanonicalName] = + ctx->cache->reverseIdentifierLookup[f->name]; + auto nf = std::dynamic_pointer_cast( + replace(std::static_pointer_cast(f), subs)); + subs[nf->name] = N(newCanonicalName); + nf->name = newCanonicalName; suite->stmts.push_back(nf); nf->attributes.parentClass = ctx->bases.back().name; // check original ast... - if (nf->attributes.has(".changedSelf")) + if (nf->attributes.has(".changedSelf")) // replace self type with new class nf->args[0].type = transformType(ctx->bases.back().ast); preamble->functions.push_back(clone(nf)); - ctx->cache->functions[newName].ast = nf; + ctx->cache->overloads[rootName].push_back({newCanonicalName, ctx->cache->age}); + ctx->cache->functions[newCanonicalName].ast = nf; ctx->cache->classes[ctx->bases.back().name] - .methods[ctx->cache->reverseIdentifierLookup[f->name]] - .push_back({newName, nullptr, ctx->cache->age}); + .methods[ctx->cache->reverseIdentifierLookup[f->name]] = rootName; } + } for (auto sp : getClassMethods(stmt->suite)) if (sp && !sp->getClass()) { transform(sp); @@ -1227,8 +1267,10 @@ StmtPtr SimplifyVisitor::transformCImport(const std::string &name, auto f = N(name, ret ? ret->clone() : N("void"), fnArgs, nullptr, attr); StmtPtr tf = transform(f); // Already in the preamble - if (!altName.empty()) + if (!altName.empty()) { ctx->add(altName, ctx->find(name)); + ctx->remove(name); + } return tf; } @@ -1371,10 +1413,11 @@ void SimplifyVisitor::transformNewImport(const ImportFile &file) { stmts[0] = N(); // Add a def import(): ... manually to the cache and to the preamble (it won't be // transformed here!). - ctx->cache->functions[importVar].ast = - N(importVar, nullptr, std::vector{}, N(stmts), - Attr({Attr::ForceRealize})); - preamble->functions.push_back(ctx->cache->functions[importVar].ast->clone()); + ctx->cache->overloads[importVar].push_back({importVar + ":0", ctx->cache->age}); + ctx->cache->functions[importVar + ":0"].ast = + N(importVar + ":0", nullptr, std::vector{}, + N(stmts), Attr({Attr::ForceRealize})); + preamble->functions.push_back(ctx->cache->functions[importVar + ":0"].ast->clone()); ; } } diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 492211d6..7e0ba953 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -296,7 +296,8 @@ void TranslateVisitor::visit(ForStmt *stmt) { auto c = stmt->decorator->getCall(); seqassert(c, "for par is not a call: {}", stmt->decorator->toString()); auto fc = c->expr->getType()->getFunc(); - seqassert(fc && fc->ast->name == "std.openmp.for_par", "for par is not a function"); + seqassert(fc && fc->ast->name == "std.openmp.for_par:0", + "for par is not a function"); auto schedule = fc->funcGenerics[0].type->getStatic()->expr->staticValue.getString(); bool ordered = fc->funcGenerics[1].type->getStatic()->expr->staticValue.getInt(); diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 867d7da7..77dc0ba0 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -33,16 +33,21 @@ StmtPtr TypecheckVisitor::apply(Cache *cache, StmtPtr stmts) { return std::move(infer.second); } -TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b) { +TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b, bool undoOnSuccess) { if (!a) return a = b; seqassert(b, "rhs is nullptr"); types::Type::Unification undo; - if (a->unify(b.get(), &undo) >= 0) + if (a->unify(b.get(), &undo) >= 0) { + if (undoOnSuccess) + undo.undo(); return a; - undo.undo(); + } else { + undo.undo(); + } // LOG("{} / {}", a->debugString(true), b->debugString(true)); - a->unify(b.get(), &undo); + if (!undoOnSuccess) + a->unify(b.get(), &undo); error("cannot unify {} and {}", a->toString(), b->toString()); return nullptr; } diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index e8344461..91e0aa74 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -283,9 +283,31 @@ private: void generateFnCall(int n); /// Make an empty partial call fn(...) for a function fn. ExprPtr partializeFunction(ExprPtr expr); + /// Picks the best method of a given expression that matches the given argument + /// types. Prefers methods whose signatures are closer to the given arguments: + /// e.g. foo(int) will match (int) better that a foo(T). + /// Also takes care of the Optional arguments. + /// If multiple equally good methods are found, return the first one. + /// Return nullptr if no methods were found. + types::FuncTypePtr findBestMethod(const Expr *expr, const std::string &member, + const std::vector &args); + types::FuncTypePtr findBestMethod(const Expr *expr, const std::string &member, + const std::vector &args); + types::FuncTypePtr findBestMethod(const std::string &fn, + const std::vector &args); + std::vector findSuperMethods(const types::FuncTypePtr &func); + std::vector + findMatchingMethods(types::ClassType *typ, + const std::vector &methods, + const std::vector &args); + + ExprPtr transformSuper(const CallExpr *expr); + std::vector getSuperTypes(const types::ClassTypePtr &cls); + private: - types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b); + types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b, + bool undoOnSuccess = false); types::TypePtr realizeType(types::ClassType *typ); types::TypePtr realizeFunc(types::FuncType *typ); std::pair inferTypes(StmtPtr stmt, bool keepLast, @@ -293,10 +315,12 @@ private: codon::ir::types::Type *getLLVMType(const types::ClassType *t); bool wrapExpr(ExprPtr &expr, types::TypePtr expectedType, - const types::FuncTypePtr &callee); + const types::FuncTypePtr &callee, bool undoOnSuccess = false); int64_t translateIndex(int64_t idx, int64_t len, bool clamp = false); int64_t sliceAdjustIndices(int64_t length, int64_t *start, int64_t *stop, int64_t step); + types::FuncTypePtr findDispatch(const std::string &fn); + std::string getRootName(const std::string &name); friend struct Cache; }; diff --git a/codon/parser/visitors/typecheck/typecheck_ctx.cpp b/codon/parser/visitors/typecheck/typecheck_ctx.cpp index 31297307..aac6eb51 100644 --- a/codon/parser/visitors/typecheck/typecheck_ctx.cpp +++ b/codon/parser/visitors/typecheck/typecheck_ctx.cpp @@ -103,19 +103,20 @@ types::TypePtr TypeContext::instantiate(const Expr *expr, types::TypePtr type, if (auto l = i.second->getLink()) { if (l->kind != types::LinkType::Unbound) continue; - i.second->setSrcInfo(expr->getSrcInfo()); + if (expr) + i.second->setSrcInfo(expr->getSrcInfo()); if (activeUnbounds.find(i.second) == activeUnbounds.end()) { LOG_TYPECHECK("[ub] #{} -> {} (during inst of {}): {} ({})", i.first, i.second->debugString(true), type->debugString(true), - expr->toString(), activate); + expr ? expr->toString() : "", activate); if (activate && allowActivation) - activeUnbounds[i.second] = - format("{} of {} in {}", l->genericName.empty() ? "?" : l->genericName, - type->toString(), cache->getContent(expr->getSrcInfo())); + activeUnbounds[i.second] = format( + "{} of {} in {}", l->genericName.empty() ? "?" : l->genericName, + type->toString(), expr ? cache->getContent(expr->getSrcInfo()) : ""); } } } - LOG_TYPECHECK("[inst] {} -> {}", expr->toString(), t->debugString(true)); + LOG_TYPECHECK("[inst] {} -> {}", expr ? expr->toString() : "", t->debugString(true)); return t; } @@ -135,24 +136,29 @@ TypeContext::instantiateGeneric(const Expr *expr, types::TypePtr root, return instantiate(expr, root, g.get()); } -std::vector -TypeContext::findMethod(const std::string &typeName, const std::string &method) const { +std::vector TypeContext::findMethod(const std::string &typeName, + const std::string &method, + bool hideShadowed) const { auto m = cache->classes.find(typeName); if (m != cache->classes.end()) { auto t = m->second.methods.find(method); if (t != m->second.methods.end()) { - std::unordered_map signatureLoci; + auto mt = cache->overloads[t->second]; + std::unordered_set signatureLoci; std::vector vv; - for (auto &mt : t->second) { - // LOG("{}::{} @ {} vs. {}", typeName, method, age, mt.age); - if (mt.age <= age) { - auto sig = cache->functions[mt.name].ast->signature(); - auto it = signatureLoci.find(sig); - if (it != signatureLoci.end()) - vv[it->second] = mt.type; - else { - signatureLoci[sig] = vv.size(); - vv.emplace_back(mt.type); + for (int mti = int(mt.size()) - 1; mti >= 0; mti--) { + auto &m = mt[mti]; + if (endswith(m.name, ":dispatch")) + continue; + if (m.age <= age) { + if (hideShadowed) { + auto sig = cache->functions[m.name].ast->signature(); + if (!in(signatureLoci, sig)) { + signatureLoci.insert(sig); + vv.emplace_back(cache->functions[m.name].type); + } + } else { + vv.emplace_back(cache->functions[m.name].type); } } } @@ -177,110 +183,6 @@ types::TypePtr TypeContext::findMember(const std::string &typeName, return nullptr; } -types::FuncTypePtr TypeContext::findBestMethod( - const Expr *expr, const std::string &member, - const std::vector> &args, bool checkSingle) { - auto typ = expr->getType()->getClass(); - seqassert(typ, "not a class"); - auto methods = findMethod(typ->name, member); - if (methods.empty()) - return nullptr; - if (methods.size() == 1 && !checkSingle) // methods is not overloaded - return methods[0]; - - // Calculate the unification score for each available methods and pick the one with - // highest score. - std::vector> scores; - for (int mi = 0; mi < methods.size(); mi++) { - auto method = instantiate(expr, methods[mi], typ.get(), false)->getFunc(); - std::vector reordered; - std::vector callArgs; - for (auto &a : args) { - callArgs.push_back({a.first, std::make_shared()}); // dummy expression - callArgs.back().value->setType(a.second); - } - auto score = reorderNamedArgs( - method.get(), callArgs, - [&](int s, int k, const std::vector> &slots, bool _) { - for (int si = 0; si < slots.size(); si++) { - // Ignore *args, *kwargs and default arguments - reordered.emplace_back(si == s || si == k || slots[si].size() != 1 - ? nullptr - : args[slots[si][0]].second); - } - return 0; - }, - [](const std::string &) { return -1; }); - if (score == -1) - continue; - // Scoring system for each argument: - // Generics, traits and default arguments get a score of zero (lowest priority). - // Optional unwrap gets the score of 1. - // Optional wrap gets the score of 2. - // Successful unification gets the score of 3 (highest priority). - for (int ai = 0, mi = 1, gi = 0; ai < reordered.size(); ai++) { - auto argType = reordered[ai]; - if (!argType) - continue; - auto expectedType = method->ast->args[ai].generic ? method->generics[gi++].type - : method->args[mi++]; - auto expectedClass = expectedType->getClass(); - // Ignore traits, *args/**kwargs and default arguments. - if (expectedClass && expectedClass->name == "Generator") - continue; - // LOG("<~> {} {}", argType->toString(), expectedType->toString()); - auto argClass = argType->getClass(); - - types::Type::Unification undo; - int u = argType->unify(expectedType.get(), &undo); - undo.undo(); - if (u >= 0) { - score += u + 3; - continue; - } - if (!method->ast->args[ai].generic) { - // Unification failed: maybe we need to wrap an argument? - if (expectedClass && expectedClass->name == TYPE_OPTIONAL && argClass && - argClass->name != expectedClass->name) { - u = argType->unify(expectedClass->generics[0].type.get(), &undo); - undo.undo(); - if (u >= 0) { - score += u + 2; - continue; - } - } - // ... or unwrap it (less ideal)? - if (argClass && argClass->name == TYPE_OPTIONAL && expectedClass && - argClass->name != expectedClass->name) { - u = argClass->generics[0].type->unify(expectedType.get(), &undo); - undo.undo(); - if (u >= 0) { - score += u; - continue; - } - } - } - // This method cannot be selected, ignore it. - score = -1; - break; - } - // LOG("{} {} / {}", typ->toString(), method->toString(), score); - if (score >= 0) - scores.emplace_back(std::make_pair(score, mi)); - } - if (scores.empty()) - return nullptr; - // Get the best score. - sort(scores.begin(), scores.end(), std::greater<>()); - // LOG("Method: {}", methods[scores[0].second]->toString()); - // std::string x; - // for (auto &a : args) - // x += format("{}{},", a.first.empty() ? "" : a.first + ": ", - // a.second->toString()); - // LOG(" {} :: {} ( {} )", typ->toString(), member, x); - return methods[scores[0].second]; -} - int TypeContext::reorderNamedArgs(types::FuncType *func, const std::vector &args, ReorderDoneFn onDone, ReorderErrorFn onError, @@ -300,15 +202,17 @@ int TypeContext::reorderNamedArgs(types::FuncType *func, int starArgIndex = -1, kwstarArgIndex = -1; for (int i = 0; i < func->ast->args.size(); i++) { - if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "**")) + // if (!known.empty() && known[i] && !partial) + // continue; + if (startswith(func->ast->args[i].name, "**")) kwstarArgIndex = i, score -= 2; - else if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "*")) + else if (startswith(func->ast->args[i].name, "*")) starArgIndex = i, score -= 2; } - seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex], - "partial *args"); - seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex], - "partial **kwargs"); + // seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex], + // "partial *args"); + // seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex], + // "partial **kwargs"); // 1. Assign positional arguments to slots // Each slot contains a list of arg's indices diff --git a/codon/parser/visitors/typecheck/typecheck_ctx.h b/codon/parser/visitors/typecheck/typecheck_ctx.h index f6852c5c..d4824692 100644 --- a/codon/parser/visitors/typecheck/typecheck_ctx.h +++ b/codon/parser/visitors/typecheck/typecheck_ctx.h @@ -48,6 +48,8 @@ struct TypeContext : public Context { /// Map of locally realized types and functions. std::unordered_map> visitedAsts; + /// List of functions that can be accessed via super() + std::vector supers; }; std::vector bases; @@ -121,23 +123,13 @@ public: /// Returns the list of generic methods that correspond to typeName.method. std::vector findMethod(const std::string &typeName, - const std::string &method) const; + const std::string &method, + bool hideShadowed = true) const; /// Returns the generic type of typeName.member, if it exists (nullptr otherwise). /// Special cases: __elemsize__ and __atomic__. types::TypePtr findMember(const std::string &typeName, const std::string &member) const; - /// Picks the best method of a given expression that matches the given argument - /// types. Prefers methods whose signatures are closer to the given arguments: - /// e.g. foo(int) will match (int) better that a foo(T). - /// Also takes care of the Optional arguments. - /// If multiple equally good methods are found, return the first one. - /// Return nullptr if no methods were found. - types::FuncTypePtr - findBestMethod(const Expr *expr, const std::string &member, - const std::vector> &args, - bool checkSingle = false); - typedef std::function> &, bool)> ReorderDoneFn; typedef std::function ReorderErrorFn; diff --git a/codon/parser/visitors/typecheck/typecheck_expr.cpp b/codon/parser/visitors/typecheck/typecheck_expr.cpp index 25cae4c6..811f28a6 100644 --- a/codon/parser/visitors/typecheck/typecheck_expr.cpp +++ b/codon/parser/visitors/typecheck/typecheck_expr.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -104,6 +105,17 @@ void TypecheckVisitor::visit(IdExpr *expr) { return; } auto val = ctx->find(expr->value); + if (!val) { + auto i = ctx->cache->overloads.find(expr->value); + if (i != ctx->cache->overloads.end()) { + if (i->second.size() == 1) { + val = ctx->find(i->second[0].name); + } else { + auto d = findDispatch(expr->value); + val = ctx->find(d->ast->name); + } + } + } seqassert(val, "cannot find IdExpr '{}' ({})", expr->value, expr->getSrcInfo()); auto t = ctx->instantiate(expr, val->type); @@ -683,8 +695,8 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic, if (isAtomic) { auto ptrlt = ctx->instantiateGeneric(expr->lexpr.get(), ctx->findInternal("Ptr"), {lt}); - method = ctx->findBestMethod(expr->lexpr.get(), format("__atomic_{}__", magic), - {{"", ptrlt}, {"", rt}}); + method = + findBestMethod(expr->lexpr.get(), format("__atomic_{}__", magic), {ptrlt, rt}); if (method) { expr->lexpr = N(expr->lexpr); if (noReturn) @@ -693,19 +705,16 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic, } // Check if lt.__iop__(lt, rt) exists. if (!method && expr->inPlace) { - method = ctx->findBestMethod(expr->lexpr.get(), format("__i{}__", magic), - {{"", lt}, {"", rt}}); + method = findBestMethod(expr->lexpr.get(), format("__i{}__", magic), {lt, rt}); if (method && noReturn) *noReturn = true; } // Check if lt.__op__(lt, rt) exists. if (!method) - method = ctx->findBestMethod(expr->lexpr.get(), format("__{}__", magic), - {{"", lt}, {"", rt}}); + method = findBestMethod(expr->lexpr.get(), format("__{}__", magic), {lt, rt}); // Check if rt.__rop__(rt, lt) exists. if (!method) { - method = ctx->findBestMethod(expr->rexpr.get(), format("__r{}__", magic), - {{"", rt}, {"", lt}}); + method = findBestMethod(expr->rexpr.get(), format("__r{}__", magic), {rt, lt}); if (method) swap(expr->lexpr, expr->rexpr); } @@ -727,13 +736,17 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic, ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &expr, ExprPtr &index) { - if (!tuple->getRecord() || - in(std::set{"Ptr", "pyobj", "str", "Array"}, tuple->name)) + if (!tuple->getRecord()) + return nullptr; + if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_PARTIAL)) + // in(std::set{"Ptr", "pyobj", "str", "Array"}, tuple->name)) // Ptr, pyobj and str are internal types and have only one overloaded __getitem__ return nullptr; - if (ctx->cache->classes[tuple->name].methods["__getitem__"].size() != 1) - // TODO: be smarter! there might be a compatible getitem? - return nullptr; + // if (in(ctx->cache->classes[tuple->name].methods, "__getitem__")) { + // ctx->cache->overloads[ctx->cache->classes[tuple->name].methods["__getitem__"]] + // .size() != 1) + // return nullptr; + // } // Extract a static integer value from a compatible expression. auto getInt = [&](int64_t *o, const ExprPtr &e) { @@ -867,14 +880,16 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, // If it exists, return a simple IdExpr with that method's name. // Append a "self" variable to the front if needed. if (args) { - std::vector> argTypes; + std::vector argTypes; bool isType = expr->expr->isType(); - if (!isType) - argTypes.emplace_back(make_pair("", typ)); // self variable + if (!isType) { + ExprPtr expr = N("self"); + expr->setType(typ); + argTypes.emplace_back(CallExpr::Arg{"", expr}); + } for (const auto &a : *args) - argTypes.emplace_back(make_pair(a.name, a.value->getType())); - if (auto bestMethod = - ctx->findBestMethod(expr->expr.get(), expr->member, argTypes)) { + argTypes.emplace_back(a); + if (auto bestMethod = findBestMethod(expr->expr.get(), expr->member, argTypes)) { ExprPtr e = N(bestMethod->ast->name); auto t = ctx->instantiate(expr, bestMethod, typ.get()); unify(e->type, t); @@ -891,7 +906,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, // No method was found, print a nice error message. std::vector nice; for (auto &t : argTypes) - nice.emplace_back(format("{} = {}", t.first, t.second->toString())); + nice.emplace_back(format("{} = {}", t.name, t.value->type->toString())); error("cannot find a method '{}' in {} with arguments {}", expr->member, typ->toString(), join(nice, ", ")); } @@ -901,23 +916,25 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, auto oldType = expr->getType() ? expr->getType()->getClass() : nullptr; if (methods.size() > 1 && oldType && oldType->getFunc()) { // If old type is already a function, use its arguments to pick the best call. - std::vector> methodArgs; + std::vector methodArgs; if (!expr->expr->isType()) // self argument - methodArgs.emplace_back(make_pair("", typ)); + methodArgs.emplace_back(typ); for (auto i = 1; i < oldType->generics.size(); i++) - methodArgs.emplace_back(make_pair("", oldType->generics[i].type)); - bestMethod = ctx->findBestMethod(expr->expr.get(), expr->member, methodArgs); + methodArgs.emplace_back(oldType->generics[i].type); + bestMethod = findBestMethod(expr->expr.get(), expr->member, methodArgs); if (!bestMethod) { // Print a nice error message. std::vector nice; for (auto &t : methodArgs) - nice.emplace_back(format("{} = {}", t.first, t.second->toString())); + nice.emplace_back(format("{}", t->toString())); error("cannot find a method '{}' in {} with arguments {}", expr->member, typ->toString(), join(nice, ", ")); } + } else if (methods.size() > 1) { + auto m = ctx->cache->classes.find(typ->name); + auto t = m->second.methods.find(expr->member); + bestMethod = findDispatch(t->second); } else { - // HACK: if we still have multiple valid methods, we just use the first one. - // TODO: handle this better (maybe hold these types until they can be selected?) bestMethod = methods[0]; } @@ -947,8 +964,8 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, if (bestMethod->ast->attributes.has(Attr::Property)) methodArgs.pop_back(); ExprPtr e = N(N(bestMethod->ast->name), methodArgs); - ExprPtr r = transform(e, false, allowVoidExpr); - return r; + auto ex = transform(e, false, allowVoidExpr); + return ex; } } @@ -1004,6 +1021,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in ai--; } else { // Case 3: Normal argument + // LOG("-> {}", expr->args[ai].value->toString()); expr->args[ai].value = transform(expr->args[ai].value, true); // Unbound inType might become a generator that will need to be extracted, so // don't unify it yet. @@ -1020,22 +1038,62 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in seenNames.insert(i.name); } - // Intercept dot-callees (e.g. expr.foo). Needed in order to select a proper - // overload for magic methods and to avoid dealing with partial calls - // (a non-intercepted object DotExpr (e.g. expr.foo) will get transformed into a - // partial call). - ExprPtr *lhs = &expr->expr; - // Make sure to check for instantiation DotExpr (e.g. a.b[T]) as well. - if (auto ei = const_cast(expr->expr->getIndex())) { - // A potential function instantiation - lhs = &ei->expr; - } else if (auto eii = CAST(expr->expr, InstantiateExpr)) { - // Real instantiation - lhs = &eii->typeExpr; + if (expr->expr->isId("superf")) { + if (ctx->bases.back().supers.empty()) + error("no matching superf methods are available"); + auto parentCls = ctx->bases.back().type->getFunc()->funcParent; + auto m = + findMatchingMethods(parentCls ? CAST(parentCls, types::ClassType) : nullptr, + ctx->bases.back().supers, expr->args); + if (m.empty()) + error("no matching superf methods are available"); + // LOG("found {} <- {}", ctx->bases.back().type->getFunc()->toString(), + // m[0]->toString()); + ExprPtr e = N(N(m[0]->ast->name), expr->args); + return transform(e, false, true); } - if (auto ed = const_cast((*lhs)->getDot())) { - if (auto edt = transformDot(ed, &expr->args)) - *lhs = edt; + if (expr->expr->isId("super")) + return transformSuper(expr); + + bool isPartial = !expr->args.empty() && expr->args.back().value->getEllipsis() && + !expr->args.back().value->getEllipsis()->isPipeArg && + expr->args.back().name.empty(); + if (!isPartial) { + // Intercept dot-callees (e.g. expr.foo). Needed in order to select a proper + // overload for magic methods and to avoid dealing with partial calls + // (a non-intercepted object DotExpr (e.g. expr.foo) will get transformed into a + // partial call). + ExprPtr *lhs = &expr->expr; + // Make sure to check for instantiation DotExpr (e.g. a.b[T]) as well. + if (auto ei = const_cast(expr->expr->getIndex())) { + // A potential function instantiation + lhs = &ei->expr; + } else if (auto eii = CAST(expr->expr, InstantiateExpr)) { + // Real instantiation + lhs = &eii->typeExpr; + } + if (auto ed = const_cast((*lhs)->getDot())) { + if (auto edt = transformDot(ed, &expr->args)) + *lhs = edt; + } else if (auto ei = const_cast((*lhs)->getId())) { + // check if this is an overloaded function? + auto i = ctx->cache->overloads.find(ei->value); + if (i != ctx->cache->overloads.end() && i->second.size() != 1) { + if (auto bestMethod = findBestMethod(ei->value, expr->args)) { + ExprPtr e = N(bestMethod->ast->name); + auto t = ctx->instantiate(expr, bestMethod); + unify(e->type, t); + unify(ei->type, e->type); + *lhs = e; + } else { + std::vector nice; + for (auto &t : expr->args) + nice.emplace_back(format("{} = {}", t.name, t.value->type->toString())); + error("cannot find an overload '{}' with arguments {}", ei->value, + join(nice, ", ")); + } + } + } } expr->expr = transform(expr->expr, true); @@ -1086,9 +1144,21 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in std::vector args; std::vector typeArgs; int typeArgCount = 0; - bool isPartial = false; + // bool isPartial = false; int ellipsisStage = -1; auto newMask = std::vector(calleeFn->ast->args.size(), 1); + auto getPartialArg = [&](int pi) { + auto id = transform(N(partialVar)); + ExprPtr it = N(pi); + // Manual call to transformStaticTupleIndex needed because otherwise + // IndexExpr routes this to InstantiateExpr. + auto ex = transformStaticTupleIndex(callee.get(), id, it); + seqassert(ex, "partial indexing failed"); + return ex; + }; + + ExprPtr partialStarArgs = nullptr; + ExprPtr partialKwstarArgs = nullptr; if (expr->ordered) args = expr->args; else @@ -1096,7 +1166,6 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in calleeFn.get(), expr->args, [&](int starArgIndex, int kwstarArgIndex, const std::vector> &slots, bool partial) { - isPartial = partial; ctx->addBlock(); // add generics for default arguments. addFunctionGenerics(calleeFn->getFunc().get()); for (int si = 0, pi = 0; si < slots.size(); si++) { @@ -1105,17 +1174,38 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in : expr->args[slots[si][0]].value); typeArgCount += typeArgs.back() != nullptr; newMask[si] = slots[si].empty() ? 0 : 1; - } else if (si == starArgIndex && !(partial && slots[si].empty())) { + } else if (si == starArgIndex) { std::vector extra; + if (!known.empty()) + extra.push_back(N(getPartialArg(-2))); for (auto &e : slots[si]) { extra.push_back(expr->args[e].value); if (extra.back()->getEllipsis()) ellipsisStage = args.size(); } - args.push_back({"", transform(N(extra))}); - } else if (si == kwstarArgIndex && !(partial && slots[si].empty())) { + auto e = transform(N(extra)); + if (partial) { + partialStarArgs = e; + args.push_back({"", transform(N())}); + newMask[si] = 0; + } else { + args.push_back({"", e}); + } + } else if (si == kwstarArgIndex) { std::vector names; std::vector values; + if (!known.empty()) { + auto e = getPartialArg(-1); + auto t = e->getType()->getRecord(); + seqassert(t && startswith(t->name, "KwTuple"), "{} not a kwtuple", + e->toString()); + auto &ff = ctx->cache->classes[t->name].fields; + for (int i = 0; i < t->getRecord()->args.size(); i++) { + names.emplace_back(ff[i].name); + values.emplace_back( + CallExpr::Arg{"", transform(N(clone(e), ff[i].name))}); + } + } for (auto &e : slots[si]) { names.emplace_back(expr->args[e].name); values.emplace_back(CallExpr::Arg{"", expr->args[e].value}); @@ -1123,16 +1213,17 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in ellipsisStage = args.size(); } auto kwName = generateTupleStub(names.size(), "KwTuple", names); - args.push_back({"", transform(N(N(kwName), values))}); + auto e = transform(N(N(kwName), values)); + if (partial) { + partialKwstarArgs = e; + args.push_back({"", transform(N())}); + newMask[si] = 0; + } else { + args.push_back({"", e}); + } } else if (slots[si].empty()) { if (!known.empty() && known[si]) { - // Manual call to transformStaticTupleIndex needed because otherwise - // IndexExpr routes this to InstantiateExpr. - auto id = transform(N(partialVar)); - ExprPtr it = N(pi++); - auto ex = transformStaticTupleIndex(callee.get(), id, it); - seqassert(ex, "partial indexing failed"); - args.push_back({"", ex}); + args.push_back({"", getPartialArg(pi++)}); } else if (partial) { args.push_back({"", transform(N())}); newMask[si] = 0; @@ -1160,6 +1251,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in if (isPartial) { deactivateUnbounds(expr->args.back().value->getType().get()); expr->args.pop_back(); + if (!partialStarArgs) + partialStarArgs = transform(N()); + if (!partialKwstarArgs) { + auto kwName = generateTupleStub(0, "KwTuple", {}); + partialKwstarArgs = transform(N(N(kwName))); + } } // Typecheck given arguments with the expected (signature) types. @@ -1181,8 +1278,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in // Special case: function instantiation if (isPartial && typeArgCount && typeArgCount == expr->args.size()) { for (auto &a : args) { - seqassert(a.value->getEllipsis(), "expected ellipsis"); - deactivateUnbounds(a.value->getType().get()); + if (a.value->getEllipsis()) + deactivateUnbounds(a.value->getType().get()); } auto e = transform(expr->expr); unify(expr->type, e->getType()); @@ -1252,11 +1349,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in deactivateUnbounds(pt->func.get()); calleeFn->generics[si + 1].type = calleeFn->args[si + 1] = replacements[si]; } - if (auto rt = realize(calleeFn)) { - unify(rt, std::static_pointer_cast(calleeFn)); - expr->expr = transform(expr->expr); + if (!isPartial) { + if (auto rt = realize(calleeFn)) { + unify(rt, std::static_pointer_cast(calleeFn)); + expr->expr = transform(expr->expr); + } } - expr->done &= expr->expr->done; // Emit the final call. @@ -1269,6 +1367,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in for (auto &r : args) if (!r.value->getEllipsis()) newArgs.push_back(r.value); + newArgs.push_back(partialStarArgs); + newArgs.push_back(partialKwstarArgs); std::string var = ctx->cache->getTemporaryVar("partial"); ExprPtr call = nullptr; @@ -1323,8 +1423,15 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) expr->args[1].value = transformType(expr->args[1].value, /*disableActivation*/ true); auto t = expr->args[1].value->type; - auto unifyOK = typ->unify(t.get(), nullptr) >= 0; - return {true, transform(N(unifyOK))}; + auto hierarchy = getSuperTypes(typ->getClass()); + + for (auto &tx: hierarchy) { + auto unifyOK = tx->unify(t.get(), nullptr) >= 0; + if (unifyOK) { + return {true, transform(N(true))}; + } + } + return {true, transform(N(false))}; } } } else if (val == "staticlen") { @@ -1355,18 +1462,17 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) if (!typ || !expr->args[1].value->staticValue.evaluated) return {true, nullptr}; auto member = expr->args[1].value->staticValue.getString(); - std::vector> args{{std::string(), typ}}; + std::vector args{typ}; for (int i = 2; i < expr->args.size(); i++) { expr->args[i].value = transformType(expr->args[i].value); if (!expr->args[i].value->getType()->getClass()) return {true, nullptr}; - args.push_back({std::string(), expr->args[i].value->getType()}); + args.push_back(expr->args[i].value->getType()); } bool exists = !ctx->findMethod(typ->getClass()->name, member).empty() || ctx->findMember(typ->getClass()->name, member); if (exists && args.size() > 1) - exists &= - ctx->findBestMethod(expr->args[0].value.get(), member, args, true) != nullptr; + exists &= findBestMethod(expr->args[0].value.get(), member, args) != nullptr; return {true, transform(N(exists))}; } else if (val == "compile_error") { expr->args[0].value = transform(expr->args[0].value); @@ -1531,7 +1637,8 @@ std::string TypecheckVisitor::generatePartialStub(const std::vector &mask, tupleSize++; auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name); if (!ctx->find(typeName)) - generateTupleStub(tupleSize, typeName, {}, false); + // 2 for .starArgs and .kwstarArgs (empty tuples if fn does not have them) + generateTupleStub(tupleSize + 2, typeName, {}, false); return typeName; } @@ -1597,9 +1704,12 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) { auto partialTypeName = generatePartialStub(mask, fn.get()); deactivateUnbounds(fn.get()); std::string var = ctx->cache->getTemporaryVar("partial"); - ExprPtr call = N( - N(N(var), N(N(partialTypeName))), - N(var)); + auto kwName = generateTupleStub(0, "KwTuple", {}); + ExprPtr call = + N(N(N(var), + N(N(partialTypeName), N(), + N(N(kwName)))), + N(var)); call = transform(call, false, allowVoidExpr); seqassert(call->type->getRecord() && startswith(call->type->getRecord()->name, partialTypeName) && @@ -1609,8 +1719,122 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) { return call; } +types::FuncTypePtr +TypecheckVisitor::findBestMethod(const Expr *expr, const std::string &member, + const std::vector &args) { + std::vector callArgs; + for (auto &a : args) { + callArgs.push_back({"", std::make_shared()}); // dummy expression + callArgs.back().value->setType(a); + } + return findBestMethod(expr, member, callArgs); +} + +types::FuncTypePtr +TypecheckVisitor::findBestMethod(const Expr *expr, const std::string &member, + const std::vector &args) { + auto typ = expr->getType()->getClass(); + seqassert(typ, "not a class"); + auto methods = ctx->findMethod(typ->name, member, false); + auto m = findMatchingMethods(typ.get(), methods, args); + return m.empty() ? nullptr : m[0]; +} + +types::FuncTypePtr +TypecheckVisitor::findBestMethod(const std::string &fn, + const std::vector &args) { + std::vector methods; + for (auto &m : ctx->cache->overloads[fn]) + if (!endswith(m.name, ":dispatch")) + methods.push_back(ctx->cache->functions[m.name].type); + std::reverse(methods.begin(), methods.end()); + auto m = findMatchingMethods(nullptr, methods, args); + return m.empty() ? nullptr : m[0]; +} + +std::vector +TypecheckVisitor::findSuperMethods(const types::FuncTypePtr &func) { + if (func->ast->attributes.parentClass.empty() || + endswith(func->ast->name, ":dispatch")) + return {}; + auto p = ctx->find(func->ast->attributes.parentClass)->type; + if (!p || !p->getClass()) + return {}; + + auto methodName = ctx->cache->reverseIdentifierLookup[func->ast->name]; + auto m = ctx->cache->classes.find(p->getClass()->name); + std::vector result; + if (m != ctx->cache->classes.end()) { + auto t = m->second.methods.find(methodName); + if (t != m->second.methods.end()) { + for (auto &m : ctx->cache->overloads[t->second]) { + if (endswith(m.name, ":dispatch")) + continue; + if (m.name == func->ast->name) + break; + result.emplace_back(ctx->cache->functions[m.name].type); + } + } + } + std::reverse(result.begin(), result.end()); + return result; +} + +std::vector +TypecheckVisitor::findMatchingMethods(types::ClassType *typ, + const std::vector &methods, + const std::vector &args) { + // Pick the last method that accepts the given arguments. + std::vector results; + for (int mi = 0; mi < methods.size(); mi++) { + auto m = ctx->instantiate(nullptr, methods[mi], typ, false)->getFunc(); + std::vector reordered; + auto score = ctx->reorderNamedArgs( + m.get(), args, + [&](int s, int k, const std::vector> &slots, bool _) { + for (int si = 0; si < slots.size(); si++) { + if (m->ast->args[si].generic) { + // Ignore type arguments + } else if (si == s || si == k || slots[si].size() != 1) { + // Ignore *args, *kwargs and default arguments + reordered.emplace_back(nullptr); + } else { + reordered.emplace_back(args[slots[si][0]].value->type); + } + } + return 0; + }, + [](const std::string &) { return -1; }); + for (int ai = 0, mi = 1, gi = 0; score != -1 && ai < reordered.size(); ai++) { + auto expectTyp = + m->ast->args[ai].generic ? m->generics[gi++].type : m->args[mi++]; + auto argType = reordered[ai]; + if (!argType) + continue; + try { + ExprPtr dummy = std::make_shared(""); + dummy->type = argType; + dummy->done = true; + wrapExpr(dummy, expectTyp, m, /*undoOnSuccess*/ true); + } catch (const exc::ParserException &) { + score = -1; + } + } + if (score != -1) { + // std::vector ar; + // for (auto &a: args) { + // if (a.first.empty()) ar.push_back(a.second->toString()); + // else ar.push_back(format("{}: {}", a.first, a.second->toString())); + // } + // LOG("- {} vs {}", m->toString(), join(ar, "; ")); + results.push_back(methods[mi]); + } + } + return results; +} + bool TypecheckVisitor::wrapExpr(ExprPtr &expr, TypePtr expectedType, - const FuncTypePtr &callee) { + const FuncTypePtr &callee, bool undoOnSuccess) { auto expectedClass = expectedType->getClass(); auto exprClass = expr->getType()->getClass(); if (callee && expr->isType()) @@ -1637,7 +1861,7 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, TypePtr expectedType, // Case 7: wrap raw Seq functions into Partial(...) call for easy realization. expr = partializeFunction(expr); } - unify(expr->type, expectedType); + unify(expr->type, expectedType, undoOnSuccess); return true; } @@ -1690,5 +1914,132 @@ int64_t TypecheckVisitor::sliceAdjustIndices(int64_t length, int64_t *start, return 0; } +types::FuncTypePtr TypecheckVisitor::findDispatch(const std::string &fn) { + for (auto &m : ctx->cache->overloads[fn]) + if (endswith(ctx->cache->functions[m.name].ast->name, ":dispatch")) + return ctx->cache->functions[m.name].type; + + // Generate dispatch and return it! + auto name = fn + ":dispatch"; + + ExprPtr root; + auto a = ctx->cache->functions[ctx->cache->overloads[fn][0].name].ast; + if (!a->attributes.parentClass.empty()) + root = N(N(a->attributes.parentClass), + ctx->cache->reverseIdentifierLookup[fn]); + else + root = N(fn); + root = N(root, N(N("args")), + N(N("kwargs"))); + auto ast = N( + name, nullptr, std::vector{Param("*args"), Param("**kwargs")}, + N(N( + N(N("isinstance"), root->clone(), N("void")), + N(root->clone()), N(root))), + Attr({"autogenerated"})); + ctx->cache->reverseIdentifierLookup[name] = ctx->cache->reverseIdentifierLookup[fn]; + + auto baseType = + ctx->instantiate(N(name).get(), ctx->find(generateFunctionStub(2))->type, + nullptr, false) + ->getRecord(); + auto typ = std::make_shared(baseType, ast.get()); + typ = std::static_pointer_cast(typ->generalize(ctx->typecheckLevel)); + ctx->add(TypecheckItem::Func, name, typ); + + ctx->cache->overloads[fn].insert(ctx->cache->overloads[fn].begin(), {name, 0}); + ctx->cache->functions[name].ast = ast; + ctx->cache->functions[name].type = typ; + prependStmts->push_back(ast); + // LOG("dispatch: {}", ast->toString(1)); + return typ; +} + +ExprPtr TypecheckVisitor::transformSuper(const CallExpr *expr) { + // For now, we just support casting to the _FIRST_ overload (i.e. empty super()) + if (!expr->args.empty()) + error("super does not take arguments"); + + if (ctx->bases.empty() || !ctx->bases.back().type) + error("no parent classes available"); + auto fptyp = ctx->bases.back().type->getFunc(); + if (!fptyp || !fptyp->ast->hasAttr(Attr::Method)) + error("no parent classes available"); + if (fptyp->args.size() < 2) + error("no parent classes available"); + ClassTypePtr typ = fptyp->args[1]->getClass(); + auto &cands = ctx->cache->classes[typ->name].parentClasses; + if (cands.empty()) + error("no parent classes available"); + // if (typ->getRecord()) + // error("cannot use super on tuple types"); + + // find parent typ + // unify top N args with parent typ args + // realize & do bitcast + // call bitcast() . method + + auto name = cands[0].first; + int fields = cands[0].second; + auto val = ctx->find(name); + seqassert(val, "cannot find '{}'", name); + auto ftyp = ctx->instantiate(expr, val->type)->getClass(); + + if (typ->getRecord()) { + std::vector members; + for (int i = 0; i < fields; i++) + members.push_back(N(N(fptyp->ast->args[0].name), + ctx->cache->classes[typ->name].fields[i].name)); + ExprPtr e = transform( + N(N(format(TYPE_TUPLE "{}", members.size())), members)); + unify(e->type, ftyp); + e->type = ftyp; + return e; + } else { + for (int i = 0; i < fields; i++) { + auto t = ctx->cache->classes[typ->name].fields[i].type; + t = ctx->instantiate(expr, t, typ.get()); + + auto ft = ctx->cache->classes[name].fields[i].type; + ft = ctx->instantiate(expr, ft, ftyp.get()); + unify(t, ft); + } + + ExprPtr typExpr = N(name); + typExpr->setType(ftyp); + auto self = fptyp->ast->args[0].name; + ExprPtr e = transform( + N(N(N("__internal__"), "to_class_ptr"), + N(N(N(self), "__raw__")), typExpr)); + return e; + } +} + +std::vector TypecheckVisitor::getSuperTypes(const ClassTypePtr &cls) { + std::vector result; + if (!cls) + return result; + result.push_back(cls); + int start = 0; + for (auto &cand: ctx->cache->classes[cls->name].parentClasses) { + auto name = cand.first; + int fields = cand.second; + auto val = ctx->find(name); + seqassert(val, "cannot find '{}'", name); + auto ftyp = ctx->instantiate(nullptr, val->type)->getClass(); + for (int i = start; i < fields; i++) { + auto t = ctx->cache->classes[cls->name].fields[i].type; + t = ctx->instantiate(nullptr, t, cls.get()); + auto ft = ctx->cache->classes[name].fields[i].type; + ft = ctx->instantiate(nullptr, ft, ftyp.get()); + unify(t, ft); + } + start += fields; + for (auto &t: getSuperTypes(ftyp)) + result.push_back(t); + } + return result; +} + } // namespace ast } // namespace codon diff --git a/codon/parser/visitors/typecheck/typecheck_infer.cpp b/codon/parser/visitors/typecheck/typecheck_infer.cpp index 2acb96ef..eef9924b 100644 --- a/codon/parser/visitors/typecheck/typecheck_infer.cpp +++ b/codon/parser/visitors/typecheck/typecheck_infer.cpp @@ -154,7 +154,15 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { ctx->realizationDepth++; ctx->addBlock(); ctx->typecheckLevel++; - ctx->bases.push_back({type->ast->name, type->getFunc(), type->args[0]}); + + // Find parents! + ctx->bases.push_back({type->ast->name, type->getFunc(), type->args[0], + {}, findSuperMethods(type->getFunc())}); + // if (startswith(type->ast->name, "Foo")) { + // LOG(": {}", type->toString()); + // for (auto &s: ctx->bases.back().supers) + // LOG(" - {}", s->toString()); + // } auto clonedAst = ctx->cache->functions[type->ast->name].ast->clone(); auto *ast = (FunctionStmt *)clonedAst.get(); addFunctionGenerics(type); diff --git a/codon/parser/visitors/typecheck/typecheck_stmt.cpp b/codon/parser/visitors/typecheck/typecheck_stmt.cpp index 432857c0..42ece717 100644 --- a/codon/parser/visitors/typecheck/typecheck_stmt.cpp +++ b/codon/parser/visitors/typecheck/typecheck_stmt.cpp @@ -153,9 +153,9 @@ void TypecheckVisitor::visit(UpdateStmt *stmt) { ctx->instantiateGeneric(stmt->lhs.get(), ctx->findInternal("Ptr"), {lhsClass}); c->args[1].value = transform(c->args[1].value); auto rhsTyp = c->args[1].value->getType()->getClass(); - if (auto method = ctx->findBestMethod( - stmt->lhs.get(), format("__atomic_{}__", c->expr->getId()->value), - {{"", ptrTyp}, {"", rhsTyp}})) { + if (auto method = findBestMethod(stmt->lhs.get(), + format("__atomic_{}__", c->expr->getId()->value), + {ptrTyp, rhsTyp})) { resultStmt = transform(N(N( N(method->ast->name), N(stmt->lhs), c->args[1].value))); return; @@ -168,8 +168,8 @@ void TypecheckVisitor::visit(UpdateStmt *stmt) { if (stmt->isAtomic && lhsClass && rhsClass) { auto ptrType = ctx->instantiateGeneric(stmt->lhs.get(), ctx->findInternal("Ptr"), {lhsClass}); - if (auto m = ctx->findBestMethod(stmt->lhs.get(), "__atomic_xchg__", - {{"", ptrType}, {"", rhsClass}})) { + if (auto m = + findBestMethod(stmt->lhs.get(), "__atomic_xchg__", {ptrType, rhsClass})) { resultStmt = transform(N( N(N(m->ast->name), N(stmt->lhs), stmt->rhs))); return; @@ -474,12 +474,12 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { typ = std::static_pointer_cast(typ->generalize(ctx->typecheckLevel)); // Check if this is a class method; if so, update the class method lookup table. if (isClassMember) { - auto &methods = ctx->cache->classes[attr.parentClass] - .methods[ctx->cache->reverseIdentifierLookup[stmt->name]]; + auto m = ctx->cache->classes[attr.parentClass] + .methods[ctx->cache->reverseIdentifierLookup[stmt->name]]; bool found = false; - for (auto &i : methods) + for (auto &i : ctx->cache->overloads[m]) if (i.name == stmt->name) { - i.type = typ; + ctx->cache->functions[i.name].type = typ; found = true; break; } @@ -570,5 +570,11 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { stmt->done = true; } +std::string TypecheckVisitor::getRootName(const std::string &name) { + auto p = name.rfind(':'); + seqassert(p != std::string::npos, ": not found in {}", name); + return name.substr(0, p); +} + } // namespace ast } // namespace codon diff --git a/codon/sir/module.cpp b/codon/sir/module.cpp index db42b406..e01e8808 100644 --- a/codon/sir/module.cpp +++ b/codon/sir/module.cpp @@ -22,12 +22,12 @@ translateGenerics(std::vector &generics) { return ret; } -std::vector> +std::vector generateDummyNames(std::vector &types) { - std::vector> ret; + std::vector ret; for (auto *t : types) { seqassert(t->getAstType(), "{} must have an ast type", *t); - ret.emplace_back("", t->getAstType()); + ret.emplace_back(t->getAstType()); } return ret; } diff --git a/stdlib/collections.codon b/stdlib/collections.codon index fc18a627..b5a1a21a 100644 --- a/stdlib/collections.codon +++ b/stdlib/collections.codon @@ -339,10 +339,12 @@ class Counter[T](Dict[T,int]): result |= other return result + @extend class Dict: def __init__(self: Dict[K,int], other: Counter[K]): self._init_from(other) + def namedtuple(): # internal pass diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index f07c7e50..cfe19211 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -125,6 +125,24 @@ class __internal__: def opt_ref_invert[T](what: Optional[T]) -> T: ret i8* %what + @pure + @llvm + def to_class_ptr[T](ptr: Ptr[byte]) -> T: + %0 = bitcast i8* %ptr to {=T} + ret {=T} %0 + + @pure + def _tuple_offsetof(x, field: Static[int]): + @llvm + def _llvm_offsetof(T: type, idx: Static[int], TE: type) -> int: + %a = alloca {=T} + %b = getelementptr inbounds {=T}, {=T}* %a, i64 0, i32 {=idx} + %base = ptrtoint {=T}* %a to i64 + %elem = ptrtoint {=TE}* %b to i64 + %offset = sub i64 %elem, %base + ret i64 %offset + return _llvm_offsetof(type(x), field, type(x[field])) + def raw_type_str(p: Ptr[byte], name: str) -> str: pstr = p.__repr__() # '<[name] at [pstr]>' diff --git a/stdlib/internal/sort.codon b/stdlib/internal/sort.codon index fa46a459..5ca675cf 100644 --- a/stdlib/internal/sort.codon +++ b/stdlib/internal/sort.codon @@ -1,14 +1,14 @@ -from algorithms.timsort import tim_sort_inplace from algorithms.pdqsort import pdq_sort_inplace from algorithms.insertionsort import insertion_sort_inplace from algorithms.heapsort import heap_sort_inplace from algorithms.qsort import qsort_inplace -def sorted[T]( +def sorted( v: Generator[T], key = Optional[int](), algorithm: Optional[str] = None, - reverse: bool = False + reverse: bool = False, + T: type ): """ Return a sorted list of the elements in v @@ -27,8 +27,6 @@ def _sort_list(self, key, algorithm: str): insertion_sort_inplace(self, key) elif algorithm == 'heap': heap_sort_inplace(self, key) - #case 'tim': - # tim_sort_inplace(self, key) elif algorithm == 'quick': qsort_inplace(self, key) else: diff --git a/stdlib/internal/types/array.codon b/stdlib/internal/types/array.codon index 834db41b..6f883dd8 100644 --- a/stdlib/internal/types/array.codon +++ b/stdlib/internal/types/array.codon @@ -1,15 +1,19 @@ from internal.gc import sizeof + @extend class Array: def __new__(ptr: Ptr[T], sz: int) -> Array[T]: return (sz, ptr) + def __new__(sz: int) -> Array[T]: return (sz, Ptr[T](sz)) + def __copy__(self) -> Array[T]: p = Ptr[T](self.len) str.memcpy(p.as_byte(), self.ptr.as_byte(), self.len * sizeof(T)) return (self.len, p) + def __deepcopy__(self) -> Array[T]: p = Ptr[T](self.len) i = 0 @@ -17,14 +21,21 @@ class Array: p[i] = self.ptr[i].__deepcopy__() i += 1 return (self.len, p) + def __len__(self) -> int: return self.len + def __bool__(self) -> bool: return bool(self.len) + def __getitem__(self, index: int) -> T: return self.ptr[index] + def __setitem__(self, index: int, what: T): self.ptr[index] = what + def slice(self, s: int, e: int) -> Array[T]: return (e - s, self.ptr + s) + + array = Array diff --git a/stdlib/internal/types/bool.codon b/stdlib/internal/types/bool.codon index de66f660..6e54e3ec 100644 --- a/stdlib/internal/types/bool.codon +++ b/stdlib/internal/types/bool.codon @@ -4,7 +4,7 @@ from internal.attributes import commutative, associative class bool: def __new__() -> bool: return False - def __new__[T](what: T) -> bool: # lowest priority! + def __new__(what) -> bool: return what.__bool__() def __repr__(self) -> str: return "True" if self else "False" diff --git a/stdlib/internal/types/collections/list.codon b/stdlib/internal/types/collections/list.codon index dd490b10..364043d7 100644 --- a/stdlib/internal/types/collections/list.codon +++ b/stdlib/internal/types/collections/list.codon @@ -2,9 +2,9 @@ import internal.gc as gc @extend class List: - def __init__(self, arr: Array[T], len: int): - self.arr = arr - self.len = len + def __init__(self): + self.arr = Array[T](10) + self.len = 0 def __init__(self, it: Generator[T]): self.arr = Array[T](10) @@ -12,27 +12,27 @@ class List: for i in it: self.append(i) - def __init__(self, capacity: int): - self.arr = Array[T](capacity) - self.len = 0 - - def __init__(self): - self.arr = Array[T](10) - self.len = 0 - def __init__(self, other: List[T]): self.arr = Array[T](other.len) self.len = 0 for i in other: self.append(i) - # Dummy __init__ used for list comprehension optimization + def __init__(self, capacity: int): + self.arr = Array[T](capacity) + self.len = 0 + def __init__(self, dummy: bool, other): + """Dummy __init__ used for list comprehension optimization""" if hasattr(other, '__len__'): self.__init__(other.__len__()) else: self.__init__() + def __init__(self, arr: Array[T], len: int): + self.arr = arr + self.len = len + def __len__(self): return self.len diff --git a/stdlib/internal/types/complex.codon b/stdlib/internal/types/complex.codon index c3e3d8de..d3c031af 100644 --- a/stdlib/internal/types/complex.codon +++ b/stdlib/internal/types/complex.codon @@ -3,21 +3,15 @@ class complex: real: float imag: float - def __new__(): - return complex(0.0, 0.0) - - def __new__(real: int, imag: int): - return complex(float(real), float(imag)) - - def __new__(real: float, imag: int): - return complex(real, float(imag)) - - def __new__(real: int, imag: float): - return complex(float(real), imag) + def __new__() -> complex: + return (0.0, 0.0) def __new__(other): return other.__complex__() + def __new__(real, imag) -> complex: + return (float(real), float(imag)) + def __complex__(self): return self @@ -42,6 +36,42 @@ class complex: def __hash__(self): return self.real.__hash__() + self.imag.__hash__()*1000003 + def __add__(self, other): + return self + complex(other) + + def __sub__(self, other): + return self - complex(other) + + def __mul__(self, other): + return self * complex(other) + + def __truediv__(self, other): + return self / complex(other) + + def __eq__(self, other): + return self == complex(other) + + def __ne__(self, other): + return self != complex(other) + + def __pow__(self, other): + return self ** complex(other) + + def __radd__(self, other): + return complex(other) + self + + def __rsub__(self, other): + return complex(other) - self + + def __rmul__(self, other): + return complex(other) * self + + def __rtruediv__(self, other): + return complex(other) / self + + def __rpow__(self, other): + return complex(other) ** self + def __add__(self, other: complex): return complex(self.real + other.real, self.imag + other.imag) @@ -160,42 +190,6 @@ class complex: phase += other.imag * log(vabs) return complex(len * cos(phase), len * sin(phase)) - def __add__(self, other): - return self + complex(other) - - def __sub__(self, other): - return self - complex(other) - - def __mul__(self, other): - return self * complex(other) - - def __truediv__(self, other): - return self / complex(other) - - def __eq__(self, other): - return self == complex(other) - - def __ne__(self, other): - return self != complex(other) - - def __pow__(self, other): - return self ** complex(other) - - def __radd__(self, other): - return complex(other) + self - - def __rsub__(self, other): - return complex(other) - self - - def __rmul__(self, other): - return complex(other) * self - - def __rtruediv__(self, other): - return complex(other) / self - - def __rpow__(self, other): - return complex(other) ** self - def __repr__(self): @pure @llvm diff --git a/stdlib/internal/types/float.codon b/stdlib/internal/types/float.codon index 54b358e0..68e59635 100644 --- a/stdlib/internal/types/float.codon +++ b/stdlib/internal/types/float.codon @@ -10,80 +10,105 @@ def seq_str_float(a: float) -> str: pass class float: def __new__() -> float: return 0.0 - def __new__[T](what: T): + + def __new__(what): return what.__float__() + + def __new__(s: str) -> float: + from C import strtod(cobj, Ptr[cobj]) -> float + buf = __array__[byte](32) + n = s.__len__() + need_dyn_alloc = (n >= buf.__len__()) + + p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr + str.memcpy(p, s.ptr, n) + p[n] = byte(0) + + end = cobj() + result = strtod(p, __ptr__(end)) + + if need_dyn_alloc: + free(p) + + if end != p + n: + raise ValueError("could not convert string to float: " + s) + + return result + def __repr__(self) -> str: s = seq_str_float(self) return s if s != "-nan" else "nan" + def __copy__(self) -> float: return self + def __deepcopy__(self) -> float: return self + @pure @llvm def __int__(self) -> int: %0 = fptosi double %self to i64 ret i64 %0 + def __float__(self): return self + @pure @llvm def __bool__(self) -> bool: %0 = fcmp one double %self, 0.000000e+00 %1 = zext i1 %0 to i8 ret i8 %1 + def __complex__(self): return complex(self, 0.0) + def __pos__(self) -> float: return self + @pure @llvm def __neg__(self) -> float: %0 = fneg double %self ret double %0 + @pure @commutative @llvm def __add__(a: float, b: float) -> float: %tmp = fadd double %a, %b ret double %tmp - @commutative - def __add__(self, other: int) -> float: - return self.__add__(float(other)) + @pure @llvm def __sub__(a: float, b: float) -> float: %tmp = fsub double %a, %b ret double %tmp - def __sub__(self, other: int) -> float: - return self.__sub__(float(other)) + @pure @commutative @llvm def __mul__(a: float, b: float) -> float: %tmp = fmul double %a, %b ret double %tmp - @commutative - def __mul__(self, other: int) -> float: - return self.__mul__(float(other)) + def __floordiv__(self, other: float) -> float: return self.__truediv__(other).__floor__() - def __floordiv__(self, other: int) -> float: - return self.__floordiv__(float(other)) + + @pure @llvm def __truediv__(a: float, b: float) -> float: %tmp = fdiv double %a, %b ret double %tmp - def __truediv__(self, other: int) -> float: - return self.__truediv__(float(other)) + @pure @llvm def __mod__(a: float, b: float) -> float: %tmp = frem double %a, %b ret double %tmp - def __mod__(self, other: int) -> float: - return self.__mod__(float(other)) + def __divmod__(self, other: float): mod = self % other div = (self - mod) / other @@ -103,16 +128,14 @@ class float: floordiv = (0.0).copysign(self / other) return (floordiv, mod) - def __divmod__(self, other: int): - return self.__divmod__(float(other)) + @pure @llvm def __eq__(a: float, b: float) -> bool: %tmp = fcmp oeq double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __eq__(self, other: int) -> bool: - return self.__eq__(float(other)) + @pure @llvm def __ne__(a: float, b: float) -> bool: @@ -120,174 +143,190 @@ class float: %tmp = fcmp one double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __ne__(self, other: int) -> bool: - return self.__ne__(float(other)) + @pure @llvm def __lt__(a: float, b: float) -> bool: %tmp = fcmp olt double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __lt__(self, other: int) -> bool: - return self.__lt__(float(other)) + @pure @llvm def __gt__(a: float, b: float) -> bool: %tmp = fcmp ogt double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __gt__(self, other: int) -> bool: - return self.__gt__(float(other)) + @pure @llvm def __le__(a: float, b: float) -> bool: %tmp = fcmp ole double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __le__(self, other: int) -> bool: - return self.__le__(float(other)) + @pure @llvm def __ge__(a: float, b: float) -> bool: %tmp = fcmp oge double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __ge__(self, other: int) -> bool: - return self.__ge__(float(other)) + @pure @llvm def sqrt(a: float) -> float: declare double @llvm.sqrt.f64(double %a) %tmp = call double @llvm.sqrt.f64(double %a) ret double %tmp + @pure @llvm def sin(a: float) -> float: declare double @llvm.sin.f64(double %a) %tmp = call double @llvm.sin.f64(double %a) ret double %tmp + @pure @llvm def cos(a: float) -> float: declare double @llvm.cos.f64(double %a) %tmp = call double @llvm.cos.f64(double %a) ret double %tmp + @pure @llvm def exp(a: float) -> float: declare double @llvm.exp.f64(double %a) %tmp = call double @llvm.exp.f64(double %a) ret double %tmp + @pure @llvm def exp2(a: float) -> float: declare double @llvm.exp2.f64(double %a) %tmp = call double @llvm.exp2.f64(double %a) ret double %tmp + @pure @llvm def log(a: float) -> float: declare double @llvm.log.f64(double %a) %tmp = call double @llvm.log.f64(double %a) ret double %tmp + @pure @llvm def log10(a: float) -> float: declare double @llvm.log10.f64(double %a) %tmp = call double @llvm.log10.f64(double %a) ret double %tmp + @pure @llvm def log2(a: float) -> float: declare double @llvm.log2.f64(double %a) %tmp = call double @llvm.log2.f64(double %a) ret double %tmp + @pure @llvm def __abs__(a: float) -> float: declare double @llvm.fabs.f64(double %a) %tmp = call double @llvm.fabs.f64(double %a) ret double %tmp + @pure @llvm def __floor__(a: float) -> float: declare double @llvm.floor.f64(double %a) %tmp = call double @llvm.floor.f64(double %a) ret double %tmp + @pure @llvm def __ceil__(a: float) -> float: declare double @llvm.ceil.f64(double %a) %tmp = call double @llvm.ceil.f64(double %a) ret double %tmp + @pure @llvm def __trunc__(a: float) -> float: declare double @llvm.trunc.f64(double %a) %tmp = call double @llvm.trunc.f64(double %a) ret double %tmp + @pure @llvm def rint(a: float) -> float: declare double @llvm.rint.f64(double %a) %tmp = call double @llvm.rint.f64(double %a) ret double %tmp + @pure @llvm def nearbyint(a: float) -> float: declare double @llvm.nearbyint.f64(double %a) %tmp = call double @llvm.nearbyint.f64(double %a) ret double %tmp + @pure @llvm def __round__(a: float) -> float: declare double @llvm.round.f64(double %a) %tmp = call double @llvm.round.f64(double %a) ret double %tmp + @pure @llvm def __pow__(a: float, b: float) -> float: declare double @llvm.pow.f64(double %a, double %b) %tmp = call double @llvm.pow.f64(double %a, double %b) ret double %tmp - def __pow__(self, other: int) -> float: - return self.__pow__(float(other)) + @pure @llvm def min(a: float, b: float) -> float: declare double @llvm.minnum.f64(double %a, double %b) %tmp = call double @llvm.minnum.f64(double %a, double %b) ret double %tmp + @pure @llvm def max(a: float, b: float) -> float: declare double @llvm.maxnum.f64(double %a, double %b) %tmp = call double @llvm.maxnum.f64(double %a, double %b) ret double %tmp + @pure @llvm def copysign(a: float, b: float) -> float: declare double @llvm.copysign.f64(double %a, double %b) %tmp = call double @llvm.copysign.f64(double %a, double %b) ret double %tmp + @pure @llvm def fma(a: float, b: float, c: float) -> float: declare double @llvm.fma.f64(double %a, double %b, double %c) %tmp = call double @llvm.fma.f64(double %a, double %b, double %c) ret double %tmp + @llvm def __atomic_xchg__(d: Ptr[float], b: float) -> void: %tmp = atomicrmw xchg double* %d, double %b seq_cst ret void + @llvm def __atomic_add__(d: Ptr[float], b: float) -> float: %tmp = atomicrmw fadd double* %d, double %b seq_cst ret double %tmp + @llvm def __atomic_sub__(d: Ptr[float], b: float) -> float: %tmp = atomicrmw fsub double* %d, double %b seq_cst ret double %tmp + def __hash__(self): from C import frexp(float, Ptr[Int[32]]) -> float @@ -332,31 +371,13 @@ class float: x = -2 return x - def __new__(s: str) -> float: - from C import strtod(cobj, Ptr[cobj]) -> float - buf = __array__[byte](32) - n = s.__len__() - need_dyn_alloc = (n >= buf.__len__()) - - p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr - str.memcpy(p, s.ptr, n) - p[n] = byte(0) - - end = cobj() - result = strtod(p, __ptr__(end)) - - if need_dyn_alloc: - free(p) - - if end != p + n: - raise ValueError("could not convert string to float: " + s) - - return result def __match__(self, i: float): return self == i + @property def real(self): return self + @property def imag(self): return 0.0 diff --git a/stdlib/internal/types/int.codon b/stdlib/internal/types/int.codon index 5e7be7db..75cf047a 100644 --- a/stdlib/internal/types/int.codon +++ b/stdlib/internal/types/int.codon @@ -1,50 +1,72 @@ from internal.attributes import commutative, associative, distributive from internal.types.complex import complex + @pure @C def seq_str_int(a: int) -> str: pass + @pure @C def seq_str_uint(a: int) -> str: pass + @extend class int: @pure @llvm def __new__() -> int: ret i64 0 - def __new__[T](what: T) -> int: # lowest priority! + + def __new__(what) -> int: return what.__int__() + + def __new__(s: str) -> int: + return int._from_str(s, 10) + + def __new__(s: str, base: int) -> int: + return int._from_str(s, base) + def __int__(self) -> int: return self + @pure @llvm def __float__(self) -> float: %tmp = sitofp i64 %self to double ret double %tmp + def __complex__(self): return complex(float(self), 0.0) + def __index__(self): return self + def __repr__(self) -> str: return seq_str_int(self) + def __copy__(self) -> int: return self + def __deepcopy__(self) -> int: return self + def __hash__(self) -> int: return self + @pure @llvm def __bool__(self) -> bool: %0 = icmp ne i64 %self, 0 %1 = zext i1 %0 to i8 ret i8 %1 + def __pos__(self) -> int: return self + def __neg__(self) -> int: return 0 - self + @pure @llvm def __abs__(self) -> int: @@ -52,23 +74,19 @@ class int: %1 = sub i64 0, %self %2 = select i1 %0, i64 %self, i64 %1 ret i64 %2 + @pure @llvm def __lshift__(self, other: int) -> int: %0 = shl i64 %self, %other ret i64 %0 + @pure @llvm def __rshift__(self, other: int) -> int: %0 = ashr i64 %self, %other ret i64 %0 - @pure - @commutative - @associative - @llvm - def __add__(self, b: int) -> int: - %tmp = add i64 %self, %b - ret i64 %tmp + @pure @commutative @llvm @@ -76,17 +94,36 @@ class int: %0 = sitofp i64 %self to double %1 = fadd double %0, %other ret double %1 + @pure + @commutative + @associative @llvm - def __sub__(self, b: int) -> int: - %tmp = sub i64 %self, %b + def __add__(self, b: int) -> int: + %tmp = add i64 %self, %b ret i64 %tmp + @pure @llvm def __sub__(self, other: float) -> float: %0 = sitofp i64 %self to double %1 = fsub double %0, %other ret double %1 + + @pure + @llvm + def __sub__(self, b: int) -> int: + %tmp = sub i64 %self, %b + ret i64 %tmp + + @pure + @commutative + @llvm + def __mul__(self, other: float) -> float: + %0 = sitofp i64 %self to double + %1 = fmul double %0, %other + ret double %1 + @pure @commutative @associative @@ -95,18 +132,7 @@ class int: def __mul__(self, b: int) -> int: %tmp = mul i64 %self, %b ret i64 %tmp - @pure - @commutative - @llvm - def __mul__(self, other: float) -> float: - %0 = sitofp i64 %self to double - %1 = fmul double %0, %other - ret double %1 - @pure - @llvm - def __floordiv__(self, b: int) -> int: - %tmp = sdiv i64 %self, %b - ret i64 %tmp + @pure @llvm def __floordiv__(self, other: float) -> float: @@ -115,6 +141,20 @@ class int: %1 = fdiv double %0, %other %2 = call double @llvm.floor.f64(double %1) ret double %2 + + @pure + @llvm + def __floordiv__(self, b: int) -> int: + %tmp = sdiv i64 %self, %b + ret i64 %tmp + + @pure + @llvm + def __truediv__(self, other: float) -> float: + %0 = sitofp i64 %self to double + %1 = fdiv double %0, %other + ret double %1 + @pure @llvm def __truediv__(self, other: int) -> float: @@ -122,23 +162,20 @@ class int: %1 = sitofp i64 %other to double %2 = fdiv double %0, %1 ret double %2 - @pure - @llvm - def __truediv__(self, other: float) -> float: - %0 = sitofp i64 %self to double - %1 = fdiv double %0, %other - ret double %1 - @pure - @llvm - def __mod__(a: int, b: int) -> int: - %tmp = srem i64 %a, %b - ret i64 %tmp + @pure @llvm def __mod__(self, other: float) -> float: %0 = sitofp i64 %self to double %1 = frem double %0, %other ret double %1 + + @pure + @llvm + def __mod__(a: int, b: int) -> int: + %tmp = srem i64 %a, %b + ret i64 %tmp + def __divmod__(self, other: int): d = self // other m = self - d*other @@ -146,11 +183,13 @@ class int: m += other d -= 1 return (d, m) + @pure @llvm def __invert__(a: int) -> int: %tmp = xor i64 %a, -1 ret i64 %tmp + @pure @commutative @associative @@ -158,6 +197,7 @@ class int: def __and__(a: int, b: int) -> int: %tmp = and i64 %a, %b ret i64 %tmp + @pure @commutative @associative @@ -165,6 +205,7 @@ class int: def __or__(a: int, b: int) -> int: %tmp = or i64 %a, %b ret i64 %tmp + @pure @commutative @associative @@ -172,42 +213,42 @@ class int: def __xor__(a: int, b: int) -> int: %tmp = xor i64 %a, %b ret i64 %tmp + @pure @llvm def __bitreverse__(a: int) -> int: declare i64 @llvm.bitreverse.i64(i64 %a) %tmp = call i64 @llvm.bitreverse.i64(i64 %a) ret i64 %tmp + @pure @llvm def __bswap__(a: int) -> int: declare i64 @llvm.bswap.i64(i64 %a) %tmp = call i64 @llvm.bswap.i64(i64 %a) ret i64 %tmp + @pure @llvm def __ctpop__(a: int) -> int: declare i64 @llvm.ctpop.i64(i64 %a) %tmp = call i64 @llvm.ctpop.i64(i64 %a) ret i64 %tmp + @pure @llvm def __ctlz__(a: int) -> int: declare i64 @llvm.ctlz.i64(i64 %a, i1 %is_zero_undef) %tmp = call i64 @llvm.ctlz.i64(i64 %a, i1 false) ret i64 %tmp + @pure @llvm def __cttz__(a: int) -> int: declare i64 @llvm.cttz.i64(i64 %a, i1 %is_zero_undef) %tmp = call i64 @llvm.cttz.i64(i64 %a, i1 false) ret i64 %tmp - @pure - @llvm - def __eq__(a: int, b: int) -> bool: - %tmp = icmp eq i64 %a, %b - %res = zext i1 %tmp to i8 - ret i8 %res + @pure @llvm def __eq__(self, b: float) -> bool: @@ -215,12 +256,14 @@ class int: %1 = fcmp oeq double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __ne__(a: int, b: int) -> bool: - %tmp = icmp ne i64 %a, %b + def __eq__(a: int, b: int) -> bool: + %tmp = icmp eq i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __ne__(self, b: float) -> bool: @@ -228,12 +271,14 @@ class int: %1 = fcmp one double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __lt__(a: int, b: int) -> bool: - %tmp = icmp slt i64 %a, %b + def __ne__(a: int, b: int) -> bool: + %tmp = icmp ne i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __lt__(self, b: float) -> bool: @@ -241,12 +286,14 @@ class int: %1 = fcmp olt double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __gt__(a: int, b: int) -> bool: - %tmp = icmp sgt i64 %a, %b + def __lt__(a: int, b: int) -> bool: + %tmp = icmp slt i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __gt__(self, b: float) -> bool: @@ -254,12 +301,14 @@ class int: %1 = fcmp ogt double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __le__(a: int, b: int) -> bool: - %tmp = icmp sle i64 %a, %b + def __gt__(a: int, b: int) -> bool: + %tmp = icmp sgt i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __le__(self, b: float) -> bool: @@ -267,12 +316,14 @@ class int: %1 = fcmp ole double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __ge__(a: int, b: int) -> bool: - %tmp = icmp sge i64 %a, %b + def __le__(a: int, b: int) -> bool: + %tmp = icmp sle i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __ge__(self, b: float) -> bool: @@ -280,10 +331,17 @@ class int: %1 = fcmp oge double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 - def __new__(s: str) -> int: - return int._from_str(s, 10) - def __new__(s: str, base: int) -> int: - return int._from_str(s, base) + + @pure + @llvm + def __ge__(a: int, b: int) -> bool: + %tmp = icmp sge i64 %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + def __pow__(self, exp: float): + return float(self) ** exp + def __pow__(self, exp: int): if exp < 0: return 0 @@ -296,53 +354,65 @@ class int: break self *= self return result - def __pow__(self, exp: float): - return float(self) ** exp + def popcnt(self): return Int[64](self).popcnt() + @llvm def __atomic_xchg__(d: Ptr[int], b: int) -> void: %tmp = atomicrmw xchg i64* %d, i64 %b seq_cst ret void + @llvm def __atomic_add__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw add i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_sub__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw sub i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_and__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw and i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_nand__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw nand i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_or__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw or i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def _atomic_xor(d: Ptr[int], b: int) -> int: %tmp = atomicrmw xor i64* %d, i64 %b seq_cst ret i64 %tmp + def __atomic_xor__(self, b: int) -> int: return int._atomic_xor(__ptr__(self), b) + @llvm def __atomic_min__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw min i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_max__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw max i64* %d, i64 %b seq_cst ret i64 %tmp + def __match__(self, i: int): return self == i + @property def real(self): return self + @property def imag(self): return 0 diff --git a/stdlib/internal/types/intn.codon b/stdlib/internal/types/intn.codon index 4ed993da..c38e4b78 100644 --- a/stdlib/internal/types/intn.codon +++ b/stdlib/internal/types/intn.codon @@ -1,18 +1,22 @@ from internal.attributes import commutative, associative, distributive + def check_N(N: Static[int]): if N <= 0: compile_error("N must be greater than 0") pass + @extend class Int: def __new__() -> Int[N]: check_N(N) return Int[N](0) + def __new__(what: Int[N]) -> Int[N]: check_N(N) return what + def __new__(what: int) -> Int[N]: check_N(N) if N < 64: @@ -21,10 +25,12 @@ class Int: return what else: return __internal__.int_sext(what, 64, N) + @pure @llvm def __new__(what: UInt[N]) -> Int[N]: ret i{=N} %what + def __new__(what: str) -> Int[N]: check_N(N) ret = Int[N]() @@ -39,6 +45,7 @@ class Int: ret = ret * Int[N](10) + Int[N](int(what.ptr[i]) - 48) i += 1 return sign * ret + def __int__(self) -> int: if N > 64: return __internal__.int_trunc(self, N, 64) @@ -46,37 +53,47 @@ class Int: return self else: return __internal__.int_sext(self, N, 64) + def __index__(self): return int(self) + def __copy__(self) -> Int[N]: return self + def __deepcopy__(self) -> Int[N]: return self + def __hash__(self) -> int: return int(self) + @pure @llvm def __float__(self) -> float: %0 = sitofp i{=N} %self to double ret double %0 + @pure @llvm def __bool__(self) -> bool: %0 = icmp ne i{=N} %self, 0 %1 = zext i1 %0 to i8 ret i8 %1 + def __pos__(self) -> Int[N]: return self + @pure @llvm def __neg__(self) -> Int[N]: %0 = sub i{=N} 0, %self ret i{=N} %0 + @pure @llvm def __invert__(self) -> Int[N]: %0 = xor i{=N} %self, -1 ret i{=N} %0 + @pure @commutative @associative @@ -84,11 +101,13 @@ class Int: def __add__(self, other: Int[N]) -> Int[N]: %0 = add i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __sub__(self, other: Int[N]) -> Int[N]: %0 = sub i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -97,11 +116,13 @@ class Int: def __mul__(self, other: Int[N]) -> Int[N]: %0 = mul i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __floordiv__(self, other: Int[N]) -> Int[N]: %0 = sdiv i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __truediv__(self, other: Int[N]) -> float: @@ -109,11 +130,13 @@ class Int: %1 = sitofp i{=N} %other to double %2 = fdiv double %0, %1 ret double %2 + @pure @llvm def __mod__(self, other: Int[N]) -> Int[N]: %0 = srem i{=N} %self, %other ret i{=N} %0 + def __divmod__(self, other: Int[N]): d = self // other m = self - d*other @@ -121,52 +144,61 @@ class Int: m += other d -= Int[N](1) return (d, m) + @pure @llvm def __lshift__(self, other: Int[N]) -> Int[N]: %0 = shl i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __rshift__(self, other: Int[N]) -> Int[N]: %0 = ashr i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __eq__(self, other: Int[N]) -> bool: %0 = icmp eq i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ne__(self, other: Int[N]) -> bool: %0 = icmp ne i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __lt__(self, other: Int[N]) -> bool: %0 = icmp slt i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __gt__(self, other: Int[N]) -> bool: %0 = icmp sgt i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __le__(self, other: Int[N]) -> bool: %0 = icmp sle i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ge__(self, other: Int[N]) -> bool: %0 = icmp sge i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @commutative @associative @@ -174,6 +206,7 @@ class Int: def __and__(self, other: Int[N]) -> Int[N]: %0 = and i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -181,6 +214,7 @@ class Int: def __or__(self, other: Int[N]) -> Int[N]: %0 = or i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -188,6 +222,7 @@ class Int: def __xor__(self, other: Int[N]) -> Int[N]: %0 = xor i{=N} %self, %other ret i{=N} %0 + @llvm def __pickle__(self, dest: Ptr[byte]) -> void: declare i32 @gzwrite(i8*, i8*, i32) @@ -198,6 +233,7 @@ class Int: %szi = ptrtoint i{=N}* %sz to i32 %2 = call i32 @gzwrite(i8* %dest, i8* %1, i32 %szi) ret void + @llvm def __unpickle__(src: Ptr[byte]) -> Int[N]: declare i32 @gzread(i8*, i8*, i32) @@ -208,29 +244,37 @@ class Int: %2 = call i32 @gzread(i8* %src, i8* %1, i32 %szi) %3 = load i{=N}, i{=N}* %0 ret i{=N} %3 + def __repr__(self) -> str: return str.cat(('Int[', seq_str_int(N), '](', seq_str_int(int(self)), ')')) + def __str__(self) -> str: return seq_str_int(int(self)) + @pure @llvm def _popcnt(self) -> Int[N]: declare i{=N} @llvm.ctpop.i{=N}(i{=N}) %0 = call i{=N} @llvm.ctpop.i{=N}(i{=N} %self) ret i{=N} %0 + def popcnt(self): return int(self._popcnt()) + def len() -> int: return N + @extend class UInt: def __new__() -> UInt[N]: check_N(N) return UInt[N](0) + def __new__(what: UInt[N]) -> UInt[N]: check_N(N) return what + def __new__(what: int) -> UInt[N]: check_N(N) if N < 64: @@ -239,13 +283,16 @@ class UInt: return UInt[N](Int[N](what)) else: return UInt[N](__internal__.int_zext(what, 64, N)) + @pure @llvm def __new__(what: Int[N]) -> UInt[N]: ret i{=N} %what + def __new__(what: str) -> UInt[N]: check_N(N) return UInt[N](Int[N](what)) + def __int__(self) -> int: if N > 64: return __internal__.int_trunc(self, N, 64) @@ -253,37 +300,47 @@ class UInt: return Int[64](self) else: return __internal__.int_zext(self, N, 64) + def __index__(self): return int(self) + def __copy__(self) -> UInt[N]: return self + def __deepcopy__(self) -> UInt[N]: return self + def __hash__(self) -> int: return int(self) + @pure @llvm def __float__(self) -> float: %0 = uitofp i{=N} %self to double ret double %0 + @pure @llvm def __bool__(self) -> bool: %0 = icmp ne i{=N} %self, 0 %1 = zext i1 %0 to i8 ret i8 %1 + def __pos__(self) -> UInt[N]: return self + @pure @llvm def __neg__(self) -> UInt[N]: %0 = sub i{=N} 0, %self ret i{=N} %0 + @pure @llvm def __invert__(self) -> UInt[N]: %0 = xor i{=N} %self, -1 ret i{=N} %0 + @pure @commutative @associative @@ -291,11 +348,13 @@ class UInt: def __add__(self, other: UInt[N]) -> UInt[N]: %0 = add i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __sub__(self, other: UInt[N]) -> UInt[N]: %0 = sub i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -304,11 +363,13 @@ class UInt: def __mul__(self, other: UInt[N]) -> UInt[N]: %0 = mul i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __floordiv__(self, other: UInt[N]) -> UInt[N]: %0 = udiv i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __truediv__(self, other: UInt[N]) -> float: @@ -316,59 +377,70 @@ class UInt: %1 = uitofp i{=N} %other to double %2 = fdiv double %0, %1 ret double %2 + @pure @llvm def __mod__(self, other: UInt[N]) -> UInt[N]: %0 = urem i{=N} %self, %other ret i{=N} %0 + def __divmod__(self, other: UInt[N]): return (self // other, self % other) + @pure @llvm def __lshift__(self, other: UInt[N]) -> UInt[N]: %0 = shl i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __rshift__(self, other: UInt[N]) -> UInt[N]: %0 = lshr i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __eq__(self, other: UInt[N]) -> bool: %0 = icmp eq i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ne__(self, other: UInt[N]) -> bool: %0 = icmp ne i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __lt__(self, other: UInt[N]) -> bool: %0 = icmp ult i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __gt__(self, other: UInt[N]) -> bool: %0 = icmp ugt i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __le__(self, other: UInt[N]) -> bool: %0 = icmp ule i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ge__(self, other: UInt[N]) -> bool: %0 = icmp uge i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @commutative @associative @@ -376,6 +448,7 @@ class UInt: def __and__(self, other: UInt[N]) -> UInt[N]: %0 = and i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -383,6 +456,7 @@ class UInt: def __or__(self, other: UInt[N]) -> UInt[N]: %0 = or i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -390,6 +464,7 @@ class UInt: def __xor__(self, other: UInt[N]) -> UInt[N]: %0 = xor i{=N} %self, %other ret i{=N} %0 + @llvm def __pickle__(self, dest: Ptr[byte]) -> void: declare i32 @gzwrite(i8*, i8*, i32) @@ -400,6 +475,7 @@ class UInt: %szi = ptrtoint i{=N}* %sz to i32 %2 = call i32 @gzwrite(i8* %dest, i8* %1, i32 %szi) ret void + @llvm def __unpickle__(src: Ptr[byte]) -> UInt[N]: declare i32 @gzread(i8*, i8*, i32) @@ -410,12 +486,16 @@ class UInt: %2 = call i32 @gzread(i8* %src, i8* %1, i32 %szi) %3 = load i{=N}, i{=N}* %0 ret i{=N} %3 + def __repr__(self) -> str: return str.cat(('UInt[', seq_str_int(N), '](', seq_str_uint(int(self)), ')')) + def __str__(self) -> str: return seq_str_uint(int(self)) + def popcnt(self): return int(Int[N](self)._popcnt()) + def len() -> int: return N diff --git a/stdlib/internal/types/optional.codon b/stdlib/internal/types/optional.codon index c9d668e8..022beaa8 100644 --- a/stdlib/internal/types/optional.codon +++ b/stdlib/internal/types/optional.codon @@ -5,29 +5,36 @@ class Optional: return __internal__.opt_tuple_new(T) else: return __internal__.opt_ref_new(T) + def __new__(what: T) -> Optional[T]: if isinstance(T, ByVal): return __internal__.opt_tuple_new_arg(what, T) else: return __internal__.opt_ref_new_arg(what, T) + def __bool__(self) -> bool: if isinstance(T, ByVal): return __internal__.opt_tuple_bool(self, T) else: return __internal__.opt_ref_bool(self, T) + def __invert__(self) -> T: if isinstance(T, ByVal): return __internal__.opt_tuple_invert(self, T) else: return __internal__.opt_ref_invert(self, T) + def __str__(self) -> str: return 'None' if not self else str(~self) + def __repr__(self) -> str: return 'None' if not self else (~self).__repr__() + def __is_optional__(self, other: Optional[T]): if (not self) or (not other): return (not self) and (not other) return self.__invert__() is other.__invert__() + optional = Optional def unwrap[T](opt: Optional[T]) -> T: diff --git a/stdlib/internal/types/ptr.codon b/stdlib/internal/types/ptr.codon index 0dfc8dc6..28b3a24a 100644 --- a/stdlib/internal/types/ptr.codon +++ b/stdlib/internal/types/ptr.codon @@ -4,53 +4,63 @@ def seq_str_ptr(a: Ptr[byte]) -> str: pass @extend class Ptr: - @__internal__ - def __new__(sz: int) -> Ptr[T]: - pass @pure @llvm def __new__() -> Ptr[T]: ret {=T}* null + + @__internal__ + def __new__(sz: int) -> Ptr[T]: + pass + + @pure + @llvm + def __new__(other: Ptr[T]) -> Ptr[T]: + ret {=T}* %other + @pure @llvm def __new__(other: Ptr[byte]) -> Ptr[T]: %0 = bitcast i8* %other to {=T}* ret {=T}* %0 - @pure - @llvm - def __new__(other: Ptr[T]) -> Ptr[T]: - ret {=T}* %other + @pure @llvm def __int__(self) -> int: %0 = ptrtoint {=T}* %self to i64 ret i64 %0 + @pure @llvm def __copy__(self) -> Ptr[T]: ret {=T}* %self + @pure @llvm def __bool__(self) -> bool: %0 = icmp ne {=T}* %self, null %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __getitem__(self, index: int) -> T: %0 = getelementptr {=T}, {=T}* %self, i64 %index %1 = load {=T}, {=T}* %0 ret {=T} %1 + @llvm def __setitem__(self, index: int, what: T) -> void: %0 = getelementptr {=T}, {=T}* %self, i64 %index store {=T} %what, {=T}* %0 ret void + @pure @llvm def __add__(self, other: int) -> Ptr[T]: %0 = getelementptr {=T}, {=T}* %self, i64 %other ret {=T}* %0 + @pure @llvm def __sub__(self, other: Ptr[T]) -> int: @@ -59,90 +69,105 @@ class Ptr: %2 = sub i64 %0, %1 %3 = sdiv exact i64 %2, ptrtoint ({=T}* getelementptr ({=T}, {=T}* null, i32 1) to i64) ret i64 %3 + @pure @llvm def __eq__(self, other: Ptr[T]) -> bool: %0 = icmp eq {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ne__(self, other: Ptr[T]) -> bool: %0 = icmp ne {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __lt__(self, other: Ptr[T]) -> bool: %0 = icmp slt {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __gt__(self, other: Ptr[T]) -> bool: %0 = icmp sgt {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __le__(self, other: Ptr[T]) -> bool: %0 = icmp sle {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ge__(self, other: Ptr[T]) -> bool: %0 = icmp sge {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @llvm def __prefetch_r0__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 0, i32 0, i32 1) ret void + @llvm def __prefetch_r1__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 0, i32 1, i32 1) ret void + @llvm def __prefetch_r2__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 0, i32 2, i32 1) ret void + @llvm def __prefetch_r3__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 0, i32 3, i32 1) ret void + @llvm def __prefetch_w0__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 1, i32 0, i32 1) ret void + @llvm def __prefetch_w1__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 1, i32 1, i32 1) ret void + @llvm def __prefetch_w2__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 1, i32 2, i32 1) ret void + @llvm def __prefetch_w3__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 1, i32 3, i32 1) ret void + @pure @llvm def as_byte(self) -> Ptr[byte]: @@ -151,20 +176,25 @@ class Ptr: def __repr__(self) -> str: return seq_str_ptr(self.as_byte()) + ptr = Ptr Jar = Ptr[byte] cobj = Ptr[byte] + # Forward declarations @__internal__ @tuple class Array[T]: len: int ptr: Ptr[T] + + class List[T]: arr: Array[T] len: int + @extend class NoneType: def __new__() -> NoneType: diff --git a/stdlib/internal/types/str.codon b/stdlib/internal/types/str.codon index 4a9e638c..f4f2a61a 100644 --- a/stdlib/internal/types/str.codon +++ b/stdlib/internal/types/str.codon @@ -7,45 +7,58 @@ class str: @__internal__ def __new__(l: int, p: Ptr[byte]) -> str: pass + def __new__(p: Ptr[byte], l: int) -> str: return str(l, p) + def __new__() -> str: return str(Ptr[byte](), 0) - def __new__[T](what: T) -> str: # lowest priority! + + def __new__(what) -> str: if hasattr(what, "__str__"): return what.__str__() else: return what.__repr__() + def __str__(what: str) -> str: return what + def __len__(self) -> int: return self.len + def __bool__(self) -> bool: return self.len != 0 + def __copy__(self) -> str: return self + def __deepcopy__(self) -> str: return self + def __ptrcopy__(self) -> str: n = self.len p = cobj(n) str.memcpy(p, self.ptr, n) return str(p, n) + @llvm def memcpy(dest: Ptr[byte], src: Ptr[byte], len: int) -> void: declare void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile) call void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 0, i1 false) ret void + @llvm def memmove(dest: Ptr[byte], src: Ptr[byte], len: int) -> void: declare void @llvm.memmove.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile) call void @llvm.memmove.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 0, i1 false) ret void + @llvm def memset(dest: Ptr[byte], val: byte, len: int) -> void: declare void @llvm.memset.p0i8.i64(i8* %dest, i8 %val, i64 %len, i32 %align, i1 %isvolatile) call void @llvm.memset.p0i8.i64(i8* %dest, i8 %val, i64 %len, i32 0, i1 false) ret void + def __add__(self, other: str) -> str: len1 = self.len len2 = other.len @@ -54,17 +67,20 @@ class str: str.memcpy(p, self.ptr, len1) str.memcpy(p + len1, other.ptr, len2) return str(p, len3) + def c_str(self): n = self.__len__() p = cobj(n + 1) str.memcpy(p, self.ptr, n) p[n] = byte(0) return p + def from_ptr(t: cobj) -> str: n = strlen(t) p = Ptr[byte](n) str.memcpy(p, t, n) return str(p, n) + def __eq__(self, other: str): if self.len != other.len: return False @@ -74,10 +90,13 @@ class str: return False i += 1 return True + def __match__(self, other: str): return self.__eq__(other) + def __ne__(self, other: str): return not self.__eq__(other) + def cat(*args): total = 0 if staticlen(args) == 1 and hasattr(args[0], "__iter__") and hasattr(args[0], "__len__"): diff --git a/stdlib/statistics.codon b/stdlib/statistics.codon index d11622d3..0d3a440a 100644 --- a/stdlib/statistics.codon +++ b/stdlib/statistics.codon @@ -394,22 +394,10 @@ class NormalDist: self._mu = mu self._sigma = sigma - def __init__(self, mu: float, sigma: float): + def __init__(self, mu, sigma): self._init(float(mu), float(sigma)) - def __init__(self, mu: int, sigma: int): - self._init(float(mu), float(sigma)) - - def __init__(self, mu: float, sigma: int): - self._init(float(mu), float(sigma)) - - def __init__(self, mu: int, sigma: float): - self._init(float(mu), float(sigma)) - - def __init__(self, mu: float): - self._init(mu, 1.0) - - def __init__(self, mu: int): + def __init__(self, mu): self._init(float(mu), 1.0) def __init__(self): diff --git a/test/parser/simplify_stmt.codon b/test/parser/simplify_stmt.codon index 3fe932bf..d849b7dd 100644 --- a/test/parser/simplify_stmt.codon +++ b/test/parser/simplify_stmt.codon @@ -538,7 +538,7 @@ def foo() -> int: a{=a} foo() #! not a type or static expression -#! while realizing foo (arguments foo) +#! while realizing foo:0 (arguments foo:0) #%% function_llvm_err_4,barebones a = 5 @@ -558,7 +558,7 @@ print f.foo() #: F class Foo: def foo(self): return 'F' -Foo.foo(1) #! cannot unify int and Foo +Foo.foo(1) #! cannot find a method 'foo' in Foo with arguments = int #%% function_nested,barebones def foo(v): @@ -941,12 +941,12 @@ print FooBarBaz[str]().foo() #: foo 0 print FooBarBaz[float]().bar() #: bar 0/float print FooBarBaz[str]().baz() #: baz! foo 0 bar /str -#%% inherit_class_2,barebones +#%% inherit_class_err_2,barebones class defdict(Dict[str,float]): def __init__(self, d: Dict[str, float]): self.__init__(d.items()) z = defdict() -z[1.1] #! cannot unify float and str +z[1.1] #! cannot find a method '__getitem__' in defdict with arguments = defdict, = float #%% inherit_tuple,barebones class Foo: @@ -982,3 +982,16 @@ class Bar: x: float class FooBar(Foo, Bar): pass #! 'x' declared twice + +#%% keyword_prefix,barebones +def foo(return_, pass_, yield_, break_, continue_, print_, assert_): + return_.append(1) + pass_.append(2) + yield_.append(3) + break_.append(4) + continue_.append(5) + print_.append(6) + assert_.append(7) + return return_, pass_, yield_, break_, continue_, print_, assert_ +print foo([1], [1], [1], [1], [1], [1], [1]) +#: ([1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7]) \ No newline at end of file diff --git a/test/parser/typecheck_expr.codon b/test/parser/typecheck_expr.codon index dedc61de..f28f1b87 100644 --- a/test/parser/typecheck_expr.codon +++ b/test/parser/typecheck_expr.codon @@ -256,21 +256,22 @@ a = [5] a.foo #! cannot find 'foo' in List[int] #%% dot_case_6,barebones +# Did heavy changes to this testcase because +# of the automatic optional wraps/unwraps and promotions class Foo: - def bar(self, a: int): - print 'normal', a - def bar(self, a: Optional[int]): - print 'optional', a - def bar[T](self, a: Optional[T]): - print 'optional generic', a, a.__class__ def bar(self, a): print 'generic', a, a.__class__ + def bar(self, a: Optional[float]): + print 'optional', a + def bar(self, a: int): + print 'normal', a f = Foo() f.bar(1) #: normal 1 -f.bar(Optional(1)) #: optional 1 -f.bar(Optional('s')) #: optional generic s Optional[str] +f.bar(1.1) #: optional 1.1 +f.bar(Optional('s')) #: generic s Optional[str] f.bar('hehe') #: generic hehe str + #%% dot_case_6b,barebones class Foo: def bar(self, a, b): @@ -305,7 +306,7 @@ class Foo: print 'foo' def method(self, a): print a -Foo().clsmethod() #! too many arguments for Foo.clsmethod (expected maximum 0, got 1) +Foo().clsmethod() #! cannot find a method 'clsmethod' in Foo with arguments = Foo #%% call,barebones def foo(a, b, c='hi'): @@ -373,7 +374,7 @@ def foo(i, j, k): return i + j + k print foo(1.1, 2.2, 3.3) #: 6.6 p = foo(6, ...) -print p.__class__ #: foo[int,...,...] +print p.__class__ #: foo:0[int,...,...] print p(2, 1) #: 9 print p(k=3, j=6) #: 15 q = p(k=1, ...) @@ -389,11 +390,11 @@ print 42 |> add_two #: 44 def moo(a, b, c=3): print a, b, c m = moo(b=2, ...) -print m.__class__ #: moo[...,int,...] +print m.__class__ #: moo:0[...,int,...] m('s', 1.1) #: s 2 1.1 # # n = m(c=2.2, ...) -print n.__class__ #: moo[...,int,float] +print n.__class__ #: moo:0[...,int,float] n('x') #: x 2 2.2 print n('y').__class__ #: void @@ -402,11 +403,11 @@ def ff(a, b, c): print ff(1.1, 2, True).__class__ #: Tuple[float,int,bool] print ff(1.1, ...)(2, True).__class__ #: Tuple[float,int,bool] y = ff(1.1, ...)(c=True, ...) -print y.__class__ #: ff[float,...,bool] +print y.__class__ #: ff:0[float,...,bool] print ff(1.1, ...)(2, ...)(True).__class__ #: Tuple[float,int,bool] print y('hei').__class__ #: Tuple[float,str,bool] z = ff(1.1, ...)(c='s', ...) -print z.__class__ #: ff[float,...,str] +print z.__class__ #: ff:0[float,...,str] #%% call_arguments_partial,barebones def doo[R, T](a: Callable[[T], R], b: Generator[T], c: Optional[T], d: T): @@ -431,7 +432,7 @@ l = [1] def adder(a, b): return a+b doo(b=l, d=Optional(5), c=l[0], a=adder(b=4, ...)) #: int int -#: adder[.. Generator[int] +#: adder:0[ Generator[int] #: 5 #: 1 Optional[int] #: 5 int @@ -446,16 +447,7 @@ q = p(zh=43, ...) q(1) #: 1 () (zh: 43) r = q(5, 38, ...) r() #: 5 (38) (zh: 43) - -#%% call_partial_star_error,barebones -def foo(x, *args, **kwargs): - print x, args, kwargs -p = foo(...) -p(1, z=5) -q = p(zh=43, ...) -q(1) -r = q(5, 38, ...) -r(1, a=1) #! too many arguments for foo[T1,T2,T3] (expected maximum 3, got 2) +r(1, a=1) #: 5 (38, 1) (zh: 43, a: 1) #%% call_kwargs,barebones def kwhatever(**kwargs): @@ -503,6 +495,79 @@ foo(*(1,2)) #: (1, 2) () foo(3, f) #: (3, (x: 6, y: True)) () foo(k = 3, **f) #: () (k: 3, x: 6, y: True) +#%% call_partial_args_kwargs,barebones +def foo(*args): + print(args) +a = foo(1, 2, ...) +b = a(3, 4, ...) +c = b(5, ...) +c('zooooo') +#: (1, 2, 3, 4, 5, 'zooooo') + +def fox(*args, **kwargs): + print(args, kwargs) +xa = fox(1, 2, x=5, ...) +xb = xa(3, 4, q=6, ...) +xc = xb(5, ...) +xd = xc(z=5.1, ...) +xd('zooooo', w='lele') +#: (1, 2, 3, 4, 5, 'zooooo') (x: 5, q: 6, z: 5.1, w: 'lele') + +class Foo: + i: int + def __str__(self): + return f'#{self.i}' + def foo(self, a): + return f'{self}:generic' + def foo(self, a: float): + return f'{self}:float' + def foo(self, a: int): + return f'{self}:int' +f = Foo(4) + +def pacman(x, f): + print f(x, '5') + print f(x, 2.1) + print f(x, 4) +pacman(f, Foo.foo) +#: #4:generic +#: #4:float +#: #4:int + +def macman(f): + print f('5') + print f(2.1) + print f(4) +macman(f.foo) +#: #4:generic +#: #4:float +#: #4:int + +class Fox: + i: int + def __str__(self): + return f'#{self.i}' + def foo(self, a, b): + return f'{self}:generic b={b}' + def foo(self, a: float, c): + return f'{self}:float, c={c}' + def foo(self, a: int): + return f'{self}:int' + def foo(self, a: int, z, q): + return f'{self}:int z={z} q={q}' +ff = Fox(5) +def maxman(f): + print f('5', b=1) + print f(2.1, 3) + print f(4) + print f(5, 1, q=3) +maxman(ff.foo) +#: #5:generic b=1 +#: #5:float, c=3 +#: #5:int +#: #5:int z=1 q=3 + + #%% call_static,barebones print isinstance(1, int), isinstance(2.2, float), isinstance(3, bool) #: True True False @@ -535,6 +600,30 @@ print hasattr(int, "__getitem__") print hasattr([1, 2], "__getitem__", str) #: False +#%% isinstance_inheritance,barebones +class AX[T]: + a: T + def __init__(self, a: T): + self.a = a +class Side: + def __init__(self): + pass +class BX[T,U](AX[T], Side): + b: U + def __init__(self, a: T, b: U): + super().__init__(a) + self.b = b +class CX[T,U](BX[T,U]): + c: int + def __init__(self, a: T, b: U): + super().__init__(a, b) + self.c = 1 +c = CX('a', False) +print isinstance(c, CX), isinstance(c, BX), isinstance(c, AX), isinstance(c, Side) +#: True True True True +print isinstance(c, BX[str, bool]), isinstance(c, BX[str, str]), isinstance(c, AX[int]) +#: True False False + #%% staticlen_err,barebones print staticlen([1, 2]) #! List[int] is not a tuple type @@ -614,3 +703,83 @@ def foo(x: Callable[[1,2], 3]): pass #! unexpected static type #%% static_unify_2,barebones def foo(x: List[1]): pass #! cannot unify T and 1 + +#%% super,barebones +class A[T]: + a: T + def __init__(self, t: T): + self.a = t + def foo(self): + return f'A:{self.a}' +class B(A[str]): + b: int + def __init__(self): + super().__init__('s') + self.b = 6 + def baz(self): + return f'{super().foo()}::{self.b}' +b = B() +print b.foo() #: A:s +print b.baz() #: A:s::6 + +class AX[T]: + a: T + def __init__(self, a: T): + self.a = a + def foo(self): + return f'[AX:{self.a}]' +class BX[T,U](AX[T]): + b: U + def __init__(self, a: T, b: U): + print super().__class__ + super().__init__(a) + self.b = b + def foo(self): + return f'[BX:{super().foo()}:{self.b}]' +class CX[T,U](BX[T,U]): + c: int + def __init__(self, a: T, b: U): + print super().__class__ + super().__init__(a, b) + self.c = 1 + def foo(self): + return f'CX:{super().foo()}:{self.c}' +c = CX('a', False) +print c.__class__, c.foo() +#: BX[str,bool] +#: AX[str] +#: CX[str,bool] CX:[BX:[AX:a]:False]:1 + + +#%% super_tuple,barebones +@tuple +class A[T]: + a: T + x: int + def __new__(a: T) -> A[T]: + return (a, 1) + def foo(self): + return f'A:{self.a}' +@tuple +class B(A[str]): + b: int + def __new__() -> B: + return (*(A('s')), 6) + def baz(self): + return f'{super().foo()}::{self.b}' + +b = B() +print b.foo() #: A:s +print b.baz() #: A:s::6 + + +#%% super_error,barebones +class A: + def __init__(self): + super().__init__() +a = A() +#! no parent classes available +#! while realizing A.__init__:1 (arguments A.__init__:1[A]) + +#%% super_error_2,barebones +super().foo(1) #! no parent classes available diff --git a/test/parser/typecheck_stmt.codon b/test/parser/typecheck_stmt.codon index af6cbe3c..1baa0fcd 100644 --- a/test/parser/typecheck_stmt.codon +++ b/test/parser/typecheck_stmt.codon @@ -252,7 +252,7 @@ try: except MyError: print "my" except OSError as o: - print "os", o._hdr[0], len(o._hdr[1]), o._hdr[3][-20:], o._hdr[4] + print "os", o._hdr.typename, len(o._hdr.msg), o._hdr.file[-20:], o._hdr.line #: os std.internal.types.error.OSError 9 typecheck_stmt.codon 249 finally: print "whoa" #: whoa @@ -263,7 +263,7 @@ def foo(): try: foo() except MyError as e: - print e._hdr[0], e._hdr[1] #: MyError foo! + print e._hdr.typename, e._hdr.msg #: MyError foo! #%% throw_error,barebones raise 'hello' #! cannot throw non-exception (first object member must be of type ExcHeader) @@ -291,24 +291,54 @@ def foo(x): print len(x) foo(5) #: 4 -def foo(x): +def foo2(x): if isinstance(x, int): print x+1 return print len(x) -foo(1) #: 2 -foo('s') #: 1 +foo2(1) #: 2 +foo2('s') #: 1 -def foo(x, y: Static[int] = 5): - if y < 3: - if y > 1: - if isinstance(x, int): - print x+1 - return - if isinstance(x, int): - return - print len(x) -foo(1, 1) -foo(1, 2) #: 2 -foo(1) -foo('s') #: 1 +#%% superf,barebones +class Foo: + def foo(a): + # superf(a) + print 'foo-1', a + def foo(a: int): + superf(a) + print 'foo-2', a + def foo(a: str): + superf(a) + print 'foo-3', a + def foo(a): + superf(a) + print 'foo-4', a +Foo.foo(1) +#: foo-1 1 +#: foo-2 1 +#: foo-4 1 + +class Bear: + def woof(x): + return f'bear woof {x}' +@extend +class Bear: + def woof(x): + return superf(x) + f' bear w--f {x}' +print Bear.woof('!') +#: bear woof ! bear w--f ! + +class PolarBear(Bear): + def woof(): + return 'polar ' + superf('@') +print PolarBear.woof() +#: polar bear woof @ bear w--f @ + +#%% superf_error,barebones +class Foo: + def foo(a): + superf(a) + print 'foo-1', a +Foo.foo(1) +#! no matching superf methods are available +#! while realizing Foo.foo:0 diff --git a/test/parser/types.codon b/test/parser/types.codon index a90894cb..82b21a66 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -199,10 +199,10 @@ def f[T](x: T) -> T: print f(1.2).__class__ #: float print f('s').__class__ #: str -def f[T](x: T): - return f(x - 1, T) if x else 1 -print f(1) #: 1 -print f(1.1).__class__ #: int +def f2[T](x: T): + return f2(x - 1, T) if x else 1 +print f2(1) #: 1 +print f2(1.1).__class__ #: int #%% recursive_error,barebones @@ -215,7 +215,7 @@ def rec3(x, y): #- ('a, 'b) -> 'b return y rec3(1, 's') #! cannot unify str and int -#! while realizing rec3 (arguments rec3[int,str]) +#! while realizing rec3:0 (arguments rec3:0[int,str]) #%% instantiate_function_2,barebones def fx[T](x: T) -> T: @@ -298,7 +298,7 @@ print h(list(map(lambda i: i-1, map(lambda i: i+2, range(5))))) #%% func_unify_error,barebones def foo(x:int): print x -z = 1 & foo #! cannot unify foo[...] and int +z = 1 & foo #! cannot find magic 'and' in int #%% void_error,barebones def foo(): @@ -447,13 +447,13 @@ def f(x): return g(x) print f(5), f('s') #: 5 s -def f[U](x: U, y): +def f2[U](x: U, y): def g[T, U](x: T, y: U): return (x, y) return g(y, x) x, y = 1, 'haha' -print f(x, y).__class__ #: Tuple[str,int] -print f('aa', 1.1, U=str).__class__ #: Tuple[float,str] +print f2(x, y).__class__ #: Tuple[str,int] +print f2('aa', 1.1, U=str).__class__ #: Tuple[float,str] #%% nested_fn_generic_error,barebones def f[U](x: U, y): # ('u, 'a) -> tuple['a, 'u] @@ -464,7 +464,7 @@ print f(1.1, 1, int).__class__ #! cannot unify float and int #%% fn_realization,barebones def ff[T](x: T, y: tuple[T]): - print ff(T=str,...).__class__ #: ff[str,Tuple[str],str] + print ff(T=str,...).__class__ #: ff:0[str,Tuple[str],str] return x x = ff(1, (1,)) print x, x.__class__ #: 1 int @@ -474,7 +474,7 @@ def fg[T](x:T): def g[T](y): z = T() return z - print fg(T=str,...).__class__ #: fg[str,str] + print fg(T=str,...).__class__ #: fg:0[str,str] print g(1, T).__class__ #: int fg(1) print fg(1).__class__ #: void @@ -515,7 +515,7 @@ class A[T]: def foo[W](t: V, u: V, v: V, w: W): return (t, u, v, w) -print A.B.C[bool].foo(W=str, ...).__class__ #: A.B.C.foo[bool,bool,bool,str,str] +print A.B.C[bool].foo(W=str, ...).__class__ #: A.B.C.foo:0[bool,bool,bool,str,str] print A.B.C.foo(1,1,1,True) #: (1, 1, 1, True) print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x') print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x') @@ -533,7 +533,8 @@ class A[T]: c: V def foo[W](t: V, u: V, v: V, w: W): return (t, u, v, w) -print A.B.C[str].foo(1,1,1,True) #! cannot unify int and str + +print A.B.C[str].foo(1,1,1,True) #! cannot find a method 'foo' in A.B.C[str] with arguments = int, = int, = int, = bool #%% nested_deep_class_error_2,barebones class A[T]: @@ -733,10 +734,10 @@ def test(name, sort, key): def foo(l, f): return [f(i) for i in l] test('hi', foo, lambda x: x+1) #: hi [2, 3, 4, 5] -# TODO -# def foof(l: List[int], x, f: Callable[[int], int]): -# return [f(i)+x for i in l] -# test('qsort', foof(..., 3, ...)) + +def foof(l: List[int], x, f: Callable[[int], int]): + return [f(i)+x for i in l] +test('qsort', foof(x=3, ...), lambda x: x+1) #: qsort [5, 6, 7, 8] #%% class_fn_access,barebones class X[T]: @@ -744,8 +745,7 @@ class X[T]: return (x+x, y+y) y = X[X[int]]() print y.__class__ #: X[X[int]] -print X[float].foo(U=int, ...).__class__ #: X.foo[X[float],float,int,int] -# print y.foo[float].__class__ +print X[float].foo(U=int, ...).__class__ #: X.foo:0[X[float],float,int,int] print X[int]().foo(1, 's') #: (2, 'ss') #%% class_partial_access,barebones @@ -753,7 +753,8 @@ class X[T]: def foo[U](self, x, y: U): return (x+x, y+y) y = X[X[int]]() -print y.foo(U=float,...).__class__ #: X.foo[X[X[int]],...,...] +# TODO: should this even be the case? +# print y.foo(U=float,...).__class__ -> X.foo:0[X[X[int]],...,...] print y.foo(1, 2.2, float) #: (2, 4.4) #%% forward,barebones @@ -764,10 +765,10 @@ def bar[T](x): print x, T.__class__ foo(bar, 1) #: 1 int -#: bar[...] +#: bar:0[...] foo(bar(...), 's') #: s str -#: bar[...] +#: bar:0[...] z = bar z('s', int) #: s int @@ -785,8 +786,8 @@ def foo(f, x): def bar[T](x): print x, T.__class__ foo(bar(T=int,...), 1) -#! too many arguments for bar[T1,int] (expected maximum 2, got 2) -#! while realizing foo (arguments foo[bar[...],int]) +#! too many arguments for bar:0[T1,int] (expected maximum 2, got 2) +#! while realizing foo:0 (arguments foo:0[bar:0[...],int]) #%% sort_partial def foo(x, y): @@ -805,16 +806,16 @@ def frec(x, y): return grec(x, y) if bl(y) else 2 print frec(1, 2).__class__, frec('s', 1).__class__ #! expression with void type -#! while realizing frec (arguments frec[int,int]) +#! while realizing frec:0 (arguments frec:0[int,int]) #%% return_fn,barebones def retfn(a): def inner(b, *args, **kwargs): print a, b, args, kwargs - print inner.__class__ #: retfn.inner[...,...,int,...] + print inner.__class__ #: retfn:0.inner:0[...,...,int,...] return inner(15, ...) f = retfn(1) -print f.__class__ #: retfn.inner[int,...,int,...] +print f.__class__ #: retfn:0.inner:0[int,...,int,...] f(2,3,foo='bar') #: 1 15 (2, 3) (foo: 'bar') #%% decorator_manual,barebones @@ -822,7 +823,7 @@ def foo(x, *args, **kwargs): print x, args, kwargs return 1 def dec(fn, a): - print 'decorating', fn.__class__ #: decorating foo[...,...,...] + print 'decorating', fn.__class__ #: decorating foo:0[...,...,...] def inner(*args, **kwargs): print 'decorator', args, kwargs #: decorator (5.5, 's') (z: True) return fn(a, *args, **kwargs) @@ -845,7 +846,7 @@ def dec(fn, a): return inner ff = dec(foo, 10) print ff(5.5, 's', z=True) -#: decorating foo[...,...,...] +#: decorating foo:0[...,...,...] #: decorator (5.5, 's') (z: True) #: 10 (5.5, 's') (z: True) #: 1 @@ -855,7 +856,7 @@ def zoo(e, b, *args): return f'zoo: {e}, {b}, {args}' print zoo(2, 3) print zoo('s', 3) -#: decorating zoo[...,...,...] +#: decorating zoo:0[...,...,...] #: decorator (2, 3) () #: zoo: 5, 2, (3) #: decorator ('s', 3) () @@ -868,9 +869,9 @@ def mydecorator(func): print("after") return inner @mydecorator -def foo(): +def foo2(): print("foo") -foo() +foo2() #: before #: foo #: after @@ -890,7 +891,7 @@ def factorial(num): return n factorial(10) #: 3628800 -#: time needed for factorial[...] is 3628799 +#: time needed for factorial:0[...] is 3628799 def dx1(func): def inner(): @@ -920,9 +921,9 @@ def dy2(func): return inner @dy1 @dy2 -def num(a, b): +def num2(a, b): return a+b -print(num(10, 20)) #: 3600 +print(num2(10, 20)) #: 3600 #%% hetero_iter,barebones e = (1, 2, 3, 'foo', 5, 'bar', 6) @@ -969,14 +970,14 @@ def tee(iterable, n=2): return list(gen(d) for d in deques) it = [1,2,3,4] a, b = tee(it) #! cannot typecheck the program -#! while realizing tee (arguments tee[List[int],int]) +#! while realizing tee:0 (arguments tee:0[List[int],int]) #%% new_syntax,barebones def foo[T,U](x: type, y, z: Static[int] = 10): print T.__class__, U.__class__, x.__class__, y.__class__, Int[z+1].__class__ return List[x]() -print foo(T=int,U=str,...).__class__ #: foo[T1,x,z,int,str] -print foo(T=int,U=str,z=5,x=bool,...).__class__ #: foo[T1,bool,5,int,str] +print foo(T=int,U=str,...).__class__ #: foo:0[T1,x,z,int,str] +print foo(T=int,U=str,z=5,x=bool,...).__class__ #: foo:0[T1,bool,5,int,str] print foo(float,3,T=int,U=str,z=5).__class__ #: List[float] foo(float,1,10,str,int) #: str int float int Int[11] @@ -992,11 +993,11 @@ print Foo[5,int,float,6].__class__ #: Foo[5,int,float,6] print Foo(1.1, 10i32, [False], 10u66).__class__ #: Foo[66,bool,float,32] -def foo[N: Static[int]](): +def foo2[N: Static[int]](): print Int[N].__class__, N x: Static[int] = 5 y: Static[int] = 105 - x * 2 -foo(y-x) #: Int[90] 90 +foo2(y-x) #: Int[90] 90 if 1.1+2.2 > 0: x: Static[int] = 88 @@ -1107,3 +1108,27 @@ v = [1] methodcaller('append')(v, 42) print v #: [1, 42] print methodcaller('index')(v, 42) #: 1 + +#%% fn_overloads,barebones +def foo(x): + return 1, x + +def foo(x, y): + def foo(x, y): + return f'{x}_{y}' + return 2, foo(x, y) + +def foo(x): + if x == '': + return 3, 0 + return 3, 1 + foo(x[1:])[1] + +print foo('hi') #: (3, 2) +print foo('hi', 1) #: (2, 'hi_1') + +#%% fn_overloads_error,barebones +def foo(x): + return 1, x +def foo(x, y): + return 2, x, y +foo('hooooooooy!', 1, 2) #! cannot find an overload 'foo' with arguments = str, = int, = int diff --git a/test/stdlib/datetime_test.codon b/test/stdlib/datetime_test.codon index b17fb392..8c94cb2e 100644 --- a/test/stdlib/datetime_test.codon +++ b/test/stdlib/datetime_test.codon @@ -708,10 +708,10 @@ class TestDate[theclass](TestCase): iso_long_years = sorted(map(int, ISO_LONG_YEARS_TABLE.split())) L = [] for i in range(400): - d = self.theclass(2000+i, 12, 31) - d1 = self.theclass(1600+i, 12, 31) - self.assertEqual(d.isocalendar()[1:], d1.isocalendar()[1:]) - if d.isocalendar()[1] == 53: + d = self.theclass(2000+i, 12, 31).isocalendar() + d1 = self.theclass(1600+i, 12, 31).isocalendar() + self.assertEqual((d.week, d.weekday), (d1.week, d1.weekday)) + if d.week == 53: L.append(i) self.assertEqual(L, iso_long_years) diff --git a/test/transform/folding.codon b/test/transform/folding.codon index 5d5820ed..a7abb34b 100644 --- a/test/transform/folding.codon +++ b/test/transform/folding.codon @@ -10,60 +10,69 @@ class I: def __float__(self: int) -> float: %tmp = sitofp i64 %self to double ret double %tmp + @llvm def __bool__(self: int) -> bool: %0 = icmp ne i64 %self, 0 %1 = zext i1 %0 to i8 ret i8 %1 + def __pos__(self: int) -> int: return self + def __neg__(self: int) -> int: return I.__sub__(0, self) + @llvm def __abs__(self: int) -> int: %0 = icmp sgt i64 %self, 0 %1 = sub i64 0, %self %2 = select i1 %0, i64 %self, i64 %1 ret i64 %2 + @llvm def __lshift__(self: int, other: int) -> int: %0 = shl i64 %self, %other ret i64 %0 + @llvm def __rshift__(self: int, other: int) -> int: %0 = ashr i64 %self, %other ret i64 %0 - @llvm - def __add__(self: int, b: int) -> int: - %tmp = add i64 %self, %b - ret i64 %tmp + @llvm def __add__(self: int, other: float) -> float: %0 = sitofp i64 %self to double %1 = fadd double %0, %other ret double %1 + @llvm - def __sub__(self: int, b: int) -> int: - %tmp = sub i64 %self, %b + def __add__(self: int, b: int) -> int: + %tmp = add i64 %self, %b ret i64 %tmp + @llvm def __sub__(self: int, other: float) -> float: %0 = sitofp i64 %self to double %1 = fsub double %0, %other ret double %1 + @llvm - def __mul__(self: int, b: int) -> int: - %tmp = mul i64 %self, %b + def __sub__(self: int, b: int) -> int: + %tmp = sub i64 %self, %b ret i64 %tmp + @llvm def __mul__(self: int, other: float) -> float: %0 = sitofp i64 %self to double %1 = fmul double %0, %other ret double %1 + @llvm - def __floordiv__(self: int, b: int) -> int: - %tmp = sdiv i64 %self, %b + def __mul__(self: int, b: int) -> int: + %tmp = mul i64 %self, %b ret i64 %tmp + @llvm def __floordiv__(self: int, other: float) -> float: declare double @llvm.floor.f64(double) @@ -71,141 +80,177 @@ class I: %1 = fdiv double %0, %other %2 = call double @llvm.floor.f64(double %1) ret double %2 + + @llvm + def __floordiv__(self: int, b: int) -> int: + %tmp = sdiv i64 %self, %b + ret i64 %tmp + + @llvm + def __truediv__(self: int, other: float) -> float: + %0 = sitofp i64 %self to double + %1 = fdiv double %0, %other + ret double %1 + @llvm def __truediv__(self: int, other: int) -> float: %0 = sitofp i64 %self to double %1 = sitofp i64 %other to double %2 = fdiv double %0, %1 ret double %2 - @llvm - def __truediv__(self: int, other: float) -> float: - %0 = sitofp i64 %self to double - %1 = fdiv double %0, %other - ret double %1 - @llvm - def __mod__(a: int, b: int) -> int: - %tmp = srem i64 %a, %b - ret i64 %tmp + @llvm def __mod__(self: int, other: float) -> float: %0 = sitofp i64 %self to double %1 = frem double %0, %other ret double %1 + + @llvm + def __mod__(a: int, b: int) -> int: + %tmp = srem i64 %a, %b + ret i64 %tmp + @llvm def __invert__(a: int) -> int: %tmp = xor i64 %a, -1 ret i64 %tmp + @llvm def __and__(a: int, b: int) -> int: %tmp = and i64 %a, %b ret i64 %tmp + @llvm def __or__(a: int, b: int) -> int: %tmp = or i64 %a, %b ret i64 %tmp + @llvm def __xor__(a: int, b: int) -> int: %tmp = xor i64 %a, %b ret i64 %tmp + @llvm def __shr__(a: int, b: int) -> int: %tmp = ashr i64 %a, %b ret i64 %tmp + @llvm def __shl__(a: int, b: int) -> int: %tmp = shl i64 %a, %b ret i64 %tmp + @llvm def __bitreverse__(a: int) -> int: declare i64 @llvm.bitreverse.i64(i64 %a) %tmp = call i64 @llvm.bitreverse.i64(i64 %a) ret i64 %tmp + @llvm def __bswap__(a: int) -> int: declare i64 @llvm.bswap.i64(i64 %a) %tmp = call i64 @llvm.bswap.i64(i64 %a) ret i64 %tmp + @llvm def __ctpop__(a: int) -> int: declare i64 @llvm.ctpop.i64(i64 %a) %tmp = call i64 @llvm.ctpop.i64(i64 %a) ret i64 %tmp + @llvm def __ctlz__(a: int) -> int: declare i64 @llvm.ctlz.i64(i64 %a, i1 %is_zero_undef) %tmp = call i64 @llvm.ctlz.i64(i64 %a, i1 false) ret i64 %tmp + @llvm def __cttz__(a: int) -> int: declare i64 @llvm.cttz.i64(i64 %a, i1 %is_zero_undef) %tmp = call i64 @llvm.cttz.i64(i64 %a, i1 false) ret i64 %tmp - @llvm - def __eq__(a: int, b: int) -> bool: - %tmp = icmp eq i64 %a, %b - %res = zext i1 %tmp to i8 - ret i8 %res + @llvm def __eq__(self: int, b: float) -> bool: %0 = sitofp i64 %self to double %1 = fcmp oeq double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @llvm - def __ne__(a: int, b: int) -> bool: - %tmp = icmp ne i64 %a, %b + def __eq__(a: int, b: int) -> bool: + %tmp = icmp eq i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @llvm def __ne__(self: int, b: float) -> bool: %0 = sitofp i64 %self to double %1 = fcmp one double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @llvm - def __lt__(a: int, b: int) -> bool: - %tmp = icmp slt i64 %a, %b + def __ne__(a: int, b: int) -> bool: + %tmp = icmp ne i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @llvm def __lt__(self: int, b: float) -> bool: %0 = sitofp i64 %self to double %1 = fcmp olt double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @llvm - def __gt__(a: int, b: int) -> bool: - %tmp = icmp sgt i64 %a, %b + def __lt__(a: int, b: int) -> bool: + %tmp = icmp slt i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @llvm def __gt__(self: int, b: float) -> bool: %0 = sitofp i64 %self to double %1 = fcmp ogt double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @llvm - def __le__(a: int, b: int) -> bool: - %tmp = icmp sle i64 %a, %b + def __gt__(a: int, b: int) -> bool: + %tmp = icmp sgt i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @llvm def __le__(self: int, b: float) -> bool: %0 = sitofp i64 %self to double %1 = fcmp ole double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @llvm - def __ge__(a: int, b: int) -> bool: - %tmp = icmp sge i64 %a, %b + def __le__(a: int, b: int) -> bool: + %tmp = icmp sle i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @llvm def __ge__(self: int, b: float) -> bool: %0 = sitofp i64 %self to double %1 = fcmp oge double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + + @llvm + def __ge__(a: int, b: int) -> bool: + %tmp = icmp sge i64 %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + def __pow__(self: int, exp: float): + return float(self) ** exp + def __pow__(self: int, exp: int): if exp < 0: return 0 @@ -218,8 +263,6 @@ class I: break self *= self return result - def __pow__(self: int, exp: float): - return float(self) ** exp @extend class int: @@ -227,158 +270,197 @@ class int: global OP_COUNT OP_COUNT = inc(OP_COUNT) return self + def __float__(self) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__float__(self) + def __bool__(self) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__bool__(self) + def __pos__(self) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) return self + def __neg__(self) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__neg__(self) + def __lshift__(self, other: int) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__lshift__(self, other) + def __rshift__(self, other: int) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__rshift__(self, other) - def __add__(self, b: int) -> int: - global OP_COUNT - OP_COUNT = inc(OP_COUNT) - return I.__add__(self, b) + def __add__(self, other: float) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__add__(self, other) - def __sub__(self, b: int) -> int: + + def __add__(self, b: int) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__sub__(self, b) + return I.__add__(self, b) + def __sub__(self, other: float) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__sub__(self, other) - def __mul__(self, b: int) -> int: + + def __sub__(self, b: int) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__mul__(self, b) + return I.__sub__(self, b) + def __mul__(self, other: float) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__mul__(self, other) - def __floordiv__(self, b: int) -> int: + + def __mul__(self, b: int) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__floordiv__(self, b) + return I.__mul__(self, b) + def __floordiv__(self, other: float) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__floordiv__(self, other) - def __truediv__(self, other: int) -> float: + + def __floordiv__(self, b: int) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__truediv__(self, other) + return I.__floordiv__(self, b) + def __truediv__(self, other: float) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__truediv__(self, other) - def __mod__(self, b: int) -> int: + + def __truediv__(self, other: int) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__mod__(self, b) + return I.__truediv__(self, other) + def __mod__(self, other: float) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__mod__(self, other) + + def __mod__(self, b: int) -> int: + global OP_COUNT + OP_COUNT = inc(OP_COUNT) + return I.__mod__(self, b) + def __invert__(self) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__invert__(self) + def __and__(self, b: int) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__and__(self, b) + def __or__(self, b: int) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__or__(self, b) + def __xor__(self, b: int) -> int: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__xor__(self, b) - def __eq__(self, b: int) -> bool: - global OP_COUNT - OP_COUNT = inc(OP_COUNT) - return I.__eq__(self, b) + def __eq__(self, b: float) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__eq__(self, b) - def __ne__(self, b: int) -> bool: + + def __eq__(self, b: int) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__ne__(self, b) + return I.__eq__(self, b) + def __ne__(self, b: float) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__ne__(self, b) - def __lt__(self, b: int) -> bool: + + def __ne__(self, b: int) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__lt__(self, b) + return I.__ne__(self, b) + def __lt__(self, b: float) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__lt__(self, b) - def __gt__(self, b: int) -> bool: + + def __lt__(self, b: int) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__gt__(self, b) + return I.__lt__(self, b) + def __gt__(self, b: float) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__gt__(self, b) - def __le__(self, b: int) -> bool: + + def __gt__(self, b: int) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__le__(self, b) + return I.__gt__(self, b) + def __le__(self, b: float) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__le__(self, b) - def __ge__(self, b: int) -> bool: + + def __le__(self, b: int) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__ge__(self, b) + return I.__le__(self, b) + def __ge__(self, b: float) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__ge__(self, b) - def __pow__(self, exp: int): + + def __ge__(self, b: int) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) - return I.__pow__(self, exp) + return I.__ge__(self, b) + def __pow__(self, exp: float): global OP_COUNT OP_COUNT = inc(OP_COUNT) return I.__pow__(self, exp) + def __pow__(self, exp: int): + global OP_COUNT + OP_COUNT = inc(OP_COUNT) + return I.__pow__(self, exp) + + class F: @llvm def __int__(self: float) -> int: %0 = fptosi double %self to i64 ret i64 %0 + def __float__(self: float): return self + @llvm def __bool__(self: float) -> bool: %0 = fcmp one double %self, 0.000000e+00 @@ -391,10 +473,12 @@ class float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return F.__int__(self) + def __float__(self) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return self + def __bool__(self) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) @@ -406,10 +490,12 @@ class bool: global OP_COUNT OP_COUNT = inc(OP_COUNT) return 1 if self else 0 + def __float__(self) -> float: global OP_COUNT OP_COUNT = inc(OP_COUNT) return 1. if self else 0. + def __bool__(self) -> bool: global OP_COUNT OP_COUNT = inc(OP_COUNT)