1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Select the last matching overload by default (remove scoring logic); Add dispatch stubs for partial overload support

This commit is contained in:
Ibrahim Numanagić 2021-12-11 11:19:14 -08:00
parent 721531c409
commit ab14cf9fc7
20 changed files with 526 additions and 344 deletions

View File

@ -63,7 +63,7 @@ Cache::findMethod(types::ClassType *typ, const std::string &member,
seqassert(e->type, "not a class");
int oldAge = typeCtx->age;
typeCtx->age = 99999;
auto f = typeCtx->findBestMethod(e.get(), member, args);
auto f = TypecheckVisitor(typeCtx).findBestMethod(e.get(), member, args);
typeCtx->age = oldAge;
return f;
}

View File

@ -61,7 +61,8 @@ std::string SimplifyContext::getBase() const {
}
std::string SimplifyContext::generateCanonicalName(const std::string &name,
bool includeBase) const {
bool includeBase,
bool zeroId) const {
std::string newName = name;
if (includeBase && name.find('.') == std::string::npos) {
std::string base = getBase();
@ -74,7 +75,7 @@ std::string SimplifyContext::generateCanonicalName(const std::string &name,
newName = (base.empty() ? "" : (base + ".")) + newName;
}
auto num = cache->identifierCount[newName]++;
newName = num ? format("{}.{}", newName, num) : newName;
newName = num || zeroId ? format("{}.{}", newName, num) : newName;
if (newName != name)
cache->identifierCount[newName]++;
cache->reverseIdentifierLookup[newName] = name;

View File

@ -113,8 +113,8 @@ public:
void dump() override { dump(0); }
/// Generate a unique identifier (name) for a given string.
std::string generateCanonicalName(const std::string &name,
bool includeBase = false) const;
std::string generateCanonicalName(const std::string &name, bool includeBase = false,
bool zeroId = false) const;
bool inFunction() const { return getLevel() && !bases.back().isType(); }
bool inClass() const { return getLevel() && bases.back().isType(); }

View File

@ -472,8 +472,25 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
return;
}
auto canonicalName = ctx->generateCanonicalName(stmt->name, true);
bool isClassMember = ctx->inClass();
if (isClassMember && !endswith(stmt->name, ".dispatch") &&
ctx->cache->classes[ctx->bases.back().name].methods[stmt->name].empty()) {
transform(
N<FunctionStmt>(stmt->name + ".dispatch", nullptr,
std::vector<Param>{Param("*args")},
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
N<DotExpr>(N<IdExpr>(ctx->bases.back().name), stmt->name),
N<StarExpr>(N<IdExpr>("args")))))));
}
auto func_name = stmt->name;
if (endswith(stmt->name, ".dispatch"))
func_name = func_name.substr(0, func_name.size() - 9);
auto canonicalName = ctx->generateCanonicalName(
func_name, true, isClassMember && !endswith(stmt->name, ".dispatch"));
if (endswith(stmt->name, ".dispatch")) {
canonicalName += ".dispatch";
ctx->cache->reverseIdentifierLookup[canonicalName] = func_name;
}
bool isEnclosedFunc = ctx->inFunction();
if (attr.has(Attr::ForceRealize) && (ctx->getLevel() || isClassMember))
@ -483,7 +500,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
ctx->bases = std::vector<SimplifyContext::Base>();
if (!isClassMember)
// Class members are added to class' method table
ctx->add(SimplifyItem::Func, stmt->name, canonicalName, ctx->isToplevel());
ctx->add(SimplifyItem::Func, func_name, canonicalName, ctx->isToplevel());
if (isClassMember)
ctx->bases.push_back(oldBases[0]);
ctx->bases.emplace_back(SimplifyContext::Base{canonicalName}); // Add new base...
@ -602,7 +619,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
// ... set the enclosing class name...
attr.parentClass = ctx->bases.back().name;
// ... add the method to class' method list ...
ctx->cache->classes[ctx->bases.back().name].methods[stmt->name].push_back(
ctx->cache->classes[ctx->bases.back().name].methods[func_name].push_back(
{canonicalName, nullptr, ctx->cache->age});
// ... and if the function references outer class variable (by definition a
// generic), mark it as not static as it needs fully instantiated class to be
@ -637,21 +654,21 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
ExprPtr finalExpr;
if (!captures.empty())
finalExpr = N<CallExpr>(N<IdExpr>(stmt->name), partialArgs);
finalExpr = N<CallExpr>(N<IdExpr>(func_name), partialArgs);
if (isClassMember && decorators.size())
error("decorators cannot be applied to class methods");
for (int j = int(decorators.size()) - 1; j >= 0; j--) {
if (auto c = const_cast<CallExpr *>(decorators[j]->getCall())) {
c->args.emplace(c->args.begin(),
CallExpr::Arg{"", finalExpr ? finalExpr : N<IdExpr>(stmt->name)});
CallExpr::Arg{"", finalExpr ? finalExpr : N<IdExpr>(func_name)});
finalExpr = N<CallExpr>(c->expr, c->args);
} else {
finalExpr =
N<CallExpr>(decorators[j], finalExpr ? finalExpr : N<IdExpr>(stmt->name));
N<CallExpr>(decorators[j], finalExpr ? finalExpr : N<IdExpr>(func_name));
}
}
if (finalExpr)
resultStmt = transform(N<AssignStmt>(N<IdExpr>(stmt->name), finalExpr));
resultStmt = transform(N<AssignStmt>(N<IdExpr>(func_name), finalExpr));
}
void SimplifyVisitor::visit(ClassStmt *stmt) {
@ -941,7 +958,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
continue;
auto subs = substitutions[ai];
auto newName = ctx->generateCanonicalName(
ctx->cache->reverseIdentifierLookup[f->name], true);
ctx->cache->reverseIdentifierLookup[f->name], true, true);
auto nf = std::dynamic_pointer_cast<FunctionStmt>(replace(sp, subs));
subs[nf->name] = N<IdExpr>(newName);
nf->name = newName;

View File

@ -33,16 +33,21 @@ StmtPtr TypecheckVisitor::apply(Cache *cache, StmtPtr stmts) {
return std::move(infer.second);
}
TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b) {
TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b, bool undoOnSuccess) {
if (!a)
return a = b;
seqassert(b, "rhs is nullptr");
types::Type::Unification undo;
if (a->unify(b.get(), &undo) >= 0)
if (a->unify(b.get(), &undo) >= 0) {
if (undoOnSuccess)
undo.undo();
return a;
undo.undo();
} else {
undo.undo();
}
// LOG("{} / {}", a->debugString(true), b->debugString(true));
a->unify(b.get(), &undo);
if (!undoOnSuccess)
a->unify(b.get(), &undo);
error("cannot unify {} and {}", a->toString(), b->toString());
return nullptr;
}

View File

@ -283,9 +283,19 @@ private:
void generateFnCall(int n);
/// Make an empty partial call fn(...) for a function fn.
ExprPtr partializeFunction(ExprPtr expr);
/// Picks the best method of a given expression that matches the given argument
/// types. Prefers methods whose signatures are closer to the given arguments:
/// e.g. foo(int) will match (int) better that a foo(T).
/// Also takes care of the Optional arguments.
/// If multiple equally good methods are found, return the first one.
/// Return nullptr if no methods were found.
types::FuncTypePtr
findBestMethod(const Expr *expr, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args);
private:
types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b);
types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b,
bool undoOnSuccess = false);
types::TypePtr realizeType(types::ClassType *typ);
types::TypePtr realizeFunc(types::FuncType *typ);
std::pair<int, StmtPtr> inferTypes(StmtPtr stmt, bool keepLast,
@ -293,7 +303,7 @@ private:
codon::ir::types::Type *getLLVMType(const types::ClassType *t);
bool wrapExpr(ExprPtr &expr, types::TypePtr expectedType,
const types::FuncTypePtr &callee);
const types::FuncTypePtr &callee, bool undoOnSuccess = false);
int64_t translateIndex(int64_t idx, int64_t len, bool clamp = false);
int64_t sliceAdjustIndices(int64_t length, int64_t *start, int64_t *stop,
int64_t step);

