Improved logic for handling overloaded functions (#10)

* Backport seq-lang/seq@develop fixes

* Backport seq-lang/seq@develop fixes

* Select the last matching overload by default (remove scoring logic); Add dispatch stubs for partial overload support

* Select the last matching overload by default [wip]

* Fix various bugs and update tests

* Add support for partial functions with *args/**kwargs; Fix partial method dispatch

* Update .gitignore

* Fix grammar to allow variable names that have reserved word as a prefix

* Add support for super() call

* Add super() tests; Allow static inheritance to inherit @extend methods

* Support for overloaded functions [wip; base logic done]

* Support for overloaded functions

* Update .gitignore

* Fix partial dots

* Rename function overload 'super' to 'superf'

* Add support for super()

* Add tests for super()

* Add tuple_offsetof

* Add tuple support for super()

* Add isinstance support for inherited classes; Fix review issues

Co-authored-by: A. R. Shajii <ars@ars.me>
pull/27/head
Ibrahim Numanagić 2022-01-11 17:39:15 -08:00 committed by GitHub
parent 240f2947c5
commit cc634d1940
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1611 additions and 666 deletions

2
.gitignore vendored
View File

@ -54,3 +54,5 @@ Thumbs.db
extra/jupyter/share/jupyter/kernels/codon/kernel.json
scratch.*
_*
.ipynb_checkpoints

View File

@ -68,11 +68,24 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code,
auto transformed = ast::SimplifyVisitor::apply(cache.get(), std::move(codeStmt),
abspath, defines, (testFlags > 1));
t2.log();
if (codon::getLogger().flags & codon::Logger::FLAG_USER) {
auto fo = fopen("_dump_simplify.sexp", "w");
fmt::print(fo, "{}\n", transformed->toString(0));
fclose(fo);
}
Timer t3("typecheck");
auto typechecked =
ast::TypecheckVisitor::apply(cache.get(), std::move(transformed));
t3.log();
if (codon::getLogger().flags & codon::Logger::FLAG_USER) {
auto fo = fopen("_dump_typecheck.sexp", "w");
fmt::print(fo, "{}\n", typechecked->toString(0));
for (auto &f : cache->functions)
for (auto &r : f.second.realizations)
fmt::print(fo, "{}\n", r.second->ast->toString(0));
fclose(fo);
}
Timer t4("translate");
ast::TranslateVisitor::apply(cache.get(), std::move(typechecked));

View File

@ -275,7 +275,7 @@ struct KeywordStarExpr : public Expr {
struct TupleExpr : public Expr {
std::vector<ExprPtr> items;
explicit TupleExpr(std::vector<ExprPtr> items);
explicit TupleExpr(std::vector<ExprPtr> items = {});
TupleExpr(const TupleExpr &expr);
std::string toString() const override;

View File

@ -50,20 +50,22 @@ types::ClassTypePtr Cache::findClass(const std::string &name) const {
types::FuncTypePtr Cache::findFunction(const std::string &name) const {
auto f = typeCtx->find(name);
if (f && f->type && f->kind == TypecheckItem::Func)
return f->type->getFunc();
f = typeCtx->find(name + ":0");
if (f && f->type && f->kind == TypecheckItem::Func)
return f->type->getFunc();
return nullptr;
}
types::FuncTypePtr
Cache::findMethod(types::ClassType *typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args) {
types::FuncTypePtr Cache::findMethod(types::ClassType *typ, const std::string &member,
const std::vector<types::TypePtr> &args) {
auto e = std::make_shared<IdExpr>(typ->name);
e->type = typ->getClass();
seqassert(e->type, "not a class");
int oldAge = typeCtx->age;
typeCtx->age = 99999;
auto f = typeCtx->findBestMethod(e.get(), member, args);
auto f = TypecheckVisitor(typeCtx).findBestMethod(e.get(), member, args);
typeCtx->age = oldAge;
return f;
}

View File

@ -106,19 +106,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
/// Non-simplified AST. Used for base class instantiation.
std::shared_ptr<ClassStmt> originalAst;
/// A class function method.
struct ClassMethod {
/// Canonical name of a method (e.g. __init__.1).
std::string name;
/// A corresponding generic function type.
types::FuncTypePtr type;
/// Method age (how many class extension were seen before a method definition).
/// Used to prevent the usage of a method before it was defined in the code.
int age;
};
/// Class method lookup table. Each name points to a list of ClassMethod instances
/// that share the same method name (a list because methods can be overloaded).
std::unordered_map<std::string, std::vector<ClassMethod>> methods;
/// Class method lookup table. Each non-canonical name points
/// to a root function name of a corresponding method.
std::unordered_map<std::string, std::string> methods;
/// A class field (member).
struct ClassField {
@ -143,6 +133,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
/// Realization lookup table that maps a realized class name to the corresponding
/// ClassRealization instance.
std::unordered_map<std::string, std::shared_ptr<ClassRealization>> realizations;
/// List of inherited class. We also keep the number of fields each of inherited
/// class.
std::vector<std::pair<std::string, int>> parentClasses;
Class() : ast(nullptr), originalAst(nullptr) {}
};
@ -177,6 +170,20 @@ struct Cache : public std::enable_shared_from_this<Cache> {
/// corresponding Function instance.
std::unordered_map<std::string, Function> functions;
struct Overload {
/// Canonical name of an overload (e.g. Foo.__init__.1).
std::string name;
/// Overload age (how many class extension were seen before a method definition).
/// Used to prevent the usage of an overload before it was defined in the code.
/// TODO: I have no recollection of how this was supposed to work. Most likely
/// it does not work at all...
int age;
};
/// Maps a "root" name of each function to the list of names of the function
/// overloads.
std::unordered_map<std::string, std::vector<Overload>> overloads;
/// Pointer to the later contexts needed for IR API access.
std::shared_ptr<TypeContext> typeCtx;
std::shared_ptr<TranslateContext> codegenCtx;
@ -223,9 +230,8 @@ public:
types::FuncTypePtr findFunction(const std::string &name) const;
/// Find the class method in a given class type that best matches the given arguments.
/// Returns an _uninstantiated_ type.
types::FuncTypePtr
findMethod(types::ClassType *typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args);
types::FuncTypePtr findMethod(types::ClassType *typ, const std::string &member,
const std::vector<types::TypePtr> &args);
/// Given a class type and the matching generic vector, instantiate the type and
/// realize it.

View File

@ -61,12 +61,12 @@ small_stmt <-
/ 'break' &(SPACE / ';' / EOL) { return any(ast<BreakStmt>(LOC)); }
/ 'continue' &(SPACE / ';' / EOL) { return any(ast<ContinueStmt>(LOC)); }
/ global_stmt
/ yield_stmt
/ yield_stmt &(SPACE / ';' / EOL)
/ assert_stmt
/ del_stmt
/ return_stmt
/ raise_stmt
/ print_stmt
/ return_stmt &(SPACE / ';' / EOL)
/ raise_stmt &(SPACE / ';' / EOL)
/ print_stmt
/ import_stmt
/ expressions &(_ ';' / _ EOL) { return any(ast<ExprStmt>(LOC, ac_expr(V0))); }
/ NAME SPACE expressions {
@ -253,7 +253,7 @@ with_stmt <- 'with' SPACE (with_parens_item / with_item) _ ':' _ suite {
with_parens_item <- '(' _ tlist(',', as_item) _ ')' { return VS; }
with_item <- list(',', as_item) { return VS; }
as_item <-
/ expression SPACE 'as' SPACE star_target &(_ (',' / ')' / ':')) {
/ expression SPACE 'as' SPACE id &(_ (',' / ')' / ':')) {
return pair(ac_expr(V0), ac_expr(V1));
}
/ expression { return pair(ac_expr(V0), (ExprPtr)nullptr); }

View File

@ -52,8 +52,6 @@ std::shared_ptr<peg::Grammar> initParser() {
template <typename T>
T parseCode(Cache *cache, const std::string &file, std::string code, int line_offset,
int col_offset, const std::string &rule) {
TIME("peg");
// Initialize
if (!grammar)
grammar = initParser();

View File

@ -88,8 +88,10 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil
}
// Reserve the following static identifiers.
for (auto name : {"staticlen", "compile_error", "isinstance", "hasattr", "type",
"TypeVar", "Callable", "argv"})
"TypeVar", "Callable", "argv", "super", "superf"})
stdlib->generateCanonicalName(name);
stdlib->add(SimplifyItem::Var, "super", "super", true);
stdlib->add(SimplifyItem::Var, "superf", "superf", true);
// This code must be placed in a preamble (these are not POD types but are
// referenced by the various preamble Function.N and Tuple.N stubs)

View File

@ -488,6 +488,9 @@ private:
// suite recursively, and assumes that each statement is either a function or a
// doc-string.
std::vector<StmtPtr> getClassMethods(const StmtPtr &s);
// Generate dispatch method for partial overloaded calls.
void generateDispatch(const std::string &name);
};
} // namespace ast

View File

@ -14,8 +14,9 @@ namespace codon {
namespace ast {
SimplifyItem::SimplifyItem(Kind k, std::string base, std::string canonicalName,
bool global)
: kind(k), base(move(base)), canonicalName(move(canonicalName)), global(global) {}
bool global, std::string moduleName)
: kind(k), base(move(base)), canonicalName(move(canonicalName)), global(global),
moduleName(move(moduleName)) {}
SimplifyContext::SimplifyContext(std::string filename, Cache *cache)
: Context<SimplifyItem>(move(filename)), cache(move(cache)),
@ -31,6 +32,7 @@ std::shared_ptr<SimplifyItem> SimplifyContext::add(SimplifyItem::Kind kind,
bool global) {
seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name);
auto t = std::make_shared<SimplifyItem>(kind, getBase(), canonicalName, global);
t->moduleName = getModule();
Context<SimplifyItem>::add(name, t);
Context<SimplifyItem>::add(canonicalName, t);
return t;
@ -60,22 +62,29 @@ std::string SimplifyContext::getBase() const {
return bases.back().name;
}
std::string SimplifyContext::getModule() const {
std::string base = moduleName.status == ImportFile::STDLIB ? "std." : "";
base += moduleName.module;
if (startswith(base, "__main__"))
base = base.substr(8);
return base;
}
std::string SimplifyContext::generateCanonicalName(const std::string &name,
bool includeBase) const {
bool includeBase,
bool zeroId) const {
std::string newName = name;
if (includeBase && name.find('.') == std::string::npos) {
bool alreadyGenerated = name.find('.') != std::string::npos;
if (includeBase && !alreadyGenerated) {
std::string base = getBase();
if (base.empty()) {
base = moduleName.status == ImportFile::STDLIB ? "std." : "";
base += moduleName.module;
if (startswith(base, "__main__"))
base = base.substr(8);
}
if (base.empty())
base = getModule();
newName = (base.empty() ? "" : (base + ".")) + newName;
}
auto num = cache->identifierCount[newName]++;
newName = num ? format("{}.{}", newName, num) : newName;
if (newName != name)
if (num)
newName = format("{}.{}", newName, num);
if (name != newName && !zeroId)
cache->identifierCount[newName]++;
cache->reverseIdentifierLookup[newName] = name;
return newName;

View File

@ -32,13 +32,16 @@ struct SimplifyItem {
bool global;
/// Non-empty string if a variable is import variable
std::string importPath;
/// Full module name
std::string moduleName;
public:
SimplifyItem(Kind k, std::string base, std::string canonicalName,
bool global = false);
SimplifyItem(Kind k, std::string base, std::string canonicalName, bool global = false,
std::string moduleName = "");
/// Convenience getters.
std::string getBase() const { return base; }
std::string getModule() const { return moduleName; }
bool isGlobal() const { return global; }
bool isVar() const { return kind == Var; }
bool isFunc() const { return kind == Func; }
@ -107,14 +110,16 @@ public:
/// Return a canonical name of the top-most base, or an empty string if this is a
/// top-level base.
std::string getBase() const;
/// Return the current module.
std::string getModule() const;
/// Return the current base nesting level (note: bases, not blocks).
int getLevel() const { return bases.size(); }
/// Pretty-print the current context state.
void dump() override { dump(0); }
/// Generate a unique identifier (name) for a given string.
std::string generateCanonicalName(const std::string &name,
bool includeBase = false) const;
std::string generateCanonicalName(const std::string &name, bool includeBase = false,
bool zeroId = false) const;
bool inFunction() const { return getLevel() && !bases.back().isType(); }
bool inClass() const { return getLevel() && bases.back().isType(); }

View File

@ -385,12 +385,22 @@ void SimplifyVisitor::visit(IndexExpr *expr) {
}
// IndexExpr[i1, ..., iN] is internally stored as IndexExpr[TupleExpr[i1, ..., iN]]
// for N > 1, so make sure to check that case.
std::vector<ExprPtr> it;
if (auto t = index->getTuple())
for (auto &i : t->items)
it.push_back(transform(i, true));
it.push_back(i);
else
it.push_back(transform(index, true));
it.push_back(index);
for (auto &i: it) {
if (auto es = i->getStar())
i = N<StarExpr>(transform(es->what));
else if (auto ek = CAST(i, KeywordStarExpr))
i = N<KeywordStarExpr>(transform(ek->what));
else
i = transform(i, true);
}
if (e->isType()) {
resultExpr = N<InstantiateExpr>(e, it);
resultExpr->markType();
@ -617,7 +627,7 @@ void SimplifyVisitor::visit(DotExpr *expr) {
auto s = join(chain, ".", importEnd, i + 1);
val = fctx->find(s);
// Make sure that we access only global imported variables.
if (val && (importName.empty() || val->isGlobal())) {
if (val && (importName.empty() || val->isType() || val->isGlobal())) {
itemName = val->canonicalName;
itemEnd = i + 1;
if (!importName.empty())

View File

@ -437,28 +437,30 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
if (stmt->decorators.size() != 1)
error("__attribute__ cannot be mixed with other decorators");
attr.isAttribute = true;
} else if (d->isId(Attr::LLVM))
} else if (d->isId(Attr::LLVM)) {
attr.set(Attr::LLVM);
else if (d->isId(Attr::Python))
} else if (d->isId(Attr::Python)) {
attr.set(Attr::Python);
else if (d->isId(Attr::Internal))
} else if (d->isId(Attr::Internal)) {
attr.set(Attr::Internal);
else if (d->isId(Attr::Atomic))
} else if (d->isId(Attr::Atomic)) {
attr.set(Attr::Atomic);
else if (d->isId(Attr::Property))
} else if (d->isId(Attr::Property)) {
attr.set(Attr::Property);
else if (d->isId(Attr::ForceRealize))
} else if (d->isId(Attr::ForceRealize)) {
attr.set(Attr::ForceRealize);
else {
} else {
// Let's check if this is a attribute
auto dt = transform(clone(d));
if (dt && dt->getId()) {
auto ci = ctx->find(dt->getId()->value);
if (ci && ci->kind == SimplifyItem::Func) {
if (ctx->cache->functions[ci->canonicalName].ast->attributes.isAttribute) {
attr.set(ci->canonicalName);
continue;
}
if (ctx->cache->overloads[ci->canonicalName].size() == 1)
if (ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name]
.ast->attributes.isAttribute) {
attr.set(ci->canonicalName);
continue;
}
}
}
decorators.emplace_back(clone(d));
@ -472,8 +474,23 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
return;
}
auto canonicalName = ctx->generateCanonicalName(stmt->name, true);
bool isClassMember = ctx->inClass();
std::string rootName;
if (isClassMember) {
auto &m = ctx->cache->classes[ctx->bases.back().name].methods;
auto i = m.find(stmt->name);
if (i != m.end())
rootName = i->second;
} else if (auto c = ctx->find(stmt->name)) {
if (c->isFunc() && c->getModule() == ctx->getModule() &&
c->getBase() == ctx->getBase())
rootName = c->canonicalName;
}
if (rootName.empty())
rootName = ctx->generateCanonicalName(stmt->name, true);
auto canonicalName =
format("{}:{}", rootName, ctx->cache->overloads[rootName].size());
ctx->cache->reverseIdentifierLookup[canonicalName] = stmt->name;
bool isEnclosedFunc = ctx->inFunction();
if (attr.has(Attr::ForceRealize) && (ctx->getLevel() || isClassMember))
@ -483,7 +500,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
ctx->bases = std::vector<SimplifyContext::Base>();
if (!isClassMember)
// Class members are added to class' method table
ctx->add(SimplifyItem::Func, stmt->name, canonicalName, ctx->isToplevel());
ctx->add(SimplifyItem::Func, stmt->name, rootName, ctx->isToplevel());
if (isClassMember)
ctx->bases.push_back(oldBases[0]);
ctx->bases.emplace_back(SimplifyContext::Base{canonicalName}); // Add new base...
@ -527,6 +544,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
if (!typeAst && isClassMember && ia == 0 && a.name == "self") {
typeAst = ctx->bases[ctx->bases.size() - 2].ast;
attr.set(".changedSelf");
attr.set(Attr::Method);
}
if (attr.has(Attr::C)) {
@ -602,8 +620,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
// ... set the enclosing class name...
attr.parentClass = ctx->bases.back().name;
// ... add the method to class' method list ...
ctx->cache->classes[ctx->bases.back().name].methods[stmt->name].push_back(
{canonicalName, nullptr, ctx->cache->age});
ctx->cache->classes[ctx->bases.back().name].methods[stmt->name] = rootName;
// ... and if the function references outer class variable (by definition a
// generic), mark it as not static as it needs fully instantiated class to be
// realized. For example, in class A[T]: def foo(): pass, A.foo() can be realized
@ -612,6 +629,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
if (isMethod)
attr.set(Attr::Method);
}
ctx->cache->overloads[rootName].push_back({canonicalName, ctx->cache->age});
std::vector<CallExpr::Arg> partialArgs;
if (!captures.empty()) {
@ -732,6 +750,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
ClassStmt *originalAST = nullptr;
auto classItem =
std::make_shared<SimplifyItem>(SimplifyItem::Type, "", "", ctx->isToplevel());
classItem->moduleName = ctx->getModule();
if (!extension) {
classItem->canonicalName = canonicalName =
ctx->generateCanonicalName(name, !attr.has(Attr::Internal));
@ -770,6 +789,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
std::vector<std::unordered_map<std::string, ExprPtr>> substitutions;
std::vector<int> argSubstitutions;
std::unordered_set<std::string> seenMembers;
std::vector<int> baseASTsFields;
for (auto &baseClass : stmt->baseClasses) {
std::string bcName;
std::vector<ExprPtr> subs;
@ -810,6 +830,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
if (!extension)
ctx->cache->classes[canonicalName].fields.push_back({a.name, nullptr});
}
baseASTsFields.push_back(args.size());
}
// Add generics, if any, to the context.
@ -891,6 +912,9 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
ctx->moduleName.module);
ctx->cache->classes[canonicalName].ast =
N<ClassStmt>(canonicalName, args, N<SuiteStmt>(), attr);
for (int i = 0; i < baseASTs.size(); i++)
ctx->cache->classes[canonicalName].parentClasses.push_back(
{baseASTs[i]->name, baseASTsFields[i]});
std::vector<StmtPtr> fns;
ExprPtr codeType = ctx->bases.back().ast->clone();
std::vector<std::string> magics{};
@ -934,29 +958,45 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
suite->stmts.push_back(preamble->functions.back());
}
}
for (int ai = 0; ai < baseASTs.size(); ai++)
for (auto sp : getClassMethods(baseASTs[ai]->suite))
if (auto f = sp->getFunction()) {
for (int ai = 0; ai < baseASTs.size(); ai++) {
// FUNCS
for (auto &mm : ctx->cache->classes[baseASTs[ai]->name].methods)
for (auto &mf : ctx->cache->overloads[mm.second]) {
auto f = ctx->cache->functions[mf.name].ast;
if (f->attributes.has("autogenerated"))
continue;
auto subs = substitutions[ai];
auto newName = ctx->generateCanonicalName(
ctx->cache->reverseIdentifierLookup[f->name], true);
auto nf = std::dynamic_pointer_cast<FunctionStmt>(replace(sp, subs));
subs[nf->name] = N<IdExpr>(newName);
nf->name = newName;
std::string rootName;
auto &mts = ctx->cache->classes[ctx->bases.back().name].methods;
auto it = mts.find(ctx->cache->reverseIdentifierLookup[f->name]);
if (it != mts.end())
rootName = it->second;
else
rootName = ctx->generateCanonicalName(
ctx->cache->reverseIdentifierLookup[f->name], true);
auto newCanonicalName =
format("{}:{}", rootName, ctx->cache->overloads[rootName].size());
ctx->cache->reverseIdentifierLookup[newCanonicalName] =
ctx->cache->reverseIdentifierLookup[f->name];
auto nf = std::dynamic_pointer_cast<FunctionStmt>(
replace(std::static_pointer_cast<Stmt>(f), subs));
subs[nf->name] = N<IdExpr>(newCanonicalName);
nf->name = newCanonicalName;
suite->stmts.push_back(nf);
nf->attributes.parentClass = ctx->bases.back().name;
// check original ast...
if (nf->attributes.has(".changedSelf"))
if (nf->attributes.has(".changedSelf")) // replace self type with new class
nf->args[0].type = transformType(ctx->bases.back().ast);
preamble->functions.push_back(clone(nf));
ctx->cache->functions[newName].ast = nf;
ctx->cache->overloads[rootName].push_back({newCanonicalName, ctx->cache->age});
ctx->cache->functions[newCanonicalName].ast = nf;
ctx->cache->classes[ctx->bases.back().name]
.methods[ctx->cache->reverseIdentifierLookup[f->name]]
.push_back({newName, nullptr, ctx->cache->age});
.methods[ctx->cache->reverseIdentifierLookup[f->name]] = rootName;
}
}
for (auto sp : getClassMethods(stmt->suite))
if (sp && !sp->getClass()) {
transform(sp);
@ -1227,8 +1267,10 @@ StmtPtr SimplifyVisitor::transformCImport(const std::string &name,
auto f = N<FunctionStmt>(name, ret ? ret->clone() : N<IdExpr>("void"), fnArgs,
nullptr, attr);
StmtPtr tf = transform(f); // Already in the preamble
if (!altName.empty())
if (!altName.empty()) {
ctx->add(altName, ctx->find(name));
ctx->remove(name);
}
return tf;
}
@ -1371,10 +1413,11 @@ void SimplifyVisitor::transformNewImport(const ImportFile &file) {
stmts[0] = N<SuiteStmt>();
// Add a def import(): ... manually to the cache and to the preamble (it won't be
// transformed here!).
ctx->cache->functions[importVar].ast =
N<FunctionStmt>(importVar, nullptr, std::vector<Param>{}, N<SuiteStmt>(stmts),
Attr({Attr::ForceRealize}));
preamble->functions.push_back(ctx->cache->functions[importVar].ast->clone());
ctx->cache->overloads[importVar].push_back({importVar + ":0", ctx->cache->age});
ctx->cache->functions[importVar + ":0"].ast =
N<FunctionStmt>(importVar + ":0", nullptr, std::vector<Param>{},
N<SuiteStmt>(stmts), Attr({Attr::ForceRealize}));
preamble->functions.push_back(ctx->cache->functions[importVar + ":0"].ast->clone());
;
}
}

View File

@ -296,7 +296,8 @@ void TranslateVisitor::visit(ForStmt *stmt) {
auto c = stmt->decorator->getCall();
seqassert(c, "for par is not a call: {}", stmt->decorator->toString());
auto fc = c->expr->getType()->getFunc();
seqassert(fc && fc->ast->name == "std.openmp.for_par", "for par is not a function");
seqassert(fc && fc->ast->name == "std.openmp.for_par:0",
"for par is not a function");
auto schedule =
fc->funcGenerics[0].type->getStatic()->expr->staticValue.getString();
bool ordered = fc->funcGenerics[1].type->getStatic()->expr->staticValue.getInt();

View File

@ -33,16 +33,21 @@ StmtPtr TypecheckVisitor::apply(Cache *cache, StmtPtr stmts) {
return std::move(infer.second);
}
TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b) {
TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b, bool undoOnSuccess) {
if (!a)
return a = b;
seqassert(b, "rhs is nullptr");
types::Type::Unification undo;
if (a->unify(b.get(), &undo) >= 0)
if (a->unify(b.get(), &undo) >= 0) {
if (undoOnSuccess)
undo.undo();
return a;
undo.undo();
} else {
undo.undo();
}
// LOG("{} / {}", a->debugString(true), b->debugString(true));
a->unify(b.get(), &undo);
if (!undoOnSuccess)
a->unify(b.get(), &undo);
error("cannot unify {} and {}", a->toString(), b->toString());
return nullptr;
}

View File

@ -283,9 +283,31 @@ private:
void generateFnCall(int n);
/// Make an empty partial call fn(...) for a function fn.
ExprPtr partializeFunction(ExprPtr expr);
/// Picks the best method of a given expression that matches the given argument
/// types. Prefers methods whose signatures are closer to the given arguments:
/// e.g. foo(int) will match (int) better that a foo(T).
/// Also takes care of the Optional arguments.
/// If multiple equally good methods are found, return the first one.
/// Return nullptr if no methods were found.
types::FuncTypePtr findBestMethod(const Expr *expr, const std::string &member,
const std::vector<types::TypePtr> &args);
types::FuncTypePtr findBestMethod(const Expr *expr, const std::string &member,
const std::vector<CallExpr::Arg> &args);
types::FuncTypePtr findBestMethod(const std::string &fn,
const std::vector<CallExpr::Arg> &args);
std::vector<types::FuncTypePtr> findSuperMethods(const types::FuncTypePtr &func);
std::vector<types::FuncTypePtr>
findMatchingMethods(types::ClassType *typ,
const std::vector<types::FuncTypePtr> &methods,
const std::vector<CallExpr::Arg> &args);
ExprPtr transformSuper(const CallExpr *expr);
std::vector<types::ClassTypePtr> getSuperTypes(const types::ClassTypePtr &cls);
private:
types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b);
types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b,
bool undoOnSuccess = false);
types::TypePtr realizeType(types::ClassType *typ);
types::TypePtr realizeFunc(types::FuncType *typ);
std::pair<int, StmtPtr> inferTypes(StmtPtr stmt, bool keepLast,
@ -293,10 +315,12 @@ private:
codon::ir::types::Type *getLLVMType(const types::ClassType *t);
bool wrapExpr(ExprPtr &expr, types::TypePtr expectedType,
const types::FuncTypePtr &callee);
const types::FuncTypePtr &callee, bool undoOnSuccess = false);
int64_t translateIndex(int64_t idx, int64_t len, bool clamp = false);
int64_t sliceAdjustIndices(int64_t length, int64_t *start, int64_t *stop,
int64_t step);
types::FuncTypePtr findDispatch(const std::string &fn);
std::string getRootName(const std::string &name);
friend struct Cache;
};

View File

@ -103,19 +103,20 @@ types::TypePtr TypeContext::instantiate(const Expr *expr, types::TypePtr type,
if (auto l = i.second->getLink()) {
if (l->kind != types::LinkType::Unbound)
continue;
i.second->setSrcInfo(expr->getSrcInfo());
if (expr)
i.second->setSrcInfo(expr->getSrcInfo());
if (activeUnbounds.find(i.second) == activeUnbounds.end()) {
LOG_TYPECHECK("[ub] #{} -> {} (during inst of {}): {} ({})", i.first,
i.second->debugString(true), type->debugString(true),
expr->toString(), activate);
expr ? expr->toString() : "", activate);
if (activate && allowActivation)
activeUnbounds[i.second] =
format("{} of {} in {}", l->genericName.empty() ? "?" : l->genericName,
type->toString(), cache->getContent(expr->getSrcInfo()));
activeUnbounds[i.second] = format(
"{} of {} in {}", l->genericName.empty() ? "?" : l->genericName,
type->toString(), expr ? cache->getContent(expr->getSrcInfo()) : "");
}
}
}
LOG_TYPECHECK("[inst] {} -> {}", expr->toString(), t->debugString(true));
LOG_TYPECHECK("[inst] {} -> {}", expr ? expr->toString() : "", t->debugString(true));
return t;
}
@ -135,24 +136,29 @@ TypeContext::instantiateGeneric(const Expr *expr, types::TypePtr root,
return instantiate(expr, root, g.get());
}
std::vector<types::FuncTypePtr>
TypeContext::findMethod(const std::string &typeName, const std::string &method) const {
std::vector<types::FuncTypePtr> TypeContext::findMethod(const std::string &typeName,
const std::string &method,
bool hideShadowed) const {
auto m = cache->classes.find(typeName);
if (m != cache->classes.end()) {
auto t = m->second.methods.find(method);
if (t != m->second.methods.end()) {
std::unordered_map<std::string, int> signatureLoci;
auto mt = cache->overloads[t->second];
std::unordered_set<std::string> signatureLoci;
std::vector<types::FuncTypePtr> vv;
for (auto &mt : t->second) {
// LOG("{}::{} @ {} vs. {}", typeName, method, age, mt.age);
if (mt.age <= age) {
auto sig = cache->functions[mt.name].ast->signature();
auto it = signatureLoci.find(sig);
if (it != signatureLoci.end())
vv[it->second] = mt.type;
else {
signatureLoci[sig] = vv.size();
vv.emplace_back(mt.type);
for (int mti = int(mt.size()) - 1; mti >= 0; mti--) {
auto &m = mt[mti];
if (endswith(m.name, ":dispatch"))
continue;
if (m.age <= age) {
if (hideShadowed) {
auto sig = cache->functions[m.name].ast->signature();
if (!in(signatureLoci, sig)) {
signatureLoci.insert(sig);
vv.emplace_back(cache->functions[m.name].type);
}
} else {
vv.emplace_back(cache->functions[m.name].type);
}
}
}
@ -177,110 +183,6 @@ types::TypePtr TypeContext::findMember(const std::string &typeName,
return nullptr;
}
types::FuncTypePtr TypeContext::findBestMethod(
const Expr *expr, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args, bool checkSingle) {
auto typ = expr->getType()->getClass();
seqassert(typ, "not a class");
auto methods = findMethod(typ->name, member);
if (methods.empty())
return nullptr;
if (methods.size() == 1 && !checkSingle) // methods is not overloaded
return methods[0];
// Calculate the unification score for each available methods and pick the one with
// highest score.
std::vector<std::pair<int, int>> scores;
for (int mi = 0; mi < methods.size(); mi++) {
auto method = instantiate(expr, methods[mi], typ.get(), false)->getFunc();
std::vector<types::TypePtr> reordered;
std::vector<CallExpr::Arg> callArgs;
for (auto &a : args) {
callArgs.push_back({a.first, std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a.second);
}
auto score = reorderNamedArgs(
method.get(), callArgs,
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
for (int si = 0; si < slots.size(); si++) {
// Ignore *args, *kwargs and default arguments
reordered.emplace_back(si == s || si == k || slots[si].size() != 1
? nullptr
: args[slots[si][0]].second);
}
return 0;
},
[](const std::string &) { return -1; });
if (score == -1)
continue;
// Scoring system for each argument:
// Generics, traits and default arguments get a score of zero (lowest priority).
// Optional unwrap gets the score of 1.
// Optional wrap gets the score of 2.
// Successful unification gets the score of 3 (highest priority).
for (int ai = 0, mi = 1, gi = 0; ai < reordered.size(); ai++) {
auto argType = reordered[ai];
if (!argType)
continue;
auto expectedType = method->ast->args[ai].generic ? method->generics[gi++].type
: method->args[mi++];
auto expectedClass = expectedType->getClass();
// Ignore traits, *args/**kwargs and default arguments.
if (expectedClass && expectedClass->name == "Generator")
continue;
// LOG("<~> {} {}", argType->toString(), expectedType->toString());
auto argClass = argType->getClass();
types::Type::Unification undo;
int u = argType->unify(expectedType.get(), &undo);
undo.undo();
if (u >= 0) {
score += u + 3;
continue;
}
if (!method->ast->args[ai].generic) {
// Unification failed: maybe we need to wrap an argument?
if (expectedClass && expectedClass->name == TYPE_OPTIONAL && argClass &&
argClass->name != expectedClass->name) {
u = argType->unify(expectedClass->generics[0].type.get(), &undo);
undo.undo();
if (u >= 0) {
score += u + 2;
continue;
}
}
// ... or unwrap it (less ideal)?
if (argClass && argClass->name == TYPE_OPTIONAL && expectedClass &&
argClass->name != expectedClass->name) {
u = argClass->generics[0].type->unify(expectedType.get(), &undo);
undo.undo();
if (u >= 0) {
score += u;
continue;
}
}
}
// This method cannot be selected, ignore it.
score = -1;
break;
}
// LOG("{} {} / {}", typ->toString(), method->toString(), score);
if (score >= 0)
scores.emplace_back(std::make_pair(score, mi));
}
if (scores.empty())
return nullptr;
// Get the best score.
sort(scores.begin(), scores.end(), std::greater<>());
// LOG("Method: {}", methods[scores[0].second]->toString());
// std::string x;
// for (auto &a : args)
// x += format("{}{},", a.first.empty() ? "" : a.first + ": ",
// a.second->toString());
// LOG(" {} :: {} ( {} )", typ->toString(), member, x);
return methods[scores[0].second];
}
int TypeContext::reorderNamedArgs(types::FuncType *func,
const std::vector<CallExpr::Arg> &args,
ReorderDoneFn onDone, ReorderErrorFn onError,
@ -300,15 +202,17 @@ int TypeContext::reorderNamedArgs(types::FuncType *func,
int starArgIndex = -1, kwstarArgIndex = -1;
for (int i = 0; i < func->ast->args.size(); i++) {
if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "**"))
// if (!known.empty() && known[i] && !partial)
// continue;
if (startswith(func->ast->args[i].name, "**"))
kwstarArgIndex = i, score -= 2;
else if ((known.empty() || !known[i]) && startswith(func->ast->args[i].name, "*"))
else if (startswith(func->ast->args[i].name, "*"))
starArgIndex = i, score -= 2;
}
seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex],
"partial *args");
seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex],
"partial **kwargs");
// seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex],
// "partial *args");
// seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex],
// "partial **kwargs");
// 1. Assign positional arguments to slots
// Each slot contains a list of arg's indices

View File

@ -48,6 +48,8 @@ struct TypeContext : public Context<TypecheckItem> {
/// Map of locally realized types and functions.
std::unordered_map<std::string, std::pair<TypecheckItem::Kind, types::TypePtr>>
visitedAsts;
/// List of functions that can be accessed via super()
std::vector<types::FuncTypePtr> supers;
};
std::vector<RealizationBase> bases;
@ -121,23 +123,13 @@ public:
/// Returns the list of generic methods that correspond to typeName.method.
std::vector<types::FuncTypePtr> findMethod(const std::string &typeName,
const std::string &method) const;
const std::string &method,
bool hideShadowed = true) const;
/// Returns the generic type of typeName.member, if it exists (nullptr otherwise).
/// Special cases: __elemsize__ and __atomic__.
types::TypePtr findMember(const std::string &typeName,
const std::string &member) const;
/// Picks the best method of a given expression that matches the given argument
/// types. Prefers methods whose signatures are closer to the given arguments:
/// e.g. foo(int) will match (int) better that a foo(T).
/// Also takes care of the Optional arguments.
/// If multiple equally good methods are found, return the first one.
/// Return nullptr if no methods were found.
types::FuncTypePtr
findBestMethod(const Expr *expr, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args,
bool checkSingle = false);
typedef std::function<int(int, int, const std::vector<std::vector<int>> &, bool)>
ReorderDoneFn;
typedef std::function<int(std::string)> ReorderErrorFn;

View File

@ -1,3 +1,4 @@
#include <algorithm>
#include <map>
#include <memory>
#include <string>
@ -104,6 +105,17 @@ void TypecheckVisitor::visit(IdExpr *expr) {
return;
}
auto val = ctx->find(expr->value);
if (!val) {
auto i = ctx->cache->overloads.find(expr->value);
if (i != ctx->cache->overloads.end()) {
if (i->second.size() == 1) {
val = ctx->find(i->second[0].name);
} else {
auto d = findDispatch(expr->value);
val = ctx->find(d->ast->name);
}
}
}
seqassert(val, "cannot find IdExpr '{}' ({})", expr->value, expr->getSrcInfo());
auto t = ctx->instantiate(expr, val->type);
@ -683,8 +695,8 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic,
if (isAtomic) {
auto ptrlt =
ctx->instantiateGeneric(expr->lexpr.get(), ctx->findInternal("Ptr"), {lt});
method = ctx->findBestMethod(expr->lexpr.get(), format("__atomic_{}__", magic),
{{"", ptrlt}, {"", rt}});
method =
findBestMethod(expr->lexpr.get(), format("__atomic_{}__", magic), {ptrlt, rt});
if (method) {
expr->lexpr = N<PtrExpr>(expr->lexpr);
if (noReturn)
@ -693,19 +705,16 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic,
}
// Check if lt.__iop__(lt, rt) exists.
if (!method && expr->inPlace) {
method = ctx->findBestMethod(expr->lexpr.get(), format("__i{}__", magic),
{{"", lt}, {"", rt}});
method = findBestMethod(expr->lexpr.get(), format("__i{}__", magic), {lt, rt});
if (method && noReturn)
*noReturn = true;
}
// Check if lt.__op__(lt, rt) exists.
if (!method)
method = ctx->findBestMethod(expr->lexpr.get(), format("__{}__", magic),
{{"", lt}, {"", rt}});
method = findBestMethod(expr->lexpr.get(), format("__{}__", magic), {lt, rt});
// Check if rt.__rop__(rt, lt) exists.
if (!method) {
method = ctx->findBestMethod(expr->rexpr.get(), format("__r{}__", magic),
{{"", rt}, {"", lt}});
method = findBestMethod(expr->rexpr.get(), format("__r{}__", magic), {rt, lt});
if (method)
swap(expr->lexpr, expr->rexpr);
}
@ -727,13 +736,17 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic,
ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &expr,
ExprPtr &index) {
if (!tuple->getRecord() ||
in(std::set<std::string>{"Ptr", "pyobj", "str", "Array"}, tuple->name))
if (!tuple->getRecord())
return nullptr;
if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_PARTIAL))
// in(std::set<std::string>{"Ptr", "pyobj", "str", "Array"}, tuple->name))
// Ptr, pyobj and str are internal types and have only one overloaded __getitem__
return nullptr;
if (ctx->cache->classes[tuple->name].methods["__getitem__"].size() != 1)
// TODO: be smarter! there might be a compatible getitem?
return nullptr;
// if (in(ctx->cache->classes[tuple->name].methods, "__getitem__")) {
// ctx->cache->overloads[ctx->cache->classes[tuple->name].methods["__getitem__"]]
// .size() != 1)
// return nullptr;
// }
// Extract a static integer value from a compatible expression.
auto getInt = [&](int64_t *o, const ExprPtr &e) {
@ -867,14 +880,16 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
// If it exists, return a simple IdExpr with that method's name.
// Append a "self" variable to the front if needed.
if (args) {
std::vector<std::pair<std::string, TypePtr>> argTypes;
std::vector<CallExpr::Arg> argTypes;
bool isType = expr->expr->isType();
if (!isType)
argTypes.emplace_back(make_pair("", typ)); // self variable
if (!isType) {
ExprPtr expr = N<IdExpr>("self");
expr->setType(typ);
argTypes.emplace_back(CallExpr::Arg{"", expr});
}
for (const auto &a : *args)
argTypes.emplace_back(make_pair(a.name, a.value->getType()));
if (auto bestMethod =
ctx->findBestMethod(expr->expr.get(), expr->member, argTypes)) {
argTypes.emplace_back(a);
if (auto bestMethod = findBestMethod(expr->expr.get(), expr->member, argTypes)) {
ExprPtr e = N<IdExpr>(bestMethod->ast->name);
auto t = ctx->instantiate(expr, bestMethod, typ.get());
unify(e->type, t);
@ -891,7 +906,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
// No method was found, print a nice error message.
std::vector<std::string> nice;
for (auto &t : argTypes)
nice.emplace_back(format("{} = {}", t.first, t.second->toString()));
nice.emplace_back(format("{} = {}", t.name, t.value->type->toString()));
error("cannot find a method '{}' in {} with arguments {}", expr->member,
typ->toString(), join(nice, ", "));
}
@ -901,23 +916,25 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
auto oldType = expr->getType() ? expr->getType()->getClass() : nullptr;
if (methods.size() > 1 && oldType && oldType->getFunc()) {
// If old type is already a function, use its arguments to pick the best call.
std::vector<std::pair<std::string, TypePtr>> methodArgs;
std::vector<TypePtr> methodArgs;
if (!expr->expr->isType()) // self argument
methodArgs.emplace_back(make_pair("", typ));
methodArgs.emplace_back(typ);
for (auto i = 1; i < oldType->generics.size(); i++)
methodArgs.emplace_back(make_pair("", oldType->generics[i].type));
bestMethod = ctx->findBestMethod(expr->expr.get(), expr->member, methodArgs);
methodArgs.emplace_back(oldType->generics[i].type);
bestMethod = findBestMethod(expr->expr.get(), expr->member, methodArgs);
if (!bestMethod) {
// Print a nice error message.
std::vector<std::string> nice;
for (auto &t : methodArgs)
nice.emplace_back(format("{} = {}", t.first, t.second->toString()));
nice.emplace_back(format("{}", t->toString()));
error("cannot find a method '{}' in {} with arguments {}", expr->member,
typ->toString(), join(nice, ", "));
}
} else if (methods.size() > 1) {
auto m = ctx->cache->classes.find(typ->name);
auto t = m->second.methods.find(expr->member);
bestMethod = findDispatch(t->second);
} else {
// HACK: if we still have multiple valid methods, we just use the first one.
// TODO: handle this better (maybe hold these types until they can be selected?)
bestMethod = methods[0];
}
@ -947,8 +964,8 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
if (bestMethod->ast->attributes.has(Attr::Property))
methodArgs.pop_back();
ExprPtr e = N<CallExpr>(N<IdExpr>(bestMethod->ast->name), methodArgs);
ExprPtr r = transform(e, false, allowVoidExpr);
return r;
auto ex = transform(e, false, allowVoidExpr);
return ex;
}
}
@ -1004,6 +1021,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
ai--;
} else {
// Case 3: Normal argument
// LOG("-> {}", expr->args[ai].value->toString());
expr->args[ai].value = transform(expr->args[ai].value, true);
// Unbound inType might become a generator that will need to be extracted, so
// don't unify it yet.
@ -1020,22 +1038,62 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
seenNames.insert(i.name);
}
// Intercept dot-callees (e.g. expr.foo). Needed in order to select a proper
// overload for magic methods and to avoid dealing with partial calls
// (a non-intercepted object DotExpr (e.g. expr.foo) will get transformed into a
// partial call).
ExprPtr *lhs = &expr->expr;
// Make sure to check for instantiation DotExpr (e.g. a.b[T]) as well.
if (auto ei = const_cast<IndexExpr *>(expr->expr->getIndex())) {
// A potential function instantiation
lhs = &ei->expr;
} else if (auto eii = CAST(expr->expr, InstantiateExpr)) {
// Real instantiation
lhs = &eii->typeExpr;
if (expr->expr->isId("superf")) {
if (ctx->bases.back().supers.empty())
error("no matching superf methods are available");
auto parentCls = ctx->bases.back().type->getFunc()->funcParent;
auto m =
findMatchingMethods(parentCls ? CAST(parentCls, types::ClassType) : nullptr,
ctx->bases.back().supers, expr->args);
if (m.empty())
error("no matching superf methods are available");
// LOG("found {} <- {}", ctx->bases.back().type->getFunc()->toString(),
// m[0]->toString());
ExprPtr e = N<CallExpr>(N<IdExpr>(m[0]->ast->name), expr->args);
return transform(e, false, true);
}
if (auto ed = const_cast<DotExpr *>((*lhs)->getDot())) {
if (auto edt = transformDot(ed, &expr->args))
*lhs = edt;
if (expr->expr->isId("super"))
return transformSuper(expr);
bool isPartial = !expr->args.empty() && expr->args.back().value->getEllipsis() &&
!expr->args.back().value->getEllipsis()->isPipeArg &&
expr->args.back().name.empty();
if (!isPartial) {
// Intercept dot-callees (e.g. expr.foo). Needed in order to select a proper
// overload for magic methods and to avoid dealing with partial calls
// (a non-intercepted object DotExpr (e.g. expr.foo) will get transformed into a
// partial call).
ExprPtr *lhs = &expr->expr;
// Make sure to check for instantiation DotExpr (e.g. a.b[T]) as well.
if (auto ei = const_cast<IndexExpr *>(expr->expr->getIndex())) {
// A potential function instantiation
lhs = &ei->expr;
} else if (auto eii = CAST(expr->expr, InstantiateExpr)) {
// Real instantiation
lhs = &eii->typeExpr;
}
if (auto ed = const_cast<DotExpr *>((*lhs)->getDot())) {
if (auto edt = transformDot(ed, &expr->args))
*lhs = edt;
} else if (auto ei = const_cast<IdExpr *>((*lhs)->getId())) {
// check if this is an overloaded function?
auto i = ctx->cache->overloads.find(ei->value);
if (i != ctx->cache->overloads.end() && i->second.size() != 1) {
if (auto bestMethod = findBestMethod(ei->value, expr->args)) {
ExprPtr e = N<IdExpr>(bestMethod->ast->name);
auto t = ctx->instantiate(expr, bestMethod);
unify(e->type, t);
unify(ei->type, e->type);
*lhs = e;
} else {
std::vector<std::string> nice;
for (auto &t : expr->args)
nice.emplace_back(format("{} = {}", t.name, t.value->type->toString()));
error("cannot find an overload '{}' with arguments {}", ei->value,
join(nice, ", "));
}
}
}
}
expr->expr = transform(expr->expr, true);
@ -1086,9 +1144,21 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
std::vector<CallExpr::Arg> args;
std::vector<ExprPtr> typeArgs;
int typeArgCount = 0;
bool isPartial = false;
// bool isPartial = false;
int ellipsisStage = -1;
auto newMask = std::vector<char>(calleeFn->ast->args.size(), 1);
auto getPartialArg = [&](int pi) {
auto id = transform(N<IdExpr>(partialVar));
ExprPtr it = N<IntExpr>(pi);
// Manual call to transformStaticTupleIndex needed because otherwise
// IndexExpr routes this to InstantiateExpr.
auto ex = transformStaticTupleIndex(callee.get(), id, it);
seqassert(ex, "partial indexing failed");
return ex;
};
ExprPtr partialStarArgs = nullptr;
ExprPtr partialKwstarArgs = nullptr;
if (expr->ordered)
args = expr->args;
else
@ -1096,7 +1166,6 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
calleeFn.get(), expr->args,
[&](int starArgIndex, int kwstarArgIndex,
const std::vector<std::vector<int>> &slots, bool partial) {
isPartial = partial;
ctx->addBlock(); // add generics for default arguments.
addFunctionGenerics(calleeFn->getFunc().get());
for (int si = 0, pi = 0; si < slots.size(); si++) {
@ -1105,17 +1174,38 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
: expr->args[slots[si][0]].value);
typeArgCount += typeArgs.back() != nullptr;
newMask[si] = slots[si].empty() ? 0 : 1;
} else if (si == starArgIndex && !(partial && slots[si].empty())) {
} else if (si == starArgIndex) {
std::vector<ExprPtr> extra;
if (!known.empty())
extra.push_back(N<StarExpr>(getPartialArg(-2)));
for (auto &e : slots[si]) {
extra.push_back(expr->args[e].value);
if (extra.back()->getEllipsis())
ellipsisStage = args.size();
}
args.push_back({"", transform(N<TupleExpr>(extra))});
} else if (si == kwstarArgIndex && !(partial && slots[si].empty())) {
auto e = transform(N<TupleExpr>(extra));
if (partial) {
partialStarArgs = e;
args.push_back({"", transform(N<EllipsisExpr>())});
newMask[si] = 0;
} else {
args.push_back({"", e});
}
} else if (si == kwstarArgIndex) {
std::vector<std::string> names;
std::vector<CallExpr::Arg> values;
if (!known.empty()) {
auto e = getPartialArg(-1);
auto t = e->getType()->getRecord();
seqassert(t && startswith(t->name, "KwTuple"), "{} not a kwtuple",
e->toString());
auto &ff = ctx->cache->classes[t->name].fields;
for (int i = 0; i < t->getRecord()->args.size(); i++) {
names.emplace_back(ff[i].name);
values.emplace_back(
CallExpr::Arg{"", transform(N<DotExpr>(clone(e), ff[i].name))});
}
}
for (auto &e : slots[si]) {
names.emplace_back(expr->args[e].name);
values.emplace_back(CallExpr::Arg{"", expr->args[e].value});
@ -1123,16 +1213,17 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
ellipsisStage = args.size();
}
auto kwName = generateTupleStub(names.size(), "KwTuple", names);
args.push_back({"", transform(N<CallExpr>(N<IdExpr>(kwName), values))});
auto e = transform(N<CallExpr>(N<IdExpr>(kwName), values));
if (partial) {
partialKwstarArgs = e;
args.push_back({"", transform(N<EllipsisExpr>())});
newMask[si] = 0;
} else {
args.push_back({"", e});
}
} else if (slots[si].empty()) {
if (!known.empty() && known[si]) {
// Manual call to transformStaticTupleIndex needed because otherwise
// IndexExpr routes this to InstantiateExpr.
auto id = transform(N<IdExpr>(partialVar));
ExprPtr it = N<IntExpr>(pi++);
auto ex = transformStaticTupleIndex(callee.get(), id, it);
seqassert(ex, "partial indexing failed");
args.push_back({"", ex});
args.push_back({"", getPartialArg(pi++)});
} else if (partial) {
args.push_back({"", transform(N<EllipsisExpr>())});
newMask[si] = 0;
@ -1160,6 +1251,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
if (isPartial) {
deactivateUnbounds(expr->args.back().value->getType().get());
expr->args.pop_back();
if (!partialStarArgs)
partialStarArgs = transform(N<TupleExpr>());
if (!partialKwstarArgs) {
auto kwName = generateTupleStub(0, "KwTuple", {});
partialKwstarArgs = transform(N<CallExpr>(N<IdExpr>(kwName)));
}
}
// Typecheck given arguments with the expected (signature) types.
@ -1181,8 +1278,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
// Special case: function instantiation
if (isPartial && typeArgCount && typeArgCount == expr->args.size()) {
for (auto &a : args) {
seqassert(a.value->getEllipsis(), "expected ellipsis");
deactivateUnbounds(a.value->getType().get());
if (a.value->getEllipsis())
deactivateUnbounds(a.value->getType().get());
}
auto e = transform(expr->expr);
unify(expr->type, e->getType());
@ -1252,11 +1349,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
deactivateUnbounds(pt->func.get());
calleeFn->generics[si + 1].type = calleeFn->args[si + 1] = replacements[si];
}
if (auto rt = realize(calleeFn)) {
unify(rt, std::static_pointer_cast<Type>(calleeFn));
expr->expr = transform(expr->expr);
if (!isPartial) {
if (auto rt = realize(calleeFn)) {
unify(rt, std::static_pointer_cast<Type>(calleeFn));
expr->expr = transform(expr->expr);
}
}
expr->done &= expr->expr->done;
// Emit the final call.
@ -1269,6 +1367,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
for (auto &r : args)
if (!r.value->getEllipsis())
newArgs.push_back(r.value);
newArgs.push_back(partialStarArgs);
newArgs.push_back(partialKwstarArgs);
std::string var = ctx->cache->getTemporaryVar("partial");
ExprPtr call = nullptr;
@ -1323,8 +1423,15 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
expr->args[1].value =
transformType(expr->args[1].value, /*disableActivation*/ true);
auto t = expr->args[1].value->type;
auto unifyOK = typ->unify(t.get(), nullptr) >= 0;
return {true, transform(N<BoolExpr>(unifyOK))};
auto hierarchy = getSuperTypes(typ->getClass());
for (auto &tx: hierarchy) {
auto unifyOK = tx->unify(t.get(), nullptr) >= 0;
if (unifyOK) {
return {true, transform(N<BoolExpr>(true))};
}
}
return {true, transform(N<BoolExpr>(false))};
}
}
} else if (val == "staticlen") {
@ -1355,18 +1462,17 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
if (!typ || !expr->args[1].value->staticValue.evaluated)
return {true, nullptr};
auto member = expr->args[1].value->staticValue.getString();
std::vector<std::pair<std::string, TypePtr>> args{{std::string(), typ}};
std::vector<TypePtr> args{typ};
for (int i = 2; i < expr->args.size(); i++) {
expr->args[i].value = transformType(expr->args[i].value);
if (!expr->args[i].value->getType()->getClass())
return {true, nullptr};
args.push_back({std::string(), expr->args[i].value->getType()});
args.push_back(expr->args[i].value->getType());
}
bool exists = !ctx->findMethod(typ->getClass()->name, member).empty() ||
ctx->findMember(typ->getClass()->name, member);
if (exists && args.size() > 1)
exists &=
ctx->findBestMethod(expr->args[0].value.get(), member, args, true) != nullptr;
exists &= findBestMethod(expr->args[0].value.get(), member, args) != nullptr;
return {true, transform(N<BoolExpr>(exists))};
} else if (val == "compile_error") {
expr->args[0].value = transform(expr->args[0].value);
@ -1531,7 +1637,8 @@ std::string TypecheckVisitor::generatePartialStub(const std::vector<char> &mask,
tupleSize++;
auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name);
if (!ctx->find(typeName))
generateTupleStub(tupleSize, typeName, {}, false);
// 2 for .starArgs and .kwstarArgs (empty tuples if fn does not have them)
generateTupleStub(tupleSize + 2, typeName, {}, false);
return typeName;
}
@ -1597,9 +1704,12 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
auto partialTypeName = generatePartialStub(mask, fn.get());
deactivateUnbounds(fn.get());
std::string var = ctx->cache->getTemporaryVar("partial");
ExprPtr call = N<StmtExpr>(
N<AssignStmt>(N<IdExpr>(var), N<CallExpr>(N<IdExpr>(partialTypeName))),
N<IdExpr>(var));
auto kwName = generateTupleStub(0, "KwTuple", {});
ExprPtr call =
N<StmtExpr>(N<AssignStmt>(N<IdExpr>(var),
N<CallExpr>(N<IdExpr>(partialTypeName), N<TupleExpr>(),
N<CallExpr>(N<IdExpr>(kwName)))),
N<IdExpr>(var));
call = transform(call, false, allowVoidExpr);
seqassert(call->type->getRecord() &&
startswith(call->type->getRecord()->name, partialTypeName) &&
@ -1609,8 +1719,122 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
return call;
}
types::FuncTypePtr
TypecheckVisitor::findBestMethod(const Expr *expr, const std::string &member,
const std::vector<types::TypePtr> &args) {
std::vector<CallExpr::Arg> callArgs;
for (auto &a : args) {
callArgs.push_back({"", std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a);
}
return findBestMethod(expr, member, callArgs);
}
types::FuncTypePtr
TypecheckVisitor::findBestMethod(const Expr *expr, const std::string &member,
const std::vector<CallExpr::Arg> &args) {
auto typ = expr->getType()->getClass();
seqassert(typ, "not a class");
auto methods = ctx->findMethod(typ->name, member, false);
auto m = findMatchingMethods(typ.get(), methods, args);
return m.empty() ? nullptr : m[0];
}
types::FuncTypePtr
TypecheckVisitor::findBestMethod(const std::string &fn,
const std::vector<CallExpr::Arg> &args) {
std::vector<types::FuncTypePtr> methods;
for (auto &m : ctx->cache->overloads[fn])
if (!endswith(m.name, ":dispatch"))
methods.push_back(ctx->cache->functions[m.name].type);
std::reverse(methods.begin(), methods.end());
auto m = findMatchingMethods(nullptr, methods, args);
return m.empty() ? nullptr : m[0];
}
std::vector<types::FuncTypePtr>
TypecheckVisitor::findSuperMethods(const types::FuncTypePtr &func) {
if (func->ast->attributes.parentClass.empty() ||
endswith(func->ast->name, ":dispatch"))
return {};
auto p = ctx->find(func->ast->attributes.parentClass)->type;
if (!p || !p->getClass())
return {};
auto methodName = ctx->cache->reverseIdentifierLookup[func->ast->name];
auto m = ctx->cache->classes.find(p->getClass()->name);
std::vector<types::FuncTypePtr> result;
if (m != ctx->cache->classes.end()) {
auto t = m->second.methods.find(methodName);
if (t != m->second.methods.end()) {
for (auto &m : ctx->cache->overloads[t->second]) {
if (endswith(m.name, ":dispatch"))
continue;
if (m.name == func->ast->name)
break;
result.emplace_back(ctx->cache->functions[m.name].type);
}
}
}
std::reverse(result.begin(), result.end());
return result;
}
std::vector<types::FuncTypePtr>
TypecheckVisitor::findMatchingMethods(types::ClassType *typ,
const std::vector<types::FuncTypePtr> &methods,
const std::vector<CallExpr::Arg> &args) {
// Pick the last method that accepts the given arguments.
std::vector<types::FuncTypePtr> results;
for (int mi = 0; mi < methods.size(); mi++) {
auto m = ctx->instantiate(nullptr, methods[mi], typ, false)->getFunc();
std::vector<types::TypePtr> reordered;
auto score = ctx->reorderNamedArgs(
m.get(), args,
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
for (int si = 0; si < slots.size(); si++) {
if (m->ast->args[si].generic) {
// Ignore type arguments
} else if (si == s || si == k || slots[si].size() != 1) {
// Ignore *args, *kwargs and default arguments
reordered.emplace_back(nullptr);
} else {
reordered.emplace_back(args[slots[si][0]].value->type);
}
}
return 0;
},
[](const std::string &) { return -1; });
for (int ai = 0, mi = 1, gi = 0; score != -1 && ai < reordered.size(); ai++) {
auto expectTyp =
m->ast->args[ai].generic ? m->generics[gi++].type : m->args[mi++];
auto argType = reordered[ai];
if (!argType)
continue;
try {
ExprPtr dummy = std::make_shared<IdExpr>("");
dummy->type = argType;
dummy->done = true;
wrapExpr(dummy, expectTyp, m, /*undoOnSuccess*/ true);
} catch (const exc::ParserException &) {
score = -1;
}
}
if (score != -1) {
// std::vector<std::string> ar;
// for (auto &a: args) {
// if (a.first.empty()) ar.push_back(a.second->toString());
// else ar.push_back(format("{}: {}", a.first, a.second->toString()));
// }
// LOG("- {} vs {}", m->toString(), join(ar, "; "));
results.push_back(methods[mi]);
}
}
return results;
}
bool TypecheckVisitor::wrapExpr(ExprPtr &expr, TypePtr expectedType,
const FuncTypePtr &callee) {
const FuncTypePtr &callee, bool undoOnSuccess) {
auto expectedClass = expectedType->getClass();
auto exprClass = expr->getType()->getClass();
if (callee && expr->isType())
@ -1637,7 +1861,7 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, TypePtr expectedType,
// Case 7: wrap raw Seq functions into Partial(...) call for easy realization.
expr = partializeFunction(expr);
}
unify(expr->type, expectedType);
unify(expr->type, expectedType, undoOnSuccess);
return true;
}
@ -1690,5 +1914,132 @@ int64_t TypecheckVisitor::sliceAdjustIndices(int64_t length, int64_t *start,
return 0;
}
types::FuncTypePtr TypecheckVisitor::findDispatch(const std::string &fn) {
for (auto &m : ctx->cache->overloads[fn])
if (endswith(ctx->cache->functions[m.name].ast->name, ":dispatch"))
return ctx->cache->functions[m.name].type;
// Generate dispatch and return it!
auto name = fn + ":dispatch";
ExprPtr root;
auto a = ctx->cache->functions[ctx->cache->overloads[fn][0].name].ast;
if (!a->attributes.parentClass.empty())
root = N<DotExpr>(N<IdExpr>(a->attributes.parentClass),
ctx->cache->reverseIdentifierLookup[fn]);
else
root = N<IdExpr>(fn);
root = N<CallExpr>(root, N<StarExpr>(N<IdExpr>("args")),
N<KeywordStarExpr>(N<IdExpr>("kwargs")));
auto ast = N<FunctionStmt>(
name, nullptr, std::vector<Param>{Param("*args"), Param("**kwargs")},
N<SuiteStmt>(N<IfStmt>(
N<CallExpr>(N<IdExpr>("isinstance"), root->clone(), N<IdExpr>("void")),
N<ExprStmt>(root->clone()), N<ReturnStmt>(root))),
Attr({"autogenerated"}));
ctx->cache->reverseIdentifierLookup[name] = ctx->cache->reverseIdentifierLookup[fn];
auto baseType =
ctx->instantiate(N<IdExpr>(name).get(), ctx->find(generateFunctionStub(2))->type,
nullptr, false)
->getRecord();
auto typ = std::make_shared<FuncType>(baseType, ast.get());
typ = std::static_pointer_cast<FuncType>(typ->generalize(ctx->typecheckLevel));
ctx->add(TypecheckItem::Func, name, typ);
ctx->cache->overloads[fn].insert(ctx->cache->overloads[fn].begin(), {name, 0});
ctx->cache->functions[name].ast = ast;
ctx->cache->functions[name].type = typ;
prependStmts->push_back(ast);
// LOG("dispatch: {}", ast->toString(1));
return typ;
}
ExprPtr TypecheckVisitor::transformSuper(const CallExpr *expr) {
// For now, we just support casting to the _FIRST_ overload (i.e. empty super())
if (!expr->args.empty())
error("super does not take arguments");
if (ctx->bases.empty() || !ctx->bases.back().type)
error("no parent classes available");
auto fptyp = ctx->bases.back().type->getFunc();
if (!fptyp || !fptyp->ast->hasAttr(Attr::Method))
error("no parent classes available");
if (fptyp->args.size() < 2)
error("no parent classes available");
ClassTypePtr typ = fptyp->args[1]->getClass();
auto &cands = ctx->cache->classes[typ->name].parentClasses;
if (cands.empty())
error("no parent classes available");
// if (typ->getRecord())
// error("cannot use super on tuple types");
// find parent typ
// unify top N args with parent typ args
// realize & do bitcast
// call bitcast() . method
auto name = cands[0].first;
int fields = cands[0].second;
auto val = ctx->find(name);
seqassert(val, "cannot find '{}'", name);
auto ftyp = ctx->instantiate(expr, val->type)->getClass();
if (typ->getRecord()) {
std::vector<ExprPtr> members;
for (int i = 0; i < fields; i++)
members.push_back(N<DotExpr>(N<IdExpr>(fptyp->ast->args[0].name),
ctx->cache->classes[typ->name].fields[i].name));
ExprPtr e = transform(
N<CallExpr>(N<IdExpr>(format(TYPE_TUPLE "{}", members.size())), members));
unify(e->type, ftyp);
e->type = ftyp;
return e;
} else {
for (int i = 0; i < fields; i++) {
auto t = ctx->cache->classes[typ->name].fields[i].type;
t = ctx->instantiate(expr, t, typ.get());
auto ft = ctx->cache->classes[name].fields[i].type;
ft = ctx->instantiate(expr, ft, ftyp.get());
unify(t, ft);
}
ExprPtr typExpr = N<IdExpr>(name);
typExpr->setType(ftyp);
auto self = fptyp->ast->args[0].name;
ExprPtr e = transform(
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "to_class_ptr"),
N<CallExpr>(N<DotExpr>(N<IdExpr>(self), "__raw__")), typExpr));
return e;
}
}
std::vector<ClassTypePtr> TypecheckVisitor::getSuperTypes(const ClassTypePtr &cls) {
std::vector<ClassTypePtr> result;
if (!cls)
return result;
result.push_back(cls);
int start = 0;
for (auto &cand: ctx->cache->classes[cls->name].parentClasses) {
auto name = cand.first;
int fields = cand.second;
auto val = ctx->find(name);
seqassert(val, "cannot find '{}'", name);
auto ftyp = ctx->instantiate(nullptr, val->type)->getClass();
for (int i = start; i < fields; i++) {
auto t = ctx->cache->classes[cls->name].fields[i].type;
t = ctx->instantiate(nullptr, t, cls.get());
auto ft = ctx->cache->classes[name].fields[i].type;
ft = ctx->instantiate(nullptr, ft, ftyp.get());
unify(t, ft);
}
start += fields;
for (auto &t: getSuperTypes(ftyp))
result.push_back(t);
}
return result;
}
} // namespace ast
} // namespace codon

