diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index dedaaf9c..d5b89cd0 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -63,7 +63,7 @@ Cache::findMethod(types::ClassType *typ, const std::string &member, seqassert(e->type, "not a class"); int oldAge = typeCtx->age; typeCtx->age = 99999; - auto f = typeCtx->findBestMethod(e.get(), member, args); + auto f = TypecheckVisitor(typeCtx).findBestMethod(e.get(), member, args); typeCtx->age = oldAge; return f; } diff --git a/codon/parser/visitors/simplify/simplify_ctx.cpp b/codon/parser/visitors/simplify/simplify_ctx.cpp index 711eeb9c..307ac1c8 100644 --- a/codon/parser/visitors/simplify/simplify_ctx.cpp +++ b/codon/parser/visitors/simplify/simplify_ctx.cpp @@ -61,7 +61,8 @@ std::string SimplifyContext::getBase() const { } std::string SimplifyContext::generateCanonicalName(const std::string &name, - bool includeBase) const { + bool includeBase, + bool zeroId) const { std::string newName = name; if (includeBase && name.find('.') == std::string::npos) { std::string base = getBase(); @@ -74,7 +75,7 @@ std::string SimplifyContext::generateCanonicalName(const std::string &name, newName = (base.empty() ? "" : (base + ".")) + newName; } auto num = cache->identifierCount[newName]++; - newName = num ? format("{}.{}", newName, num) : newName; + newName = num || zeroId ? format("{}.{}", newName, num) : newName; if (newName != name) cache->identifierCount[newName]++; cache->reverseIdentifierLookup[newName] = name; diff --git a/codon/parser/visitors/simplify/simplify_ctx.h b/codon/parser/visitors/simplify/simplify_ctx.h index ed5ad79a..11a628a1 100644 --- a/codon/parser/visitors/simplify/simplify_ctx.h +++ b/codon/parser/visitors/simplify/simplify_ctx.h @@ -113,8 +113,8 @@ public: void dump() override { dump(0); } /// Generate a unique identifier (name) for a given string. - std::string generateCanonicalName(const std::string &name, - bool includeBase = false) const; + std::string generateCanonicalName(const std::string &name, bool includeBase = false, + bool zeroId = false) const; bool inFunction() const { return getLevel() && !bases.back().isType(); } bool inClass() const { return getLevel() && bases.back().isType(); } diff --git a/codon/parser/visitors/simplify/simplify_stmt.cpp b/codon/parser/visitors/simplify/simplify_stmt.cpp index a3158edc..26e3b25d 100644 --- a/codon/parser/visitors/simplify/simplify_stmt.cpp +++ b/codon/parser/visitors/simplify/simplify_stmt.cpp @@ -472,8 +472,25 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { return; } - auto canonicalName = ctx->generateCanonicalName(stmt->name, true); bool isClassMember = ctx->inClass(); + if (isClassMember && !endswith(stmt->name, ".dispatch") && + ctx->cache->classes[ctx->bases.back().name].methods[stmt->name].empty()) { + transform( + N(stmt->name + ".dispatch", nullptr, + std::vector{Param("*args")}, + N(N(N( + N(N(ctx->bases.back().name), stmt->name), + N(N("args"))))))); + } + auto func_name = stmt->name; + if (endswith(stmt->name, ".dispatch")) + func_name = func_name.substr(0, func_name.size() - 9); + auto canonicalName = ctx->generateCanonicalName( + func_name, true, isClassMember && !endswith(stmt->name, ".dispatch")); + if (endswith(stmt->name, ".dispatch")) { + canonicalName += ".dispatch"; + ctx->cache->reverseIdentifierLookup[canonicalName] = func_name; + } bool isEnclosedFunc = ctx->inFunction(); if (attr.has(Attr::ForceRealize) && (ctx->getLevel() || isClassMember)) @@ -483,7 +500,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { ctx->bases = std::vector(); if (!isClassMember) // Class members are added to class' method table - ctx->add(SimplifyItem::Func, stmt->name, canonicalName, ctx->isToplevel()); + ctx->add(SimplifyItem::Func, func_name, canonicalName, ctx->isToplevel()); if (isClassMember) ctx->bases.push_back(oldBases[0]); ctx->bases.emplace_back(SimplifyContext::Base{canonicalName}); // Add new base... @@ -602,7 +619,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { // ... set the enclosing class name... attr.parentClass = ctx->bases.back().name; // ... add the method to class' method list ... - ctx->cache->classes[ctx->bases.back().name].methods[stmt->name].push_back( + ctx->cache->classes[ctx->bases.back().name].methods[func_name].push_back( {canonicalName, nullptr, ctx->cache->age}); // ... and if the function references outer class variable (by definition a // generic), mark it as not static as it needs fully instantiated class to be @@ -637,21 +654,21 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { ExprPtr finalExpr; if (!captures.empty()) - finalExpr = N(N(stmt->name), partialArgs); + finalExpr = N(N(func_name), partialArgs); if (isClassMember && decorators.size()) error("decorators cannot be applied to class methods"); for (int j = int(decorators.size()) - 1; j >= 0; j--) { if (auto c = const_cast(decorators[j]->getCall())) { c->args.emplace(c->args.begin(), - CallExpr::Arg{"", finalExpr ? finalExpr : N(stmt->name)}); + CallExpr::Arg{"", finalExpr ? finalExpr : N(func_name)}); finalExpr = N(c->expr, c->args); } else { finalExpr = - N(decorators[j], finalExpr ? finalExpr : N(stmt->name)); + N(decorators[j], finalExpr ? finalExpr : N(func_name)); } } if (finalExpr) - resultStmt = transform(N(N(stmt->name), finalExpr)); + resultStmt = transform(N(N(func_name), finalExpr)); } void SimplifyVisitor::visit(ClassStmt *stmt) { @@ -941,7 +958,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { continue; auto subs = substitutions[ai]; auto newName = ctx->generateCanonicalName( - ctx->cache->reverseIdentifierLookup[f->name], true); + ctx->cache->reverseIdentifierLookup[f->name], true, true); auto nf = std::dynamic_pointer_cast(replace(sp, subs)); subs[nf->name] = N(newName); nf->name = newName; diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 867d7da7..77dc0ba0 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -33,16 +33,21 @@ StmtPtr TypecheckVisitor::apply(Cache *cache, StmtPtr stmts) { return std::move(infer.second); } -TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b) { +TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b, bool undoOnSuccess) { if (!a) return a = b; seqassert(b, "rhs is nullptr"); types::Type::Unification undo; - if (a->unify(b.get(), &undo) >= 0) + if (a->unify(b.get(), &undo) >= 0) { + if (undoOnSuccess) + undo.undo(); return a; - undo.undo(); + } else { + undo.undo(); + } // LOG("{} / {}", a->debugString(true), b->debugString(true)); - a->unify(b.get(), &undo); + if (!undoOnSuccess) + a->unify(b.get(), &undo); error("cannot unify {} and {}", a->toString(), b->toString()); return nullptr; } diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index e8344461..fc75f4f4 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -283,9 +283,19 @@ private: void generateFnCall(int n); /// Make an empty partial call fn(...) for a function fn. ExprPtr partializeFunction(ExprPtr expr); + /// Picks the best method of a given expression that matches the given argument + /// types. Prefers methods whose signatures are closer to the given arguments: + /// e.g. foo(int) will match (int) better that a foo(T). + /// Also takes care of the Optional arguments. + /// If multiple equally good methods are found, return the first one. + /// Return nullptr if no methods were found. + types::FuncTypePtr + findBestMethod(const Expr *expr, const std::string &member, + const std::vector> &args); private: - types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b); + types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b, + bool undoOnSuccess = false); types::TypePtr realizeType(types::ClassType *typ); types::TypePtr realizeFunc(types::FuncType *typ); std::pair inferTypes(StmtPtr stmt, bool keepLast, @@ -293,7 +303,7 @@ private: codon::ir::types::Type *getLLVMType(const types::ClassType *t); bool wrapExpr(ExprPtr &expr, types::TypePtr expectedType, - const types::FuncTypePtr &callee); + const types::FuncTypePtr &callee, bool undoOnSuccess = false); int64_t translateIndex(int64_t idx, int64_t len, bool clamp = false); int64_t sliceAdjustIndices(int64_t length, int64_t *start, int64_t *stop, int64_t step); diff --git a/codon/parser/visitors/typecheck/typecheck_ctx.cpp b/codon/parser/visitors/typecheck/typecheck_ctx.cpp index 31297307..d1c72784 100644 --- a/codon/parser/visitors/typecheck/typecheck_ctx.cpp +++ b/codon/parser/visitors/typecheck/typecheck_ctx.cpp @@ -141,10 +141,12 @@ TypeContext::findMethod(const std::string &typeName, const std::string &method) if (m != cache->classes.end()) { auto t = m->second.methods.find(method); if (t != m->second.methods.end()) { + seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"), + "first method is not dispatch"); std::unordered_map signatureLoci; std::vector vv; - for (auto &mt : t->second) { - // LOG("{}::{} @ {} vs. {}", typeName, method, age, mt.age); + for (int mti = 1; mti < t->second.size(); mti++) { + auto &mt = t->second[mti]; if (mt.age <= age) { auto sig = cache->functions[mt.name].ast->signature(); auto it = signatureLoci.find(sig); @@ -177,110 +179,6 @@ types::TypePtr TypeContext::findMember(const std::string &typeName, return nullptr; } -types::FuncTypePtr TypeContext::findBestMethod( - const Expr *expr, const std::string &member, - const std::vector> &args, bool checkSingle) { - auto typ = expr->getType()->getClass(); - seqassert(typ, "not a class"); - auto methods = findMethod(typ->name, member); - if (methods.empty()) - return nullptr; - if (methods.size() == 1 && !checkSingle) // methods is not overloaded - return methods[0]; - - // Calculate the unification score for each available methods and pick the one with - // highest score. - std::vector> scores; - for (int mi = 0; mi < methods.size(); mi++) { - auto method = instantiate(expr, methods[mi], typ.get(), false)->getFunc(); - std::vector reordered; - std::vector callArgs; - for (auto &a : args) { - callArgs.push_back({a.first, std::make_shared()}); // dummy expression - callArgs.back().value->setType(a.second); - } - auto score = reorderNamedArgs( - method.get(), callArgs, - [&](int s, int k, const std::vector> &slots, bool _) { - for (int si = 0; si < slots.size(); si++) { - // Ignore *args, *kwargs and default arguments - reordered.emplace_back(si == s || si == k || slots[si].size() != 1 - ? nullptr - : args[slots[si][0]].second); - } - return 0; - }, - [](const std::string &) { return -1; }); - if (score == -1) - continue; - // Scoring system for each argument: - // Generics, traits and default arguments get a score of zero (lowest priority). - // Optional unwrap gets the score of 1. - // Optional wrap gets the score of 2. - // Successful unification gets the score of 3 (highest priority). - for (int ai = 0, mi = 1, gi = 0; ai < reordered.size(); ai++) { - auto argType = reordered[ai]; - if (!argType) - continue; - auto expectedType = method->ast->args[ai].generic ? method->generics[gi++].type - : method->args[mi++]; - auto expectedClass = expectedType->getClass(); - // Ignore traits, *args/**kwargs and default arguments. - if (expectedClass && expectedClass->name == "Generator") - continue; - // LOG("<~> {} {}", argType->toString(), expectedType->toString()); - auto argClass = argType->getClass(); - - types::Type::Unification undo; - int u = argType->unify(expectedType.get(), &undo); - undo.undo(); - if (u >= 0) { - score += u + 3; - continue; - } - if (!method->ast->args[ai].generic) { - // Unification failed: maybe we need to wrap an argument? - if (expectedClass && expectedClass->name == TYPE_OPTIONAL && argClass && - argClass->name != expectedClass->name) { - u = argType->unify(expectedClass->generics[0].type.get(), &undo); - undo.undo(); - if (u >= 0) { - score += u + 2; - continue; - } - } - // ... or unwrap it (less ideal)? - if (argClass && argClass->name == TYPE_OPTIONAL && expectedClass && - argClass->name != expectedClass->name) { - u = argClass->generics[0].type->unify(expectedType.get(), &undo); - undo.undo(); - if (u >= 0) { - score += u; - continue; - } - } - } - // This method cannot be selected, ignore it. - score = -1; - break; - } - // LOG("{} {} / {}", typ->toString(), method->toString(), score); - if (score >= 0) - scores.emplace_back(std::make_pair(score, mi)); - } - if (scores.empty()) - return nullptr; - // Get the best score. - sort(scores.begin(), scores.end(), std::greater<>()); - // LOG("Method: {}", methods[scores[0].second]->toString()); - // std::string x; - // for (auto &a : args) - // x += format("{}{},", a.first.empty() ? "" : a.first + ": ", - // a.second->toString()); - // LOG(" {} :: {} ( {} )", typ->toString(), member, x); - return methods[scores[0].second]; -} - int TypeContext::reorderNamedArgs(types::FuncType *func, const std::vector &args, ReorderDoneFn onDone, ReorderErrorFn onError, diff --git a/codon/parser/visitors/typecheck/typecheck_ctx.h b/codon/parser/visitors/typecheck/typecheck_ctx.h index 9713ae64..b1961653 100644 --- a/codon/parser/visitors/typecheck/typecheck_ctx.h +++ b/codon/parser/visitors/typecheck/typecheck_ctx.h @@ -127,16 +127,6 @@ public: types::TypePtr findMember(const std::string &typeName, const std::string &member) const; - /// Picks the best method of a given expression that matches the given argument - /// types. Prefers methods whose signatures are closer to the given arguments: - /// e.g. foo(int) will match (int) better that a foo(T). - /// Also takes care of the Optional arguments. - /// If multiple equally good methods are found, return the first one. - /// Return nullptr if no methods were found. - types::FuncTypePtr - findBestMethod(const Expr *expr, const std::string &member, - const std::vector> &args, - bool checkSingle = false); typedef std::function> &, bool)> ReorderDoneFn; diff --git a/codon/parser/visitors/typecheck/typecheck_expr.cpp b/codon/parser/visitors/typecheck/typecheck_expr.cpp index 4b0a619f..b9201490 100644 --- a/codon/parser/visitors/typecheck/typecheck_expr.cpp +++ b/codon/parser/visitors/typecheck/typecheck_expr.cpp @@ -683,8 +683,8 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic, if (isAtomic) { auto ptrlt = ctx->instantiateGeneric(expr->lexpr.get(), ctx->findInternal("Ptr"), {lt}); - method = ctx->findBestMethod(expr->lexpr.get(), format("__atomic_{}__", magic), - {{"", ptrlt}, {"", rt}}); + method = findBestMethod(expr->lexpr.get(), format("__atomic_{}__", magic), + {{"", ptrlt}, {"", rt}}); if (method) { expr->lexpr = N(expr->lexpr); if (noReturn) @@ -693,19 +693,19 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic, } // Check if lt.__iop__(lt, rt) exists. if (!method && expr->inPlace) { - method = ctx->findBestMethod(expr->lexpr.get(), format("__i{}__", magic), - {{"", lt}, {"", rt}}); + method = findBestMethod(expr->lexpr.get(), format("__i{}__", magic), + {{"", lt}, {"", rt}}); if (method && noReturn) *noReturn = true; } // Check if lt.__op__(lt, rt) exists. if (!method) - method = ctx->findBestMethod(expr->lexpr.get(), format("__{}__", magic), - {{"", lt}, {"", rt}}); + method = findBestMethod(expr->lexpr.get(), format("__{}__", magic), + {{"", lt}, {"", rt}}); // Check if rt.__rop__(rt, lt) exists. if (!method) { - method = ctx->findBestMethod(expr->rexpr.get(), format("__r{}__", magic), - {{"", rt}, {"", lt}}); + method = findBestMethod(expr->rexpr.get(), format("__r{}__", magic), + {{"", rt}, {"", lt}}); if (method) swap(expr->lexpr, expr->rexpr); } @@ -873,8 +873,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, argTypes.emplace_back(make_pair("", typ)); // self variable for (const auto &a : *args) argTypes.emplace_back(make_pair(a.name, a.value->getType())); - if (auto bestMethod = - ctx->findBestMethod(expr->expr.get(), expr->member, argTypes)) { + if (auto bestMethod = findBestMethod(expr->expr.get(), expr->member, argTypes)) { ExprPtr e = N(bestMethod->ast->name); auto t = ctx->instantiate(expr, bestMethod, typ.get()); unify(e->type, t); @@ -906,7 +905,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, methodArgs.emplace_back(make_pair("", typ)); for (auto i = 1; i < oldType->generics.size(); i++) methodArgs.emplace_back(make_pair("", oldType->generics[i].type)); - bestMethod = ctx->findBestMethod(expr->expr.get(), expr->member, methodArgs); + bestMethod = findBestMethod(expr->expr.get(), expr->member, methodArgs); if (!bestMethod) { // Print a nice error message. std::vector nice; @@ -916,9 +915,11 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, typ->toString(), join(nice, ", ")); } } else { - // HACK: if we still have multiple valid methods, we just use the first one. - // TODO: handle this better (maybe hold these types until they can be selected?) - bestMethod = methods[0]; + auto m = ctx->cache->classes.find(typ->name); + auto t = m->second.methods.find(expr->member); + seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"), + "first method is not dispatch"); + bestMethod = t->second[0].type; } // Case 7: only one valid method remaining. Check if this is a class method or an @@ -1004,6 +1005,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in ai--; } else { // Case 3: Normal argument + // LOG("-> {}", expr->args[ai].value->toString()); expr->args[ai].value = transform(expr->args[ai].value, true); // Unbound inType might become a generator that will need to be extracted, so // don't unify it yet. @@ -1365,8 +1367,7 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) bool exists = !ctx->findMethod(typ->getClass()->name, member).empty() || ctx->findMember(typ->getClass()->name, member); if (exists && args.size() > 1) - exists &= - ctx->findBestMethod(expr->args[0].value.get(), member, args, true) != nullptr; + exists &= findBestMethod(expr->args[0].value.get(), member, args) != nullptr; return {true, transform(N(exists))}; } else if (val == "compile_error") { expr->args[0].value = transform(expr->args[0].value); @@ -1609,8 +1610,72 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) { return call; } +types::FuncTypePtr TypecheckVisitor::findBestMethod( + const Expr *expr, const std::string &member, + const std::vector> &args) { + auto typ = expr->getType()->getClass(); + seqassert(typ, "not a class"); + + // Pick the last method that accepts the given arguments. + auto methods = ctx->findMethod(typ->name, member); + // if (methods.size() == 1) + // return methods[0]; + types::FuncTypePtr method = nullptr; + for (int mi = int(methods.size()) - 1; mi >= 0; mi--) { + auto m = ctx->instantiate(expr, methods[mi], typ.get(), false)->getFunc(); + std::vector reordered; + std::vector callArgs; + for (auto &a : args) { + callArgs.push_back({a.first, std::make_shared()}); // dummy expression + callArgs.back().value->setType(a.second); + } + auto score = ctx->reorderNamedArgs( + m.get(), callArgs, + [&](int s, int k, const std::vector> &slots, bool _) { + for (int si = 0; si < slots.size(); si++) { + if (m->ast->args[si].generic) { + // Ignore type arguments + } else if (si == s || si == k || slots[si].size() != 1) { + // Ignore *args, *kwargs and default arguments + reordered.emplace_back(nullptr); + } else { + reordered.emplace_back(args[slots[si][0]].second); + } + } + return 0; + }, + [](const std::string &) { return -1; }); + for (int ai = 0, mi = 1, gi = 0; score != -1 && ai < reordered.size(); ai++) { + auto argType = reordered[ai]; + if (!argType) + continue; + auto expectTyp = + m->ast->args[ai].generic ? m->generics[gi++].type : m->args[mi++]; + try { + ExprPtr dummy = std::make_shared(""); + dummy->type = argType; + dummy->done = true; + wrapExpr(dummy, expectTyp, m, /*undoOnSuccess*/ true); + } catch (const exc::ParserException &) { + score = -1; + } + } + if (score != -1) { + // std::vector ar; + // for (auto &a: args) { + // if (a.first.empty()) ar.push_back(a.second->toString()); + // else ar.push_back(format("{}: {}", a.first, a.second->toString())); + // } + // LOG("- {} vs {}", m->toString(), join(ar, "; ")); + method = methods[mi]; + break; + } + } + return method; +} + bool TypecheckVisitor::wrapExpr(ExprPtr &expr, TypePtr expectedType, - const FuncTypePtr &callee) { + const FuncTypePtr &callee, bool undoOnSuccess) { auto expectedClass = expectedType->getClass(); auto exprClass = expr->getType()->getClass(); if (callee && expr->isType()) @@ -1637,7 +1702,7 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, TypePtr expectedType, // Case 7: wrap raw Seq functions into Partial(...) call for easy realization. expr = partializeFunction(expr); } - unify(expr->type, expectedType); + unify(expr->type, expectedType, undoOnSuccess); return true; } diff --git a/codon/parser/visitors/typecheck/typecheck_stmt.cpp b/codon/parser/visitors/typecheck/typecheck_stmt.cpp index 432857c0..5f08fb94 100644 --- a/codon/parser/visitors/typecheck/typecheck_stmt.cpp +++ b/codon/parser/visitors/typecheck/typecheck_stmt.cpp @@ -153,9 +153,9 @@ void TypecheckVisitor::visit(UpdateStmt *stmt) { ctx->instantiateGeneric(stmt->lhs.get(), ctx->findInternal("Ptr"), {lhsClass}); c->args[1].value = transform(c->args[1].value); auto rhsTyp = c->args[1].value->getType()->getClass(); - if (auto method = ctx->findBestMethod( - stmt->lhs.get(), format("__atomic_{}__", c->expr->getId()->value), - {{"", ptrTyp}, {"", rhsTyp}})) { + if (auto method = findBestMethod(stmt->lhs.get(), + format("__atomic_{}__", c->expr->getId()->value), + {{"", ptrTyp}, {"", rhsTyp}})) { resultStmt = transform(N(N( N(method->ast->name), N(stmt->lhs), c->args[1].value))); return; @@ -168,8 +168,8 @@ void TypecheckVisitor::visit(UpdateStmt *stmt) { if (stmt->isAtomic && lhsClass && rhsClass) { auto ptrType = ctx->instantiateGeneric(stmt->lhs.get(), ctx->findInternal("Ptr"), {lhsClass}); - if (auto m = ctx->findBestMethod(stmt->lhs.get(), "__atomic_xchg__", - {{"", ptrType}, {"", rhsClass}})) { + if (auto m = findBestMethod(stmt->lhs.get(), "__atomic_xchg__", + {{"", ptrType}, {"", rhsClass}})) { resultStmt = transform(N( N(N(m->ast->name), N(stmt->lhs), stmt->rhs))); return; diff --git a/stdlib/internal/types/bool.codon b/stdlib/internal/types/bool.codon index de66f660..6e54e3ec 100644 --- a/stdlib/internal/types/bool.codon +++ b/stdlib/internal/types/bool.codon @@ -4,7 +4,7 @@ from internal.attributes import commutative, associative class bool: def __new__() -> bool: return False - def __new__[T](what: T) -> bool: # lowest priority! + def __new__(what) -> bool: return what.__bool__() def __repr__(self) -> str: return "True" if self else "False" diff --git a/stdlib/internal/types/collections/list.codon b/stdlib/internal/types/collections/list.codon index dd490b10..364043d7 100644 --- a/stdlib/internal/types/collections/list.codon +++ b/stdlib/internal/types/collections/list.codon @@ -2,9 +2,9 @@ import internal.gc as gc @extend class List: - def __init__(self, arr: Array[T], len: int): - self.arr = arr - self.len = len + def __init__(self): + self.arr = Array[T](10) + self.len = 0 def __init__(self, it: Generator[T]): self.arr = Array[T](10) @@ -12,27 +12,27 @@ class List: for i in it: self.append(i) - def __init__(self, capacity: int): - self.arr = Array[T](capacity) - self.len = 0 - - def __init__(self): - self.arr = Array[T](10) - self.len = 0 - def __init__(self, other: List[T]): self.arr = Array[T](other.len) self.len = 0 for i in other: self.append(i) - # Dummy __init__ used for list comprehension optimization + def __init__(self, capacity: int): + self.arr = Array[T](capacity) + self.len = 0 + def __init__(self, dummy: bool, other): + """Dummy __init__ used for list comprehension optimization""" if hasattr(other, '__len__'): self.__init__(other.__len__()) else: self.__init__() + def __init__(self, arr: Array[T], len: int): + self.arr = arr + self.len = len + def __len__(self): return self.len diff --git a/stdlib/internal/types/complex.codon b/stdlib/internal/types/complex.codon index ef4502b4..7e4811c6 100644 --- a/stdlib/internal/types/complex.codon +++ b/stdlib/internal/types/complex.codon @@ -6,18 +6,12 @@ class complex: def __new__(): return complex(0.0, 0.0) - def __new__(real: int, imag: int): - return complex(float(real), float(imag)) - - def __new__(real: float, imag: int): - return complex(real, float(imag)) - - def __new__(real: int, imag: float): - return complex(float(real), imag) - def __new__(other): return other.__complex__() + def __new__(real, imag): + return complex(float(real), float(imag)) + def __complex__(self): return self @@ -42,6 +36,42 @@ class complex: def __hash__(self): return self.real.__hash__() + self.imag.__hash__()*1000003 + def __add__(self, other): + return self + complex(other) + + def __sub__(self, other): + return self - complex(other) + + def __mul__(self, other): + return self * complex(other) + + def __truediv__(self, other): + return self / complex(other) + + def __eq__(self, other): + return self == complex(other) + + def __ne__(self, other): + return self != complex(other) + + def __pow__(self, other): + return self ** complex(other) + + def __radd__(self, other): + return complex(other) + self + + def __rsub__(self, other): + return complex(other) - self + + def __rmul__(self, other): + return complex(other) * self + + def __rtruediv__(self, other): + return complex(other) / self + + def __rpow__(self, other): + return complex(other) ** self + def __add__(self, other: complex): return complex(self.real + other.real, self.imag + other.imag) @@ -160,42 +190,6 @@ class complex: phase += other.imag * log(vabs) return complex(len * cos(phase), len * sin(phase)) - def __add__(self, other): - return self + complex(other) - - def __sub__(self, other): - return self - complex(other) - - def __mul__(self, other): - return self * complex(other) - - def __truediv__(self, other): - return self / complex(other) - - def __eq__(self, other): - return self == complex(other) - - def __ne__(self, other): - return self != complex(other) - - def __pow__(self, other): - return self ** complex(other) - - def __radd__(self, other): - return complex(other) + self - - def __rsub__(self, other): - return complex(other) - self - - def __rmul__(self, other): - return complex(other) * self - - def __rtruediv__(self, other): - return complex(other) / self - - def __rpow__(self, other): - return complex(other) ** self - def __repr__(self): @pure @llvm diff --git a/stdlib/internal/types/float.codon b/stdlib/internal/types/float.codon index 54b358e0..68e59635 100644 --- a/stdlib/internal/types/float.codon +++ b/stdlib/internal/types/float.codon @@ -10,80 +10,105 @@ def seq_str_float(a: float) -> str: pass class float: def __new__() -> float: return 0.0 - def __new__[T](what: T): + + def __new__(what): return what.__float__() + + def __new__(s: str) -> float: + from C import strtod(cobj, Ptr[cobj]) -> float + buf = __array__[byte](32) + n = s.__len__() + need_dyn_alloc = (n >= buf.__len__()) + + p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr + str.memcpy(p, s.ptr, n) + p[n] = byte(0) + + end = cobj() + result = strtod(p, __ptr__(end)) + + if need_dyn_alloc: + free(p) + + if end != p + n: + raise ValueError("could not convert string to float: " + s) + + return result + def __repr__(self) -> str: s = seq_str_float(self) return s if s != "-nan" else "nan" + def __copy__(self) -> float: return self + def __deepcopy__(self) -> float: return self + @pure @llvm def __int__(self) -> int: %0 = fptosi double %self to i64 ret i64 %0 + def __float__(self): return self + @pure @llvm def __bool__(self) -> bool: %0 = fcmp one double %self, 0.000000e+00 %1 = zext i1 %0 to i8 ret i8 %1 + def __complex__(self): return complex(self, 0.0) + def __pos__(self) -> float: return self + @pure @llvm def __neg__(self) -> float: %0 = fneg double %self ret double %0 + @pure @commutative @llvm def __add__(a: float, b: float) -> float: %tmp = fadd double %a, %b ret double %tmp - @commutative - def __add__(self, other: int) -> float: - return self.__add__(float(other)) + @pure @llvm def __sub__(a: float, b: float) -> float: %tmp = fsub double %a, %b ret double %tmp - def __sub__(self, other: int) -> float: - return self.__sub__(float(other)) + @pure @commutative @llvm def __mul__(a: float, b: float) -> float: %tmp = fmul double %a, %b ret double %tmp - @commutative - def __mul__(self, other: int) -> float: - return self.__mul__(float(other)) + def __floordiv__(self, other: float) -> float: return self.__truediv__(other).__floor__() - def __floordiv__(self, other: int) -> float: - return self.__floordiv__(float(other)) + + @pure @llvm def __truediv__(a: float, b: float) -> float: %tmp = fdiv double %a, %b ret double %tmp - def __truediv__(self, other: int) -> float: - return self.__truediv__(float(other)) + @pure @llvm def __mod__(a: float, b: float) -> float: %tmp = frem double %a, %b ret double %tmp - def __mod__(self, other: int) -> float: - return self.__mod__(float(other)) + def __divmod__(self, other: float): mod = self % other div = (self - mod) / other @@ -103,16 +128,14 @@ class float: floordiv = (0.0).copysign(self / other) return (floordiv, mod) - def __divmod__(self, other: int): - return self.__divmod__(float(other)) + @pure @llvm def __eq__(a: float, b: float) -> bool: %tmp = fcmp oeq double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __eq__(self, other: int) -> bool: - return self.__eq__(float(other)) + @pure @llvm def __ne__(a: float, b: float) -> bool: @@ -120,174 +143,190 @@ class float: %tmp = fcmp one double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __ne__(self, other: int) -> bool: - return self.__ne__(float(other)) + @pure @llvm def __lt__(a: float, b: float) -> bool: %tmp = fcmp olt double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __lt__(self, other: int) -> bool: - return self.__lt__(float(other)) + @pure @llvm def __gt__(a: float, b: float) -> bool: %tmp = fcmp ogt double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __gt__(self, other: int) -> bool: - return self.__gt__(float(other)) + @pure @llvm def __le__(a: float, b: float) -> bool: %tmp = fcmp ole double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __le__(self, other: int) -> bool: - return self.__le__(float(other)) + @pure @llvm def __ge__(a: float, b: float) -> bool: %tmp = fcmp oge double %a, %b %res = zext i1 %tmp to i8 ret i8 %res - def __ge__(self, other: int) -> bool: - return self.__ge__(float(other)) + @pure @llvm def sqrt(a: float) -> float: declare double @llvm.sqrt.f64(double %a) %tmp = call double @llvm.sqrt.f64(double %a) ret double %tmp + @pure @llvm def sin(a: float) -> float: declare double @llvm.sin.f64(double %a) %tmp = call double @llvm.sin.f64(double %a) ret double %tmp + @pure @llvm def cos(a: float) -> float: declare double @llvm.cos.f64(double %a) %tmp = call double @llvm.cos.f64(double %a) ret double %tmp + @pure @llvm def exp(a: float) -> float: declare double @llvm.exp.f64(double %a) %tmp = call double @llvm.exp.f64(double %a) ret double %tmp + @pure @llvm def exp2(a: float) -> float: declare double @llvm.exp2.f64(double %a) %tmp = call double @llvm.exp2.f64(double %a) ret double %tmp + @pure @llvm def log(a: float) -> float: declare double @llvm.log.f64(double %a) %tmp = call double @llvm.log.f64(double %a) ret double %tmp + @pure @llvm def log10(a: float) -> float: declare double @llvm.log10.f64(double %a) %tmp = call double @llvm.log10.f64(double %a) ret double %tmp + @pure @llvm def log2(a: float) -> float: declare double @llvm.log2.f64(double %a) %tmp = call double @llvm.log2.f64(double %a) ret double %tmp + @pure @llvm def __abs__(a: float) -> float: declare double @llvm.fabs.f64(double %a) %tmp = call double @llvm.fabs.f64(double %a) ret double %tmp + @pure @llvm def __floor__(a: float) -> float: declare double @llvm.floor.f64(double %a) %tmp = call double @llvm.floor.f64(double %a) ret double %tmp + @pure @llvm def __ceil__(a: float) -> float: declare double @llvm.ceil.f64(double %a) %tmp = call double @llvm.ceil.f64(double %a) ret double %tmp + @pure @llvm def __trunc__(a: float) -> float: declare double @llvm.trunc.f64(double %a) %tmp = call double @llvm.trunc.f64(double %a) ret double %tmp + @pure @llvm def rint(a: float) -> float: declare double @llvm.rint.f64(double %a) %tmp = call double @llvm.rint.f64(double %a) ret double %tmp + @pure @llvm def nearbyint(a: float) -> float: declare double @llvm.nearbyint.f64(double %a) %tmp = call double @llvm.nearbyint.f64(double %a) ret double %tmp + @pure @llvm def __round__(a: float) -> float: declare double @llvm.round.f64(double %a) %tmp = call double @llvm.round.f64(double %a) ret double %tmp + @pure @llvm def __pow__(a: float, b: float) -> float: declare double @llvm.pow.f64(double %a, double %b) %tmp = call double @llvm.pow.f64(double %a, double %b) ret double %tmp - def __pow__(self, other: int) -> float: - return self.__pow__(float(other)) + @pure @llvm def min(a: float, b: float) -> float: declare double @llvm.minnum.f64(double %a, double %b) %tmp = call double @llvm.minnum.f64(double %a, double %b) ret double %tmp + @pure @llvm def max(a: float, b: float) -> float: declare double @llvm.maxnum.f64(double %a, double %b) %tmp = call double @llvm.maxnum.f64(double %a, double %b) ret double %tmp + @pure @llvm def copysign(a: float, b: float) -> float: declare double @llvm.copysign.f64(double %a, double %b) %tmp = call double @llvm.copysign.f64(double %a, double %b) ret double %tmp + @pure @llvm def fma(a: float, b: float, c: float) -> float: declare double @llvm.fma.f64(double %a, double %b, double %c) %tmp = call double @llvm.fma.f64(double %a, double %b, double %c) ret double %tmp + @llvm def __atomic_xchg__(d: Ptr[float], b: float) -> void: %tmp = atomicrmw xchg double* %d, double %b seq_cst ret void + @llvm def __atomic_add__(d: Ptr[float], b: float) -> float: %tmp = atomicrmw fadd double* %d, double %b seq_cst ret double %tmp + @llvm def __atomic_sub__(d: Ptr[float], b: float) -> float: %tmp = atomicrmw fsub double* %d, double %b seq_cst ret double %tmp + def __hash__(self): from C import frexp(float, Ptr[Int[32]]) -> float @@ -332,31 +371,13 @@ class float: x = -2 return x - def __new__(s: str) -> float: - from C import strtod(cobj, Ptr[cobj]) -> float - buf = __array__[byte](32) - n = s.__len__() - need_dyn_alloc = (n >= buf.__len__()) - - p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr - str.memcpy(p, s.ptr, n) - p[n] = byte(0) - - end = cobj() - result = strtod(p, __ptr__(end)) - - if need_dyn_alloc: - free(p) - - if end != p + n: - raise ValueError("could not convert string to float: " + s) - - return result def __match__(self, i: float): return self == i + @property def real(self): return self + @property def imag(self): return 0.0 diff --git a/stdlib/internal/types/int.codon b/stdlib/internal/types/int.codon index 5e7be7db..8c2e8676 100644 --- a/stdlib/internal/types/int.codon +++ b/stdlib/internal/types/int.codon @@ -14,37 +14,56 @@ class int: @llvm def __new__() -> int: ret i64 0 - def __new__[T](what: T) -> int: # lowest priority! + + def __new__(what) -> int: return what.__int__() + + def __new__(s: str) -> int: + return int._from_str(s, 10) + + def __new__(s: str, base: int) -> int: + return int._from_str(s, base) + def __int__(self) -> int: return self + @pure @llvm def __float__(self) -> float: %tmp = sitofp i64 %self to double ret double %tmp + def __complex__(self): return complex(float(self), 0.0) + def __index__(self): return self + def __repr__(self) -> str: return seq_str_int(self) + def __copy__(self) -> int: return self + def __deepcopy__(self) -> int: return self + def __hash__(self) -> int: return self + @pure @llvm def __bool__(self) -> bool: %0 = icmp ne i64 %self, 0 %1 = zext i1 %0 to i8 ret i8 %1 + def __pos__(self) -> int: return self + def __neg__(self) -> int: return 0 - self + @pure @llvm def __abs__(self) -> int: @@ -52,23 +71,19 @@ class int: %1 = sub i64 0, %self %2 = select i1 %0, i64 %self, i64 %1 ret i64 %2 + @pure @llvm def __lshift__(self, other: int) -> int: %0 = shl i64 %self, %other ret i64 %0 + @pure @llvm def __rshift__(self, other: int) -> int: %0 = ashr i64 %self, %other ret i64 %0 - @pure - @commutative - @associative - @llvm - def __add__(self, b: int) -> int: - %tmp = add i64 %self, %b - ret i64 %tmp + @pure @commutative @llvm @@ -76,17 +91,36 @@ class int: %0 = sitofp i64 %self to double %1 = fadd double %0, %other ret double %1 + @pure + @commutative + @associative @llvm - def __sub__(self, b: int) -> int: - %tmp = sub i64 %self, %b + def __add__(self, b: int) -> int: + %tmp = add i64 %self, %b ret i64 %tmp + @pure @llvm def __sub__(self, other: float) -> float: %0 = sitofp i64 %self to double %1 = fsub double %0, %other ret double %1 + + @pure + @llvm + def __sub__(self, b: int) -> int: + %tmp = sub i64 %self, %b + ret i64 %tmp + + @pure + @commutative + @llvm + def __mul__(self, other: float) -> float: + %0 = sitofp i64 %self to double + %1 = fmul double %0, %other + ret double %1 + @pure @commutative @associative @@ -95,18 +129,7 @@ class int: def __mul__(self, b: int) -> int: %tmp = mul i64 %self, %b ret i64 %tmp - @pure - @commutative - @llvm - def __mul__(self, other: float) -> float: - %0 = sitofp i64 %self to double - %1 = fmul double %0, %other - ret double %1 - @pure - @llvm - def __floordiv__(self, b: int) -> int: - %tmp = sdiv i64 %self, %b - ret i64 %tmp + @pure @llvm def __floordiv__(self, other: float) -> float: @@ -115,6 +138,20 @@ class int: %1 = fdiv double %0, %other %2 = call double @llvm.floor.f64(double %1) ret double %2 + + @pure + @llvm + def __floordiv__(self, b: int) -> int: + %tmp = sdiv i64 %self, %b + ret i64 %tmp + + @pure + @llvm + def __truediv__(self, other: float) -> float: + %0 = sitofp i64 %self to double + %1 = fdiv double %0, %other + ret double %1 + @pure @llvm def __truediv__(self, other: int) -> float: @@ -122,23 +159,20 @@ class int: %1 = sitofp i64 %other to double %2 = fdiv double %0, %1 ret double %2 - @pure - @llvm - def __truediv__(self, other: float) -> float: - %0 = sitofp i64 %self to double - %1 = fdiv double %0, %other - ret double %1 - @pure - @llvm - def __mod__(a: int, b: int) -> int: - %tmp = srem i64 %a, %b - ret i64 %tmp + @pure @llvm def __mod__(self, other: float) -> float: %0 = sitofp i64 %self to double %1 = frem double %0, %other ret double %1 + + @pure + @llvm + def __mod__(a: int, b: int) -> int: + %tmp = srem i64 %a, %b + ret i64 %tmp + def __divmod__(self, other: int): d = self // other m = self - d*other @@ -146,11 +180,13 @@ class int: m += other d -= 1 return (d, m) + @pure @llvm def __invert__(a: int) -> int: %tmp = xor i64 %a, -1 ret i64 %tmp + @pure @commutative @associative @@ -158,6 +194,7 @@ class int: def __and__(a: int, b: int) -> int: %tmp = and i64 %a, %b ret i64 %tmp + @pure @commutative @associative @@ -165,6 +202,7 @@ class int: def __or__(a: int, b: int) -> int: %tmp = or i64 %a, %b ret i64 %tmp + @pure @commutative @associative @@ -172,42 +210,42 @@ class int: def __xor__(a: int, b: int) -> int: %tmp = xor i64 %a, %b ret i64 %tmp + @pure @llvm def __bitreverse__(a: int) -> int: declare i64 @llvm.bitreverse.i64(i64 %a) %tmp = call i64 @llvm.bitreverse.i64(i64 %a) ret i64 %tmp + @pure @llvm def __bswap__(a: int) -> int: declare i64 @llvm.bswap.i64(i64 %a) %tmp = call i64 @llvm.bswap.i64(i64 %a) ret i64 %tmp + @pure @llvm def __ctpop__(a: int) -> int: declare i64 @llvm.ctpop.i64(i64 %a) %tmp = call i64 @llvm.ctpop.i64(i64 %a) ret i64 %tmp + @pure @llvm def __ctlz__(a: int) -> int: declare i64 @llvm.ctlz.i64(i64 %a, i1 %is_zero_undef) %tmp = call i64 @llvm.ctlz.i64(i64 %a, i1 false) ret i64 %tmp + @pure @llvm def __cttz__(a: int) -> int: declare i64 @llvm.cttz.i64(i64 %a, i1 %is_zero_undef) %tmp = call i64 @llvm.cttz.i64(i64 %a, i1 false) ret i64 %tmp - @pure - @llvm - def __eq__(a: int, b: int) -> bool: - %tmp = icmp eq i64 %a, %b - %res = zext i1 %tmp to i8 - ret i8 %res + @pure @llvm def __eq__(self, b: float) -> bool: @@ -215,12 +253,14 @@ class int: %1 = fcmp oeq double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __ne__(a: int, b: int) -> bool: - %tmp = icmp ne i64 %a, %b + def __eq__(a: int, b: int) -> bool: + %tmp = icmp eq i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __ne__(self, b: float) -> bool: @@ -228,12 +268,14 @@ class int: %1 = fcmp one double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __lt__(a: int, b: int) -> bool: - %tmp = icmp slt i64 %a, %b + def __ne__(a: int, b: int) -> bool: + %tmp = icmp ne i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __lt__(self, b: float) -> bool: @@ -241,12 +283,14 @@ class int: %1 = fcmp olt double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __gt__(a: int, b: int) -> bool: - %tmp = icmp sgt i64 %a, %b + def __lt__(a: int, b: int) -> bool: + %tmp = icmp slt i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __gt__(self, b: float) -> bool: @@ -254,12 +298,14 @@ class int: %1 = fcmp ogt double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __le__(a: int, b: int) -> bool: - %tmp = icmp sle i64 %a, %b + def __gt__(a: int, b: int) -> bool: + %tmp = icmp sgt i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __le__(self, b: float) -> bool: @@ -267,12 +313,14 @@ class int: %1 = fcmp ole double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 + @pure @llvm - def __ge__(a: int, b: int) -> bool: - %tmp = icmp sge i64 %a, %b + def __le__(a: int, b: int) -> bool: + %tmp = icmp sle i64 %a, %b %res = zext i1 %tmp to i8 ret i8 %res + @pure @llvm def __ge__(self, b: float) -> bool: @@ -280,10 +328,17 @@ class int: %1 = fcmp oge double %0, %b %2 = zext i1 %1 to i8 ret i8 %2 - def __new__(s: str) -> int: - return int._from_str(s, 10) - def __new__(s: str, base: int) -> int: - return int._from_str(s, base) + + @pure + @llvm + def __ge__(a: int, b: int) -> bool: + %tmp = icmp sge i64 %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + def __pow__(self, exp: float): + return float(self) ** exp + def __pow__(self, exp: int): if exp < 0: return 0 @@ -296,53 +351,65 @@ class int: break self *= self return result - def __pow__(self, exp: float): - return float(self) ** exp + def popcnt(self): return Int[64](self).popcnt() + @llvm def __atomic_xchg__(d: Ptr[int], b: int) -> void: %tmp = atomicrmw xchg i64* %d, i64 %b seq_cst ret void + @llvm def __atomic_add__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw add i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_sub__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw sub i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_and__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw and i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_nand__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw nand i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_or__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw or i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def _atomic_xor(d: Ptr[int], b: int) -> int: %tmp = atomicrmw xor i64* %d, i64 %b seq_cst ret i64 %tmp + def __atomic_xor__(self, b: int) -> int: return int._atomic_xor(__ptr__(self), b) + @llvm def __atomic_min__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw min i64* %d, i64 %b seq_cst ret i64 %tmp + @llvm def __atomic_max__(d: Ptr[int], b: int) -> int: %tmp = atomicrmw max i64* %d, i64 %b seq_cst ret i64 %tmp + def __match__(self, i: int): return self == i + @property def real(self): return self + @property def imag(self): return 0 diff --git a/stdlib/internal/types/intn.codon b/stdlib/internal/types/intn.codon index 4ed993da..99b774e7 100644 --- a/stdlib/internal/types/intn.codon +++ b/stdlib/internal/types/intn.codon @@ -10,9 +10,11 @@ class Int: def __new__() -> Int[N]: check_N(N) return Int[N](0) + def __new__(what: Int[N]) -> Int[N]: check_N(N) return what + def __new__(what: int) -> Int[N]: check_N(N) if N < 64: @@ -21,10 +23,12 @@ class Int: return what else: return __internal__.int_sext(what, 64, N) + @pure @llvm def __new__(what: UInt[N]) -> Int[N]: ret i{=N} %what + def __new__(what: str) -> Int[N]: check_N(N) ret = Int[N]() @@ -39,6 +43,7 @@ class Int: ret = ret * Int[N](10) + Int[N](int(what.ptr[i]) - 48) i += 1 return sign * ret + def __int__(self) -> int: if N > 64: return __internal__.int_trunc(self, N, 64) @@ -46,37 +51,47 @@ class Int: return self else: return __internal__.int_sext(self, N, 64) + def __index__(self): return int(self) + def __copy__(self) -> Int[N]: return self + def __deepcopy__(self) -> Int[N]: return self + def __hash__(self) -> int: return int(self) + @pure @llvm def __float__(self) -> float: %0 = sitofp i{=N} %self to double ret double %0 + @pure @llvm def __bool__(self) -> bool: %0 = icmp ne i{=N} %self, 0 %1 = zext i1 %0 to i8 ret i8 %1 + def __pos__(self) -> Int[N]: return self + @pure @llvm def __neg__(self) -> Int[N]: %0 = sub i{=N} 0, %self ret i{=N} %0 + @pure @llvm def __invert__(self) -> Int[N]: %0 = xor i{=N} %self, -1 ret i{=N} %0 + @pure @commutative @associative @@ -84,11 +99,13 @@ class Int: def __add__(self, other: Int[N]) -> Int[N]: %0 = add i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __sub__(self, other: Int[N]) -> Int[N]: %0 = sub i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -97,11 +114,13 @@ class Int: def __mul__(self, other: Int[N]) -> Int[N]: %0 = mul i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __floordiv__(self, other: Int[N]) -> Int[N]: %0 = sdiv i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __truediv__(self, other: Int[N]) -> float: @@ -109,11 +128,13 @@ class Int: %1 = sitofp i{=N} %other to double %2 = fdiv double %0, %1 ret double %2 + @pure @llvm def __mod__(self, other: Int[N]) -> Int[N]: %0 = srem i{=N} %self, %other ret i{=N} %0 + def __divmod__(self, other: Int[N]): d = self // other m = self - d*other @@ -121,52 +142,61 @@ class Int: m += other d -= Int[N](1) return (d, m) + @pure @llvm def __lshift__(self, other: Int[N]) -> Int[N]: %0 = shl i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __rshift__(self, other: Int[N]) -> Int[N]: %0 = ashr i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __eq__(self, other: Int[N]) -> bool: %0 = icmp eq i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ne__(self, other: Int[N]) -> bool: %0 = icmp ne i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __lt__(self, other: Int[N]) -> bool: %0 = icmp slt i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __gt__(self, other: Int[N]) -> bool: %0 = icmp sgt i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __le__(self, other: Int[N]) -> bool: %0 = icmp sle i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ge__(self, other: Int[N]) -> bool: %0 = icmp sge i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @commutative @associative @@ -174,6 +204,7 @@ class Int: def __and__(self, other: Int[N]) -> Int[N]: %0 = and i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -181,6 +212,7 @@ class Int: def __or__(self, other: Int[N]) -> Int[N]: %0 = or i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -188,6 +220,7 @@ class Int: def __xor__(self, other: Int[N]) -> Int[N]: %0 = xor i{=N} %self, %other ret i{=N} %0 + @llvm def __pickle__(self, dest: Ptr[byte]) -> void: declare i32 @gzwrite(i8*, i8*, i32) @@ -198,6 +231,7 @@ class Int: %szi = ptrtoint i{=N}* %sz to i32 %2 = call i32 @gzwrite(i8* %dest, i8* %1, i32 %szi) ret void + @llvm def __unpickle__(src: Ptr[byte]) -> Int[N]: declare i32 @gzread(i8*, i8*, i32) @@ -208,18 +242,23 @@ class Int: %2 = call i32 @gzread(i8* %src, i8* %1, i32 %szi) %3 = load i{=N}, i{=N}* %0 ret i{=N} %3 + def __repr__(self) -> str: return str.cat(('Int[', seq_str_int(N), '](', seq_str_int(int(self)), ')')) + def __str__(self) -> str: return seq_str_int(int(self)) + @pure @llvm def _popcnt(self) -> Int[N]: declare i{=N} @llvm.ctpop.i{=N}(i{=N}) %0 = call i{=N} @llvm.ctpop.i{=N}(i{=N} %self) ret i{=N} %0 + def popcnt(self): return int(self._popcnt()) + def len() -> int: return N @@ -228,9 +267,11 @@ class UInt: def __new__() -> UInt[N]: check_N(N) return UInt[N](0) + def __new__(what: UInt[N]) -> UInt[N]: check_N(N) return what + def __new__(what: int) -> UInt[N]: check_N(N) if N < 64: @@ -239,13 +280,16 @@ class UInt: return UInt[N](Int[N](what)) else: return UInt[N](__internal__.int_zext(what, 64, N)) + @pure @llvm def __new__(what: Int[N]) -> UInt[N]: ret i{=N} %what + def __new__(what: str) -> UInt[N]: check_N(N) return UInt[N](Int[N](what)) + def __int__(self) -> int: if N > 64: return __internal__.int_trunc(self, N, 64) @@ -253,37 +297,47 @@ class UInt: return Int[64](self) else: return __internal__.int_zext(self, N, 64) + def __index__(self): return int(self) + def __copy__(self) -> UInt[N]: return self + def __deepcopy__(self) -> UInt[N]: return self + def __hash__(self) -> int: return int(self) + @pure @llvm def __float__(self) -> float: %0 = uitofp i{=N} %self to double ret double %0 + @pure @llvm def __bool__(self) -> bool: %0 = icmp ne i{=N} %self, 0 %1 = zext i1 %0 to i8 ret i8 %1 + def __pos__(self) -> UInt[N]: return self + @pure @llvm def __neg__(self) -> UInt[N]: %0 = sub i{=N} 0, %self ret i{=N} %0 + @pure @llvm def __invert__(self) -> UInt[N]: %0 = xor i{=N} %self, -1 ret i{=N} %0 + @pure @commutative @associative @@ -291,11 +345,13 @@ class UInt: def __add__(self, other: UInt[N]) -> UInt[N]: %0 = add i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __sub__(self, other: UInt[N]) -> UInt[N]: %0 = sub i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -304,11 +360,13 @@ class UInt: def __mul__(self, other: UInt[N]) -> UInt[N]: %0 = mul i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __floordiv__(self, other: UInt[N]) -> UInt[N]: %0 = udiv i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __truediv__(self, other: UInt[N]) -> float: @@ -316,59 +374,70 @@ class UInt: %1 = uitofp i{=N} %other to double %2 = fdiv double %0, %1 ret double %2 + @pure @llvm def __mod__(self, other: UInt[N]) -> UInt[N]: %0 = urem i{=N} %self, %other ret i{=N} %0 + def __divmod__(self, other: UInt[N]): return (self // other, self % other) + @pure @llvm def __lshift__(self, other: UInt[N]) -> UInt[N]: %0 = shl i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __rshift__(self, other: UInt[N]) -> UInt[N]: %0 = lshr i{=N} %self, %other ret i{=N} %0 + @pure @llvm def __eq__(self, other: UInt[N]) -> bool: %0 = icmp eq i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ne__(self, other: UInt[N]) -> bool: %0 = icmp ne i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __lt__(self, other: UInt[N]) -> bool: %0 = icmp ult i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __gt__(self, other: UInt[N]) -> bool: %0 = icmp ugt i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __le__(self, other: UInt[N]) -> bool: %0 = icmp ule i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ge__(self, other: UInt[N]) -> bool: %0 = icmp uge i{=N} %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @commutative @associative @@ -376,6 +445,7 @@ class UInt: def __and__(self, other: UInt[N]) -> UInt[N]: %0 = and i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -383,6 +453,7 @@ class UInt: def __or__(self, other: UInt[N]) -> UInt[N]: %0 = or i{=N} %self, %other ret i{=N} %0 + @pure @commutative @associative @@ -390,6 +461,7 @@ class UInt: def __xor__(self, other: UInt[N]) -> UInt[N]: %0 = xor i{=N} %self, %other ret i{=N} %0 + @llvm def __pickle__(self, dest: Ptr[byte]) -> void: declare i32 @gzwrite(i8*, i8*, i32) @@ -400,6 +472,7 @@ class UInt: %szi = ptrtoint i{=N}* %sz to i32 %2 = call i32 @gzwrite(i8* %dest, i8* %1, i32 %szi) ret void + @llvm def __unpickle__(src: Ptr[byte]) -> UInt[N]: declare i32 @gzread(i8*, i8*, i32) @@ -410,12 +483,16 @@ class UInt: %2 = call i32 @gzread(i8* %src, i8* %1, i32 %szi) %3 = load i{=N}, i{=N}* %0 ret i{=N} %3 + def __repr__(self) -> str: return str.cat(('UInt[', seq_str_int(N), '](', seq_str_uint(int(self)), ')')) + def __str__(self) -> str: return seq_str_uint(int(self)) + def popcnt(self): return int(Int[N](self)._popcnt()) + def len() -> int: return N diff --git a/stdlib/internal/types/optional.codon b/stdlib/internal/types/optional.codon index c9d668e8..022beaa8 100644 --- a/stdlib/internal/types/optional.codon +++ b/stdlib/internal/types/optional.codon @@ -5,29 +5,36 @@ class Optional: return __internal__.opt_tuple_new(T) else: return __internal__.opt_ref_new(T) + def __new__(what: T) -> Optional[T]: if isinstance(T, ByVal): return __internal__.opt_tuple_new_arg(what, T) else: return __internal__.opt_ref_new_arg(what, T) + def __bool__(self) -> bool: if isinstance(T, ByVal): return __internal__.opt_tuple_bool(self, T) else: return __internal__.opt_ref_bool(self, T) + def __invert__(self) -> T: if isinstance(T, ByVal): return __internal__.opt_tuple_invert(self, T) else: return __internal__.opt_ref_invert(self, T) + def __str__(self) -> str: return 'None' if not self else str(~self) + def __repr__(self) -> str: return 'None' if not self else (~self).__repr__() + def __is_optional__(self, other: Optional[T]): if (not self) or (not other): return (not self) and (not other) return self.__invert__() is other.__invert__() + optional = Optional def unwrap[T](opt: Optional[T]) -> T: diff --git a/stdlib/internal/types/ptr.codon b/stdlib/internal/types/ptr.codon index 0dfc8dc6..defdfb37 100644 --- a/stdlib/internal/types/ptr.codon +++ b/stdlib/internal/types/ptr.codon @@ -4,53 +4,63 @@ def seq_str_ptr(a: Ptr[byte]) -> str: pass @extend class Ptr: - @__internal__ - def __new__(sz: int) -> Ptr[T]: - pass @pure @llvm def __new__() -> Ptr[T]: ret {=T}* null + + @__internal__ + def __new__(sz: int) -> Ptr[T]: + pass + + @pure + @llvm + def __new__(other: Ptr[T]) -> Ptr[T]: + ret {=T}* %other + @pure @llvm def __new__(other: Ptr[byte]) -> Ptr[T]: %0 = bitcast i8* %other to {=T}* ret {=T}* %0 - @pure - @llvm - def __new__(other: Ptr[T]) -> Ptr[T]: - ret {=T}* %other + @pure @llvm def __int__(self) -> int: %0 = ptrtoint {=T}* %self to i64 ret i64 %0 + @pure @llvm def __copy__(self) -> Ptr[T]: ret {=T}* %self + @pure @llvm def __bool__(self) -> bool: %0 = icmp ne {=T}* %self, null %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __getitem__(self, index: int) -> T: %0 = getelementptr {=T}, {=T}* %self, i64 %index %1 = load {=T}, {=T}* %0 ret {=T} %1 + @llvm def __setitem__(self, index: int, what: T) -> void: %0 = getelementptr {=T}, {=T}* %self, i64 %index store {=T} %what, {=T}* %0 ret void + @pure @llvm def __add__(self, other: int) -> Ptr[T]: %0 = getelementptr {=T}, {=T}* %self, i64 %other ret {=T}* %0 + @pure @llvm def __sub__(self, other: Ptr[T]) -> int: @@ -59,90 +69,105 @@ class Ptr: %2 = sub i64 %0, %1 %3 = sdiv exact i64 %2, ptrtoint ({=T}* getelementptr ({=T}, {=T}* null, i32 1) to i64) ret i64 %3 + @pure @llvm def __eq__(self, other: Ptr[T]) -> bool: %0 = icmp eq {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ne__(self, other: Ptr[T]) -> bool: %0 = icmp ne {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __lt__(self, other: Ptr[T]) -> bool: %0 = icmp slt {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __gt__(self, other: Ptr[T]) -> bool: %0 = icmp sgt {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __le__(self, other: Ptr[T]) -> bool: %0 = icmp sle {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @pure @llvm def __ge__(self, other: Ptr[T]) -> bool: %0 = icmp sge {=T}* %self, %other %1 = zext i1 %0 to i8 ret i8 %1 + @llvm def __prefetch_r0__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 0, i32 0, i32 1) ret void + @llvm def __prefetch_r1__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 0, i32 1, i32 1) ret void + @llvm def __prefetch_r2__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 0, i32 2, i32 1) ret void + @llvm def __prefetch_r3__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 0, i32 3, i32 1) ret void + @llvm def __prefetch_w0__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 1, i32 0, i32 1) ret void + @llvm def __prefetch_w1__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 1, i32 1, i32 1) ret void + @llvm def __prefetch_w2__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 1, i32 2, i32 1) ret void + @llvm def __prefetch_w3__(self) -> void: declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32) %0 = bitcast {=T}* %self to i8* call void @llvm.prefetch(i8* %0, i32 1, i32 3, i32 1) ret void + @pure @llvm def as_byte(self) -> Ptr[byte]: diff --git a/stdlib/internal/types/str.codon b/stdlib/internal/types/str.codon index 229eca1c..5cd02b67 100644 --- a/stdlib/internal/types/str.codon +++ b/stdlib/internal/types/str.codon @@ -7,41 +7,52 @@ class str: @__internal__ def __new__(l: int, p: Ptr[byte]) -> str: pass + def __new__(p: Ptr[byte], l: int) -> str: return str(l, p) + def __new__() -> str: return str(Ptr[byte](), 0) - def __new__[T](what: T) -> str: # lowest priority! + + def __new__(what) -> str: if hasattr(what, "__str__"): return what.__str__() else: return what.__repr__() + def __str__(what: str) -> str: return what + def __len__(self) -> int: return self.len + def __bool__(self) -> bool: return self.len != 0 + def __copy__(self) -> str: n = self.len p = cobj(n) str.memcpy(p, self.ptr, n) return str(p, n) + @llvm def memcpy(dest: Ptr[byte], src: Ptr[byte], len: int) -> void: declare void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile) call void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 0, i1 false) ret void + @llvm def memmove(dest: Ptr[byte], src: Ptr[byte], len: int) -> void: declare void @llvm.memmove.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile) call void @llvm.memmove.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 0, i1 false) ret void + @llvm def memset(dest: Ptr[byte], val: byte, len: int) -> void: declare void @llvm.memset.p0i8.i64(i8* %dest, i8 %val, i64 %len, i32 %align, i1 %isvolatile) call void @llvm.memset.p0i8.i64(i8* %dest, i8 %val, i64 %len, i32 0, i1 false) ret void + def __add__(self, other: str) -> str: len1 = self.len len2 = other.len @@ -50,17 +61,20 @@ class str: str.memcpy(p, self.ptr, len1) str.memcpy(p + len1, other.ptr, len2) return str(p, len3) + def c_str(self): n = self.__len__() p = cobj(n + 1) str.memcpy(p, self.ptr, n) p[n] = byte(0) return p + def from_ptr(t: cobj) -> str: n = strlen(t) p = Ptr[byte](n) str.memcpy(p, t, n) return str(p, n) + def __eq__(self, other: str): if self.len != other.len: return False @@ -70,10 +84,13 @@ class str: return False i += 1 return True + def __match__(self, other: str): return self.__eq__(other) + def __ne__(self, other: str): return not self.__eq__(other) + def cat(*args): total = 0 if staticlen(args) == 1 and hasattr(args[0], "__iter__") and hasattr(args[0], "__len__"): diff --git a/stdlib/statistics.codon b/stdlib/statistics.codon index d11622d3..0d3a440a 100644 --- a/stdlib/statistics.codon +++ b/stdlib/statistics.codon @@ -394,22 +394,10 @@ class NormalDist: self._mu = mu self._sigma = sigma - def __init__(self, mu: float, sigma: float): + def __init__(self, mu, sigma): self._init(float(mu), float(sigma)) - def __init__(self, mu: int, sigma: int): - self._init(float(mu), float(sigma)) - - def __init__(self, mu: float, sigma: int): - self._init(float(mu), float(sigma)) - - def __init__(self, mu: int, sigma: float): - self._init(float(mu), float(sigma)) - - def __init__(self, mu: float): - self._init(mu, 1.0) - - def __init__(self, mu: int): + def __init__(self, mu): self._init(float(mu), 1.0) def __init__(self):