View File

@ -141,10 +141,12 @@ TypeContext::findMethod(const std::string &typeName, const std::string &method)
if (m != cache->classes.end()) {
auto t = m->second.methods.find(method);
if (t != m->second.methods.end()) {
seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"),
"first method is not dispatch");
std::unordered_map<std::string, int> signatureLoci;
std::vector<types::FuncTypePtr> vv;
for (auto &mt : t->second) {
// LOG("{}::{} @ {} vs. {}", typeName, method, age, mt.age);
for (int mti = 1; mti < t->second.size(); mti++) {
auto &mt = t->second[mti];
if (mt.age <= age) {
auto sig = cache->functions[mt.name].ast->signature();
auto it = signatureLoci.find(sig);
@ -177,110 +179,6 @@ types::TypePtr TypeContext::findMember(const std::string &typeName,
return nullptr;
}
types::FuncTypePtr TypeContext::findBestMethod(
const Expr *expr, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args, bool checkSingle) {
auto typ = expr->getType()->getClass();
seqassert(typ, "not a class");
auto methods = findMethod(typ->name, member);
if (methods.empty())
return nullptr;
if (methods.size() == 1 && !checkSingle) // methods is not overloaded
return methods[0];
// Calculate the unification score for each available methods and pick the one with
// highest score.
std::vector<std::pair<int, int>> scores;
for (int mi = 0; mi < methods.size(); mi++) {
auto method = instantiate(expr, methods[mi], typ.get(), false)->getFunc();
std::vector<types::TypePtr> reordered;
std::vector<CallExpr::Arg> callArgs;
for (auto &a : args) {
callArgs.push_back({a.first, std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a.second);
}
auto score = reorderNamedArgs(
method.get(), callArgs,
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
for (int si = 0; si < slots.size(); si++) {
// Ignore *args, *kwargs and default arguments
reordered.emplace_back(si == s || si == k || slots[si].size() != 1
? nullptr
: args[slots[si][0]].second);
}
return 0;
},
[](const std::string &) { return -1; });
if (score == -1)
continue;
// Scoring system for each argument:
// Generics, traits and default arguments get a score of zero (lowest priority).
// Optional unwrap gets the score of 1.
// Optional wrap gets the score of 2.
// Successful unification gets the score of 3 (highest priority).
for (int ai = 0, mi = 1, gi = 0; ai < reordered.size(); ai++) {
auto argType = reordered[ai];
if (!argType)
continue;
auto expectedType = method->ast->args[ai].generic ? method->generics[gi++].type
: method->args[mi++];
auto expectedClass = expectedType->getClass();
// Ignore traits, *args/**kwargs and default arguments.
if (expectedClass && expectedClass->name == "Generator")
continue;
// LOG("<~> {} {}", argType->toString(), expectedType->toString());
auto argClass = argType->getClass();
types::Type::Unification undo;
int u = argType->unify(expectedType.get(), &undo);
undo.undo();
if (u >= 0) {
score += u + 3;
continue;
}
if (!method->ast->args[ai].generic) {
// Unification failed: maybe we need to wrap an argument?
if (expectedClass && expectedClass->name == TYPE_OPTIONAL && argClass &&
argClass->name != expectedClass->name) {
u = argType->unify(expectedClass->generics[0].type.get(), &undo);
undo.undo();
if (u >= 0) {
score += u + 2;
continue;
}
}
// ... or unwrap it (less ideal)?
if (argClass && argClass->name == TYPE_OPTIONAL && expectedClass &&
argClass->name != expectedClass->name) {
u = argClass->generics[0].type->unify(expectedType.get(), &undo);
undo.undo();
if (u >= 0) {
score += u;
continue;
}
}
}
// This method cannot be selected, ignore it.
score = -1;
break;
}
// LOG("{} {} / {}", typ->toString(), method->toString(), score);
if (score >= 0)
scores.emplace_back(std::make_pair(score, mi));
}
if (scores.empty())
return nullptr;
// Get the best score.
sort(scores.begin(), scores.end(), std::greater<>());
// LOG("Method: {}", methods[scores[0].second]->toString());
// std::string x;
// for (auto &a : args)
// x += format("{}{},", a.first.empty() ? "" : a.first + ": ",
// a.second->toString());
// LOG(" {} :: {} ( {} )", typ->toString(), member, x);
return methods[scores[0].second];
}
int TypeContext::reorderNamedArgs(types::FuncType *func,
const std::vector<CallExpr::Arg> &args,
ReorderDoneFn onDone, ReorderErrorFn onError,

View File

@ -127,16 +127,6 @@ public:
types::TypePtr findMember(const std::string &typeName,
const std::string &member) const;
/// Picks the best method of a given expression that matches the given argument
/// types. Prefers methods whose signatures are closer to the given arguments:
/// e.g. foo(int) will match (int) better that a foo(T).
/// Also takes care of the Optional arguments.
/// If multiple equally good methods are found, return the first one.
/// Return nullptr if no methods were found.
types::FuncTypePtr
findBestMethod(const Expr *expr, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args,
bool checkSingle = false);
typedef std::function<int(int, int, const std::vector<std::vector<int>> &, bool)>
ReorderDoneFn;

View File

@ -683,8 +683,8 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic,
if (isAtomic) {
auto ptrlt =
ctx->instantiateGeneric(expr->lexpr.get(), ctx->findInternal("Ptr"), {lt});
method = ctx->findBestMethod(expr->lexpr.get(), format("__atomic_{}__", magic),
{{"", ptrlt}, {"", rt}});
method = findBestMethod(expr->lexpr.get(), format("__atomic_{}__", magic),
{{"", ptrlt}, {"", rt}});
if (method) {
expr->lexpr = N<PtrExpr>(expr->lexpr);
if (noReturn)
@ -693,19 +693,19 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic,
}
// Check if lt.__iop__(lt, rt) exists.
if (!method && expr->inPlace) {
method = ctx->findBestMethod(expr->lexpr.get(), format("__i{}__", magic),
{{"", lt}, {"", rt}});
method = findBestMethod(expr->lexpr.get(), format("__i{}__", magic),
{{"", lt}, {"", rt}});
if (method && noReturn)
*noReturn = true;
}
// Check if lt.__op__(lt, rt) exists.
if (!method)
method = ctx->findBestMethod(expr->lexpr.get(), format("__{}__", magic),
{{"", lt}, {"", rt}});
method = findBestMethod(expr->lexpr.get(), format("__{}__", magic),
{{"", lt}, {"", rt}});
// Check if rt.__rop__(rt, lt) exists.
if (!method) {
method = ctx->findBestMethod(expr->rexpr.get(), format("__r{}__", magic),
{{"", rt}, {"", lt}});
method = findBestMethod(expr->rexpr.get(), format("__r{}__", magic),
{{"", rt}, {"", lt}});
if (method)
swap(expr->lexpr, expr->rexpr);
}
@ -873,8 +873,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
argTypes.emplace_back(make_pair("", typ)); // self variable
for (const auto &a : *args)
argTypes.emplace_back(make_pair(a.name, a.value->getType()));
if (auto bestMethod =
ctx->findBestMethod(expr->expr.get(), expr->member, argTypes)) {
if (auto bestMethod = findBestMethod(expr->expr.get(), expr->member, argTypes)) {
ExprPtr e = N<IdExpr>(bestMethod->ast->name);
auto t = ctx->instantiate(expr, bestMethod, typ.get());
unify(e->type, t);
@ -906,7 +905,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
methodArgs.emplace_back(make_pair("", typ));
for (auto i = 1; i < oldType->generics.size(); i++)
methodArgs.emplace_back(make_pair("", oldType->generics[i].type));
bestMethod = ctx->findBestMethod(expr->expr.get(), expr->member, methodArgs);
bestMethod = findBestMethod(expr->expr.get(), expr->member, methodArgs);
if (!bestMethod) {
// Print a nice error message.
std::vector<std::string> nice;
@ -916,9 +915,11 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
typ->toString(), join(nice, ", "));
}
} else {
// HACK: if we still have multiple valid methods, we just use the first one.
// TODO: handle this better (maybe hold these types until they can be selected?)
bestMethod = methods[0];
auto m = ctx->cache->classes.find(typ->name);
auto t = m->second.methods.find(expr->member);
seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"),
"first method is not dispatch");
bestMethod = t->second[0].type;
}
// Case 7: only one valid method remaining. Check if this is a class method or an
@ -1004,6 +1005,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
ai--;
} else {
// Case 3: Normal argument
// LOG("-> {}", expr->args[ai].value->toString());
expr->args[ai].value = transform(expr->args[ai].value, true);
// Unbound inType might become a generator that will need to be extracted, so
// don't unify it yet.
@ -1365,8 +1367,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
bool exists = !ctx->findMethod(typ->getClass()->name, member).empty() ||
ctx->findMember(typ->getClass()->name, member);
if (exists && args.size() > 1)
exists &=
ctx->findBestMethod(expr->args[0].value.get(), member, args, true) != nullptr;
exists &= findBestMethod(expr->args[0].value.get(), member, args) != nullptr;
return {true, transform(N<BoolExpr>(exists))};
} else if (val == "compile_error") {
expr->args[0].value = transform(expr->args[0].value);
@ -1609,8 +1610,72 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
return call;
}
types::FuncTypePtr TypecheckVisitor::findBestMethod(
const Expr *expr, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args) {
auto typ = expr->getType()->getClass();
seqassert(typ, "not a class");
// Pick the last method that accepts the given arguments.
auto methods = ctx->findMethod(typ->name, member);
// if (methods.size() == 1)
// return methods[0];
types::FuncTypePtr method = nullptr;
for (int mi = int(methods.size()) - 1; mi >= 0; mi--) {
auto m = ctx->instantiate(expr, methods[mi], typ.get(), false)->getFunc();
std::vector<types::TypePtr> reordered;
std::vector<CallExpr::Arg> callArgs;
for (auto &a : args) {
callArgs.push_back({a.first, std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a.second);
}
auto score = ctx->reorderNamedArgs(
m.get(), callArgs,
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
for (int si = 0; si < slots.size(); si++) {
if (m->ast->args[si].generic) {
// Ignore type arguments
} else if (si == s || si == k || slots[si].size() != 1) {
// Ignore *args, *kwargs and default arguments
reordered.emplace_back(nullptr);
} else {
reordered.emplace_back(args[slots[si][0]].second);
}
}
return 0;
},
[](const std::string &) { return -1; });
for (int ai = 0, mi = 1, gi = 0; score != -1 && ai < reordered.size(); ai++) {
auto argType = reordered[ai];
if (!argType)
continue;
auto expectTyp =
m->ast->args[ai].generic ? m->generics[gi++].type : m->args[mi++];
try {
ExprPtr dummy = std::make_shared<IdExpr>("");
dummy->type = argType;
dummy->done = true;
wrapExpr(dummy, expectTyp, m, /*undoOnSuccess*/ true);
} catch (const exc::ParserException &) {
score = -1;
}
}
if (score != -1) {
// std::vector<std::string> ar;
// for (auto &a: args) {
// if (a.first.empty()) ar.push_back(a.second->toString());
// else ar.push_back(format("{}: {}", a.first, a.second->toString()));
// }
// LOG("- {} vs {}", m->toString(), join(ar, "; "));
method = methods[mi];
break;
}
}
return method;
}
bool TypecheckVisitor::wrapExpr(ExprPtr &expr, TypePtr expectedType,
const FuncTypePtr &callee) {
const FuncTypePtr &callee, bool undoOnSuccess) {
auto expectedClass = expectedType->getClass();
auto exprClass = expr->getType()->getClass();
if (callee && expr->isType())
@ -1637,7 +1702,7 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, TypePtr expectedType,
// Case 7: wrap raw Seq functions into Partial(...) call for easy realization.
expr = partializeFunction(expr);
}
unify(expr->type, expectedType);
unify(expr->type, expectedType, undoOnSuccess);
return true;
}

View File

@ -153,9 +153,9 @@ void TypecheckVisitor::visit(UpdateStmt *stmt) {
ctx->instantiateGeneric(stmt->lhs.get(), ctx->findInternal("Ptr"), {lhsClass});
c->args[1].value = transform(c->args[1].value);
auto rhsTyp = c->args[1].value->getType()->getClass();
if (auto method = ctx->findBestMethod(
stmt->lhs.get(), format("__atomic_{}__", c->expr->getId()->value),
{{"", ptrTyp}, {"", rhsTyp}})) {
if (auto method = findBestMethod(stmt->lhs.get(),
format("__atomic_{}__", c->expr->getId()->value),
{{"", ptrTyp}, {"", rhsTyp}})) {
resultStmt = transform(N<ExprStmt>(N<CallExpr>(
N<IdExpr>(method->ast->name), N<PtrExpr>(stmt->lhs), c->args[1].value)));
return;
@ -168,8 +168,8 @@ void TypecheckVisitor::visit(UpdateStmt *stmt) {
if (stmt->isAtomic && lhsClass && rhsClass) {
auto ptrType =
ctx->instantiateGeneric(stmt->lhs.get(), ctx->findInternal("Ptr"), {lhsClass});
if (auto m = ctx->findBestMethod(stmt->lhs.get(), "__atomic_xchg__",
{{"", ptrType}, {"", rhsClass}})) {
if (auto m = findBestMethod(stmt->lhs.get(), "__atomic_xchg__",
{{"", ptrType}, {"", rhsClass}})) {
resultStmt = transform(N<ExprStmt>(
N<CallExpr>(N<IdExpr>(m->ast->name), N<PtrExpr>(stmt->lhs), stmt->rhs)));
return;

View File

@ -4,7 +4,7 @@ from internal.attributes import commutative, associative
class bool:
def __new__() -> bool:
return False
def __new__[T](what: T) -> bool: # lowest priority!
def __new__(what) -> bool:
return what.__bool__()
def __repr__(self) -> str:
return "True" if self else "False"

View File

@ -2,9 +2,9 @@ import internal.gc as gc
@extend
class List:
def __init__(self, arr: Array[T], len: int):
self.arr = arr
self.len = len
def __init__(self):
self.arr = Array[T](10)
self.len = 0
def __init__(self, it: Generator[T]):
self.arr = Array[T](10)
@ -12,27 +12,27 @@ class List:
for i in it:
self.append(i)
def __init__(self, capacity: int):
self.arr = Array[T](capacity)
self.len = 0
def __init__(self):
self.arr = Array[T](10)
self.len = 0
def __init__(self, other: List[T]):
self.arr = Array[T](other.len)
self.len = 0
for i in other:
self.append(i)
# Dummy __init__ used for list comprehension optimization
def __init__(self, capacity: int):
self.arr = Array[T](capacity)
self.len = 0
def __init__(self, dummy: bool, other):
"""Dummy __init__ used for list comprehension optimization"""
if hasattr(other, '__len__'):
self.__init__(other.__len__())
else:
self.__init__()
def __init__(self, arr: Array[T], len: int):
self.arr = arr
self.len = len
def __len__(self):
return self.len

View File

@ -6,18 +6,12 @@ class complex:
def __new__():
return complex(0.0, 0.0)
def __new__(real: int, imag: int):
return complex(float(real), float(imag))
def __new__(real: float, imag: int):
return complex(real, float(imag))
def __new__(real: int, imag: float):
return complex(float(real), imag)
def __new__(other):
return other.__complex__()
def __new__(real, imag):
return complex(float(real), float(imag))
def __complex__(self):
return self
@ -42,6 +36,42 @@ class complex:
def __hash__(self):
return self.real.__hash__() + self.imag.__hash__()*1000003
def __add__(self, other):
return self + complex(other)
def __sub__(self, other):
return self - complex(other)
def __mul__(self, other):
return self * complex(other)
def __truediv__(self, other):
return self / complex(other)
def __eq__(self, other):
return self == complex(other)
def __ne__(self, other):
return self != complex(other)
def __pow__(self, other):
return self ** complex(other)
def __radd__(self, other):
return complex(other) + self
def __rsub__(self, other):
return complex(other) - self
def __rmul__(self, other):
return complex(other) * self
def __rtruediv__(self, other):
return complex(other) / self
def __rpow__(self, other):
return complex(other) ** self
def __add__(self, other: complex):
return complex(self.real + other.real, self.imag + other.imag)
@ -160,42 +190,6 @@ class complex:
phase += other.imag * log(vabs)
return complex(len * cos(phase), len * sin(phase))
def __add__(self, other):
return self + complex(other)
def __sub__(self, other):
return self - complex(other)
def __mul__(self, other):
return self * complex(other)
def __truediv__(self, other):
return self / complex(other)
def __eq__(self, other):
return self == complex(other)
def __ne__(self, other):
return self != complex(other)
def __pow__(self, other):
return self ** complex(other)
def __radd__(self, other):
return complex(other) + self
def __rsub__(self, other):
return complex(other) - self
def __rmul__(self, other):
return complex(other) * self
def __rtruediv__(self, other):
return complex(other) / self
def __rpow__(self, other):
return complex(other) ** self
def __repr__(self):
@pure
@llvm

View File

@ -10,80 +10,105 @@ def seq_str_float(a: float) -> str: pass
class float:
def __new__() -> float:
return 0.0
def __new__[T](what: T):
def __new__(what):
return what.__float__()
def __new__(s: str) -> float:
from C import strtod(cobj, Ptr[cobj]) -> float
buf = __array__[byte](32)
n = s.__len__()
need_dyn_alloc = (n >= buf.__len__())
p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
str.memcpy(p, s.ptr, n)
p[n] = byte(0)
end = cobj()
result = strtod(p, __ptr__(end))
if need_dyn_alloc:
free(p)
if end != p + n:
raise ValueError("could not convert string to float: " + s)
return result
def __repr__(self) -> str:
s = seq_str_float(self)
return s if s != "-nan" else "nan"
def __copy__(self) -> float:
return self
def __deepcopy__(self) -> float:
return self
@pure
@llvm
def __int__(self) -> int:
%0 = fptosi double %self to i64
ret i64 %0
def __float__(self):
return self
@pure
@llvm
def __bool__(self) -> bool:
%0 = fcmp one double %self, 0.000000e+00
%1 = zext i1 %0 to i8
ret i8 %1
def __complex__(self):
return complex(self, 0.0)
def __pos__(self) -> float:
return self
@pure
@llvm
def __neg__(self) -> float:
%0 = fneg double %self
ret double %0
@pure
@commutative
@llvm
def __add__(a: float, b: float) -> float:
%tmp = fadd double %a, %b
ret double %tmp
@commutative
def __add__(self, other: int) -> float:
return self.__add__(float(other))
@pure
@llvm
def __sub__(a: float, b: float) -> float:
%tmp = fsub double %a, %b
ret double %tmp
def __sub__(self, other: int) -> float:
return self.__sub__(float(other))
@pure
@commutative
@llvm
def __mul__(a: float, b: float) -> float:
%tmp = fmul double %a, %b
ret double %tmp
@commutative
def __mul__(self, other: int) -> float:
return self.__mul__(float(other))
def __floordiv__(self, other: float) -> float:
return self.__truediv__(other).__floor__()
def __floordiv__(self, other: int) -> float:
return self.__floordiv__(float(other))
@pure
@llvm
def __truediv__(a: float, b: float) -> float:
%tmp = fdiv double %a, %b
ret double %tmp
def __truediv__(self, other: int) -> float:
return self.__truediv__(float(other))
@pure
@llvm
def __mod__(a: float, b: float) -> float:
%tmp = frem double %a, %b
ret double %tmp
def __mod__(self, other: int) -> float:
return self.__mod__(float(other))
def __divmod__(self, other: float):
mod = self % other
div = (self - mod) / other
@ -103,16 +128,14 @@ class float:
floordiv = (0.0).copysign(self / other)
return (floordiv, mod)
def __divmod__(self, other: int):
return self.__divmod__(float(other))
@pure
@llvm
def __eq__(a: float, b: float) -> bool:
%tmp = fcmp oeq double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __eq__(self, other: int) -> bool:
return self.__eq__(float(other))
@pure
@llvm
def __ne__(a: float, b: float) -> bool:
@ -120,174 +143,190 @@ class float:
%tmp = fcmp one double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __ne__(self, other: int) -> bool:
return self.__ne__(float(other))
@pure
@llvm
def __lt__(a: float, b: float) -> bool:
%tmp = fcmp olt double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __lt__(self, other: int) -> bool:
return self.__lt__(float(other))
@pure
@llvm
def __gt__(a: float, b: float) -> bool:
%tmp = fcmp ogt double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __gt__(self, other: int) -> bool:
return self.__gt__(float(other))
@pure
@llvm
def __le__(a: float, b: float) -> bool:
%tmp = fcmp ole double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __le__(self, other: int) -> bool:
return self.__le__(float(other))
@pure
@llvm
def __ge__(a: float, b: float) -> bool:
%tmp = fcmp oge double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __ge__(self, other: int) -> bool:
return self.__ge__(float(other))
@pure
@llvm
def sqrt(a: float) -> float:
declare double @llvm.sqrt.f64(double %a)
%tmp = call double @llvm.sqrt.f64(double %a)
ret double %tmp
@pure
@llvm
def sin(a: float) -> float:
declare double @llvm.sin.f64(double %a)
%tmp = call double @llvm.sin.f64(double %a)
ret double %tmp
@pure
@llvm
def cos(a: float) -> float:
declare double @llvm.cos.f64(double %a)
%tmp = call double @llvm.cos.f64(double %a)
ret double %tmp
@pure
@llvm
def exp(a: float) -> float:
declare double @llvm.exp.f64(double %a)
%tmp = call double @llvm.exp.f64(double %a)
ret double %tmp
@pure
@llvm
def exp2(a: float) -> float:
declare double @llvm.exp2.f64(double %a)
%tmp = call double @llvm.exp2.f64(double %a)
ret double %tmp
@pure
@llvm
def log(a: float) -> float:
declare double @llvm.log.f64(double %a)
%tmp = call double @llvm.log.f64(double %a)
ret double %tmp
@pure
@llvm
def log10(a: float) -> float:
declare double @llvm.log10.f64(double %a)
%tmp = call double @llvm.log10.f64(double %a)
ret double %tmp
@pure
@llvm
def log2(a: float) -> float:
declare double @llvm.log2.f64(double %a)
%tmp = call double @llvm.log2.f64(double %a)
ret double %tmp
@pure
@llvm
def __abs__(a: float) -> float:
declare double @llvm.fabs.f64(double %a)
%tmp = call double @llvm.fabs.f64(double %a)
ret double %tmp
@pure
@llvm
def __floor__(a: float) -> float:
declare double @llvm.floor.f64(double %a)
%tmp = call double @llvm.floor.f64(double %a)
ret double %tmp
@pure
@llvm
def __ceil__(a: float) -> float:
declare double @llvm.ceil.f64(double %a)
%tmp = call double @llvm.ceil.f64(double %a)
ret double %tmp
@pure
@llvm
def __trunc__(a: float) -> float:
declare double @llvm.trunc.f64(double %a)
%tmp = call double @llvm.trunc.f64(double %a)
ret double %tmp
@pure
@llvm
def rint(a: float) -> float:
declare double @llvm.rint.f64(double %a)
%tmp = call double @llvm.rint.f64(double %a)
ret double %tmp
@pure
@llvm
def nearbyint(a: float) -> float:
declare double @llvm.nearbyint.f64(double %a)
%tmp = call double @llvm.nearbyint.f64(double %a)
ret double %tmp
@pure
@llvm
def __round__(a: float) -> float:
declare double @llvm.round.f64(double %a)
%tmp = call double @llvm.round.f64(double %a)
ret double %tmp
@pure
@llvm
def __pow__(a: float, b: float) -> float:
declare double @llvm.pow.f64(double %a, double %b)
%tmp = call double @llvm.pow.f64(double %a, double %b)
ret double %tmp
def __pow__(self, other: int) -> float:
return self.__pow__(float(other))
@pure
@llvm
def min(a: float, b: float) -> float:
declare double @llvm.minnum.f64(double %a, double %b)
%tmp = call double @llvm.minnum.f64(double %a, double %b)
ret double %tmp
@pure
@llvm
def max(a: float, b: float) -> float:
declare double @llvm.maxnum.f64(double %a, double %b)
%tmp = call double @llvm.maxnum.f64(double %a, double %b)
ret double %tmp
@pure
@llvm
def copysign(a: float, b: float) -> float:
declare double @llvm.copysign.f64(double %a, double %b)
%tmp = call double @llvm.copysign.f64(double %a, double %b)
ret double %tmp
@pure
@llvm
def fma(a: float, b: float, c: float) -> float:
declare double @llvm.fma.f64(double %a, double %b, double %c)
%tmp = call double @llvm.fma.f64(double %a, double %b, double %c)
ret double %tmp
@llvm
def __atomic_xchg__(d: Ptr[float], b: float) -> void:
%tmp = atomicrmw xchg double* %d, double %b seq_cst
ret void
@llvm
def __atomic_add__(d: Ptr[float], b: float) -> float:
%tmp = atomicrmw fadd double* %d, double %b seq_cst
ret double %tmp
@llvm
def __atomic_sub__(d: Ptr[float], b: float) -> float:
%tmp = atomicrmw fsub double* %d, double %b seq_cst
ret double %tmp
def __hash__(self):
from C import frexp(float, Ptr[Int[32]]) -> float
@ -332,31 +371,13 @@ class float:
x = -2
return x
def __new__(s: str) -> float:
from C import strtod(cobj, Ptr[cobj]) -> float
buf = __array__[byte](32)
n = s.__len__()
need_dyn_alloc = (n >= buf.__len__())
p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
str.memcpy(p, s.ptr, n)
p[n] = byte(0)
end = cobj()
result = strtod(p, __ptr__(end))
if need_dyn_alloc:
free(p)
if end != p + n:
raise ValueError("could not convert string to float: " + s)
return result
def __match__(self, i: float):
return self == i
@property
def real(self):
return self
@property
def imag(self):
return 0.0

View File

@ -14,37 +14,56 @@ class int:
@llvm
def __new__() -> int:
ret i64 0
def __new__[T](what: T) -> int: # lowest priority!
def __new__(what) -> int:
return what.__int__()
def __new__(s: str) -> int:
return int._from_str(s, 10)
def __new__(s: str, base: int) -> int:
return int._from_str(s, base)
def __int__(self) -> int:
return self
@pure
@llvm
def __float__(self) -> float:
%tmp = sitofp i64 %self to double
ret double %tmp
def __complex__(self):
return complex(float(self), 0.0)
def __index__(self):
return self
def __repr__(self) -> str:
return seq_str_int(self)
def __copy__(self) -> int:
return self
def __deepcopy__(self) -> int:
return self
def __hash__(self) -> int:
return self
@pure
@llvm
def __bool__(self) -> bool:
%0 = icmp ne i64 %self, 0
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> int:
return self
def __neg__(self) -> int:
return 0 - self
@pure
@llvm
def __abs__(self) -> int:
@ -52,23 +71,19 @@ class int:
%1 = sub i64 0, %self
%2 = select i1 %0, i64 %self, i64 %1
ret i64 %2
@pure
@llvm
def __lshift__(self, other: int) -> int:
%0 = shl i64 %self, %other
ret i64 %0
@pure
@llvm
def __rshift__(self, other: int) -> int:
%0 = ashr i64 %self, %other
ret i64 %0
@pure
@commutative
@associative
@llvm
def __add__(self, b: int) -> int:
%tmp = add i64 %self, %b
ret i64 %tmp
@pure
@commutative
@llvm
@ -76,17 +91,36 @@ class int:
%0 = sitofp i64 %self to double
%1 = fadd double %0, %other
ret double %1
@pure
@commutative
@associative
@llvm
def __sub__(self, b: int) -> int:
%tmp = sub i64 %self, %b
def __add__(self, b: int) -> int:
%tmp = add i64 %self, %b
ret i64 %tmp
@pure
@llvm
def __sub__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fsub double %0, %other
ret double %1
@pure
@llvm
def __sub__(self, b: int) -> int:
%tmp = sub i64 %self, %b
ret i64 %tmp
@pure
@commutative
@llvm
def __mul__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fmul double %0, %other
ret double %1
@pure
@commutative
@associative
@ -95,18 +129,7 @@ class int:
def __mul__(self, b: int) -> int:
%tmp = mul i64 %self, %b
ret i64 %tmp
@pure
@commutative
@llvm
def __mul__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fmul double %0, %other
ret double %1
@pure
@llvm
def __floordiv__(self, b: int) -> int:
%tmp = sdiv i64 %self, %b
ret i64 %tmp
@pure
@llvm
def __floordiv__(self, other: float) -> float:
@ -115,6 +138,20 @@ class int:
%1 = fdiv double %0, %other
%2 = call double @llvm.floor.f64(double %1)
ret double %2
@pure
@llvm
def __floordiv__(self, b: int) -> int:
%tmp = sdiv i64 %self, %b
ret i64 %tmp
@pure
@llvm
def __truediv__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fdiv double %0, %other
ret double %1
@pure
@llvm
def __truediv__(self, other: int) -> float:
@ -122,23 +159,20 @@ class int:
%1 = sitofp i64 %other to double
%2 = fdiv double %0, %1
ret double %2
@pure
@llvm
def __truediv__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fdiv double %0, %other
ret double %1
@pure
@llvm
def __mod__(a: int, b: int) -> int:
%tmp = srem i64 %a, %b
ret i64 %tmp
@pure
@llvm
def __mod__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = frem double %0, %other
ret double %1
@pure
@llvm
def __mod__(a: int, b: int) -> int:
%tmp = srem i64 %a, %b
ret i64 %tmp
def __divmod__(self, other: int):
d = self // other
m = self - d*other
@ -146,11 +180,13 @@ class int:
m += other
d -= 1
return (d, m)
@pure
@llvm
def __invert__(a: int) -> int:
%tmp = xor i64 %a, -1
ret i64 %tmp
@pure
@commutative
@associative
@ -158,6 +194,7 @@ class int:
def __and__(a: int, b: int) -> int:
%tmp = and i64 %a, %b
ret i64 %tmp
@pure
@commutative
@associative
@ -165,6 +202,7 @@ class int:
def __or__(a: int, b: int) -> int:
%tmp = or i64 %a, %b
ret i64 %tmp
@pure
@commutative
@associative
@ -172,42 +210,42 @@ class int:
def __xor__(a: int, b: int) -> int:
%tmp = xor i64 %a, %b
ret i64 %tmp
@pure
@llvm
def __bitreverse__(a: int) -> int:
declare i64 @llvm.bitreverse.i64(i64 %a)
%tmp = call i64 @llvm.bitreverse.i64(i64 %a)
ret i64 %tmp
@pure
@llvm
def __bswap__(a: int) -> int:
declare i64 @llvm.bswap.i64(i64 %a)
%tmp = call i64 @llvm.bswap.i64(i64 %a)
ret i64 %tmp
@pure
@llvm
def __ctpop__(a: int) -> int:
declare i64 @llvm.ctpop.i64(i64 %a)
%tmp = call i64 @llvm.ctpop.i64(i64 %a)
ret i64 %tmp
@pure
@llvm
def __ctlz__(a: int) -> int:
declare i64 @llvm.ctlz.i64(i64 %a, i1 %is_zero_undef)
%tmp = call i64 @llvm.ctlz.i64(i64 %a, i1 false)
ret i64 %tmp
@pure
@llvm
def __cttz__(a: int) -> int:
declare i64 @llvm.cttz.i64(i64 %a, i1 %is_zero_undef)
%tmp = call i64 @llvm.cttz.i64(i64 %a, i1 false)
ret i64 %tmp
@pure
@llvm
def __eq__(a: int, b: int) -> bool:
%tmp = icmp eq i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __eq__(self, b: float) -> bool:
@ -215,12 +253,14 @@ class int:
%1 = fcmp oeq double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __ne__(a: int, b: int) -> bool:
%tmp = icmp ne i64 %a, %b
def __eq__(a: int, b: int) -> bool:
%tmp = icmp eq i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ne__(self, b: float) -> bool:
@ -228,12 +268,14 @@ class int:
%1 = fcmp one double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __lt__(a: int, b: int) -> bool:
%tmp = icmp slt i64 %a, %b
def __ne__(a: int, b: int) -> bool:
%tmp = icmp ne i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __lt__(self, b: float) -> bool:
@ -241,12 +283,14 @@ class int:
%1 = fcmp olt double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __gt__(a: int, b: int) -> bool:
%tmp = icmp sgt i64 %a, %b
def __lt__(a: int, b: int) -> bool:
%tmp = icmp slt i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __gt__(self, b: float) -> bool:
@ -254,12 +298,14 @@ class int:
%1 = fcmp ogt double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __le__(a: int, b: int) -> bool:
%tmp = icmp sle i64 %a, %b
def __gt__(a: int, b: int) -> bool:
%tmp = icmp sgt i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __le__(self, b: float) -> bool:
@ -267,12 +313,14 @@ class int:
%1 = fcmp ole double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __ge__(a: int, b: int) -> bool:
%tmp = icmp sge i64 %a, %b
def __le__(a: int, b: int) -> bool:
%tmp = icmp sle i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ge__(self, b: float) -> bool:
@ -280,10 +328,17 @@ class int:
%1 = fcmp oge double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
def __new__(s: str) -> int:
return int._from_str(s, 10)
def __new__(s: str, base: int) -> int:
return int._from_str(s, base)
@pure
@llvm
def __ge__(a: int, b: int) -> bool:
%tmp = icmp sge i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __pow__(self, exp: float):
return float(self) ** exp
def __pow__(self, exp: int):
if exp < 0:
return 0
@ -296,53 +351,65 @@ class int:
break
self *= self
return result
def __pow__(self, exp: float):
return float(self) ** exp
def popcnt(self):
return Int[64](self).popcnt()
@llvm
def __atomic_xchg__(d: Ptr[int], b: int) -> void:
%tmp = atomicrmw xchg i64* %d, i64 %b seq_cst
ret void
@llvm
def __atomic_add__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw add i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_sub__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw sub i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_and__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw and i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_nand__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw nand i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_or__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw or i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def _atomic_xor(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw xor i64* %d, i64 %b seq_cst
ret i64 %tmp
def __atomic_xor__(self, b: int) -> int:
return int._atomic_xor(__ptr__(self), b)
@llvm
def __atomic_min__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw min i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_max__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw max i64* %d, i64 %b seq_cst
ret i64 %tmp
def __match__(self, i: int):
return self == i
@property
def real(self):
return self
@property
def imag(self):
return 0

View File

@ -10,9 +10,11 @@ class Int:
def __new__() -> Int[N]:
check_N(N)
return Int[N](0)
def __new__(what: Int[N]) -> Int[N]:
check_N(N)
return what
def __new__(what: int) -> Int[N]:
check_N(N)
if N < 64:
@ -21,10 +23,12 @@ class Int:
return what
else:
return __internal__.int_sext(what, 64, N)
@pure
@llvm
def __new__(what: UInt[N]) -> Int[N]:
ret i{=N} %what
def __new__(what: str) -> Int[N]:
check_N(N)
ret = Int[N]()
@ -39,6 +43,7 @@ class Int:
ret = ret * Int[N](10) + Int[N](int(what.ptr[i]) - 48)
i += 1
return sign * ret
def __int__(self) -> int:
if N > 64:
return __internal__.int_trunc(self, N, 64)
@ -46,37 +51,47 @@ class Int:
return self
else:
return __internal__.int_sext(self, N, 64)
def __index__(self):
return int(self)
def __copy__(self) -> Int[N]:
return self
def __deepcopy__(self) -> Int[N]:
return self
def __hash__(self) -> int:
return int(self)
@pure
@llvm
def __float__(self) -> float:
%0 = sitofp i{=N} %self to double
ret double %0
@pure
@llvm
def __bool__(self) -> bool:
%0 = icmp ne i{=N} %self, 0
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> Int[N]:
return self
@pure
@llvm
def __neg__(self) -> Int[N]:
%0 = sub i{=N} 0, %self
ret i{=N} %0
@pure
@llvm
def __invert__(self) -> Int[N]:
%0 = xor i{=N} %self, -1
ret i{=N} %0
@pure
@commutative
@associative
@ -84,11 +99,13 @@ class Int:
def __add__(self, other: Int[N]) -> Int[N]:
%0 = add i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __sub__(self, other: Int[N]) -> Int[N]:
%0 = sub i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -97,11 +114,13 @@ class Int:
def __mul__(self, other: Int[N]) -> Int[N]:
%0 = mul i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __floordiv__(self, other: Int[N]) -> Int[N]:
%0 = sdiv i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __truediv__(self, other: Int[N]) -> float:
@ -109,11 +128,13 @@ class Int:
%1 = sitofp i{=N} %other to double
%2 = fdiv double %0, %1
ret double %2
@pure
@llvm
def __mod__(self, other: Int[N]) -> Int[N]:
%0 = srem i{=N} %self, %other
ret i{=N} %0
def __divmod__(self, other: Int[N]):
d = self // other
m = self - d*other
@ -121,52 +142,61 @@ class Int:
m += other
d -= Int[N](1)
return (d, m)
@pure
@llvm
def __lshift__(self, other: Int[N]) -> Int[N]:
%0 = shl i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __rshift__(self, other: Int[N]) -> Int[N]:
%0 = ashr i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __eq__(self, other: Int[N]) -> bool:
%0 = icmp eq i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ne__(self, other: Int[N]) -> bool:
%0 = icmp ne i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __lt__(self, other: Int[N]) -> bool:
%0 = icmp slt i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __gt__(self, other: Int[N]) -> bool:
%0 = icmp sgt i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __le__(self, other: Int[N]) -> bool:
%0 = icmp sle i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ge__(self, other: Int[N]) -> bool:
%0 = icmp sge i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@commutative
@associative
@ -174,6 +204,7 @@ class Int:
def __and__(self, other: Int[N]) -> Int[N]:
%0 = and i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -181,6 +212,7 @@ class Int:
def __or__(self, other: Int[N]) -> Int[N]:
%0 = or i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -188,6 +220,7 @@ class Int:
def __xor__(self, other: Int[N]) -> Int[N]:
%0 = xor i{=N} %self, %other
ret i{=N} %0
@llvm
def __pickle__(self, dest: Ptr[byte]) -> void:
declare i32 @gzwrite(i8*, i8*, i32)
@ -198,6 +231,7 @@ class Int:
%szi = ptrtoint i{=N}* %sz to i32
%2 = call i32 @gzwrite(i8* %dest, i8* %1, i32 %szi)
ret void
@llvm
def __unpickle__(src: Ptr[byte]) -> Int[N]:
declare i32 @gzread(i8*, i8*, i32)
@ -208,18 +242,23 @@ class Int:
%2 = call i32 @gzread(i8* %src, i8* %1, i32 %szi)
%3 = load i{=N}, i{=N}* %0
ret i{=N} %3
def __repr__(self) -> str:
return str.cat(('Int[', seq_str_int(N), '](', seq_str_int(int(self)), ')'))
def __str__(self) -> str:
return seq_str_int(int(self))
@pure
@llvm
def _popcnt(self) -> Int[N]:
declare i{=N} @llvm.ctpop.i{=N}(i{=N})
%0 = call i{=N} @llvm.ctpop.i{=N}(i{=N} %self)
ret i{=N} %0
def popcnt(self):
return int(self._popcnt())
def len() -> int:
return N
@ -228,9 +267,11 @@ class UInt:
def __new__() -> UInt[N]:
check_N(N)
return UInt[N](0)
def __new__(what: UInt[N]) -> UInt[N]:
check_N(N)
return what
def __new__(what: int) -> UInt[N]:
check_N(N)
if N < 64:
@ -239,13 +280,16 @@ class UInt:
return UInt[N](Int[N](what))
else:
return UInt[N](__internal__.int_zext(what, 64, N))
@pure
@llvm
def __new__(what: Int[N]) -> UInt[N]:
ret i{=N} %what
def __new__(what: str) -> UInt[N]:
check_N(N)
return UInt[N](Int[N](what))
def __int__(self) -> int:
if N > 64:
return __internal__.int_trunc(self, N, 64)
@ -253,37 +297,47 @@ class UInt:
return Int[64](self)
else:
return __internal__.int_zext(self, N, 64)
def __index__(self):
return int(self)
def __copy__(self) -> UInt[N]:
return self
def __deepcopy__(self) -> UInt[N]:
return self
def __hash__(self) -> int:
return int(self)
@pure
@llvm
def __float__(self) -> float:
%0 = uitofp i{=N} %self to double
ret double %0
@pure
@llvm
def __bool__(self) -> bool:
%0 = icmp ne i{=N} %self, 0
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> UInt[N]:
return self
@pure
@llvm
def __neg__(self) -> UInt[N]:
%0 = sub i{=N} 0, %self
ret i{=N} %0
@pure
@llvm
def __invert__(self) -> UInt[N]:
%0 = xor i{=N} %self, -1
ret i{=N} %0
@pure
@commutative
@associative
@ -291,11 +345,13 @@ class UInt:
def __add__(self, other: UInt[N]) -> UInt[N]:
%0 = add i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __sub__(self, other: UInt[N]) -> UInt[N]:
%0 = sub i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -304,11 +360,13 @@ class UInt:
def __mul__(self, other: UInt[N]) -> UInt[N]:
%0 = mul i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __floordiv__(self, other: UInt[N]) -> UInt[N]:
%0 = udiv i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __truediv__(self, other: UInt[N]) -> float:
@ -316,59 +374,70 @@ class UInt:
%1 = uitofp i{=N} %other to double
%2 = fdiv double %0, %1
ret double %2
@pure
@llvm
def __mod__(self, other: UInt[N]) -> UInt[N]:
%0 = urem i{=N} %self, %other
ret i{=N} %0
def __divmod__(self, other: UInt[N]):
return (self // other, self % other)
@pure
@llvm
def __lshift__(self, other: UInt[N]) -> UInt[N]:
%0 = shl i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __rshift__(self, other: UInt[N]) -> UInt[N]:
%0 = lshr i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __eq__(self, other: UInt[N]) -> bool:
%0 = icmp eq i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ne__(self, other: UInt[N]) -> bool:
%0 = icmp ne i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __lt__(self, other: UInt[N]) -> bool:
%0 = icmp ult i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __gt__(self, other: UInt[N]) -> bool:
%0 = icmp ugt i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __le__(self, other: UInt[N]) -> bool:
%0 = icmp ule i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ge__(self, other: UInt[N]) -> bool:
%0 = icmp uge i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@commutative
@associative
@ -376,6 +445,7 @@ class UInt:
def __and__(self, other: UInt[N]) -> UInt[N]:
%0 = and i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -383,6 +453,7 @@ class UInt:
def __or__(self, other: UInt[N]) -> UInt[N]:
%0 = or i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -390,6 +461,7 @@ class UInt:
def __xor__(self, other: UInt[N]) -> UInt[N]:
%0 = xor i{=N} %self, %other
ret i{=N} %0
@llvm
def __pickle__(self, dest: Ptr[byte]) -> void:
declare i32 @gzwrite(i8*, i8*, i32)
@ -400,6 +472,7 @@ class UInt:
%szi = ptrtoint i{=N}* %sz to i32
%2 = call i32 @gzwrite(i8* %dest, i8* %1, i32 %szi)
ret void
@llvm
def __unpickle__(src: Ptr[byte]) -> UInt[N]:
declare i32 @gzread(i8*, i8*, i32)
@ -410,12 +483,16 @@ class UInt:
%2 = call i32 @gzread(i8* %src, i8* %1, i32 %szi)
%3 = load i{=N}, i{=N}* %0
ret i{=N} %3
def __repr__(self) -> str:
return str.cat(('UInt[', seq_str_int(N), '](', seq_str_uint(int(self)), ')'))
def __str__(self) -> str:
return seq_str_uint(int(self))
def popcnt(self):
return int(Int[N](self)._popcnt())
def len() -> int:
return N

View File

@ -5,29 +5,36 @@ class Optional:
return __internal__.opt_tuple_new(T)
else:
return __internal__.opt_ref_new(T)
def __new__(what: T) -> Optional[T]:
if isinstance(T, ByVal):
return __internal__.opt_tuple_new_arg(what, T)
else:
return __internal__.opt_ref_new_arg(what, T)
def __bool__(self) -> bool:
if isinstance(T, ByVal):
return __internal__.opt_tuple_bool(self, T)
else:
return __internal__.opt_ref_bool(self, T)
def __invert__(self) -> T:
if isinstance(T, ByVal):
return __internal__.opt_tuple_invert(self, T)
else:
return __internal__.opt_ref_invert(self, T)
def __str__(self) -> str:
return 'None' if not self else str(~self)
def __repr__(self) -> str:
return 'None' if not self else (~self).__repr__()
def __is_optional__(self, other: Optional[T]):
if (not self) or (not other):
return (not self) and (not other)
return self.__invert__() is other.__invert__()
optional = Optional
def unwrap[T](opt: Optional[T]) -> T:

View File

@ -4,53 +4,63 @@ def seq_str_ptr(a: Ptr[byte]) -> str: pass
@extend
class Ptr:
@__internal__
def __new__(sz: int) -> Ptr[T]:
pass
@pure
@llvm
def __new__() -> Ptr[T]:
ret {=T}* null
@__internal__
def __new__(sz: int) -> Ptr[T]:
pass
@pure
@llvm
def __new__(other: Ptr[T]) -> Ptr[T]:
ret {=T}* %other
@pure
@llvm
def __new__(other: Ptr[byte]) -> Ptr[T]:
%0 = bitcast i8* %other to {=T}*
ret {=T}* %0
@pure
@llvm
def __new__(other: Ptr[T]) -> Ptr[T]:
ret {=T}* %other
@pure
@llvm
def __int__(self) -> int:
%0 = ptrtoint {=T}* %self to i64
ret i64 %0
@pure
@llvm
def __copy__(self) -> Ptr[T]:
ret {=T}* %self
@pure
@llvm
def __bool__(self) -> bool:
%0 = icmp ne {=T}* %self, null
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __getitem__(self, index: int) -> T:
%0 = getelementptr {=T}, {=T}* %self, i64 %index
%1 = load {=T}, {=T}* %0
ret {=T} %1
@llvm
def __setitem__(self, index: int, what: T) -> void:
%0 = getelementptr {=T}, {=T}* %self, i64 %index
store {=T} %what, {=T}* %0
ret void
@pure
@llvm
def __add__(self, other: int) -> Ptr[T]:
%0 = getelementptr {=T}, {=T}* %self, i64 %other
ret {=T}* %0
@pure
@llvm
def __sub__(self, other: Ptr[T]) -> int:
@ -59,90 +69,105 @@ class Ptr:
%2 = sub i64 %0, %1
%3 = sdiv exact i64 %2, ptrtoint ({=T}* getelementptr ({=T}, {=T}* null, i32 1) to i64)
ret i64 %3
@pure
@llvm
def __eq__(self, other: Ptr[T]) -> bool:
%0 = icmp eq {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ne__(self, other: Ptr[T]) -> bool:
%0 = icmp ne {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __lt__(self, other: Ptr[T]) -> bool:
%0 = icmp slt {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __gt__(self, other: Ptr[T]) -> bool:
%0 = icmp sgt {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __le__(self, other: Ptr[T]) -> bool:
%0 = icmp sle {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ge__(self, other: Ptr[T]) -> bool:
%0 = icmp sge {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@llvm
def __prefetch_r0__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 0, i32 0, i32 1)
ret void
@llvm
def __prefetch_r1__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 0, i32 1, i32 1)
ret void
@llvm
def __prefetch_r2__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 0, i32 2, i32 1)
ret void
@llvm
def __prefetch_r3__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 0, i32 3, i32 1)
ret void
@llvm
def __prefetch_w0__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 1, i32 0, i32 1)
ret void
@llvm
def __prefetch_w1__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 1, i32 1, i32 1)
ret void
@llvm
def __prefetch_w2__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 1, i32 2, i32 1)
ret void
@llvm
def __prefetch_w3__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 1, i32 3, i32 1)
ret void
@pure
@llvm
def as_byte(self) -> Ptr[byte]:

View File

@ -7,41 +7,52 @@ class str:
@__internal__
def __new__(l: int, p: Ptr[byte]) -> str:
pass
def __new__(p: Ptr[byte], l: int) -> str:
return str(l, p)
def __new__() -> str:
return str(Ptr[byte](), 0)
def __new__[T](what: T) -> str: # lowest priority!
def __new__(what) -> str:
if hasattr(what, "__str__"):
return what.__str__()
else:
return what.__repr__()
def __str__(what: str) -> str:
return what
def __len__(self) -> int:
return self.len
def __bool__(self) -> bool:
return self.len != 0
def __copy__(self) -> str:
n = self.len
p = cobj(n)
str.memcpy(p, self.ptr, n)
return str(p, n)
@llvm
def memcpy(dest: Ptr[byte], src: Ptr[byte], len: int) -> void:
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile)
call void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 0, i1 false)
ret void
@llvm
def memmove(dest: Ptr[byte], src: Ptr[byte], len: int) -> void:
declare void @llvm.memmove.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile)
call void @llvm.memmove.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 0, i1 false)
ret void
@llvm
def memset(dest: Ptr[byte], val: byte, len: int) -> void:
declare void @llvm.memset.p0i8.i64(i8* %dest, i8 %val, i64 %len, i32 %align, i1 %isvolatile)
call void @llvm.memset.p0i8.i64(i8* %dest, i8 %val, i64 %len, i32 0, i1 false)
ret void
def __add__(self, other: str) -> str:
len1 = self.len
len2 = other.len
@ -50,17 +61,20 @@ class str:
str.memcpy(p, self.ptr, len1)
str.memcpy(p + len1, other.ptr, len2)
return str(p, len3)
def c_str(self):
n = self.__len__()
p = cobj(n + 1)
str.memcpy(p, self.ptr, n)
p[n] = byte(0)
return p
def from_ptr(t: cobj) -> str:
n = strlen(t)
p = Ptr[byte](n)
str.memcpy(p, t, n)
return str(p, n)
def __eq__(self, other: str):
if self.len != other.len:
return False
@ -70,10 +84,13 @@ class str:
return False
i += 1
return True
def __match__(self, other: str):
return self.__eq__(other)
def __ne__(self, other: str):
return not self.__eq__(other)
def cat(*args):
total = 0
if staticlen(args) == 1 and hasattr(args[0], "__iter__") and hasattr(args[0], "__len__"):

View File

@ -394,22 +394,10 @@ class NormalDist:
self._mu = mu
self._sigma = sigma
def __init__(self, mu: float, sigma: float):
def __init__(self, mu, sigma):
self._init(float(mu), float(sigma))
def __init__(self, mu: int, sigma: int):
self._init(float(mu), float(sigma))
def __init__(self, mu: float, sigma: int):
self._init(float(mu), float(sigma))
def __init__(self, mu: int, sigma: float):
self._init(float(mu), float(sigma))
def __init__(self, mu: float):
self._init(mu, 1.0)
def __init__(self, mu: int):
def __init__(self, mu):
self._init(float(mu), 1.0)
def __init__(self):