View File

@ -154,7 +154,15 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
ctx->realizationDepth++;
ctx->addBlock();
ctx->typecheckLevel++;
ctx->bases.push_back({type->ast->name, type->getFunc(), type->args[0]});
// Find parents!
ctx->bases.push_back({type->ast->name, type->getFunc(), type->args[0],
{}, findSuperMethods(type->getFunc())});
// if (startswith(type->ast->name, "Foo")) {
// LOG(": {}", type->toString());
// for (auto &s: ctx->bases.back().supers)
// LOG(" - {}", s->toString());
// }
auto clonedAst = ctx->cache->functions[type->ast->name].ast->clone();
auto *ast = (FunctionStmt *)clonedAst.get();
addFunctionGenerics(type);

View File

@ -153,9 +153,9 @@ void TypecheckVisitor::visit(UpdateStmt *stmt) {
ctx->instantiateGeneric(stmt->lhs.get(), ctx->findInternal("Ptr"), {lhsClass});
c->args[1].value = transform(c->args[1].value);
auto rhsTyp = c->args[1].value->getType()->getClass();
if (auto method = ctx->findBestMethod(
stmt->lhs.get(), format("__atomic_{}__", c->expr->getId()->value),
{{"", ptrTyp}, {"", rhsTyp}})) {
if (auto method = findBestMethod(stmt->lhs.get(),
format("__atomic_{}__", c->expr->getId()->value),
{ptrTyp, rhsTyp})) {
resultStmt = transform(N<ExprStmt>(N<CallExpr>(
N<IdExpr>(method->ast->name), N<PtrExpr>(stmt->lhs), c->args[1].value)));
return;
@ -168,8 +168,8 @@ void TypecheckVisitor::visit(UpdateStmt *stmt) {
if (stmt->isAtomic && lhsClass && rhsClass) {
auto ptrType =
ctx->instantiateGeneric(stmt->lhs.get(), ctx->findInternal("Ptr"), {lhsClass});
if (auto m = ctx->findBestMethod(stmt->lhs.get(), "__atomic_xchg__",
{{"", ptrType}, {"", rhsClass}})) {
if (auto m =
findBestMethod(stmt->lhs.get(), "__atomic_xchg__", {ptrType, rhsClass})) {
resultStmt = transform(N<ExprStmt>(
N<CallExpr>(N<IdExpr>(m->ast->name), N<PtrExpr>(stmt->lhs), stmt->rhs)));
return;
@ -474,12 +474,12 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
typ = std::static_pointer_cast<FuncType>(typ->generalize(ctx->typecheckLevel));
// Check if this is a class method; if so, update the class method lookup table.
if (isClassMember) {
auto &methods = ctx->cache->classes[attr.parentClass]
.methods[ctx->cache->reverseIdentifierLookup[stmt->name]];
auto m = ctx->cache->classes[attr.parentClass]
.methods[ctx->cache->reverseIdentifierLookup[stmt->name]];
bool found = false;
for (auto &i : methods)
for (auto &i : ctx->cache->overloads[m])
if (i.name == stmt->name) {
i.type = typ;
ctx->cache->functions[i.name].type = typ;
found = true;
break;
}
@ -570,5 +570,11 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
stmt->done = true;
}
std::string TypecheckVisitor::getRootName(const std::string &name) {
auto p = name.rfind(':');
seqassert(p != std::string::npos, ": not found in {}", name);
return name.substr(0, p);
}
} // namespace ast
} // namespace codon

View File

@ -22,12 +22,12 @@ translateGenerics(std::vector<types::Generic> &generics) {
return ret;
}
std::vector<std::pair<std::string, codon::ast::types::TypePtr>>
std::vector<codon::ast::types::TypePtr>
generateDummyNames(std::vector<types::Type *> &types) {
std::vector<std::pair<std::string, codon::ast::types::TypePtr>> ret;
std::vector<codon::ast::types::TypePtr> ret;
for (auto *t : types) {
seqassert(t->getAstType(), "{} must have an ast type", *t);
ret.emplace_back("", t->getAstType());
ret.emplace_back(t->getAstType());
}
return ret;
}

View File

@ -339,10 +339,12 @@ class Counter[T](Dict[T,int]):
result |= other
return result
@extend
class Dict:
def __init__(self: Dict[K,int], other: Counter[K]):
self._init_from(other)
def namedtuple(): # internal
pass

View File

@ -125,6 +125,24 @@ class __internal__:
def opt_ref_invert[T](what: Optional[T]) -> T:
ret i8* %what
@pure
@llvm
def to_class_ptr[T](ptr: Ptr[byte]) -> T:
%0 = bitcast i8* %ptr to {=T}
ret {=T} %0
@pure
def _tuple_offsetof(x, field: Static[int]):
@llvm
def _llvm_offsetof(T: type, idx: Static[int], TE: type) -> int:
%a = alloca {=T}
%b = getelementptr inbounds {=T}, {=T}* %a, i64 0, i32 {=idx}
%base = ptrtoint {=T}* %a to i64
%elem = ptrtoint {=TE}* %b to i64
%offset = sub i64 %elem, %base
ret i64 %offset
return _llvm_offsetof(type(x), field, type(x[field]))
def raw_type_str(p: Ptr[byte], name: str) -> str:
pstr = p.__repr__()
# '<[name] at [pstr]>'

View File

@ -1,14 +1,14 @@
from algorithms.timsort import tim_sort_inplace
from algorithms.pdqsort import pdq_sort_inplace
from algorithms.insertionsort import insertion_sort_inplace
from algorithms.heapsort import heap_sort_inplace
from algorithms.qsort import qsort_inplace
def sorted[T](
def sorted(
v: Generator[T],
key = Optional[int](),
algorithm: Optional[str] = None,
reverse: bool = False
reverse: bool = False,
T: type
):
"""
Return a sorted list of the elements in v
@ -27,8 +27,6 @@ def _sort_list(self, key, algorithm: str):
insertion_sort_inplace(self, key)
elif algorithm == 'heap':
heap_sort_inplace(self, key)
#case 'tim':
# tim_sort_inplace(self, key)
elif algorithm == 'quick':
qsort_inplace(self, key)
else:

View File

@ -1,15 +1,19 @@
from internal.gc import sizeof
@extend
class Array:
def __new__(ptr: Ptr[T], sz: int) -> Array[T]:
return (sz, ptr)
def __new__(sz: int) -> Array[T]:
return (sz, Ptr[T](sz))
def __copy__(self) -> Array[T]:
p = Ptr[T](self.len)
str.memcpy(p.as_byte(), self.ptr.as_byte(), self.len * sizeof(T))
return (self.len, p)
def __deepcopy__(self) -> Array[T]:
p = Ptr[T](self.len)
i = 0
@ -17,14 +21,21 @@ class Array:
p[i] = self.ptr[i].__deepcopy__()
i += 1
return (self.len, p)
def __len__(self) -> int:
return self.len
def __bool__(self) -> bool:
return bool(self.len)
def __getitem__(self, index: int) -> T:
return self.ptr[index]
def __setitem__(self, index: int, what: T):
self.ptr[index] = what
def slice(self, s: int, e: int) -> Array[T]:
return (e - s, self.ptr + s)
array = Array

View File

@ -4,7 +4,7 @@ from internal.attributes import commutative, associative
class bool:
def __new__() -> bool:
return False
def __new__[T](what: T) -> bool: # lowest priority!
def __new__(what) -> bool:
return what.__bool__()
def __repr__(self) -> str:
return "True" if self else "False"

View File

@ -2,9 +2,9 @@ import internal.gc as gc
@extend
class List:
def __init__(self, arr: Array[T], len: int):
self.arr = arr
self.len = len
def __init__(self):
self.arr = Array[T](10)
self.len = 0
def __init__(self, it: Generator[T]):
self.arr = Array[T](10)
@ -12,27 +12,27 @@ class List:
for i in it:
self.append(i)
def __init__(self, capacity: int):
self.arr = Array[T](capacity)
self.len = 0
def __init__(self):
self.arr = Array[T](10)
self.len = 0
def __init__(self, other: List[T]):
self.arr = Array[T](other.len)
self.len = 0
for i in other:
self.append(i)
# Dummy __init__ used for list comprehension optimization
def __init__(self, capacity: int):
self.arr = Array[T](capacity)
self.len = 0
def __init__(self, dummy: bool, other):
"""Dummy __init__ used for list comprehension optimization"""
if hasattr(other, '__len__'):
self.__init__(other.__len__())
else:
self.__init__()
def __init__(self, arr: Array[T], len: int):
self.arr = arr
self.len = len
def __len__(self):
return self.len

View File

@ -3,21 +3,15 @@ class complex:
real: float
imag: float
def __new__():
return complex(0.0, 0.0)
def __new__(real: int, imag: int):
return complex(float(real), float(imag))
def __new__(real: float, imag: int):
return complex(real, float(imag))
def __new__(real: int, imag: float):
return complex(float(real), imag)
def __new__() -> complex:
return (0.0, 0.0)
def __new__(other):
return other.__complex__()
def __new__(real, imag) -> complex:
return (float(real), float(imag))
def __complex__(self):
return self
@ -42,6 +36,42 @@ class complex:
def __hash__(self):
return self.real.__hash__() + self.imag.__hash__()*1000003
def __add__(self, other):
return self + complex(other)
def __sub__(self, other):
return self - complex(other)
def __mul__(self, other):
return self * complex(other)
def __truediv__(self, other):
return self / complex(other)
def __eq__(self, other):
return self == complex(other)
def __ne__(self, other):
return self != complex(other)
def __pow__(self, other):
return self ** complex(other)
def __radd__(self, other):
return complex(other) + self
def __rsub__(self, other):
return complex(other) - self
def __rmul__(self, other):
return complex(other) * self
def __rtruediv__(self, other):
return complex(other) / self
def __rpow__(self, other):
return complex(other) ** self
def __add__(self, other: complex):
return complex(self.real + other.real, self.imag + other.imag)
@ -160,42 +190,6 @@ class complex:
phase += other.imag * log(vabs)
return complex(len * cos(phase), len * sin(phase))
def __add__(self, other):
return self + complex(other)
def __sub__(self, other):
return self - complex(other)
def __mul__(self, other):
return self * complex(other)
def __truediv__(self, other):
return self / complex(other)
def __eq__(self, other):
return self == complex(other)
def __ne__(self, other):
return self != complex(other)
def __pow__(self, other):
return self ** complex(other)
def __radd__(self, other):
return complex(other) + self
def __rsub__(self, other):
return complex(other) - self
def __rmul__(self, other):
return complex(other) * self
def __rtruediv__(self, other):
return complex(other) / self
def __rpow__(self, other):
return complex(other) ** self
def __repr__(self):
@pure
@llvm

View File

@ -10,80 +10,105 @@ def seq_str_float(a: float) -> str: pass
class float:
def __new__() -> float:
return 0.0
def __new__[T](what: T):
def __new__(what):
return what.__float__()
def __new__(s: str) -> float:
from C import strtod(cobj, Ptr[cobj]) -> float
buf = __array__[byte](32)
n = s.__len__()
need_dyn_alloc = (n >= buf.__len__())
p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
str.memcpy(p, s.ptr, n)
p[n] = byte(0)
end = cobj()
result = strtod(p, __ptr__(end))
if need_dyn_alloc:
free(p)
if end != p + n:
raise ValueError("could not convert string to float: " + s)
return result
def __repr__(self) -> str:
s = seq_str_float(self)
return s if s != "-nan" else "nan"
def __copy__(self) -> float:
return self
def __deepcopy__(self) -> float:
return self
@pure
@llvm
def __int__(self) -> int:
%0 = fptosi double %self to i64
ret i64 %0
def __float__(self):
return self
@pure
@llvm
def __bool__(self) -> bool:
%0 = fcmp one double %self, 0.000000e+00
%1 = zext i1 %0 to i8
ret i8 %1
def __complex__(self):
return complex(self, 0.0)
def __pos__(self) -> float:
return self
@pure
@llvm
def __neg__(self) -> float:
%0 = fneg double %self
ret double %0
@pure
@commutative
@llvm
def __add__(a: float, b: float) -> float:
%tmp = fadd double %a, %b
ret double %tmp
@commutative
def __add__(self, other: int) -> float:
return self.__add__(float(other))
@pure
@llvm
def __sub__(a: float, b: float) -> float:
%tmp = fsub double %a, %b
ret double %tmp
def __sub__(self, other: int) -> float:
return self.__sub__(float(other))
@pure
@commutative
@llvm
def __mul__(a: float, b: float) -> float:
%tmp = fmul double %a, %b
ret double %tmp
@commutative
def __mul__(self, other: int) -> float:
return self.__mul__(float(other))
def __floordiv__(self, other: float) -> float:
return self.__truediv__(other).__floor__()
def __floordiv__(self, other: int) -> float:
return self.__floordiv__(float(other))
@pure
@llvm
def __truediv__(a: float, b: float) -> float:
%tmp = fdiv double %a, %b
ret double %tmp
def __truediv__(self, other: int) -> float:
return self.__truediv__(float(other))
@pure
@llvm
def __mod__(a: float, b: float) -> float:
%tmp = frem double %a, %b
ret double %tmp
def __mod__(self, other: int) -> float:
return self.__mod__(float(other))
def __divmod__(self, other: float):
mod = self % other
div = (self - mod) / other
@ -103,16 +128,14 @@ class float:
floordiv = (0.0).copysign(self / other)
return (floordiv, mod)
def __divmod__(self, other: int):
return self.__divmod__(float(other))
@pure
@llvm
def __eq__(a: float, b: float) -> bool:
%tmp = fcmp oeq double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __eq__(self, other: int) -> bool:
return self.__eq__(float(other))
@pure
@llvm
def __ne__(a: float, b: float) -> bool:
@ -120,174 +143,190 @@ class float:
%tmp = fcmp one double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __ne__(self, other: int) -> bool:
return self.__ne__(float(other))
@pure
@llvm
def __lt__(a: float, b: float) -> bool:
%tmp = fcmp olt double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __lt__(self, other: int) -> bool:
return self.__lt__(float(other))
@pure
@llvm
def __gt__(a: float, b: float) -> bool:
%tmp = fcmp ogt double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __gt__(self, other: int) -> bool:
return self.__gt__(float(other))
@pure
@llvm
def __le__(a: float, b: float) -> bool:
%tmp = fcmp ole double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __le__(self, other: int) -> bool:
return self.__le__(float(other))
@pure
@llvm
def __ge__(a: float, b: float) -> bool:
%tmp = fcmp oge double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __ge__(self, other: int) -> bool:
return self.__ge__(float(other))
@pure
@llvm
def sqrt(a: float) -> float:
declare double @llvm.sqrt.f64(double %a)
%tmp = call double @llvm.sqrt.f64(double %a)
ret double %tmp
@pure
@llvm
def sin(a: float) -> float:
declare double @llvm.sin.f64(double %a)
%tmp = call double @llvm.sin.f64(double %a)
ret double %tmp
@pure
@llvm
def cos(a: float) -> float:
declare double @llvm.cos.f64(double %a)
%tmp = call double @llvm.cos.f64(double %a)
ret double %tmp
@pure
@llvm
def exp(a: float) -> float:
declare double @llvm.exp.f64(double %a)
%tmp = call double @llvm.exp.f64(double %a)
ret double %tmp
@pure
@llvm
def exp2(a: float) -> float:
declare double @llvm.exp2.f64(double %a)
%tmp = call double @llvm.exp2.f64(double %a)
ret double %tmp
@pure
@llvm
def log(a: float) -> float:
declare double @llvm.log.f64(double %a)
%tmp = call double @llvm.log.f64(double %a)
ret double %tmp
@pure
@llvm
def log10(a: float) -> float:
declare double @llvm.log10.f64(double %a)
%tmp = call double @llvm.log10.f64(double %a)
ret double %tmp
@pure
@llvm
def log2(a: float) -> float:
declare double @llvm.log2.f64(double %a)
%tmp = call double @llvm.log2.f64(double %a)
ret double %tmp
@pure
@llvm
def __abs__(a: float) -> float:
declare double @llvm.fabs.f64(double %a)
%tmp = call double @llvm.fabs.f64(double %a)
ret double %tmp
@pure
@llvm
def __floor__(a: float) -> float:
declare double @llvm.floor.f64(double %a)
%tmp = call double @llvm.floor.f64(double %a)
ret double %tmp
@pure
@llvm
def __ceil__(a: float) -> float:
declare double @llvm.ceil.f64(double %a)
%tmp = call double @llvm.ceil.f64(double %a)
ret double %tmp
@pure
@llvm
def __trunc__(a: float) -> float:
declare double @llvm.trunc.f64(double %a)
%tmp = call double @llvm.trunc.f64(double %a)
ret double %tmp
@pure
@llvm
def rint(a: float) -> float:
declare double @llvm.rint.f64(double %a)
%tmp = call double @llvm.rint.f64(double %a)
ret double %tmp
@pure
@llvm
def nearbyint(a: float) -> float:
declare double @llvm.nearbyint.f64(double %a)
%tmp = call double @llvm.nearbyint.f64(double %a)
ret double %tmp
@pure
@llvm
def __round__(a: float) -> float:
declare double @llvm.round.f64(double %a)
%tmp = call double @llvm.round.f64(double %a)
ret double %tmp
@pure
@llvm
def __pow__(a: float, b: float) -> float:
declare double @llvm.pow.f64(double %a, double %b)
%tmp = call double @llvm.pow.f64(double %a, double %b)
ret double %tmp
def __pow__(self, other: int) -> float:
return self.__pow__(float(other))
@pure
@llvm
def min(a: float, b: float) -> float:
declare double @llvm.minnum.f64(double %a, double %b)
%tmp = call double @llvm.minnum.f64(double %a, double %b)
ret double %tmp
@pure
@llvm
def max(a: float, b: float) -> float:
declare double @llvm.maxnum.f64(double %a, double %b)
%tmp = call double @llvm.maxnum.f64(double %a, double %b)
ret double %tmp
@pure
@llvm
def copysign(a: float, b: float) -> float:
declare double @llvm.copysign.f64(double %a, double %b)
%tmp = call double @llvm.copysign.f64(double %a, double %b)
ret double %tmp
@pure
@llvm
def fma(a: float, b: float, c: float) -> float:
declare double @llvm.fma.f64(double %a, double %b, double %c)
%tmp = call double @llvm.fma.f64(double %a, double %b, double %c)
ret double %tmp
@llvm
def __atomic_xchg__(d: Ptr[float], b: float) -> void:
%tmp = atomicrmw xchg double* %d, double %b seq_cst
ret void
@llvm
def __atomic_add__(d: Ptr[float], b: float) -> float:
%tmp = atomicrmw fadd double* %d, double %b seq_cst
ret double %tmp
@llvm
def __atomic_sub__(d: Ptr[float], b: float) -> float:
%tmp = atomicrmw fsub double* %d, double %b seq_cst
ret double %tmp
def __hash__(self):
from C import frexp(float, Ptr[Int[32]]) -> float
@ -332,31 +371,13 @@ class float:
x = -2
return x
def __new__(s: str) -> float:
from C import strtod(cobj, Ptr[cobj]) -> float
buf = __array__[byte](32)
n = s.__len__()
need_dyn_alloc = (n >= buf.__len__())
p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
str.memcpy(p, s.ptr, n)
p[n] = byte(0)
end = cobj()
result = strtod(p, __ptr__(end))
if need_dyn_alloc:
free(p)
if end != p + n:
raise ValueError("could not convert string to float: " + s)
return result
def __match__(self, i: float):
return self == i
@property
def real(self):
return self
@property
def imag(self):
return 0.0

View File

@ -1,50 +1,72 @@
from internal.attributes import commutative, associative, distributive
from internal.types.complex import complex
@pure
@C
def seq_str_int(a: int) -> str: pass
@pure
@C
def seq_str_uint(a: int) -> str: pass
@extend
class int:
@pure
@llvm
def __new__() -> int:
ret i64 0
def __new__[T](what: T) -> int: # lowest priority!
def __new__(what) -> int:
return what.__int__()
def __new__(s: str) -> int:
return int._from_str(s, 10)
def __new__(s: str, base: int) -> int:
return int._from_str(s, base)
def __int__(self) -> int:
return self
@pure
@llvm
def __float__(self) -> float:
%tmp = sitofp i64 %self to double
ret double %tmp
def __complex__(self):
return complex(float(self), 0.0)
def __index__(self):
return self
def __repr__(self) -> str:
return seq_str_int(self)
def __copy__(self) -> int:
return self
def __deepcopy__(self) -> int:
return self
def __hash__(self) -> int:
return self
@pure
@llvm
def __bool__(self) -> bool:
%0 = icmp ne i64 %self, 0
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> int:
return self
def __neg__(self) -> int:
return 0 - self
@pure
@llvm
def __abs__(self) -> int:
@ -52,23 +74,19 @@ class int:
%1 = sub i64 0, %self
%2 = select i1 %0, i64 %self, i64 %1
ret i64 %2
@pure
@llvm
def __lshift__(self, other: int) -> int:
%0 = shl i64 %self, %other
ret i64 %0
@pure
@llvm
def __rshift__(self, other: int) -> int:
%0 = ashr i64 %self, %other
ret i64 %0
@pure
@commutative
@associative
@llvm
def __add__(self, b: int) -> int:
%tmp = add i64 %self, %b
ret i64 %tmp
@pure
@commutative
@llvm
@ -76,17 +94,36 @@ class int:
%0 = sitofp i64 %self to double
%1 = fadd double %0, %other
ret double %1
@pure
@commutative
@associative
@llvm
def __sub__(self, b: int) -> int:
%tmp = sub i64 %self, %b
def __add__(self, b: int) -> int:
%tmp = add i64 %self, %b
ret i64 %tmp
@pure
@llvm
def __sub__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fsub double %0, %other
ret double %1
@pure
@llvm
def __sub__(self, b: int) -> int:
%tmp = sub i64 %self, %b
ret i64 %tmp
@pure
@commutative
@llvm
def __mul__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fmul double %0, %other
ret double %1
@pure
@commutative
@associative
@ -95,18 +132,7 @@ class int:
def __mul__(self, b: int) -> int:
%tmp = mul i64 %self, %b
ret i64 %tmp
@pure
@commutative
@llvm
def __mul__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fmul double %0, %other
ret double %1
@pure
@llvm
def __floordiv__(self, b: int) -> int:
%tmp = sdiv i64 %self, %b
ret i64 %tmp
@pure
@llvm
def __floordiv__(self, other: float) -> float:
@ -115,6 +141,20 @@ class int:
%1 = fdiv double %0, %other
%2 = call double @llvm.floor.f64(double %1)
ret double %2
@pure
@llvm
def __floordiv__(self, b: int) -> int:
%tmp = sdiv i64 %self, %b
ret i64 %tmp
@pure
@llvm
def __truediv__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fdiv double %0, %other
ret double %1
@pure
@llvm
def __truediv__(self, other: int) -> float:
@ -122,23 +162,20 @@ class int:
%1 = sitofp i64 %other to double
%2 = fdiv double %0, %1
ret double %2
@pure
@llvm
def __truediv__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fdiv double %0, %other
ret double %1
@pure
@llvm
def __mod__(a: int, b: int) -> int:
%tmp = srem i64 %a, %b
ret i64 %tmp
@pure
@llvm
def __mod__(self, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = frem double %0, %other
ret double %1
@pure
@llvm
def __mod__(a: int, b: int) -> int:
%tmp = srem i64 %a, %b
ret i64 %tmp
def __divmod__(self, other: int):
d = self // other
m = self - d*other
@ -146,11 +183,13 @@ class int:
m += other
d -= 1
return (d, m)
@pure
@llvm
def __invert__(a: int) -> int:
%tmp = xor i64 %a, -1
ret i64 %tmp
@pure
@commutative
@associative
@ -158,6 +197,7 @@ class int:
def __and__(a: int, b: int) -> int:
%tmp = and i64 %a, %b
ret i64 %tmp
@pure
@commutative
@associative
@ -165,6 +205,7 @@ class int:
def __or__(a: int, b: int) -> int:
%tmp = or i64 %a, %b
ret i64 %tmp
@pure
@commutative
@associative
@ -172,42 +213,42 @@ class int:
def __xor__(a: int, b: int) -> int:
%tmp = xor i64 %a, %b
ret i64 %tmp
@pure
@llvm
def __bitreverse__(a: int) -> int:
declare i64 @llvm.bitreverse.i64(i64 %a)
%tmp = call i64 @llvm.bitreverse.i64(i64 %a)
ret i64 %tmp
@pure
@llvm
def __bswap__(a: int) -> int:
declare i64 @llvm.bswap.i64(i64 %a)
%tmp = call i64 @llvm.bswap.i64(i64 %a)
ret i64 %tmp
@pure
@llvm
def __ctpop__(a: int) -> int:
declare i64 @llvm.ctpop.i64(i64 %a)
%tmp = call i64 @llvm.ctpop.i64(i64 %a)
ret i64 %tmp
@pure
@llvm
def __ctlz__(a: int) -> int:
declare i64 @llvm.ctlz.i64(i64 %a, i1 %is_zero_undef)
%tmp = call i64 @llvm.ctlz.i64(i64 %a, i1 false)
ret i64 %tmp
@pure
@llvm
def __cttz__(a: int) -> int:
declare i64 @llvm.cttz.i64(i64 %a, i1 %is_zero_undef)
%tmp = call i64 @llvm.cttz.i64(i64 %a, i1 false)
ret i64 %tmp
@pure
@llvm
def __eq__(a: int, b: int) -> bool:
%tmp = icmp eq i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __eq__(self, b: float) -> bool:
@ -215,12 +256,14 @@ class int:
%1 = fcmp oeq double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __ne__(a: int, b: int) -> bool:
%tmp = icmp ne i64 %a, %b
def __eq__(a: int, b: int) -> bool:
%tmp = icmp eq i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ne__(self, b: float) -> bool:
@ -228,12 +271,14 @@ class int:
%1 = fcmp one double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __lt__(a: int, b: int) -> bool:
%tmp = icmp slt i64 %a, %b
def __ne__(a: int, b: int) -> bool:
%tmp = icmp ne i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __lt__(self, b: float) -> bool:
@ -241,12 +286,14 @@ class int:
%1 = fcmp olt double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __gt__(a: int, b: int) -> bool:
%tmp = icmp sgt i64 %a, %b
def __lt__(a: int, b: int) -> bool:
%tmp = icmp slt i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __gt__(self, b: float) -> bool:
@ -254,12 +301,14 @@ class int:
%1 = fcmp ogt double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __le__(a: int, b: int) -> bool:
%tmp = icmp sle i64 %a, %b
def __gt__(a: int, b: int) -> bool:
%tmp = icmp sgt i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __le__(self, b: float) -> bool:
@ -267,12 +316,14 @@ class int:
%1 = fcmp ole double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@pure
@llvm
def __ge__(a: int, b: int) -> bool:
%tmp = icmp sge i64 %a, %b
def __le__(a: int, b: int) -> bool:
%tmp = icmp sle i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ge__(self, b: float) -> bool:
@ -280,10 +331,17 @@ class int:
%1 = fcmp oge double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
def __new__(s: str) -> int:
return int._from_str(s, 10)
def __new__(s: str, base: int) -> int:
return int._from_str(s, base)
@pure
@llvm
def __ge__(a: int, b: int) -> bool:
%tmp = icmp sge i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __pow__(self, exp: float):
return float(self) ** exp
def __pow__(self, exp: int):
if exp < 0:
return 0
@ -296,53 +354,65 @@ class int:
break
self *= self
return result
def __pow__(self, exp: float):
return float(self) ** exp
def popcnt(self):
return Int[64](self).popcnt()
@llvm
def __atomic_xchg__(d: Ptr[int], b: int) -> void:
%tmp = atomicrmw xchg i64* %d, i64 %b seq_cst
ret void
@llvm
def __atomic_add__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw add i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_sub__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw sub i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_and__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw and i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_nand__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw nand i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_or__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw or i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def _atomic_xor(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw xor i64* %d, i64 %b seq_cst
ret i64 %tmp
def __atomic_xor__(self, b: int) -> int:
return int._atomic_xor(__ptr__(self), b)
@llvm
def __atomic_min__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw min i64* %d, i64 %b seq_cst
ret i64 %tmp
@llvm
def __atomic_max__(d: Ptr[int], b: int) -> int:
%tmp = atomicrmw max i64* %d, i64 %b seq_cst
ret i64 %tmp
def __match__(self, i: int):
return self == i
@property
def real(self):
return self
@property
def imag(self):
return 0

View File

@ -1,18 +1,22 @@
from internal.attributes import commutative, associative, distributive
def check_N(N: Static[int]):
if N <= 0:
compile_error("N must be greater than 0")
pass
@extend
class Int:
def __new__() -> Int[N]:
check_N(N)
return Int[N](0)
def __new__(what: Int[N]) -> Int[N]:
check_N(N)
return what
def __new__(what: int) -> Int[N]:
check_N(N)
if N < 64:
@ -21,10 +25,12 @@ class Int:
return what
else:
return __internal__.int_sext(what, 64, N)
@pure
@llvm
def __new__(what: UInt[N]) -> Int[N]:
ret i{=N} %what
def __new__(what: str) -> Int[N]:
check_N(N)
ret = Int[N]()
@ -39,6 +45,7 @@ class Int:
ret = ret * Int[N](10) + Int[N](int(what.ptr[i]) - 48)
i += 1
return sign * ret
def __int__(self) -> int:
if N > 64:
return __internal__.int_trunc(self, N, 64)
@ -46,37 +53,47 @@ class Int:
return self
else:
return __internal__.int_sext(self, N, 64)
def __index__(self):
return int(self)
def __copy__(self) -> Int[N]:
return self
def __deepcopy__(self) -> Int[N]:
return self
def __hash__(self) -> int:
return int(self)
@pure
@llvm
def __float__(self) -> float:
%0 = sitofp i{=N} %self to double
ret double %0
@pure
@llvm
def __bool__(self) -> bool:
%0 = icmp ne i{=N} %self, 0
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> Int[N]:
return self
@pure
@llvm
def __neg__(self) -> Int[N]:
%0 = sub i{=N} 0, %self
ret i{=N} %0
@pure
@llvm
def __invert__(self) -> Int[N]:
%0 = xor i{=N} %self, -1
ret i{=N} %0
@pure
@commutative
@associative
@ -84,11 +101,13 @@ class Int:
def __add__(self, other: Int[N]) -> Int[N]:
%0 = add i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __sub__(self, other: Int[N]) -> Int[N]:
%0 = sub i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -97,11 +116,13 @@ class Int:
def __mul__(self, other: Int[N]) -> Int[N]:
%0 = mul i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __floordiv__(self, other: Int[N]) -> Int[N]:
%0 = sdiv i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __truediv__(self, other: Int[N]) -> float:
@ -109,11 +130,13 @@ class Int:
%1 = sitofp i{=N} %other to double
%2 = fdiv double %0, %1
ret double %2
@pure
@llvm
def __mod__(self, other: Int[N]) -> Int[N]:
%0 = srem i{=N} %self, %other
ret i{=N} %0
def __divmod__(self, other: Int[N]):
d = self // other
m = self - d*other
@ -121,52 +144,61 @@ class Int:
m += other
d -= Int[N](1)
return (d, m)
@pure
@llvm
def __lshift__(self, other: Int[N]) -> Int[N]:
%0 = shl i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __rshift__(self, other: Int[N]) -> Int[N]:
%0 = ashr i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __eq__(self, other: Int[N]) -> bool:
%0 = icmp eq i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ne__(self, other: Int[N]) -> bool:
%0 = icmp ne i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __lt__(self, other: Int[N]) -> bool:
%0 = icmp slt i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __gt__(self, other: Int[N]) -> bool:
%0 = icmp sgt i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __le__(self, other: Int[N]) -> bool:
%0 = icmp sle i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ge__(self, other: Int[N]) -> bool:
%0 = icmp sge i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@commutative
@associative
@ -174,6 +206,7 @@ class Int:
def __and__(self, other: Int[N]) -> Int[N]:
%0 = and i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -181,6 +214,7 @@ class Int:
def __or__(self, other: Int[N]) -> Int[N]:
%0 = or i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -188,6 +222,7 @@ class Int:
def __xor__(self, other: Int[N]) -> Int[N]:
%0 = xor i{=N} %self, %other
ret i{=N} %0
@llvm
def __pickle__(self, dest: Ptr[byte]) -> void:
declare i32 @gzwrite(i8*, i8*, i32)
@ -198,6 +233,7 @@ class Int:
%szi = ptrtoint i{=N}* %sz to i32
%2 = call i32 @gzwrite(i8* %dest, i8* %1, i32 %szi)
ret void
@llvm
def __unpickle__(src: Ptr[byte]) -> Int[N]:
declare i32 @gzread(i8*, i8*, i32)
@ -208,29 +244,37 @@ class Int:
%2 = call i32 @gzread(i8* %src, i8* %1, i32 %szi)
%3 = load i{=N}, i{=N}* %0
ret i{=N} %3
def __repr__(self) -> str:
return str.cat(('Int[', seq_str_int(N), '](', seq_str_int(int(self)), ')'))
def __str__(self) -> str:
return seq_str_int(int(self))
@pure
@llvm
def _popcnt(self) -> Int[N]:
declare i{=N} @llvm.ctpop.i{=N}(i{=N})
%0 = call i{=N} @llvm.ctpop.i{=N}(i{=N} %self)
ret i{=N} %0
def popcnt(self):
return int(self._popcnt())
def len() -> int:
return N
@extend
class UInt:
def __new__() -> UInt[N]:
check_N(N)
return UInt[N](0)
def __new__(what: UInt[N]) -> UInt[N]:
check_N(N)
return what
def __new__(what: int) -> UInt[N]:
check_N(N)
if N < 64:
@ -239,13 +283,16 @@ class UInt:
return UInt[N](Int[N](what))
else:
return UInt[N](__internal__.int_zext(what, 64, N))
@pure
@llvm
def __new__(what: Int[N]) -> UInt[N]:
ret i{=N} %what
def __new__(what: str) -> UInt[N]:
check_N(N)
return UInt[N](Int[N](what))
def __int__(self) -> int:
if N > 64:
return __internal__.int_trunc(self, N, 64)
@ -253,37 +300,47 @@ class UInt:
return Int[64](self)
else:
return __internal__.int_zext(self, N, 64)
def __index__(self):
return int(self)
def __copy__(self) -> UInt[N]:
return self
def __deepcopy__(self) -> UInt[N]:
return self
def __hash__(self) -> int:
return int(self)
@pure
@llvm
def __float__(self) -> float:
%0 = uitofp i{=N} %self to double
ret double %0
@pure
@llvm
def __bool__(self) -> bool:
%0 = icmp ne i{=N} %self, 0
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> UInt[N]:
return self
@pure
@llvm
def __neg__(self) -> UInt[N]:
%0 = sub i{=N} 0, %self
ret i{=N} %0
@pure
@llvm
def __invert__(self) -> UInt[N]:
%0 = xor i{=N} %self, -1
ret i{=N} %0
@pure
@commutative
@associative
@ -291,11 +348,13 @@ class UInt:
def __add__(self, other: UInt[N]) -> UInt[N]:
%0 = add i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __sub__(self, other: UInt[N]) -> UInt[N]:
%0 = sub i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -304,11 +363,13 @@ class UInt:
def __mul__(self, other: UInt[N]) -> UInt[N]:
%0 = mul i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __floordiv__(self, other: UInt[N]) -> UInt[N]:
%0 = udiv i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __truediv__(self, other: UInt[N]) -> float:
@ -316,59 +377,70 @@ class UInt:
%1 = uitofp i{=N} %other to double
%2 = fdiv double %0, %1
ret double %2
@pure
@llvm
def __mod__(self, other: UInt[N]) -> UInt[N]:
%0 = urem i{=N} %self, %other
ret i{=N} %0
def __divmod__(self, other: UInt[N]):
return (self // other, self % other)
@pure
@llvm
def __lshift__(self, other: UInt[N]) -> UInt[N]:
%0 = shl i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __rshift__(self, other: UInt[N]) -> UInt[N]:
%0 = lshr i{=N} %self, %other
ret i{=N} %0
@pure
@llvm
def __eq__(self, other: UInt[N]) -> bool:
%0 = icmp eq i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ne__(self, other: UInt[N]) -> bool:
%0 = icmp ne i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __lt__(self, other: UInt[N]) -> bool:
%0 = icmp ult i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __gt__(self, other: UInt[N]) -> bool:
%0 = icmp ugt i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __le__(self, other: UInt[N]) -> bool:
%0 = icmp ule i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ge__(self, other: UInt[N]) -> bool:
%0 = icmp uge i{=N} %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@commutative
@associative
@ -376,6 +448,7 @@ class UInt:
def __and__(self, other: UInt[N]) -> UInt[N]:
%0 = and i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -383,6 +456,7 @@ class UInt:
def __or__(self, other: UInt[N]) -> UInt[N]:
%0 = or i{=N} %self, %other
ret i{=N} %0
@pure
@commutative
@associative
@ -390,6 +464,7 @@ class UInt:
def __xor__(self, other: UInt[N]) -> UInt[N]:
%0 = xor i{=N} %self, %other
ret i{=N} %0
@llvm
def __pickle__(self, dest: Ptr[byte]) -> void:
declare i32 @gzwrite(i8*, i8*, i32)
@ -400,6 +475,7 @@ class UInt:
%szi = ptrtoint i{=N}* %sz to i32
%2 = call i32 @gzwrite(i8* %dest, i8* %1, i32 %szi)
ret void
@llvm
def __unpickle__(src: Ptr[byte]) -> UInt[N]:
declare i32 @gzread(i8*, i8*, i32)
@ -410,12 +486,16 @@ class UInt:
%2 = call i32 @gzread(i8* %src, i8* %1, i32 %szi)
%3 = load i{=N}, i{=N}* %0
ret i{=N} %3
def __repr__(self) -> str:
return str.cat(('UInt[', seq_str_int(N), '](', seq_str_uint(int(self)), ')'))
def __str__(self) -> str:
return seq_str_uint(int(self))
def popcnt(self):
return int(Int[N](self)._popcnt())
def len() -> int:
return N

View File

@ -5,29 +5,36 @@ class Optional:
return __internal__.opt_tuple_new(T)
else:
return __internal__.opt_ref_new(T)
def __new__(what: T) -> Optional[T]:
if isinstance(T, ByVal):
return __internal__.opt_tuple_new_arg(what, T)
else:
return __internal__.opt_ref_new_arg(what, T)
def __bool__(self) -> bool:
if isinstance(T, ByVal):
return __internal__.opt_tuple_bool(self, T)
else:
return __internal__.opt_ref_bool(self, T)
def __invert__(self) -> T:
if isinstance(T, ByVal):
return __internal__.opt_tuple_invert(self, T)
else:
return __internal__.opt_ref_invert(self, T)
def __str__(self) -> str:
return 'None' if not self else str(~self)
def __repr__(self) -> str:
return 'None' if not self else (~self).__repr__()
def __is_optional__(self, other: Optional[T]):
if (not self) or (not other):
return (not self) and (not other)
return self.__invert__() is other.__invert__()
optional = Optional
def unwrap[T](opt: Optional[T]) -> T:

View File

@ -4,53 +4,63 @@ def seq_str_ptr(a: Ptr[byte]) -> str: pass
@extend
class Ptr:
@__internal__
def __new__(sz: int) -> Ptr[T]:
pass
@pure
@llvm
def __new__() -> Ptr[T]:
ret {=T}* null
@__internal__
def __new__(sz: int) -> Ptr[T]:
pass
@pure
@llvm
def __new__(other: Ptr[T]) -> Ptr[T]:
ret {=T}* %other
@pure
@llvm
def __new__(other: Ptr[byte]) -> Ptr[T]:
%0 = bitcast i8* %other to {=T}*
ret {=T}* %0
@pure
@llvm
def __new__(other: Ptr[T]) -> Ptr[T]:
ret {=T}* %other
@pure
@llvm
def __int__(self) -> int:
%0 = ptrtoint {=T}* %self to i64
ret i64 %0
@pure
@llvm
def __copy__(self) -> Ptr[T]:
ret {=T}* %self
@pure
@llvm
def __bool__(self) -> bool:
%0 = icmp ne {=T}* %self, null
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __getitem__(self, index: int) -> T:
%0 = getelementptr {=T}, {=T}* %self, i64 %index
%1 = load {=T}, {=T}* %0
ret {=T} %1
@llvm
def __setitem__(self, index: int, what: T) -> void:
%0 = getelementptr {=T}, {=T}* %self, i64 %index
store {=T} %what, {=T}* %0
ret void
@pure
@llvm
def __add__(self, other: int) -> Ptr[T]:
%0 = getelementptr {=T}, {=T}* %self, i64 %other
ret {=T}* %0
@pure
@llvm
def __sub__(self, other: Ptr[T]) -> int:
@ -59,90 +69,105 @@ class Ptr:
%2 = sub i64 %0, %1
%3 = sdiv exact i64 %2, ptrtoint ({=T}* getelementptr ({=T}, {=T}* null, i32 1) to i64)
ret i64 %3
@pure
@llvm
def __eq__(self, other: Ptr[T]) -> bool:
%0 = icmp eq {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ne__(self, other: Ptr[T]) -> bool:
%0 = icmp ne {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __lt__(self, other: Ptr[T]) -> bool:
%0 = icmp slt {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __gt__(self, other: Ptr[T]) -> bool:
%0 = icmp sgt {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __le__(self, other: Ptr[T]) -> bool:
%0 = icmp sle {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@pure
@llvm
def __ge__(self, other: Ptr[T]) -> bool:
%0 = icmp sge {=T}* %self, %other
%1 = zext i1 %0 to i8
ret i8 %1
@llvm
def __prefetch_r0__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 0, i32 0, i32 1)
ret void
@llvm
def __prefetch_r1__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 0, i32 1, i32 1)
ret void
@llvm
def __prefetch_r2__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 0, i32 2, i32 1)
ret void
@llvm
def __prefetch_r3__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 0, i32 3, i32 1)
ret void
@llvm
def __prefetch_w0__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 1, i32 0, i32 1)
ret void
@llvm
def __prefetch_w1__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 1, i32 1, i32 1)
ret void
@llvm
def __prefetch_w2__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 1, i32 2, i32 1)
ret void
@llvm
def __prefetch_w3__(self) -> void:
declare void @llvm.prefetch(i8* nocapture readonly, i32, i32, i32)
%0 = bitcast {=T}* %self to i8*
call void @llvm.prefetch(i8* %0, i32 1, i32 3, i32 1)
ret void
@pure
@llvm
def as_byte(self) -> Ptr[byte]:
@ -151,20 +176,25 @@ class Ptr:
def __repr__(self) -> str:
return seq_str_ptr(self.as_byte())
ptr = Ptr
Jar = Ptr[byte]
cobj = Ptr[byte]
# Forward declarations
@__internal__
@tuple
class Array[T]:
len: int
ptr: Ptr[T]
class List[T]:
arr: Array[T]
len: int
@extend
class NoneType:
def __new__() -> NoneType:

View File

@ -7,45 +7,58 @@ class str:
@__internal__
def __new__(l: int, p: Ptr[byte]) -> str:
pass
def __new__(p: Ptr[byte], l: int) -> str:
return str(l, p)
def __new__() -> str:
return str(Ptr[byte](), 0)
def __new__[T](what: T) -> str: # lowest priority!
def __new__(what) -> str:
if hasattr(what, "__str__"):
return what.__str__()
else:
return what.__repr__()
def __str__(what: str) -> str:
return what
def __len__(self) -> int:
return self.len
def __bool__(self) -> bool:
return self.len != 0
def __copy__(self) -> str:
return self
def __deepcopy__(self) -> str:
return self
def __ptrcopy__(self) -> str:
n = self.len
p = cobj(n)
str.memcpy(p, self.ptr, n)
return str(p, n)
@llvm
def memcpy(dest: Ptr[byte], src: Ptr[byte], len: int) -> void:
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile)
call void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 0, i1 false)
ret void
@llvm
def memmove(dest: Ptr[byte], src: Ptr[byte], len: int) -> void:
declare void @llvm.memmove.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile)
call void @llvm.memmove.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 0, i1 false)
ret void
@llvm
def memset(dest: Ptr[byte], val: byte, len: int) -> void:
declare void @llvm.memset.p0i8.i64(i8* %dest, i8 %val, i64 %len, i32 %align, i1 %isvolatile)
call void @llvm.memset.p0i8.i64(i8* %dest, i8 %val, i64 %len, i32 0, i1 false)
ret void
def __add__(self, other: str) -> str:
len1 = self.len
len2 = other.len
@ -54,17 +67,20 @@ class str:
str.memcpy(p, self.ptr, len1)
str.memcpy(p + len1, other.ptr, len2)
return str(p, len3)
def c_str(self):
n = self.__len__()
p = cobj(n + 1)
str.memcpy(p, self.ptr, n)
p[n] = byte(0)
return p
def from_ptr(t: cobj) -> str:
n = strlen(t)
p = Ptr[byte](n)
str.memcpy(p, t, n)
return str(p, n)
def __eq__(self, other: str):
if self.len != other.len:
return False
@ -74,10 +90,13 @@ class str:
return False
i += 1
return True
def __match__(self, other: str):
return self.__eq__(other)
def __ne__(self, other: str):
return not self.__eq__(other)
def cat(*args):
total = 0
if staticlen(args) == 1 and hasattr(args[0], "__iter__") and hasattr(args[0], "__len__"):

View File

@ -394,22 +394,10 @@ class NormalDist:
self._mu = mu
self._sigma = sigma
def __init__(self, mu: float, sigma: float):
def __init__(self, mu, sigma):
self._init(float(mu), float(sigma))
def __init__(self, mu: int, sigma: int):
self._init(float(mu), float(sigma))
def __init__(self, mu: float, sigma: int):
self._init(float(mu), float(sigma))
def __init__(self, mu: int, sigma: float):
self._init(float(mu), float(sigma))
def __init__(self, mu: float):
self._init(mu, 1.0)
def __init__(self, mu: int):
def __init__(self, mu):
self._init(float(mu), 1.0)
def __init__(self):

View File

@ -538,7 +538,7 @@ def foo() -> int:
a{=a}
foo()
#! not a type or static expression
#! while realizing foo (arguments foo)
#! while realizing foo:0 (arguments foo:0)
#%% function_llvm_err_4,barebones
a = 5
@ -558,7 +558,7 @@ print f.foo() #: F
class Foo:
def foo(self):
return 'F'
Foo.foo(1) #! cannot unify int and Foo
Foo.foo(1) #! cannot find a method 'foo' in Foo with arguments = int
#%% function_nested,barebones
def foo(v):
@ -941,12 +941,12 @@ print FooBarBaz[str]().foo() #: foo 0
print FooBarBaz[float]().bar() #: bar 0/float
print FooBarBaz[str]().baz() #: baz! foo 0 bar /str
#%% inherit_class_2,barebones
#%% inherit_class_err_2,barebones
class defdict(Dict[str,float]):
def __init__(self, d: Dict[str, float]):
self.__init__(d.items())
z = defdict()
z[1.1] #! cannot unify float and str
z[1.1] #! cannot find a method '__getitem__' in defdict with arguments = defdict, = float
#%% inherit_tuple,barebones
class Foo:
@ -982,3 +982,16 @@ class Bar:
x: float
class FooBar(Foo, Bar):
pass #! 'x' declared twice
#%% keyword_prefix,barebones
def foo(return_, pass_, yield_, break_, continue_, print_, assert_):
return_.append(1)
pass_.append(2)
yield_.append(3)
break_.append(4)
continue_.append(5)
print_.append(6)
assert_.append(7)
return return_, pass_, yield_, break_, continue_, print_, assert_
print foo([1], [1], [1], [1], [1], [1], [1])
#: ([1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7])

View File

@ -256,21 +256,22 @@ a = [5]
a.foo #! cannot find 'foo' in List[int]
#%% dot_case_6,barebones
# Did heavy changes to this testcase because
# of the automatic optional wraps/unwraps and promotions
class Foo:
def bar(self, a: int):
print 'normal', a
def bar(self, a: Optional[int]):
print 'optional', a
def bar[T](self, a: Optional[T]):
print 'optional generic', a, a.__class__
def bar(self, a):
print 'generic', a, a.__class__
def bar(self, a: Optional[float]):
print 'optional', a
def bar(self, a: int):
print 'normal', a
f = Foo()
f.bar(1) #: normal 1
f.bar(Optional(1)) #: optional 1
f.bar(Optional('s')) #: optional generic s Optional[str]
f.bar(1.1) #: optional 1.1
f.bar(Optional('s')) #: generic s Optional[str]
f.bar('hehe') #: generic hehe str
#%% dot_case_6b,barebones
class Foo:
def bar(self, a, b):
@ -305,7 +306,7 @@ class Foo:
print 'foo'
def method(self, a):
print a
Foo().clsmethod() #! too many arguments for Foo.clsmethod (expected maximum 0, got 1)
Foo().clsmethod() #! cannot find a method 'clsmethod' in Foo with arguments = Foo
#%% call,barebones
def foo(a, b, c='hi'):
@ -373,7 +374,7 @@ def foo(i, j, k):
return i + j + k
print foo(1.1, 2.2, 3.3) #: 6.6
p = foo(6, ...)
print p.__class__ #: foo[int,...,...]
print p.__class__ #: foo:0[int,...,...]
print p(2, 1) #: 9
print p(k=3, j=6) #: 15
q = p(k=1, ...)
@ -389,11 +390,11 @@ print 42 |> add_two #: 44
def moo(a, b, c=3):
print a, b, c
m = moo(b=2, ...)
print m.__class__ #: moo[...,int,...]
print m.__class__ #: moo:0[...,int,...]
m('s', 1.1) #: s 2 1.1
# #
n = m(c=2.2, ...)
print n.__class__ #: moo[...,int,float]
print n.__class__ #: moo:0[...,int,float]
n('x') #: x 2 2.2
print n('y').__class__ #: void
@ -402,11 +403,11 @@ def ff(a, b, c):
print ff(1.1, 2, True).__class__ #: Tuple[float,int,bool]
print ff(1.1, ...)(2, True).__class__ #: Tuple[float,int,bool]
y = ff(1.1, ...)(c=True, ...)
print y.__class__ #: ff[float,...,bool]
print y.__class__ #: ff:0[float,...,bool]
print ff(1.1, ...)(2, ...)(True).__class__ #: Tuple[float,int,bool]
print y('hei').__class__ #: Tuple[float,str,bool]
z = ff(1.1, ...)(c='s', ...)
print z.__class__ #: ff[float,...,str]
print z.__class__ #: ff:0[float,...,str]
#%% call_arguments_partial,barebones
def doo[R, T](a: Callable[[T], R], b: Generator[T], c: Optional[T], d: T):
@ -431,7 +432,7 @@ l = [1]
def adder(a, b): return a+b
doo(b=l, d=Optional(5), c=l[0], a=adder(b=4, ...))
#: int int
#: adder[.. Generator[int]
#: adder:0[ Generator[int]
#: 5
#: 1 Optional[int]
#: 5 int
@ -446,16 +447,7 @@ q = p(zh=43, ...)
q(1) #: 1 () (zh: 43)
r = q(5, 38, ...)
r() #: 5 (38) (zh: 43)
#%% call_partial_star_error,barebones
def foo(x, *args, **kwargs):
print x, args, kwargs
p = foo(...)
p(1, z=5)
q = p(zh=43, ...)
q(1)
r = q(5, 38, ...)
r(1, a=1) #! too many arguments for foo[T1,T2,T3] (expected maximum 3, got 2)
r(1, a=1) #: 5 (38, 1) (zh: 43, a: 1)
#%% call_kwargs,barebones
def kwhatever(**kwargs):
@ -503,6 +495,79 @@ foo(*(1,2)) #: (1, 2) ()
foo(3, f) #: (3, (x: 6, y: True)) ()
foo(k = 3, **f) #: () (k: 3, x: 6, y: True)
#%% call_partial_args_kwargs,barebones
def foo(*args):
print(args)
a = foo(1, 2, ...)
b = a(3, 4, ...)
c = b(5, ...)
c('zooooo')
#: (1, 2, 3, 4, 5, 'zooooo')
def fox(*args, **kwargs):
print(args, kwargs)
xa = fox(1, 2, x=5, ...)
xb = xa(3, 4, q=6, ...)
xc = xb(5, ...)
xd = xc(z=5.1, ...)
xd('zooooo', w='lele')
#: (1, 2, 3, 4, 5, 'zooooo') (x: 5, q: 6, z: 5.1, w: 'lele')
class Foo:
i: int
def __str__(self):
return f'#{self.i}'
def foo(self, a):
return f'{self}:generic'
def foo(self, a: float):
return f'{self}:float'
def foo(self, a: int):
return f'{self}:int'
f = Foo(4)
def pacman(x, f):
print f(x, '5')
print f(x, 2.1)
print f(x, 4)
pacman(f, Foo.foo)
#: #4:generic
#: #4:float
#: #4:int
def macman(f):
print f('5')
print f(2.1)
print f(4)
macman(f.foo)
#: #4:generic
#: #4:float
#: #4:int
class Fox:
i: int
def __str__(self):
return f'#{self.i}'
def foo(self, a, b):
return f'{self}:generic b={b}'
def foo(self, a: float, c):
return f'{self}:float, c={c}'
def foo(self, a: int):
return f'{self}:int'
def foo(self, a: int, z, q):
return f'{self}:int z={z} q={q}'
ff = Fox(5)
def maxman(f):
print f('5', b=1)
print f(2.1, 3)
print f(4)
print f(5, 1, q=3)
maxman(ff.foo)
#: #5:generic b=1
#: #5:float, c=3
#: #5:int
#: #5:int z=1 q=3
#%% call_static,barebones
print isinstance(1, int), isinstance(2.2, float), isinstance(3, bool)
#: True True False
@ -535,6 +600,30 @@ print hasattr(int, "__getitem__")
print hasattr([1, 2], "__getitem__", str)
#: False
#%% isinstance_inheritance,barebones
class AX[T]:
a: T
def __init__(self, a: T):
self.a = a
class Side:
def __init__(self):
pass
class BX[T,U](AX[T], Side):
b: U
def __init__(self, a: T, b: U):
super().__init__(a)
self.b = b
class CX[T,U](BX[T,U]):
c: int
def __init__(self, a: T, b: U):
super().__init__(a, b)
self.c = 1
c = CX('a', False)
print isinstance(c, CX), isinstance(c, BX), isinstance(c, AX), isinstance(c, Side)
#: True True True True
print isinstance(c, BX[str, bool]), isinstance(c, BX[str, str]), isinstance(c, AX[int])
#: True False False
#%% staticlen_err,barebones
print staticlen([1, 2]) #! List[int] is not a tuple type
@ -614,3 +703,83 @@ def foo(x: Callable[[1,2], 3]): pass #! unexpected static type
#%% static_unify_2,barebones
def foo(x: List[1]): pass #! cannot unify T and 1
#%% super,barebones
class A[T]:
a: T
def __init__(self, t: T):
self.a = t
def foo(self):
return f'A:{self.a}'
class B(A[str]):
b: int
def __init__(self):
super().__init__('s')
self.b = 6
def baz(self):
return f'{super().foo()}::{self.b}'
b = B()
print b.foo() #: A:s
print b.baz() #: A:s::6
class AX[T]:
a: T
def __init__(self, a: T):
self.a = a
def foo(self):
return f'[AX:{self.a}]'
class BX[T,U](AX[T]):
b: U
def __init__(self, a: T, b: U):
print super().__class__
super().__init__(a)
self.b = b
def foo(self):
return f'[BX:{super().foo()}:{self.b}]'
class CX[T,U](BX[T,U]):
c: int
def __init__(self, a: T, b: U):
print super().__class__
super().__init__(a, b)
self.c = 1
def foo(self):
return f'CX:{super().foo()}:{self.c}'
c = CX('a', False)
print c.__class__, c.foo()
#: BX[str,bool]
#: AX[str]
#: CX[str,bool] CX:[BX:[AX:a]:False]:1
#%% super_tuple,barebones
@tuple
class A[T]:
a: T
x: int
def __new__(a: T) -> A[T]:
return (a, 1)
def foo(self):
return f'A:{self.a}'
@tuple
class B(A[str]):
b: int
def __new__() -> B:
return (*(A('s')), 6)
def baz(self):
return f'{super().foo()}::{self.b}'
b = B()
print b.foo() #: A:s
print b.baz() #: A:s::6
#%% super_error,barebones
class A:
def __init__(self):
super().__init__()
a = A()
#! no parent classes available
#! while realizing A.__init__:1 (arguments A.__init__:1[A])
#%% super_error_2,barebones
super().foo(1) #! no parent classes available

View File

@ -252,7 +252,7 @@ try:
except MyError:
print "my"
except OSError as o:
print "os", o._hdr[0], len(o._hdr[1]), o._hdr[3][-20:], o._hdr[4]
print "os", o._hdr.typename, len(o._hdr.msg), o._hdr.file[-20:], o._hdr.line
#: os std.internal.types.error.OSError 9 typecheck_stmt.codon 249
finally:
print "whoa" #: whoa
@ -263,7 +263,7 @@ def foo():
try:
foo()
except MyError as e:
print e._hdr[0], e._hdr[1] #: MyError foo!
print e._hdr.typename, e._hdr.msg #: MyError foo!
#%% throw_error,barebones
raise 'hello' #! cannot throw non-exception (first object member must be of type ExcHeader)
@ -291,24 +291,54 @@ def foo(x):
print len(x)
foo(5) #: 4
def foo(x):
def foo2(x):
if isinstance(x, int):
print x+1
return
print len(x)
foo(1) #: 2
foo('s') #: 1
foo2(1) #: 2
foo2('s') #: 1
def foo(x, y: Static[int] = 5):
if y < 3:
if y > 1:
if isinstance(x, int):
print x+1
return
if isinstance(x, int):
return
print len(x)
foo(1, 1)
foo(1, 2) #: 2
foo(1)
foo('s') #: 1
#%% superf,barebones
class Foo:
def foo(a):
# superf(a)
print 'foo-1', a
def foo(a: int):
superf(a)
print 'foo-2', a
def foo(a: str):
superf(a)
print 'foo-3', a
def foo(a):
superf(a)
print 'foo-4', a
Foo.foo(1)
#: foo-1 1
#: foo-2 1
#: foo-4 1
class Bear:
def woof(x):
return f'bear woof {x}'
@extend
class Bear:
def woof(x):
return superf(x) + f' bear w--f {x}'
print Bear.woof('!')
#: bear woof ! bear w--f !
class PolarBear(Bear):
def woof():
return 'polar ' + superf('@')
print PolarBear.woof()
#: polar bear woof @ bear w--f @
#%% superf_error,barebones
class Foo:
def foo(a):
superf(a)
print 'foo-1', a
Foo.foo(1)
#! no matching superf methods are available
#! while realizing Foo.foo:0

View File

@ -199,10 +199,10 @@ def f[T](x: T) -> T:
print f(1.2).__class__ #: float
print f('s').__class__ #: str
def f[T](x: T):
return f(x - 1, T) if x else 1
print f(1) #: 1
print f(1.1).__class__ #: int
def f2[T](x: T):
return f2(x - 1, T) if x else 1
print f2(1) #: 1
print f2(1.1).__class__ #: int
#%% recursive_error,barebones
@ -215,7 +215,7 @@ def rec3(x, y): #- ('a, 'b) -> 'b
return y
rec3(1, 's')
#! cannot unify str and int
#! while realizing rec3 (arguments rec3[int,str])
#! while realizing rec3:0 (arguments rec3:0[int,str])
#%% instantiate_function_2,barebones
def fx[T](x: T) -> T:
@ -298,7 +298,7 @@ print h(list(map(lambda i: i-1, map(lambda i: i+2, range(5)))))
#%% func_unify_error,barebones
def foo(x:int):
print x
z = 1 & foo #! cannot unify foo[...] and int
z = 1 & foo #! cannot find magic 'and' in int
#%% void_error,barebones
def foo():
@ -447,13 +447,13 @@ def f(x):
return g(x)
print f(5), f('s') #: 5 s
def f[U](x: U, y):
def f2[U](x: U, y):
def g[T, U](x: T, y: U):
return (x, y)
return g(y, x)
x, y = 1, 'haha'
print f(x, y).__class__ #: Tuple[str,int]
print f('aa', 1.1, U=str).__class__ #: Tuple[float,str]
print f2(x, y).__class__ #: Tuple[str,int]
print f2('aa', 1.1, U=str).__class__ #: Tuple[float,str]
#%% nested_fn_generic_error,barebones
def f[U](x: U, y): # ('u, 'a) -> tuple['a, 'u]
@ -464,7 +464,7 @@ print f(1.1, 1, int).__class__ #! cannot unify float and int
#%% fn_realization,barebones
def ff[T](x: T, y: tuple[T]):
print ff(T=str,...).__class__ #: ff[str,Tuple[str],str]
print ff(T=str,...).__class__ #: ff:0[str,Tuple[str],str]
return x
x = ff(1, (1,))
print x, x.__class__ #: 1 int
@ -474,7 +474,7 @@ def fg[T](x:T):
def g[T](y):
z = T()
return z
print fg(T=str,...).__class__ #: fg[str,str]
print fg(T=str,...).__class__ #: fg:0[str,str]
print g(1, T).__class__ #: int
fg(1)
print fg(1).__class__ #: void
@ -515,7 +515,7 @@ class A[T]:
def foo[W](t: V, u: V, v: V, w: W):
return (t, u, v, w)
print A.B.C[bool].foo(W=str, ...).__class__ #: A.B.C.foo[bool,bool,bool,str,str]
print A.B.C[bool].foo(W=str, ...).__class__ #: A.B.C.foo:0[bool,bool,bool,str,str]
print A.B.C.foo(1,1,1,True) #: (1, 1, 1, True)
print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x')
print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x')
@ -533,7 +533,8 @@ class A[T]:
c: V
def foo[W](t: V, u: V, v: V, w: W):
return (t, u, v, w)
print A.B.C[str].foo(1,1,1,True) #! cannot unify int and str
print A.B.C[str].foo(1,1,1,True) #! cannot find a method 'foo' in A.B.C[str] with arguments = int, = int, = int, = bool
#%% nested_deep_class_error_2,barebones
class A[T]:
@ -733,10 +734,10 @@ def test(name, sort, key):
def foo(l, f):
return [f(i) for i in l]
test('hi', foo, lambda x: x+1) #: hi [2, 3, 4, 5]
# TODO
# def foof(l: List[int], x, f: Callable[[int], int]):
# return [f(i)+x for i in l]
# test('qsort', foof(..., 3, ...))
def foof(l: List[int], x, f: Callable[[int], int]):
return [f(i)+x for i in l]
test('qsort', foof(x=3, ...), lambda x: x+1) #: qsort [5, 6, 7, 8]
#%% class_fn_access,barebones
class X[T]:
@ -744,8 +745,7 @@ class X[T]:
return (x+x, y+y)
y = X[X[int]]()
print y.__class__ #: X[X[int]]
print X[float].foo(U=int, ...).__class__ #: X.foo[X[float],float,int,int]
# print y.foo[float].__class__
print X[float].foo(U=int, ...).__class__ #: X.foo:0[X[float],float,int,int]
print X[int]().foo(1, 's') #: (2, 'ss')
#%% class_partial_access,barebones
@ -753,7 +753,8 @@ class X[T]:
def foo[U](self, x, y: U):
return (x+x, y+y)
y = X[X[int]]()
print y.foo(U=float,...).__class__ #: X.foo[X[X[int]],...,...]
# TODO: should this even be the case?
# print y.foo(U=float,...).__class__ -> X.foo:0[X[X[int]],...,...]
print y.foo(1, 2.2, float) #: (2, 4.4)
#%% forward,barebones
@ -764,10 +765,10 @@ def bar[T](x):
print x, T.__class__
foo(bar, 1)
#: 1 int
#: bar[...]
#: bar:0[...]
foo(bar(...), 's')
#: s str
#: bar[...]
#: bar:0[...]
z = bar
z('s', int)
#: s int
@ -785,8 +786,8 @@ def foo(f, x):
def bar[T](x):
print x, T.__class__
foo(bar(T=int,...), 1)
#! too many arguments for bar[T1,int] (expected maximum 2, got 2)
#! while realizing foo (arguments foo[bar[...],int])
#! too many arguments for bar:0[T1,int] (expected maximum 2, got 2)
#! while realizing foo:0 (arguments foo:0[bar:0[...],int])
#%% sort_partial
def foo(x, y):
@ -805,16 +806,16 @@ def frec(x, y):
return grec(x, y) if bl(y) else 2
print frec(1, 2).__class__, frec('s', 1).__class__
#! expression with void type
#! while realizing frec (arguments frec[int,int])
#! while realizing frec:0 (arguments frec:0[int,int])
#%% return_fn,barebones
def retfn(a):
def inner(b, *args, **kwargs):
print a, b, args, kwargs
print inner.__class__ #: retfn.inner[...,...,int,...]
print inner.__class__ #: retfn:0.inner:0[...,...,int,...]
return inner(15, ...)
f = retfn(1)
print f.__class__ #: retfn.inner[int,...,int,...]
print f.__class__ #: retfn:0.inner:0[int,...,int,...]
f(2,3,foo='bar') #: 1 15 (2, 3) (foo: 'bar')
#%% decorator_manual,barebones
@ -822,7 +823,7 @@ def foo(x, *args, **kwargs):
print x, args, kwargs
return 1
def dec(fn, a):
print 'decorating', fn.__class__ #: decorating foo[...,...,...]
print 'decorating', fn.__class__ #: decorating foo:0[...,...,...]
def inner(*args, **kwargs):
print 'decorator', args, kwargs #: decorator (5.5, 's') (z: True)
return fn(a, *args, **kwargs)
@ -845,7 +846,7 @@ def dec(fn, a):
return inner
ff = dec(foo, 10)
print ff(5.5, 's', z=True)
#: decorating foo[...,...,...]
#: decorating foo:0[...,...,...]
#: decorator (5.5, 's') (z: True)
#: 10 (5.5, 's') (z: True)
#: 1
@ -855,7 +856,7 @@ def zoo(e, b, *args):
return f'zoo: {e}, {b}, {args}'
print zoo(2, 3)
print zoo('s', 3)
#: decorating zoo[...,...,...]
#: decorating zoo:0[...,...,...]
#: decorator (2, 3) ()
#: zoo: 5, 2, (3)
#: decorator ('s', 3) ()
@ -868,9 +869,9 @@ def mydecorator(func):
print("after")
return inner
@mydecorator
def foo():
def foo2():
print("foo")
foo()
foo2()
#: before
#: foo
#: after
@ -890,7 +891,7 @@ def factorial(num):
return n
factorial(10)
#: 3628800
#: time needed for factorial[...] is 3628799
#: time needed for factorial:0[...] is 3628799
def dx1(func):
def inner():
@ -920,9 +921,9 @@ def dy2(func):
return inner
@dy1
@dy2
def num(a, b):
def num2(a, b):
return a+b
print(num(10, 20)) #: 3600
print(num2(10, 20)) #: 3600
#%% hetero_iter,barebones
e = (1, 2, 3, 'foo', 5, 'bar', 6)
@ -969,14 +970,14 @@ def tee(iterable, n=2):
return list(gen(d) for d in deques)
it = [1,2,3,4]
a, b = tee(it) #! cannot typecheck the program
#! while realizing tee (arguments tee[List[int],int])
#! while realizing tee:0 (arguments tee:0[List[int],int])
#%% new_syntax,barebones
def foo[T,U](x: type, y, z: Static[int] = 10):
print T.__class__, U.__class__, x.__class__, y.__class__, Int[z+1].__class__
return List[x]()
print foo(T=int,U=str,...).__class__ #: foo[T1,x,z,int,str]
print foo(T=int,U=str,z=5,x=bool,...).__class__ #: foo[T1,bool,5,int,str]
print foo(T=int,U=str,...).__class__ #: foo:0[T1,x,z,int,str]
print foo(T=int,U=str,z=5,x=bool,...).__class__ #: foo:0[T1,bool,5,int,str]
print foo(float,3,T=int,U=str,z=5).__class__ #: List[float]
foo(float,1,10,str,int) #: str int float int Int[11]
@ -992,11 +993,11 @@ print Foo[5,int,float,6].__class__ #: Foo[5,int,float,6]
print Foo(1.1, 10i32, [False], 10u66).__class__ #: Foo[66,bool,float,32]
def foo[N: Static[int]]():
def foo2[N: Static[int]]():
print Int[N].__class__, N
x: Static[int] = 5
y: Static[int] = 105 - x * 2
foo(y-x) #: Int[90] 90
foo2(y-x) #: Int[90] 90
if 1.1+2.2 > 0:
x: Static[int] = 88
@ -1107,3 +1108,27 @@ v = [1]
methodcaller('append')(v, 42)
print v #: [1, 42]
print methodcaller('index')(v, 42) #: 1
#%% fn_overloads,barebones
def foo(x):
return 1, x
def foo(x, y):
def foo(x, y):
return f'{x}_{y}'
return 2, foo(x, y)
def foo(x):
if x == '':
return 3, 0
return 3, 1 + foo(x[1:])[1]
print foo('hi') #: (3, 2)
print foo('hi', 1) #: (2, 'hi_1')
#%% fn_overloads_error,barebones
def foo(x):
return 1, x
def foo(x, y):
return 2, x, y
foo('hooooooooy!', 1, 2) #! cannot find an overload 'foo' with arguments = str, = int, = int

View File

@ -708,10 +708,10 @@ class TestDate[theclass](TestCase):
iso_long_years = sorted(map(int, ISO_LONG_YEARS_TABLE.split()))
L = []
for i in range(400):
d = self.theclass(2000+i, 12, 31)
d1 = self.theclass(1600+i, 12, 31)
self.assertEqual(d.isocalendar()[1:], d1.isocalendar()[1:])
if d.isocalendar()[1] == 53:
d = self.theclass(2000+i, 12, 31).isocalendar()
d1 = self.theclass(1600+i, 12, 31).isocalendar()
self.assertEqual((d.week, d.weekday), (d1.week, d1.weekday))
if d.week == 53:
L.append(i)
self.assertEqual(L, iso_long_years)

View File

@ -10,60 +10,69 @@ class I:
def __float__(self: int) -> float:
%tmp = sitofp i64 %self to double
ret double %tmp
@llvm
def __bool__(self: int) -> bool:
%0 = icmp ne i64 %self, 0
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self: int) -> int:
return self
def __neg__(self: int) -> int:
return I.__sub__(0, self)
@llvm
def __abs__(self: int) -> int:
%0 = icmp sgt i64 %self, 0
%1 = sub i64 0, %self
%2 = select i1 %0, i64 %self, i64 %1
ret i64 %2
@llvm
def __lshift__(self: int, other: int) -> int:
%0 = shl i64 %self, %other
ret i64 %0
@llvm
def __rshift__(self: int, other: int) -> int:
%0 = ashr i64 %self, %other
ret i64 %0
@llvm
def __add__(self: int, b: int) -> int:
%tmp = add i64 %self, %b
ret i64 %tmp
@llvm
def __add__(self: int, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fadd double %0, %other
ret double %1
@llvm
def __sub__(self: int, b: int) -> int:
%tmp = sub i64 %self, %b
def __add__(self: int, b: int) -> int:
%tmp = add i64 %self, %b
ret i64 %tmp
@llvm
def __sub__(self: int, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fsub double %0, %other
ret double %1
@llvm
def __mul__(self: int, b: int) -> int:
%tmp = mul i64 %self, %b
def __sub__(self: int, b: int) -> int:
%tmp = sub i64 %self, %b
ret i64 %tmp
@llvm
def __mul__(self: int, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fmul double %0, %other
ret double %1
@llvm
def __floordiv__(self: int, b: int) -> int:
%tmp = sdiv i64 %self, %b
def __mul__(self: int, b: int) -> int:
%tmp = mul i64 %self, %b
ret i64 %tmp
@llvm
def __floordiv__(self: int, other: float) -> float:
declare double @llvm.floor.f64(double)
@ -71,141 +80,177 @@ class I:
%1 = fdiv double %0, %other
%2 = call double @llvm.floor.f64(double %1)
ret double %2
@llvm
def __floordiv__(self: int, b: int) -> int:
%tmp = sdiv i64 %self, %b
ret i64 %tmp
@llvm
def __truediv__(self: int, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fdiv double %0, %other
ret double %1
@llvm
def __truediv__(self: int, other: int) -> float:
%0 = sitofp i64 %self to double
%1 = sitofp i64 %other to double
%2 = fdiv double %0, %1
ret double %2
@llvm
def __truediv__(self: int, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fdiv double %0, %other
ret double %1
@llvm
def __mod__(a: int, b: int) -> int:
%tmp = srem i64 %a, %b
ret i64 %tmp
@llvm
def __mod__(self: int, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = frem double %0, %other
ret double %1
@llvm
def __mod__(a: int, b: int) -> int:
%tmp = srem i64 %a, %b
ret i64 %tmp
@llvm
def __invert__(a: int) -> int:
%tmp = xor i64 %a, -1
ret i64 %tmp
@llvm
def __and__(a: int, b: int) -> int:
%tmp = and i64 %a, %b
ret i64 %tmp
@llvm
def __or__(a: int, b: int) -> int:
%tmp = or i64 %a, %b
ret i64 %tmp
@llvm
def __xor__(a: int, b: int) -> int:
%tmp = xor i64 %a, %b
ret i64 %tmp
@llvm
def __shr__(a: int, b: int) -> int:
%tmp = ashr i64 %a, %b
ret i64 %tmp
@llvm
def __shl__(a: int, b: int) -> int:
%tmp = shl i64 %a, %b
ret i64 %tmp
@llvm
def __bitreverse__(a: int) -> int:
declare i64 @llvm.bitreverse.i64(i64 %a)
%tmp = call i64 @llvm.bitreverse.i64(i64 %a)
ret i64 %tmp
@llvm
def __bswap__(a: int) -> int:
declare i64 @llvm.bswap.i64(i64 %a)
%tmp = call i64 @llvm.bswap.i64(i64 %a)
ret i64 %tmp
@llvm
def __ctpop__(a: int) -> int:
declare i64 @llvm.ctpop.i64(i64 %a)
%tmp = call i64 @llvm.ctpop.i64(i64 %a)
ret i64 %tmp
@llvm
def __ctlz__(a: int) -> int:
declare i64 @llvm.ctlz.i64(i64 %a, i1 %is_zero_undef)
%tmp = call i64 @llvm.ctlz.i64(i64 %a, i1 false)
ret i64 %tmp
@llvm
def __cttz__(a: int) -> int:
declare i64 @llvm.cttz.i64(i64 %a, i1 %is_zero_undef)
%tmp = call i64 @llvm.cttz.i64(i64 %a, i1 false)
ret i64 %tmp
@llvm
def __eq__(a: int, b: int) -> bool:
%tmp = icmp eq i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@llvm
def __eq__(self: int, b: float) -> bool:
%0 = sitofp i64 %self to double
%1 = fcmp oeq double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@llvm
def __ne__(a: int, b: int) -> bool:
%tmp = icmp ne i64 %a, %b
def __eq__(a: int, b: int) -> bool:
%tmp = icmp eq i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@llvm
def __ne__(self: int, b: float) -> bool:
%0 = sitofp i64 %self to double
%1 = fcmp one double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@llvm
def __lt__(a: int, b: int) -> bool:
%tmp = icmp slt i64 %a, %b
def __ne__(a: int, b: int) -> bool:
%tmp = icmp ne i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@llvm
def __lt__(self: int, b: float) -> bool:
%0 = sitofp i64 %self to double
%1 = fcmp olt double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@llvm
def __gt__(a: int, b: int) -> bool:
%tmp = icmp sgt i64 %a, %b
def __lt__(a: int, b: int) -> bool:
%tmp = icmp slt i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@llvm
def __gt__(self: int, b: float) -> bool:
%0 = sitofp i64 %self to double
%1 = fcmp ogt double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@llvm
def __le__(a: int, b: int) -> bool:
%tmp = icmp sle i64 %a, %b
def __gt__(a: int, b: int) -> bool:
%tmp = icmp sgt i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@llvm
def __le__(self: int, b: float) -> bool:
%0 = sitofp i64 %self to double
%1 = fcmp ole double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@llvm
def __ge__(a: int, b: int) -> bool:
%tmp = icmp sge i64 %a, %b
def __le__(a: int, b: int) -> bool:
%tmp = icmp sle i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@llvm
def __ge__(self: int, b: float) -> bool:
%0 = sitofp i64 %self to double
%1 = fcmp oge double %0, %b
%2 = zext i1 %1 to i8
ret i8 %2
@llvm
def __ge__(a: int, b: int) -> bool:
%tmp = icmp sge i64 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
def __pow__(self: int, exp: float):
return float(self) ** exp
def __pow__(self: int, exp: int):
if exp < 0:
return 0
@ -218,8 +263,6 @@ class I:
break
self *= self
return result
def __pow__(self: int, exp: float):
return float(self) ** exp
@extend
class int:
@ -227,158 +270,197 @@ class int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return self
def __float__(self) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__float__(self)
def __bool__(self) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__bool__(self)
def __pos__(self) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return self
def __neg__(self) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__neg__(self)
def __lshift__(self, other: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__lshift__(self, other)
def __rshift__(self, other: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__rshift__(self, other)
def __add__(self, b: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__add__(self, b)
def __add__(self, other: float) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__add__(self, other)
def __sub__(self, b: int) -> int:
def __add__(self, b: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__sub__(self, b)
return I.__add__(self, b)
def __sub__(self, other: float) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__sub__(self, other)
def __mul__(self, b: int) -> int:
def __sub__(self, b: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__mul__(self, b)
return I.__sub__(self, b)
def __mul__(self, other: float) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__mul__(self, other)
def __floordiv__(self, b: int) -> int:
def __mul__(self, b: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__floordiv__(self, b)
return I.__mul__(self, b)
def __floordiv__(self, other: float) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__floordiv__(self, other)
def __truediv__(self, other: int) -> float:
def __floordiv__(self, b: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__truediv__(self, other)
return I.__floordiv__(self, b)
def __truediv__(self, other: float) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__truediv__(self, other)
def __mod__(self, b: int) -> int:
def __truediv__(self, other: int) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__mod__(self, b)
return I.__truediv__(self, other)
def __mod__(self, other: float) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__mod__(self, other)
def __mod__(self, b: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__mod__(self, b)
def __invert__(self) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__invert__(self)
def __and__(self, b: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__and__(self, b)
def __or__(self, b: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__or__(self, b)
def __xor__(self, b: int) -> int:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__xor__(self, b)
def __eq__(self, b: int) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__eq__(self, b)
def __eq__(self, b: float) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__eq__(self, b)
def __ne__(self, b: int) -> bool:
def __eq__(self, b: int) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__ne__(self, b)
return I.__eq__(self, b)
def __ne__(self, b: float) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__ne__(self, b)
def __lt__(self, b: int) -> bool:
def __ne__(self, b: int) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__lt__(self, b)
return I.__ne__(self, b)
def __lt__(self, b: float) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__lt__(self, b)
def __gt__(self, b: int) -> bool:
def __lt__(self, b: int) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__gt__(self, b)
return I.__lt__(self, b)
def __gt__(self, b: float) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__gt__(self, b)
def __le__(self, b: int) -> bool:
def __gt__(self, b: int) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__le__(self, b)
return I.__gt__(self, b)
def __le__(self, b: float) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__le__(self, b)
def __ge__(self, b: int) -> bool:
def __le__(self, b: int) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__ge__(self, b)
return I.__le__(self, b)
def __ge__(self, b: float) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__ge__(self, b)
def __pow__(self, exp: int):
def __ge__(self, b: int) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__pow__(self, exp)
return I.__ge__(self, b)
def __pow__(self, exp: float):
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__pow__(self, exp)
def __pow__(self, exp: int):
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return I.__pow__(self, exp)
class F:
@llvm
def __int__(self: float) -> int:
%0 = fptosi double %self to i64
ret i64 %0
def __float__(self: float):
return self
@llvm
def __bool__(self: float) -> bool:
%0 = fcmp one double %self, 0.000000e+00
@ -391,10 +473,12 @@ class float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return F.__int__(self)
def __float__(self) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return self
def __bool__(self) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
@ -406,10 +490,12 @@ class bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return 1 if self else 0
def __float__(self) -> float:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)
return 1. if self else 0.
def __bool__(self) -> bool:
global OP_COUNT
OP_COUNT = inc(OP_COUNT)