/* * simplify_statement.cpp --- AST statement simplifications. * * (c) Seq project. All rights reserved. * This file is subject to the terms and conditions defined in * file 'LICENSE', which is part of this source code package. */ #include #include #include #include #include "parser/ast.h" #include "parser/common.h" #include "parser/peg/peg.h" #include "parser/visitors/format/format.h" #include "parser/visitors/simplify/simplify.h" using fmt::format; using std::dynamic_pointer_cast; namespace seq { namespace ast { struct ReplacementVisitor : ReplaceASTVisitor { const unordered_map *table; void transform(ExprPtr &e) override { if (!e) return; ReplacementVisitor v; v.table = table; e->accept(v); if (auto i = e->getId()) { auto it = table->find(i->value); if (it != table->end()) e = it->second->clone(); } } void transform(StmtPtr &e) override { if (!e) return; ReplacementVisitor v; v.table = table; e->accept(v); } }; template T replace(const T &e, const unordered_map &s) { ReplacementVisitor v; v.table = &s; auto ep = clone(e); v.transform(ep); return ep; } StmtPtr SimplifyVisitor::transform(const StmtPtr &stmt) { if (!stmt) return nullptr; SimplifyVisitor v(ctx, preamble); v.setSrcInfo(stmt->getSrcInfo()); const_cast(stmt.get())->accept(v); if (v.resultStmt) v.resultStmt->age = ctx->cache->age; return v.resultStmt; } void SimplifyVisitor::defaultVisit(Stmt *s) { resultStmt = s->clone(); } /**************************************************************************************/ void SimplifyVisitor::visit(SuiteStmt *stmt) { vector r; // Make sure to add context blocks if this suite requires it... if (stmt->ownBlock) ctx->addBlock(); for (const auto &s : stmt->stmts) SuiteStmt::flatten(transform(s), r); // ... and to remove it later. if (stmt->ownBlock) ctx->popBlock(); resultStmt = N(r, stmt->ownBlock); } void SimplifyVisitor::visit(ContinueStmt *stmt) { if (ctx->loops.empty()) error("continue outside of a loop"); resultStmt = stmt->clone(); } void SimplifyVisitor::visit(BreakStmt *stmt) { if (ctx->loops.empty()) error("break outside of a loop"); if (!ctx->loops.back().empty()) { resultStmt = N( transform(N(N(ctx->loops.back()), N(false))), stmt->clone()); } else { resultStmt = stmt->clone(); } } void SimplifyVisitor::visit(ExprStmt *stmt) { resultStmt = N(transform(stmt->expr, true)); } void SimplifyVisitor::visit(AssignStmt *stmt) { vector stmts; if (stmt->rhs && stmt->rhs->getBinary() && stmt->rhs->getBinary()->inPlace) { /// Case 1: a += b seqassert(!stmt->type, "invalid AssignStmt {}", stmt->toString()); stmts.push_back(transformAssignment(stmt->lhs, stmt->rhs, nullptr, false, true)); } else if (stmt->type) { /// Case 2: stmts.push_back(transformAssignment(stmt->lhs, stmt->rhs, stmt->type, true, false)); } else { unpackAssignments(stmt->lhs, stmt->rhs, stmts, stmt->shadow, false); } resultStmt = stmts.size() == 1 ? stmts[0] : N(stmts); } void SimplifyVisitor::visit(DelStmt *stmt) { if (auto eix = stmt->expr->getIndex()) { resultStmt = N(transform( N(N(clone(eix->expr), "__delitem__"), clone(eix->index)))); } else if (auto ei = stmt->expr->getId()) { resultStmt = transform( N(clone(stmt->expr), N(N(N("type"), clone(stmt->expr))))); ctx->remove(ei->value); } else { error("invalid del statement"); } } void SimplifyVisitor::visit(PrintStmt *stmt) { vector args; for (auto &i : stmt->items) args.emplace_back(CallExpr::Arg{"", transform(i)}); if (stmt->isInline) args.emplace_back(CallExpr::Arg{"end", N(" ")}); resultStmt = N(N(transform(N("print")), args)); } void SimplifyVisitor::visit(ReturnStmt *stmt) { if (!ctx->inFunction()) error("expected function body"); resultStmt = N(transform(stmt->expr)); } void SimplifyVisitor::visit(YieldStmt *stmt) { if (!ctx->inFunction()) error("expected function body"); resultStmt = N(transform(stmt->expr)); } void SimplifyVisitor::visit(YieldFromStmt *stmt) { auto var = ctx->cache->getTemporaryVar("yield"); resultStmt = transform( N(N(var), clone(stmt->expr), N(N(var)))); } void SimplifyVisitor::visit(AssertStmt *stmt) { ExprPtr msg = N(""); if (stmt->message) msg = N(N("str"), clone(stmt->message)); if (ctx->getLevel() && ctx->bases.back().attributes & FLAG_TEST) resultStmt = transform( N(N("!", clone(stmt->expr)), N(N(N("__internal__", "seq_assert_test"), N(stmt->getSrcInfo().file), N(stmt->getSrcInfo().line), msg)))); else resultStmt = transform( N(N("!", clone(stmt->expr)), N(N(N("__internal__", "seq_assert"), N(stmt->getSrcInfo().file), N(stmt->getSrcInfo().line), msg)))); } void SimplifyVisitor::visit(WhileStmt *stmt) { ExprPtr cond = N(N(clone(stmt->cond), "__bool__")); string breakVar; StmtPtr assign = nullptr; if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { breakVar = ctx->cache->getTemporaryVar("no_break"); assign = transform(N(N(breakVar), N(true), nullptr, true)); } ctx->loops.push_back(breakVar); // needed for transforming break in loop..else blocks StmtPtr whileStmt = N(transform(cond), transform(stmt->suite)); ctx->loops.pop_back(); if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { resultStmt = N(assign, whileStmt, N(transform(N(N(breakVar, "__bool__"))), transform(stmt->elseSuite))); } else { resultStmt = whileStmt; } } void SimplifyVisitor::visit(ForStmt *stmt) { vector ompArgs; ExprPtr decorator = clone(stmt->decorator); if (decorator) { ExprPtr callee = decorator; if (auto c = callee->getCall()) callee = c->expr; if (!callee || !callee->isId("par")) error("for loop can only take parallel decorator"); vector args; string openmp; vector omp; if (auto c = decorator->getCall()) for (auto &a : c->args) { if (a.name == "openmp" || (a.name.empty() && openmp.empty() && a.value->getString())) { omp = parseOpenMP(ctx->cache, a.value->getString()->getValue(), a.value->getSrcInfo()); } else { args.push_back({a.name, transform(a.value)}); } } for (auto &a : omp) args.push_back({a.name, transform(a.value)}); decorator = N(transform(N("for_par")), args); } string breakVar; auto iter = transform(stmt->iter); // needs in-advance transformation to prevent // name clashes with the iterator variable StmtPtr assign = nullptr, forStmt = nullptr; if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { breakVar = ctx->cache->getTemporaryVar("no_break"); assign = transform(N(N(breakVar), N(true), nullptr, true)); } ctx->loops.push_back(breakVar); // needed for transforming break in loop..else blocks ctx->addBlock(); if (auto i = stmt->var->getId()) { ctx->add(SimplifyItem::Var, i->value, ctx->generateCanonicalName(i->value)); forStmt = N(transform(stmt->var), clone(iter), transform(stmt->suite), nullptr, decorator, ompArgs); } else { string varName = ctx->cache->getTemporaryVar("for"); ctx->add(SimplifyItem::Var, varName, varName); auto var = N(varName); vector stmts; stmts.push_back(N(clone(stmt->var), clone(var), nullptr, true)); stmts.push_back(clone(stmt->suite)); forStmt = N(clone(var), clone(iter), transform(N(stmts)), nullptr, decorator, ompArgs); } ctx->popBlock(); ctx->loops.pop_back(); if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { resultStmt = N(assign, forStmt, N(transform(N(N(breakVar, "__bool__"))), transform(stmt->elseSuite))); } else { resultStmt = forStmt; } } void SimplifyVisitor::visit(IfStmt *stmt) { seqassert(stmt->cond, "invalid if statement"); resultStmt = N(transform(stmt->cond), transform(stmt->ifSuite), transform(stmt->elseSuite)); } void SimplifyVisitor::visit(MatchStmt *stmt) { auto var = ctx->cache->getTemporaryVar("match"); auto result = N(); result->stmts.push_back( N(N(var), clone(stmt->what), nullptr, true)); for (auto &c : stmt->cases) { ctx->addBlock(); StmtPtr suite = N(clone(c.suite), N()); if (c.guard) suite = N(clone(c.guard), suite); result->stmts.push_back(transformPattern(N(var), clone(c.pattern), suite)); ctx->popBlock(); } result->stmts.push_back(N()); // break even if there is no case _. resultStmt = transform(N(N(true), result)); } void SimplifyVisitor::visit(TryStmt *stmt) { vector catches; auto suite = transform(stmt->suite); for (auto &ctch : stmt->catches) { ctx->addBlock(); auto var = ctch.var; if (!ctch.var.empty()) { var = ctx->generateCanonicalName(ctch.var); ctx->add(SimplifyItem::Var, ctch.var, var); } catches.push_back({var, transformType(ctch.exc), transform(ctch.suite)}); ctx->popBlock(); } resultStmt = N(suite, catches, transform(stmt->finally)); } void SimplifyVisitor::visit(ThrowStmt *stmt) { resultStmt = N(transform(stmt->expr)); } void SimplifyVisitor::visit(WithStmt *stmt) { assert(stmt->items.size()); vector content; for (int i = int(stmt->items.size()) - 1; i >= 0; i--) { string var = stmt->vars[i].empty() ? ctx->cache->getTemporaryVar("with") : stmt->vars[i]; content = vector{ N(N(var), clone(stmt->items[i]), nullptr, true), N(N(N(var, "__enter__"))), N(!content.empty() ? N(content, true) : clone(stmt->suite), vector{}, N(vector{N( N(N(var, "__exit__")))}, true))}; } resultStmt = transform(N(content, true)); } void SimplifyVisitor::visit(GlobalStmt *stmt) { if (ctx->bases.empty() || ctx->bases.back().isType()) error("global outside of a function"); auto val = ctx->find(stmt->var); if (!val || !val->isVar()) error("identifier '{}' not found", stmt->var); if (!val->getBase().empty()) error("not a top-level variable"); seqassert(!val->canonicalName.empty(), "'{}' does not have a canonical name", stmt->var); ctx->cache->globals.insert(val->canonicalName); val->global = true; ctx->add(SimplifyItem::Var, stmt->var, val->canonicalName, true); } void SimplifyVisitor::visit(ImportStmt *stmt) { seqassert(!ctx->inClass(), "imports within a class"); if (stmt->from && stmt->from->isId("C")) { /// Handle C imports if (auto i = stmt->what->getId()) resultStmt = transformCImport(i->value, stmt->args, stmt->ret.get(), stmt->as); else if (auto d = stmt->what->getDot()) resultStmt = transformCDLLImport(d->expr.get(), d->member, stmt->args, stmt->ret.get(), stmt->as); else seqassert(false, "invalid C import statement"); return; } else if (stmt->from && stmt->from->isId("python") && stmt->what) { resultStmt = transformPythonImport(stmt->what.get(), stmt->args, stmt->ret.get(), stmt->as); return; } // Transform import a.b.c.d to "a/b/c/d". vector dirs; // Path components if (stmt->from) { Expr *e = stmt->from.get(); while (auto d = e->getDot()) { dirs.push_back(d->member); e = d->expr.get(); } if (!e->getId() || !stmt->args.empty() || stmt->ret || (stmt->what && !stmt->what->getId())) error("invalid import statement"); dirs.push_back(e->getId()->value); } // Handle dots (e.g. .. in from ..m import x). seqassert(stmt->dots >= 0, "negative dots in ImportStmt"); for (int i = 0; i < stmt->dots - 1; i++) dirs.emplace_back(".."); string path; for (int i = int(dirs.size()) - 1; i >= 0; i--) path += dirs[i] + (i ? "/" : ""); // Fetch the import! auto file = getImportFile(ctx->cache->argv0, path, ctx->getFilename(), false, ctx->cache->module0); if (!file) error("cannot locate import '{}'", join(dirs, ".")); // If the imported file has not been seen before, load it. if (ctx->cache->imports.find(file->path) == ctx->cache->imports.end()) transformNewImport(*file); const auto &import = ctx->cache->imports[file->path]; string importVar = import.importVar; string importDoneVar = importVar + "_done"; // Import variable is empty if it has already been loaded during the standard library // initialization. if (!ctx->isStdlibLoading && !importVar.empty()) { vector ifSuite; ifSuite.emplace_back(N(N(N(importVar)))); ifSuite.emplace_back(N(N(importDoneVar), N(true))); resultStmt = N(N(N(importDoneVar, "__invert__")), N(ifSuite)); } if (!stmt->what) { // Case 1: import foo auto name = stmt->as.empty() ? path : stmt->as; auto var = importVar + "_var"; resultStmt = N( resultStmt, transform(N(N(var), N(N("Import"), N(file->module), N(file->path)), N("Import")))); ctx->add(SimplifyItem::Var, name, var); ctx->find(name)->importPath = file->path; } else if (stmt->what->isId("*")) { // Case 2: from foo import * seqassert(stmt->as.empty(), "renamed star-import"); // Just copy all symbols from import's context here. for (auto &i : *(import.ctx)) if (!startswith(i.first, "_") && i.second.front().second->isGlobal()) { ctx->add(i.first, i.second.front().second); ctx->add(i.second.front().second->canonicalName, i.second.front().second); } } else { // Case 3: from foo import bar auto i = stmt->what->getId(); seqassert(i, "not a valid import what expression"); auto c = import.ctx->find(i->value); // Make sure that we are importing an existing global symbol if (!c || !c->isGlobal()) error("symbol '{}' not found in {}", i->value, file->path); ctx->add(stmt->as.empty() ? i->value : stmt->as, c); ctx->add(c->canonicalName, c); } } void SimplifyVisitor::visit(FunctionStmt *stmt) { vector decorators; Attr attr = stmt->attributes; for (auto &d : stmt->decorators) { if (d->isId("__attribute__")) { if (stmt->decorators.size() != 1) error("__attribute__ cannot be mixed with other decorators"); attr.isAttribute = true; } else if (d->isId(Attr::LLVM)) attr.set(Attr::LLVM); else if (d->isId(Attr::Python)) attr.set(Attr::Python); else if (d->isId(Attr::Internal)) attr.set(Attr::Internal); else if (d->isId(Attr::Atomic)) attr.set(Attr::Atomic); else if (d->isId(Attr::Property)) attr.set(Attr::Property); else if (d->isId(Attr::ForceRealize)) attr.set(Attr::ForceRealize); else { // Let's check if this is a attribute auto dt = transform(clone(d)); if (dt && dt->getId()) { auto ci = ctx->find(dt->getId()->value); if (ci && ci->kind == SimplifyItem::Func) { if (ctx->cache->functions[ci->canonicalName].ast->attributes.isAttribute) { attr.set(ci->canonicalName); continue; } } } decorators.emplace_back(clone(d)); } } if (attr.has(Attr::Python)) { // Handle Python code separately resultStmt = transformPythonDefinition(stmt->name, stmt->args, stmt->ret.get(), stmt->suite->firstInBlock()); // TODO: error on decorators return; } auto canonicalName = ctx->generateCanonicalName(stmt->name, true); bool isClassMember = ctx->inClass(); bool isEnclosedFunc = ctx->inFunction(); if (attr.has(Attr::ForceRealize) && (ctx->getLevel() || isClassMember)) error("builtins must be defined at the toplevel"); auto oldBases = move(ctx->bases); ctx->bases = vector(); if (!isClassMember) // Class members are added to class' method table ctx->add(SimplifyItem::Func, stmt->name, canonicalName, ctx->isToplevel()); if (isClassMember) ctx->bases.push_back(oldBases[0]); ctx->bases.emplace_back(SimplifyContext::Base{canonicalName}); // Add new base... ctx->addBlock(); // ... and a block! // Set atomic flag if @atomic attribute is present. if (attr.has(Attr::Atomic)) ctx->bases.back().attributes |= FLAG_ATOMIC; if (attr.has(Attr::Test)) ctx->bases.back().attributes |= FLAG_TEST; // Add generic identifiers to the context unordered_set seenArgs; // Parse function arguments and add them to the context. vector args; bool defaultsStarted = false, hasStarArg = false, hasKwArg = false; // Add generics first for (int ia = 0; ia < stmt->args.size(); ia++) { auto &a = stmt->args[ia]; // check if this is a generic! if (a.type && (a.type->isId("type") || a.type->isId("TypeVar") || (a.type->getIndex() && a.type->getIndex()->expr->isId("Static")))) a.generic = true; string varName = a.name; int stars = trimStars(varName); if (stars == 2) { if (hasKwArg || a.deflt || ia != stmt->args.size() - 1) error("invalid **kwargs"); hasKwArg = true; } else if (stars == 1) { if (hasStarArg || a.deflt) error("invalid *args"); hasStarArg = true; } if (in(seenArgs, varName)) error("'{}' declared twice", varName); seenArgs.insert(varName); if (!a.deflt && defaultsStarted && !stars && !a.generic) error("non-default argument '{}' after a default argument", varName); defaultsStarted |= bool(a.deflt); auto typeAst = a.type; if (!typeAst && isClassMember && ia == 0 && a.name == "self") { typeAst = ctx->bases[ctx->bases.size() - 2].ast; attr.set(".changedSelf"); } if (attr.has(Attr::C)) { if (a.deflt) error("C functions do not accept default argument"); if (stars != 1 && !typeAst) error("C functions require explicit type annotations"); if (stars == 1) attr.set(Attr::CVarArg); } // First add all generics! auto name = ctx->generateCanonicalName(varName); args.emplace_back(Param{string(stars, '*') + name, typeAst, a.deflt, a.generic}); if (a.generic) { if (a.type->getIndex() && a.type->getIndex()->expr->isId("Static")) ctx->add(SimplifyItem::Var, varName, name); else ctx->add(SimplifyItem::Type, varName, name); } } for (auto &a : args) { a.type = transformType(a.type, false); a.deflt = transform(a.deflt, true); } // Delay adding to context to prevent "def foo(a, b=a)" for (auto &a : args) { if (!a.generic) { string canName = a.name; trimStars(canName); ctx->add(SimplifyItem::Var, ctx->cache->reverseIdentifierLookup[canName], canName); } } // Parse the return type. if (!stmt->ret && (attr.has(Attr::LLVM) || attr.has(Attr::C))) error("LLVM functions must have a return type"); auto ret = transformType(stmt->ret, false); // Parse function body. StmtPtr suite = nullptr; std::map captures; if (!attr.has(Attr::Internal) && !attr.has(Attr::C)) { ctx->addBlock(); if (attr.has(Attr::LLVM)) { suite = transformLLVMDefinition(stmt->suite->firstInBlock()); } else if (attr.has(Attr::C)) { ; } else { if ((isEnclosedFunc || attr.has(Attr::Capture)) && !isClassMember) ctx->captures.emplace_back(std::map{}); suite = SimplifyVisitor(ctx, preamble).transform(stmt->suite); if ((isEnclosedFunc || attr.has(Attr::Capture)) && !isClassMember) { captures = ctx->captures.back(); ctx->captures.pop_back(); } } ctx->popBlock(); } // Once the body is done, check if this function refers to a variable (or generic) // from outer scope (e.g. it's parent is not -1). If so, store the name of the // innermost base that was referred to in this function. auto isMethod = ctx->bases.back().attributes & FLAG_METHOD; ctx->bases.pop_back(); ctx->bases = move(oldBases); ctx->popBlock(); attr.module = format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::", ctx->moduleName.module); if (isClassMember) { // If this is a method... // ... set the enclosing class name... attr.parentClass = ctx->bases.back().name; // ... add the method to class' method list ... ctx->cache->classes[ctx->bases.back().name].methods[stmt->name].push_back( {canonicalName, nullptr, ctx->cache->age}); // ... and if the function references outer class variable (by definition a // generic), mark it as not static as it needs fully instantiated class to be // realized. For example, in class A[T]: def foo(): pass, A.foo() can be realized // even if T is unknown. However, def bar(): return T() cannot because it needs T // (and is thus accordingly marked with ATTR_IS_METHOD). if (isMethod) attr.set(Attr::Method); } vector partialArgs; if (!captures.empty()) { Param kw; if (hasKwArg) { kw = args.back(); args.pop_back(); } for (auto &c : captures) { args.emplace_back(Param{c.second, nullptr, nullptr}); partialArgs.emplace_back(CallExpr::Arg{ c.second, N(ctx->cache->reverseIdentifierLookup[c.first])}); } if (hasKwArg) args.push_back(kw); partialArgs.emplace_back(CallExpr::Arg{"", N()}); } auto f = N(canonicalName, ret, args, suite, attr); preamble->functions.push_back( N(canonicalName, clone(f->ret), clone_nop(f->args), suite, attr)); // Make sure to cache this (generic) AST for later realization. ctx->cache->functions[canonicalName].ast = f; ExprPtr finalExpr; if (!captures.empty()) finalExpr = N(N(stmt->name), partialArgs); if (isClassMember && decorators.size()) error("decorators cannot be applied to class methods"); for (int j = int(decorators.size()) - 1; j >= 0; j--) { if (auto c = const_cast(decorators[j]->getCall())) { c->args.emplace(c->args.begin(), CallExpr::Arg{"", finalExpr ? finalExpr : N(stmt->name)}); finalExpr = N(c->expr, c->args); } else { finalExpr = N(decorators[j], finalExpr ? finalExpr : N(stmt->name)); } } if (finalExpr) resultStmt = transform(N(N(stmt->name), finalExpr)); } void SimplifyVisitor::visit(ClassStmt *stmt) { enum Magic { Init, Repr, Eq, Order, Hash, Pickle, Container, Python }; Attr attr = stmt->attributes; vector hasMagic(10, 2); hasMagic[Init] = hasMagic[Pickle] = 1; // @tuple(init=, repr=, eq=, order=, hash=, pickle=, container=, python=, add=, // internal=...) // @dataclass(...) // @extend for (auto &d : stmt->decorators) { if (auto c = d->getCall()) { if (c->expr->isId(Attr::Tuple)) attr.set(Attr::Tuple); else if (!c->expr->isId("dataclass")) error("invalid class attribute"); else if (attr.has(Attr::Tuple)) error("class already marked as tuple"); for (auto &a : c->args) { auto b = CAST(a.value, BoolExpr); if (!b) error("expected static boolean"); auto val = b->value; if (a.name == "init") hasMagic[Init] = val; else if (a.name == "repr") hasMagic[Repr] = val; else if (a.name == "eq") hasMagic[Eq] = val; else if (a.name == "order") hasMagic[Order] = val; else if (a.name == "hash") hasMagic[Hash] = val; else if (a.name == "pickle") hasMagic[Pickle] = val; else if (a.name == "python") hasMagic[Python] = val; else if (a.name == "container") hasMagic[Container] = val; else error("invalid decorator argument"); } } else if (d->isId(Attr::Tuple)) { if (attr.has(Attr::Tuple)) error("class already marked as tuple"); attr.set(Attr::Tuple); } else if (d->isId(Attr::Extend)) { attr.set(Attr::Extend); if (stmt->decorators.size() != 1) error("extend cannot be combined with other decorators"); if (!ctx->bases.empty()) error("extend is only allowed at the toplevel"); } else if (d->isId(Attr::Internal)) { attr.set(Attr::Internal); } } for (int i = 1; i < hasMagic.size(); i++) if (hasMagic[i] == 2) hasMagic[i] = attr.has(Attr::Tuple) ? 1 : 0; // Extensions (@extend) cases are handled bit differently // (no auto method-generation, no arguments etc.) bool extension = attr.has(Attr::Extend); bool isRecord = attr.has(Attr::Tuple); // does it have @tuple attribute // Special name handling is needed because of nested classes. string name = stmt->name; if (!ctx->bases.empty() && ctx->bases.back().isType()) { const auto &a = ctx->bases.back().ast; string parentName = a->getId() ? a->getId()->value : a->getIndex()->expr->getId()->value; name = parentName + "." + name; } // Generate/find class' canonical name (unique ID) and AST string canonicalName; ClassStmt *originalAST = nullptr; auto classItem = make_shared(SimplifyItem::Type, "", "", ctx->isToplevel()); if (!extension) { classItem->canonicalName = canonicalName = ctx->generateCanonicalName(name, !attr.has(Attr::Internal)); // Reference types are added to the context at this stage. // Record types (tuples) are added after parsing class arguments to prevent // recursive record types (that are allowed for reference types). if (!isRecord) { ctx->add(name, classItem); ctx->cache->imports[STDLIB_IMPORT].ctx->addToplevel(canonicalName, classItem); } originalAST = stmt; } else { // Find the canonical name of a class that is to be extended auto val = ctx->find(name); if (!val || val->kind != SimplifyItem::Type) error("cannot find type '{}' to extend", name); canonicalName = val->canonicalName; const auto &astIter = ctx->cache->classes.find(canonicalName); if (astIter == ctx->cache->classes.end()) error("cannot extend type alias or an instantiation ({})", name); originalAST = astIter->second.ast.get(); if (stmt->args.size()) error("extensions cannot be generic or declare members"); } // Add the class base. auto oldBases = move(ctx->bases); ctx->bases = vector(); ctx->bases.emplace_back(SimplifyContext::Base(canonicalName)); ctx->bases.back().ast = make_shared(name); if (extension && !stmt->baseClasses.empty()) error("extensions cannot inherit other classes"); vector baseASTs; vector args; vector> substitutions; vector argSubstitutions; unordered_set seenMembers; for (auto &baseClass : stmt->baseClasses) { string bcName; vector subs; if (auto i = baseClass->getId()) bcName = i->value; else if (auto e = baseClass->getIndex()) { if (auto i = e->expr->getId()) { bcName = i->value; subs = e->index->getTuple() ? e->index->getTuple()->items : vector{e->index}; } } bcName = transformType(N(bcName))->getId()->value; if (bcName.empty() || !in(ctx->cache->classes, bcName)) error(baseClass.get(), "invalid base class"); baseASTs.push_back(ctx->cache->classes[bcName].ast.get()); if (baseASTs.back()->attributes.has(Attr::Tuple) != isRecord) error("tuples cannot inherit reference classes (and vice versa)"); if (baseASTs.back()->attributes.has(Attr::Internal)) error("cannot inherit internal types"); int si = 0; substitutions.push_back({}); for (auto &a : baseASTs.back()->args) if (a.generic) { if (si >= subs.size()) error(baseClass.get(), "wrong number of generics"); substitutions.back()[a.name] = clone(subs[si++]); } if (si != subs.size()) error(baseClass.get(), "wrong number of generics"); for (auto &a : baseASTs.back()->args) if (!a.generic) { if (seenMembers.find(a.name) != seenMembers.end()) error(a.type, "'{}' declared twice", a.name); seenMembers.insert(a.name); args.emplace_back(Param{a.name, a.type, a.deflt}); argSubstitutions.push_back(substitutions.size() - 1); if (!extension) ctx->cache->classes[canonicalName].fields.push_back({a.name, nullptr}); } } // Add generics, if any, to the context. ctx->addBlock(); vector genAst; substitutions.push_back({}); for (auto &a : (extension ? originalAST : stmt)->args) { seqassert(a.type, "no type provided for '{}'", a.name); if (a.type && (a.type->isId("type") || a.type->isId("TypeVar") || (a.type->getIndex() && a.type->getIndex()->expr->isId("Static")))) a.generic = true; if (seenMembers.find(a.name) != seenMembers.end()) error(a.type, "'{}' declared twice", a.name); seenMembers.insert(a.name); if (a.generic) { auto varName = extension ? a.name : ctx->generateCanonicalName(a.name); auto name = extension ? ctx->cache->reverseIdentifierLookup[a.name] : a.name; if (a.type->getIndex() && a.type->getIndex()->expr->isId("Static")) ctx->add(SimplifyItem::Var, name, varName, true); else ctx->add(SimplifyItem::Type, name, varName, true); genAst.push_back(N(varName)); args.emplace_back(Param{varName, a.type, a.deflt, a.generic}); } else { args.emplace_back(Param{a.name, a.type, a.deflt}); if (!extension) ctx->cache->classes[canonicalName].fields.push_back({a.name, nullptr}); } argSubstitutions.push_back(substitutions.size() - 1); } if (!genAst.empty()) ctx->bases.back().ast = make_shared(N(name), N(genAst)); vector stmts{nullptr}; // Will be filled later! // Parse nested classes for (auto sp : getClassMethods(stmt->suite)) if (sp && sp->getClass()) { // Add dummy base to fix nested class' name. ctx->bases.emplace_back(SimplifyContext::Base(canonicalName)); ctx->bases.back().ast = make_shared(name); auto origName = sp->getClass()->name; stmts.emplace_back(transform(sp)); ctx->add(origName, ctx->find(stmts.back()->getSuite()->stmts[0]->getClass()->name)); ctx->bases.pop_back(); } vector memberArgs; for (auto &s : substitutions) for (auto &i : s) i.second = transform(i.second, true); for (int ai = 0; ai < args.size(); ai++) { auto &a = args[ai]; if (argSubstitutions[ai] == substitutions.size() - 1) { a.type = transformType(a.type, false); a.deflt = transform(a.deflt, true); } else { a.type = replace(a.type, substitutions[argSubstitutions[ai]]); a.deflt = replace(a.deflt, substitutions[argSubstitutions[ai]]); } if (!a.generic) memberArgs.push_back(a); } // Parse class members (arguments) and methods. auto suite = N(); if (!extension) { // Now that we are done with arguments, add record type to the context. // However, we need to unroll a block/base, add it, and add the unrolled // block/base back. if (isRecord) { ctx->addPrevBlock(name, classItem); ctx->cache->imports[STDLIB_IMPORT].ctx->addToplevel(canonicalName, classItem); } // Create a cached AST. attr.module = format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::", ctx->moduleName.module); ctx->cache->classes[canonicalName].ast = N(canonicalName, args, N(), attr); vector fns; ExprPtr codeType = ctx->bases.back().ast->clone(); vector magics{}; // Internal classes do not get any auto-generated members. if (!attr.has(Attr::Internal)) { // Prepare a list of magics that are to be auto-generated. if (isRecord) magics = {"len", "hash"}; else magics = {"new", "raw"}; if (hasMagic[Init]) magics.emplace_back(isRecord ? "new" : "init"); if (hasMagic[Eq]) for (auto &i : {"eq", "ne"}) magics.emplace_back(i); if (hasMagic[Order]) for (auto &i : {"lt", "gt", "le", "ge"}) magics.emplace_back(i); if (hasMagic[Pickle]) for (auto &i : {"pickle", "unpickle"}) magics.emplace_back(i); if (hasMagic[Repr]) magics.emplace_back("str"); if (hasMagic[Container]) for (auto &i : {"iter", "getitem"}) magics.emplace_back(i); if (hasMagic[Python]) for (auto &i : {"to_py", "from_py"}) magics.emplace_back(i); if (hasMagic[Container] && startswith(stmt->name, TYPE_TUPLE)) magics.emplace_back("contains"); if (!startswith(stmt->name, TYPE_TUPLE)) magics.emplace_back("dict"); if (startswith(stmt->name, TYPE_TUPLE)) magics.emplace_back("add"); } // Codegen default magic methods and add them to the final AST. for (auto &m : magics) { transform(codegenMagic(m, ctx->bases.back().ast.get(), memberArgs, isRecord)); suite->stmts.push_back(preamble->functions.back()); } } for (int ai = 0; ai < baseASTs.size(); ai++) for (auto sp : getClassMethods(baseASTs[ai]->suite)) if (auto f = sp->getFunction()) { if (f->attributes.has("autogenerated")) continue; auto subs = substitutions[ai]; auto newName = ctx->generateCanonicalName( ctx->cache->reverseIdentifierLookup[f->name], true); auto nf = std::dynamic_pointer_cast(replace(sp, subs)); subs[nf->name] = N(newName); nf->name = newName; suite->stmts.push_back(nf); nf->attributes.parentClass = ctx->bases.back().name; // check original ast... if (nf->attributes.has(".changedSelf")) nf->args[0].type = transformType(ctx->bases.back().ast); preamble->functions.push_back(clone(nf)); ctx->cache->functions[newName].ast = nf; ctx->cache->classes[ctx->bases.back().name] .methods[ctx->cache->reverseIdentifierLookup[f->name]] .push_back({newName, nullptr, ctx->cache->age}); } for (auto sp : getClassMethods(stmt->suite)) if (sp && !sp->getClass()) { transform(sp); suite->stmts.push_back(preamble->functions.back()); } ctx->bases.pop_back(); ctx->bases = move(oldBases); ctx->popBlock(); auto c = ctx->cache->classes[canonicalName].ast.get(); if (!extension) { // Update the cached AST. seqassert(c, "not a class AST for {}", canonicalName); preamble->globals.push_back(c->clone()); c->suite = clone(suite); // if (stmt->baseClasses.size()) // LOG("{} -> {}", stmt->name, c->toString(0)); } stmts[0] = N(canonicalName, vector{}, N(), Attr({Attr::Extend}), vector{}, vector{}); resultStmt = N(stmts); } void SimplifyVisitor::visit(CustomStmt *stmt) { if (stmt->suite) { auto fn = ctx->cache->customBlockStmts.find(stmt->keyword); seqassert(fn != ctx->cache->customBlockStmts.end(), "unknown keyword {}", stmt->keyword); resultStmt = fn->second.second(this, stmt); } else { auto fn = ctx->cache->customExprStmts.find(stmt->keyword); seqassert(fn != ctx->cache->customExprStmts.end(), "unknown keyword {}", stmt->keyword); resultStmt = fn->second(this, stmt); } } /**************************************************************************************/ StmtPtr SimplifyVisitor::transformAssignment(const ExprPtr &lhs, const ExprPtr &rhs, const ExprPtr &type, bool shadow, bool mustExist) { if (auto ei = lhs->getIndex()) { seqassert(!type, "unexpected type annotation"); return transform(N(N(N(clone(ei->expr), "__setitem__"), clone(ei->index), rhs->clone()))); } else if (auto ed = lhs->getDot()) { seqassert(!type, "unexpected type annotation"); return N(transform(ed->expr), ed->member, transform(rhs, false)); } else if (auto e = lhs->getId()) { ExprPtr t = transformType(type, false); if (!shadow && !t) { auto val = ctx->find(e->value); if (e->value != "_" && val && val->isVar()) { if (val->getBase() == ctx->getBase()) return N(transform(lhs, false), transform(rhs, true), !ctx->bases.empty() && ctx->bases.back().attributes & FLAG_ATOMIC); else if (mustExist) error("variable '{}' is not global", e->value); } } // Function and type aliases are not normal assignments. They are treated like a // simple context renames. // Note: x = Ptr[byte] is not a simple alias, and is handled separately below. auto r = transform(rhs, true); if (r && r->getId()) { auto val = ctx->find(r->getId()->value); if (!val) error("cannot find '{}'", r->getId()->value); if (val->isType() || val->isFunc()) { ctx->add(e->value, val); return nullptr; } } // This assignment is a new variable assignment (not a rename or an update). // Generate new canonical variable name for this assignment and use it afterwards. auto canonical = ctx->generateCanonicalName(e->value); auto l = N(canonical); bool global = ctx->isToplevel(); bool isStatic = t && t->getIndex() && t->getIndex()->expr->isId("Static"); // ctx->moduleName != MODULE_MAIN; // ⚠️ TODO: should we make __main__ top-level variables NOT global by default? // Problem: a = [1]; def foo(): a.append(2) won't work anymore as in Python. if (global && !isStatic) ctx->cache->globals.insert(canonical); // Handle type aliases as well! ctx->add(r && r->isType() ? SimplifyItem::Type : SimplifyItem::Var, e->value, canonical, global); if (global && !isStatic) { if (r && r->isType()) { preamble->globals.push_back(N(N(canonical), clone(r))); } else { preamble->globals.push_back(N(N(canonical), nullptr, t)); return r ? N(l, r) : nullptr; } } else if (isStatic) { preamble->globals.push_back( N(N(canonical), clone(r), clone(t))); } return N(l, r, t); } else { error("invalid assignment"); return nullptr; } } void SimplifyVisitor::unpackAssignments(ExprPtr lhs, ExprPtr rhs, vector &stmts, bool shadow, bool mustExist) { vector leftSide; if (auto et = lhs->getTuple()) { // (a, b) = ... for (auto &i : et->items) leftSide.push_back(i); } else if (auto el = lhs->getList()) { // [a, b] = ... for (auto &i : el->items) leftSide.push_back(i); } else { // A simple assignment. stmts.push_back(transformAssignment(lhs, rhs, nullptr, shadow, mustExist)); return; } // Prepare the right-side expression auto srcPos = rhs.get(); ExprPtr newRhs = nullptr; // This expression must not be deleted until the very end. if (!rhs->getId()) { // Store any non-trivial right-side expression (assign = rhs). auto var = ctx->cache->getTemporaryVar("assign"); newRhs = Nx(srcPos, var); stmts.push_back(transformAssignment(newRhs, rhs, nullptr, shadow, mustExist)); rhs = newRhs; } // Process each assignment until the fist StarExpr (if any). int st; for (st = 0; st < leftSide.size(); st++) { if (leftSide[st]->getStar()) break; // Transformation: leftSide_st = rhs[st] auto rightSide = Nx(srcPos, rhs->clone(), Nx(srcPos, st)); // Recursively process the assignment (as we can have cases like (a, (b, c)) = d). unpackAssignments(leftSide[st], rightSide, stmts, shadow, mustExist); } // If there is a StarExpr, process it and the remaining assignments after it (if // any). if (st < leftSide.size() && leftSide[st]->getStar()) { // StarExpr becomes SliceExpr: in (a, *b, c) = d, b is d[1:-2] auto rightSide = Nx( srcPos, rhs->clone(), Nx(srcPos, Nx(srcPos, st), // This slice is either [st:] or [st:-lhs_len + st + 1] leftSide.size() == st + 1 ? nullptr : Nx(srcPos, -leftSide.size() + st + 1), nullptr)); unpackAssignments(leftSide[st]->getStar()->what, rightSide, stmts, shadow, mustExist); st += 1; // Keep going till the very end. Remaining assignments use negative indices (-1, // -2 etc) as we are not sure how big is StarExpr. for (; st < leftSide.size(); st++) { if (leftSide[st]->getStar()) error(leftSide[st], "multiple unpack expressions"); rightSide = Nx(srcPos, rhs->clone(), Nx(srcPos, -leftSide.size() + st)); unpackAssignments(leftSide[st], rightSide, stmts, shadow, mustExist); } } } StmtPtr SimplifyVisitor::transformPattern(ExprPtr var, ExprPtr pattern, StmtPtr suite) { auto isinstance = [&](const ExprPtr &e, const string &typ) -> ExprPtr { return N(N("isinstance"), e->clone(), N(typ)); }; auto findEllipsis = [&](const vector &items) { int i = items.size(); for (int it = 0; it < items.size(); it++) if (items[it]->getEllipsis()) { if (i != items.size()) error("cannot have multiple ranges in a pattern"); i = it; } return i; }; if (pattern->getInt() || CAST(pattern, BoolExpr)) { return N(isinstance(var, CAST(pattern, BoolExpr) ? "bool" : "int"), N(N(var->clone(), "==", pattern), suite)); } else if (auto er = CAST(pattern, RangeExpr)) { return N( isinstance(var, "int"), N( N(var->clone(), ">=", clone(er->start)), N(N(var->clone(), "<=", clone(er->stop)), suite))); } else if (auto et = pattern->getTuple()) { for (int it = int(et->items.size()) - 1; it >= 0; it--) suite = transformPattern(N(var->clone(), N(it)), clone(et->items[it]), suite); return N( isinstance(var, "Tuple"), N(N(N(N("staticlen"), clone(var)), "==", N(et->items.size())), suite)); } else if (auto el = pattern->getList()) { auto ellipsis = findEllipsis(el->items), sz = int(el->items.size()); string op; if (ellipsis == el->items.size()) op = "=="; else op = ">=", sz -= 1; for (int it = int(el->items.size()) - 1; it > ellipsis; it--) suite = transformPattern( N(var->clone(), N(it - el->items.size())), clone(el->items[it]), suite); for (int it = ellipsis - 1; it >= 0; it--) suite = transformPattern(N(var->clone(), N(it)), clone(el->items[it]), suite); return N(isinstance(var, "List"), N(N(N(N("len"), clone(var)), op, N(sz)), suite)); } else if (auto eb = pattern->getBinary()) { if (eb->op == "|") { return N(transformPattern(clone(var), clone(eb->lexpr), clone(suite)), transformPattern(clone(var), clone(eb->rexpr), suite)); } } else if (auto ea = CAST(pattern, AssignExpr)) { seqassert(ea->var->getId(), "only simple assignment expressions are supported"); return N( vector{N(clone(ea->var), clone(var)), transformPattern(clone(var), clone(ea->expr), clone(suite))}, true); } else if (auto ei = pattern->getId()) { if (ei->value != "_") return N( vector{N(clone(pattern), clone(var)), suite}, true); else return suite; } pattern = transform(pattern); // basically check for errors return N( N(N("hasattr"), var->clone(), N("__match__"), N(N("type"), pattern->clone())), N(N(N(var->clone(), "__match__"), pattern), suite)); } StmtPtr SimplifyVisitor::transformCImport(const string &name, const vector &args, const Expr *ret, const string &altName) { vector fnArgs; auto attr = Attr({Attr::C}); for (int ai = 0; ai < args.size(); ai++) { seqassert(args[ai].name.empty(), "unexpected argument name"); seqassert(!args[ai].deflt, "unexpected default argument"); seqassert(args[ai].type, "missing type"); if (dynamic_cast(args[ai].type.get()) && ai + 1 == args.size()) { attr.set(Attr::CVarArg); fnArgs.emplace_back(Param{"*args", nullptr, nullptr}); } else { fnArgs.emplace_back( Param{args[ai].name.empty() ? format("a{}", ai) : args[ai].name, args[ai].type->clone(), nullptr}); } } auto f = N(name, ret ? ret->clone() : N("void"), fnArgs, nullptr, attr); StmtPtr tf = transform(f); // Already in the preamble if (!altName.empty()) ctx->add(altName, ctx->find(name)); return tf; } StmtPtr SimplifyVisitor::transformCDLLImport(const Expr *dylib, const string &name, const vector &args, const Expr *ret, const string &altName) { // name : Function[args] = _dlsym(dylib, "name", Fn=Function[args]) vector fnArgs{N(vector{}), ret ? ret->clone() : N("void")}; for (const auto &a : args) { seqassert(a.name.empty(), "unexpected argument name"); seqassert(!a.deflt, "unexpected default argument"); seqassert(a.type, "missing type"); const_cast(fnArgs[0]->getList())->items.emplace_back(clone(a.type)); } auto type = N(N("Function"), N(fnArgs)); return transform(N( N(altName.empty() ? name : altName), N(N("_dlsym"), vector{{"", dylib->clone()}, {"", N(name)}, {"Fn", type}}))); } StmtPtr SimplifyVisitor::transformPythonImport(const Expr *what, const vector &args, const Expr *ret, const string &altName) { // Get a module name (e.g. os.path) vector dirs; auto e = what; while (auto d = e->getDot()) { dirs.push_back(d->member); e = d->expr.get(); } seqassert(e && e->getId(), "invalid import python statement"); dirs.push_back(e->getId()->value); string name = dirs[0], lib; for (int i = int(dirs.size()) - 1; i > 0; i--) lib += dirs[i] + (i > 1 ? "." : ""); // Simple module import: from python import foo if (!ret && args.empty()) // altName = pyobj._import("name") return transform(N( N(altName.empty() ? name : altName), N(N("pyobj", "_import"), N((lib.empty() ? "" : lib + ".") + name)))); // Typed function import: from python import foo.bar(int) -> float. // f = pyobj._import("lib")._getattr("name") auto call = N( N("f"), N(N(N(N("pyobj", "_import"), N(lib)), "_getattr"), N(name))); // Make a call expression: f(args...) vector params; vector callArgs; for (int i = 0; i < args.size(); i++) { params.emplace_back(Param{format("a{}", i), clone(args[i].type), nullptr}); callArgs.emplace_back(N(format("a{}", i))); } // Make a return expression: return f(args...), // or return retType.__from_py__(f(args...)) ExprPtr retExpr = N(N("f"), callArgs); if (ret && !ret->isId("void")) retExpr = N(N(ret->clone(), "__from_py__"), retExpr); StmtPtr retStmt = nullptr; if (ret && ret->isId("void")) retStmt = N(retExpr); else retStmt = N(retExpr); // Return a wrapper function return transform(N(altName.empty() ? name : altName, ret ? ret->clone() : nullptr, params, N(call, retStmt))); } void SimplifyVisitor::transformNewImport(const ImportFile &file) { // Use a clean context to parse a new file. if (ctx->cache->age) ctx->cache->age++; auto ictx = make_shared(file.path, ctx->cache); ictx->isStdlibLoading = ctx->isStdlibLoading; ictx->moduleName = file; auto import = ctx->cache->imports.insert({file.path, {file.path, ictx}}).first; // __name__ = (set the Python's __name__ variable) auto sn = SimplifyVisitor(ictx, preamble) .transform(N(N(N("__name__"), N(ictx->moduleName.module), N("str"), true), parseFile(ctx->cache, file.path))); // If we are loading standard library, we won't wrap imports in functions as we // assume that standard library has no recursive imports. We will just append the // top-level statements as-is. if (ctx->isStdlibLoading) { resultStmt = N(vector{sn}, true); } else { // Generate import function identifier. string importVar = import->second.importVar = ctx->cache->getTemporaryVar(format("import_{}", file.module), '.'), importDoneVar; // import_done = False (global variable that indicates if an import has been // loaded) preamble->globals.push_back(N( N(importDoneVar = importVar + "_done"), N(false))); ctx->cache->globals.insert(importDoneVar); vector stmts; stmts.push_back(nullptr); // placeholder to be filled later! // We need to wrap all imported top-level statements (not signatures! they have // already been handled and are in the preamble) into a function. We also take the // list of global variables so that we can access them via "global" statement. auto processStmt = [&](StmtPtr s) { if (s->getAssign() && s->getAssign()->lhs->getId()) { // a = ... globals auto a = const_cast(s->getAssign()); bool isStatic = a->type && a->type->getIndex() && a->type->getIndex()->expr->isId("Static"); auto val = ictx->find(a->lhs->getId()->value); seqassert(val, "cannot locate '{}' in imported file {}", s->getAssign()->lhs->getId()->value, file.path); if (val->kind == SimplifyItem::Var && val->global && val->base.empty() && !isStatic) { stmts.push_back(N(a->lhs, a->rhs)); } else { stmts.push_back(s); } } else if (!s->getFunction() && !s->getClass()) { stmts.push_back(s); } }; if (auto st = const_cast(sn->getSuite())) for (auto &ss : st->stmts) processStmt(ss); else processStmt(sn); stmts[0] = N(); // Add a def import(): ... manually to the cache and to the preamble (it won't be // transformed here!). ctx->cache->functions[importVar].ast = N(importVar, nullptr, vector{}, N(stmts), Attr({Attr::ForceRealize})); preamble->functions.push_back(ctx->cache->functions[importVar].ast->clone()); ; } } StmtPtr SimplifyVisitor::transformPythonDefinition(const string &name, const vector &args, const Expr *ret, const Stmt *codeStmt) { seqassert(codeStmt && codeStmt->getExpr() && codeStmt->getExpr()->expr->getString(), "invalid Python definition"); auto code = codeStmt->getExpr()->expr->getString()->getValue(); vector pyargs; for (const auto &a : args) pyargs.emplace_back(a.name); code = format("def {}({}):\n{}\n", name, join(pyargs, ", "), code); return transform(N( N(N(N("pyobj", "_exec"), N(code))), N(N("python"), N("__main__", name), clone_nop(args), ret ? ret->clone() : N("pyobj")))); } StmtPtr SimplifyVisitor::transformLLVMDefinition(const Stmt *codeStmt) { seqassert(codeStmt && codeStmt->getExpr() && codeStmt->getExpr()->expr->getString(), "invalid LLVM definition"); auto code = codeStmt->getExpr()->expr->getString()->getValue(); vector items; auto se = N(""); string finalCode = se->getValue(); items.push_back(N(se)); int braceCount = 0, braceStart = 0; for (int i = 0; i < code.size(); i++) { if (i < code.size() - 1 && code[i] == '{' && code[i + 1] == '=') { if (braceStart < i) finalCode += escapeFStringBraces(code, braceStart, i - braceStart) + '{'; if (!braceCount) { braceStart = i + 2; braceCount++; } else { error("invalid LLVM substitution"); } } else if (braceCount && code[i] == '}') { braceCount--; string exprCode = code.substr(braceStart, i - braceStart); auto offset = getSrcInfo(); offset.col += i; auto expr = transform(parseExpr(ctx->cache, exprCode, offset), true); items.push_back(N(expr)); braceStart = i + 1; finalCode += '}'; } } if (braceCount) error("invalid LLVM substitution"); if (braceStart != code.size()) finalCode += escapeFStringBraces(code, braceStart, int(code.size()) - braceStart); se->strings[0].first = finalCode; return N(items); } StmtPtr SimplifyVisitor::codegenMagic(const string &op, const Expr *typExpr, const vector &args, bool isRecord) { #define I(s) N(s) assert(typExpr); ExprPtr ret; vector fargs; vector stmts; Attr attr; attr.set("autogenerated"); if (op == "new") { // Classes: @internal def __new__() -> T // Tuples: @internal def __new__(a1: T1, ..., aN: TN) -> T ret = typExpr->clone(); if (isRecord) for (auto &a : args) fargs.emplace_back( Param{a.name, clone(a.type), a.deflt ? clone(a.deflt) : N(clone(a.type))}); attr.set(Attr::Internal); } else if (op == "init") { // Classes: def __init__(self: T, a1: T1, ..., aN: TN) -> void: // self.aI = aI ... ret = I("void"); fargs.emplace_back(Param{"self", typExpr->clone()}); for (auto &a : args) { stmts.push_back(N(N(I("self"), a.name), I(a.name))); fargs.emplace_back(Param{a.name, clone(a.type), a.deflt ? clone(a.deflt) : N(clone(a.type))}); } } else if (op == "raw") { // Classes: def __raw__(self: T) -> Ptr[byte]: // return __internal__.class_raw(self) fargs.emplace_back(Param{"self", typExpr->clone()}); ret = N(I("Ptr"), I("byte")); stmts.emplace_back(N( N(N(I("__internal__"), "class_raw"), I("self")))); } else if (op == "getitem") { // Tuples: def __getitem__(self: T, index: int) -> T1: // return __internal__.tuple_getitem[T, T1](self, index) // (error during a realizeFunc() method if T is a heterogeneous tuple) fargs.emplace_back(Param{"self", typExpr->clone()}); fargs.emplace_back(Param{"index", I("int")}); ret = !args.empty() ? clone(args[0].type) : I("void"); stmts.emplace_back(N( N(N(I("__internal__"), "tuple_getitem"), I("self"), I("index"), typExpr->clone(), ret->clone()))); } else if (op == "iter") { // Tuples: def __iter__(self: T) -> Generator[T]: // yield self.aI ... // (error during a realizeFunc() method if T is a heterogeneous tuple) fargs.emplace_back(Param{"self", typExpr->clone()}); ret = N(I("Generator"), !args.empty() ? clone(args[0].type) : I("int")); for (auto &a : args) stmts.emplace_back(N(N("self", a.name))); if (args.empty()) // Hack for empty tuple: yield from List[int]() stmts.emplace_back( N(N(N(I("List"), I("int"))))); } else if (op == "contains") { // Tuples: def __contains__(self: T, what) -> bool: // if isinstance(what, T1): if what == self.a1: return True ... // return False fargs.emplace_back(Param{"self", typExpr->clone()}); fargs.emplace_back(Param{"what", nullptr}); ret = I("bool"); for (auto &a : args) stmts.push_back(N(N(I("isinstance"), I("what"), clone(a.type)), N(N(N(I("what"), "__eq__"), N(I("self"), a.name)), N(N(true))))); stmts.emplace_back(N(N(false))); } else if (op == "eq") { // def __eq__(self: T, other: T) -> bool: // if not self.arg1.__eq__(other.arg1): return False ... // return True fargs.emplace_back(Param{"self", typExpr->clone()}); fargs.emplace_back(Param{"other", typExpr->clone()}); ret = I("bool"); for (auto &a : args) stmts.push_back(N( N("!", N(N(N(I("self"), a.name), "__eq__"), N(I("other"), a.name))), N(N(false)))); stmts.emplace_back(N(N(true))); } else if (op == "ne") { // def __ne__(self: T, other: T) -> bool: // if self.arg1.__ne__(other.arg1): return True ... // return False fargs.emplace_back(Param{"self", typExpr->clone()}); fargs.emplace_back(Param{"other", typExpr->clone()}); ret = I("bool"); for (auto &a : args) stmts.emplace_back( N(N(N(N(I("self"), a.name), "__ne__"), N(I("other"), a.name)), N(N(true)))); stmts.push_back(N(N(false))); } else if (op == "lt" || op == "gt") { // def __lt__(self: T, other: T) -> bool: (same for __gt__) // if self.arg1.__lt__(other.arg1): return True // elif self.arg1.__eq__(other.arg1): // ... (arg2, ...) ... // return False fargs.emplace_back(Param{"self", typExpr->clone()}); fargs.emplace_back(Param{"other", typExpr->clone()}); ret = I("bool"); vector *v = &stmts; for (int i = 0; i < (int)args.size() - 1; i++) { v->emplace_back(N( N( N(N(I("self"), args[i].name), format("__{}__", op)), N(I("other"), args[i].name)), N(N(true)), N( N(N(N(I("self"), args[i].name), "__eq__"), N(I("other"), args[i].name)), N()))); v = &((SuiteStmt *)(((IfStmt *)(((IfStmt *)(v->back().get()))->elseSuite.get())) ->ifSuite) .get()) ->stmts; } if (!args.empty()) v->emplace_back(N(N( N(N(I("self"), args.back().name), format("__{}__", op)), N(I("other"), args.back().name)))); stmts.emplace_back(N(N(false))); } else if (op == "le" || op == "ge") { // def __le__(self: T, other: T) -> bool: (same for __ge__) // if not self.arg1.__le__(other.arg1): return False // elif self.arg1.__eq__(other.arg1): // ... (arg2, ...) ... // return True fargs.emplace_back(Param{"self", typExpr->clone()}); fargs.emplace_back(Param{"other", typExpr->clone()}); ret = I("bool"); vector *v = &stmts; for (int i = 0; i < (int)args.size() - 1; i++) { v->emplace_back(N( N("!", N(N(N(I("self"), args[i].name), format("__{}__", op)), N(I("other"), args[i].name))), N(N(false)), N( N(N(N(I("self"), args[i].name), "__eq__"), N(I("other"), args[i].name)), N()))); v = &((SuiteStmt *)(((IfStmt *)(((IfStmt *)(v->back().get()))->elseSuite.get())) ->ifSuite) .get()) ->stmts; } if (!args.empty()) v->emplace_back(N(N( N(N(I("self"), args.back().name), format("__{}__", op)), N(I("other"), args.back().name)))); stmts.emplace_back(N(N(true))); } else if (op == "hash") { // def __hash__(self: T) -> int: // seed = 0 // seed = ( // seed ^ ((self.arg1.__hash__() + 2654435769) + ((seed << 6) + (seed >> 2))) // ) ... // return seed fargs.emplace_back(Param{"self", typExpr->clone()}); ret = I("int"); stmts.emplace_back(N(I("seed"), N(0))); for (auto &a : args) stmts.push_back(N( I("seed"), N( I("seed"), "^", N( N(N(N(N(I("self"), a.name), "__hash__")), "+", N(0x9e3779b9)), "+", N(N(I("seed"), "<<", N(6)), "+", N(I("seed"), ">>", N(2))))))); stmts.emplace_back(N(I("seed"))); } else if (op == "pickle") { // def __pickle__(self: T, dest: Ptr[byte]) -> void: // self.arg1.__pickle__(dest) ... fargs.emplace_back(Param{"self", typExpr->clone()}); fargs.emplace_back(Param{"dest", N(I("Ptr"), I("byte"))}); ret = I("void"); for (auto &a : args) stmts.emplace_back(N(N( N(N(I("self"), a.name), "__pickle__"), I("dest")))); } else if (op == "unpickle") { // def __unpickle__(src: Ptr[byte]) -> T: // return T(T1.__unpickle__(src),...) fargs.emplace_back(Param{"src", N(I("Ptr"), I("byte"))}); ret = typExpr->clone(); vector ar; for (auto &a : args) ar.emplace_back(N(N(clone(a.type), "__unpickle__"), I("src"))); stmts.emplace_back(N(N(typExpr->clone(), ar))); } else if (op == "len") { // def __len__(self: T) -> int: // return N (number of args) fargs.emplace_back(Param{"self", typExpr->clone()}); ret = I("int"); stmts.emplace_back(N(N(args.size()))); } else if (op == "to_py") { // def __to_py__(self: T) -> pyobj: // o = pyobj._tuple_new(N) (number of args) // o._tuple_set(1, self.arg1.__to_py__()) ... // return o fargs.emplace_back(Param{"self", typExpr->clone()}); ret = I("pyobj"); stmts.emplace_back( N(I("o"), N(N(I("pyobj"), "_tuple_new"), N(args.size())))); for (int i = 0; i < args.size(); i++) stmts.push_back(N(N( N(I("o"), "_tuple_set"), N(i), N(N(N(I("self"), args[i].name), "__to_py__"))))); stmts.emplace_back(N(I("o"))); } else if (op == "from_py") { // def __from_py__(src: pyobj) -> T: // return T(T1.__from_py__(src._tuple_get(1)), ...) fargs.emplace_back(Param{"src", I("pyobj")}); ret = typExpr->clone(); vector ar; for (int i = 0; i < args.size(); i++) ar.push_back( N(N(clone(args[i].type), "__from_py__"), N(N(I("src"), "_tuple_get"), N(i)))); stmts.emplace_back(N(N(typExpr->clone(), ar))); } else if (op == "str") { // def __str__(self: T) -> str: // a = __array__[str](N) (number of args) // n = __array__[str](N) (number of args) // a.__setitem__(0, self.arg1.__str__()) ... // n.__setitem__(0, "arg1") ... (if not a Tuple.N; otherwise "") // return __internal__.tuple_str(a.ptr, n.ptr, N) fargs.emplace_back(Param{"self", typExpr->clone()}); ret = I("str"); if (!args.empty()) { stmts.emplace_back( N(I("a"), N(N(I("__array__"), I("str")), N(args.size())))); stmts.emplace_back( N(I("n"), N(N(I("__array__"), I("str")), N(args.size())))); for (int i = 0; i < args.size(); i++) { stmts.push_back(N(N( N(I("a"), "__setitem__"), N(i), N(N(N(I("self"), args[i].name), "__str__"))))); auto name = typExpr->getIndex() ? typExpr->getIndex()->expr->getId() : nullptr; stmts.push_back(N(N( N(I("n"), "__setitem__"), N(i), N( name && startswith(name->value, TYPE_TUPLE) ? "" : args[i].name)))); } stmts.emplace_back(N(N( N(I("__internal__"), "tuple_str"), N(I("a"), "ptr"), N(I("n"), "ptr"), N(args.size())))); } else { stmts.emplace_back(N(N("()"))); } } else if (op == "dict") { // def __dict__(self: T): // d = List[str](N) // d.append('arg1') ... // return d fargs.emplace_back(Param{"self", typExpr->clone()}); stmts.emplace_back( N(I("d"), N(N(I("List"), I("str")), N(args.size())))); for (auto &a : args) stmts.push_back(N( N(N(I("d"), "append"), N(a.name)))); stmts.emplace_back(N(I("d"))); } else if (op == "add") { // def __add__(self, tup): // return (*self, *t) fargs.emplace_back(Param{"self", typExpr->clone()}); fargs.emplace_back(Param{"tup", nullptr}); stmts.emplace_back(N( N(vector{N(I("self")), N(I("tup"))}))); } else { seqassert(false, "invalid magic {}", op); } #undef I auto t = make_shared(format("__{}__", op), ret, fargs, N(stmts), attr); t->setSrcInfo(ctx->cache->generateSrcInfo()); return t; } vector SimplifyVisitor::getClassMethods(const StmtPtr &s) { vector v; if (!s) return v; if (auto sp = s->getSuite()) { for (const auto &ss : sp->stmts) for (auto u : getClassMethods(ss)) v.push_back(u); } else if (s->getExpr() && s->getExpr()->expr->getString()) { /// Those are doc-strings, ignore them. } else if (!s->getFunction() && !s->getClass()) { error("only function and class definitions are allowed within classes"); } else { v.push_back(s); } return v; } } // namespace ast } // namespace seq