/* * translate.cpp --- AST-to-IR translation. * * (c) Seq project. All rights reserved. * This file is subject to the terms and conditions defined in * file 'LICENSE', which is part of this source code package. */ #include #include #include #include #include "parser/ast.h" #include "parser/common.h" #include "parser/visitors/translate/translate.h" #include "parser/visitors/translate/translate_ctx.h" #include "sir/transform/parallel/schedule.h" #include "sir/util/cloning.h" using fmt::format; using seq::ir::cast; using seq::ir::transform::parallel::OMPSched; using std::function; using std::get; using std::make_shared; using std::move; using std::shared_ptr; using std::stack; using std::vector; namespace seq { namespace ast { TranslateVisitor::TranslateVisitor(shared_ptr ctx) : ctx(move(ctx)), result(nullptr) {} ir::Module *TranslateVisitor::apply(shared_ptr cache, StmtPtr stmts) { auto main = cast(cache->module->getMainFunc()); char buf[PATH_MAX + 1]; realpath(cache->module0.c_str(), buf); main->setSrcInfo({string(buf), 0, 0, 0}); auto block = cache->module->Nr("body"); main->setBody(block); cache->codegenCtx = make_shared(cache, block, main); TranslateVisitor(cache->codegenCtx).transform(stmts); return cache->module; } /************************************************************************************/ ir::Value *TranslateVisitor::transform(const ExprPtr &expr) { TranslateVisitor v(ctx); v.setSrcInfo(expr->getSrcInfo()); expr->accept(v); return v.result; } void TranslateVisitor::defaultVisit(Expr *n) { seqassert(false, "invalid node {}", n->toString()); } void TranslateVisitor::visit(BoolExpr *expr) { result = make(expr, expr->value, getType(expr->getType())); } void TranslateVisitor::visit(IntExpr *expr) { result = make(expr, expr->intValue, getType(expr->getType())); } void TranslateVisitor::visit(FloatExpr *expr) { result = make(expr, expr->floatValue, getType(expr->getType())); } void TranslateVisitor::visit(StringExpr *expr) { result = make(expr, expr->getValue(), getType(expr->getType())); } void TranslateVisitor::visit(IdExpr *expr) { auto val = ctx->find(expr->value); seqassert(val, "cannot find '{}'", expr->value); if (auto *v = val->getVar()) result = make(expr, v); else if (auto *f = val->getFunc()) result = make(expr, f); } void TranslateVisitor::visit(IfExpr *expr) { result = make(expr, transform(expr->cond), transform(expr->ifexpr), transform(expr->elsexpr)); } void TranslateVisitor::visit(CallExpr *expr) { auto ft = expr->expr->type->getFunc(); seqassert(ft, "not calling function: {}", ft->toString()); auto callee = transform(expr->expr); bool isVariadic = ft->ast->hasAttr(Attr::CVarArg); vector items; for (int i = 0; i < expr->args.size(); i++) { seqassert(!expr->args[i].value->getEllipsis(), "ellipsis not elided"); if (i + 1 == expr->args.size() && isVariadic) { auto call = expr->args[i].value->getCall(); seqassert(call && call->expr->getId() && startswith(call->expr->getId()->value, TYPE_TUPLE), "expected *args tuple"); for (auto &arg : call->args) items.emplace_back(transform(arg.value)); } else { items.emplace_back(transform(expr->args[i].value)); } } result = make(expr, callee, move(items)); } void TranslateVisitor::visit(StackAllocExpr *expr) { auto *arrayType = ctx->getModule()->unsafeGetArrayType(getType(expr->typeExpr->getType())); arrayType->setAstType(expr->getType()); seqassert(expr->expr->getInt(), "expected a static integer, got {}", expr->expr->toString()); result = make(expr, arrayType, expr->expr->getInt()->intValue); } void TranslateVisitor::visit(DotExpr *expr) { if (expr->member == "__atomic__" || expr->member == "__elemsize__") { seqassert(expr->expr->getId(), "expected IdExpr, got {}", expr->expr->toString()); auto type = ctx->find(expr->expr->getId()->value)->getType(); seqassert(type, "{} is not a type", expr->expr->getId()->value); result = make( expr, type, expr->member == "__atomic__" ? ir::TypePropertyInstr::Property::IS_ATOMIC : ir::TypePropertyInstr::Property::SIZEOF); } else { result = make(expr, transform(expr->expr), expr->member); } } void TranslateVisitor::visit(PtrExpr *expr) { seqassert(expr->expr->getId(), "expected IdExpr, got {}", expr->expr->toString()); auto val = ctx->find(expr->expr->getId()->value); seqassert(val && val->getVar(), "{} is not a variable", expr->expr->getId()->value); result = make(expr, val->getVar()); } void TranslateVisitor::visit(YieldExpr *expr) { result = make(expr, getType(expr->getType())); } void TranslateVisitor::visit(PipeExpr *expr) { auto isGen = [](const ir::Value *v) -> bool { auto *type = v->getType(); if (ir::isA(type)) return true; else if (auto *fn = cast(type)) { return ir::isA(fn->getReturnType()); } return false; }; vector stages; auto *firstStage = transform(expr->items[0].expr); auto firstIsGen = isGen(firstStage); stages.emplace_back(firstStage, vector(), firstIsGen, false); // Pipeline without generators (just function call sugar) auto simplePipeline = !firstIsGen; for (auto i = 1; i < expr->items.size(); i++) { auto call = expr->items[i].expr->getCall(); seqassert(call, "{} is not a call", expr->items[i].expr->toString()); auto fn = transform(call->expr); if (i + 1 != expr->items.size()) simplePipeline &= !isGen(fn); vector args; for (auto &a : call->args) args.emplace_back(a.value->getEllipsis() ? nullptr : transform(a.value)); stages.emplace_back(fn, args, isGen(fn), false); } if (simplePipeline) { // Transform a |> b |> c to c(b(a)) ir::util::CloneVisitor cv(ctx->getModule()); result = cv.clone(stages[0].getCallee()); for (auto i = 1; i < stages.size(); ++i) { vector newArgs; for (auto arg : stages[i]) newArgs.push_back(arg ? cv.clone(arg) : result); result = make(expr, cv.clone(stages[i].getCallee()), newArgs); } } else { for (int i = 0; i < expr->items.size(); i++) if (expr->items[i].op == "||>") stages[i].setParallel(); // This is a statement in IR. ctx->getSeries()->push_back(make(expr, stages)); } } void TranslateVisitor::visit(StmtExpr *expr) { auto *bodySeries = make(expr, "body"); ctx->addSeries(bodySeries); for (auto &s : expr->stmts) transform(s); ctx->popSeries(); result = make(expr, bodySeries, transform(expr->expr)); } /************************************************************************************/ ir::Value *TranslateVisitor::transform(const StmtPtr &stmt) { TranslateVisitor v(ctx); v.setSrcInfo(stmt->getSrcInfo()); stmt->accept(v); if (v.result) ctx->getSeries()->push_back(v.result); return v.result; } void TranslateVisitor::defaultVisit(Stmt *n) { seqassert(false, "invalid node {}", n->toString()); } void TranslateVisitor::visit(SuiteStmt *stmt) { for (auto &s : stmt->stmts) transform(s); } void TranslateVisitor::visit(BreakStmt *stmt) { result = make(stmt); } void TranslateVisitor::visit(ContinueStmt *stmt) { result = make(stmt); } void TranslateVisitor::visit(ExprStmt *stmt) { result = transform(stmt->expr); } void TranslateVisitor::visit(AssignStmt *stmt) { seqassert(stmt->lhs->getId(), "expected IdExpr, got {}", stmt->lhs->toString()); auto var = stmt->lhs->getId()->value; if (!stmt->rhs && var == VAR_ARGV) { ctx->add(TranslateItem::Var, var, ctx->getModule()->getArgVar()); } else if (!stmt->rhs || !stmt->rhs->isType()) { auto *newVar = make(stmt, getType((stmt->rhs ? stmt->rhs : stmt->lhs)->getType()), in(ctx->cache->globals, var), var); if (!in(ctx->cache->globals, var)) ctx->getBase()->push_back(newVar); ctx->add(TranslateItem::Var, var, newVar); if (stmt->rhs) result = make(stmt, newVar, transform(stmt->rhs)); } } void TranslateVisitor::visit(AssignMemberStmt *stmt) { result = make(stmt, transform(stmt->lhs), stmt->member, transform(stmt->rhs)); } void TranslateVisitor::visit(UpdateStmt *stmt) { seqassert(stmt->lhs->getId(), "expected IdExpr, got {}", stmt->lhs->toString()); auto val = ctx->find(stmt->lhs->getId()->value); seqassert(val && val->getVar(), "{} is not a variable", stmt->lhs->getId()->value); result = make(stmt, val->getVar(), transform(stmt->rhs)); } void TranslateVisitor::visit(ReturnStmt *stmt) { result = make(stmt, stmt->expr ? transform(stmt->expr) : nullptr); } void TranslateVisitor::visit(YieldStmt *stmt) { result = make(stmt, stmt->expr ? transform(stmt->expr) : nullptr); ctx->getBase()->setGenerator(); } void TranslateVisitor::visit(WhileStmt *stmt) { auto loop = make(stmt, transform(stmt->cond), make(stmt, "body")); ctx->addSeries(cast(loop->getBody())); transform(stmt->suite); ctx->popSeries(); result = loop; } void TranslateVisitor::visit(ForStmt *stmt) { unique_ptr os = nullptr; if (stmt->decorator) { os = make_unique(); auto c = stmt->decorator->getCall(); seqassert(c, "for par is not a call: {}", stmt->decorator->toString()); auto fc = c->expr->getType()->getFunc(); seqassert(fc && fc->ast->name == "std.openmp.for_par", "for par is not a function"); auto schedule = fc->funcGenerics[0].type->getStatic()->expr->staticValue.getString(); bool ordered = fc->funcGenerics[1].type->getStatic()->expr->staticValue.getInt(); auto threads = transform(c->args[0].value); auto chunk = transform(c->args[1].value); os = make_unique(schedule, threads, chunk, ordered); LOG_TYPECHECK("parsed {}", stmt->decorator->toString()); } seqassert(stmt->var->getId(), "expected IdExpr, got {}", stmt->var->toString()); auto varName = stmt->var->getId()->value; auto var = make(stmt, getType(stmt->var->getType()), false, varName); ctx->getBase()->push_back(var); auto bodySeries = make(stmt, "body"); auto loop = make(stmt, transform(stmt->iter), bodySeries, var); if (os) loop->setSchedule(move(os)); ctx->add(TranslateItem::Var, varName, var); ctx->addSeries(cast(loop->getBody())); transform(stmt->suite); ctx->popSeries(); result = loop; } void TranslateVisitor::visit(IfStmt *stmt) { auto cond = transform(stmt->cond); auto trueSeries = make(stmt, "ifstmt_true"); ctx->addSeries(trueSeries); transform(stmt->ifSuite); ctx->popSeries(); ir::SeriesFlow *falseSeries = nullptr; if (stmt->elseSuite) { falseSeries = make(stmt, "ifstmt_false"); ctx->addSeries(falseSeries); transform(stmt->elseSuite); ctx->popSeries(); } result = make(stmt, cond, trueSeries, falseSeries); } void TranslateVisitor::visit(TryStmt *stmt) { auto *bodySeries = make(stmt, "body"); ctx->addSeries(bodySeries); transform(stmt->suite); ctx->popSeries(); auto finallySeries = make(stmt, "finally"); if (stmt->finally) { ctx->addSeries(finallySeries); transform(stmt->finally); ctx->popSeries(); } auto *tc = make(stmt, bodySeries, finallySeries); for (auto &c : stmt->catches) { auto *catchBody = make(stmt, "catch"); auto *excType = c.exc ? getType(c.exc->getType()) : nullptr; ir::Var *catchVar = nullptr; if (!c.var.empty()) { catchVar = make(stmt, excType, false, c.var); ctx->add(TranslateItem::Var, c.var, catchVar); ctx->getBase()->push_back(catchVar); } ctx->addSeries(catchBody); transform(c.suite); ctx->popSeries(); tc->push_back(ir::TryCatchFlow::Catch(catchBody, excType, catchVar)); } result = tc; } void TranslateVisitor::visit(ThrowStmt *stmt) { result = make(stmt, transform(stmt->expr)); } void TranslateVisitor::visit(FunctionStmt *stmt) { // Process all realizations. for (auto &real : ctx->cache->functions[stmt->name].realizations) { if (!in(ctx->cache->pendingRealizations, make_pair(stmt->name, real.first))) continue; ctx->cache->pendingRealizations.erase(make_pair(stmt->name, real.first)); LOG_TYPECHECK("[translate] generating fn {}", real.first); real.second->ir->setSrcInfo(getSrcInfo()); const auto &ast = real.second->ast; seqassert(ast, "AST not set for {}", real.first); if (!stmt->attributes.has(Attr::LLVM)) transformFunction(real.second->type.get(), ast.get(), real.second->ir); else transformLLVMFunction(real.second->type.get(), ast.get(), real.second->ir); } } void TranslateVisitor::visit(ClassStmt *stmt) { // Nothing to see here, as all type handles are already generated. // Methods will be handled by FunctionStmt visitor. } /************************************************************************************/ seq::ir::types::Type *TranslateVisitor::getType(const types::TypePtr &t) { seqassert(t && t->getClass(), "{} is not a class", t ? t->toString() : "-"); string name = t->getClass()->realizedTypeName(); auto i = ctx->find(name); seqassert(i, "type {} not realized", t->toString()); return i->getType(); } void TranslateVisitor::transformFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func) { vector names; vector indices; vector srcInfos; vector types; for (int i = 0, j = 1; i < ast->args.size(); i++) if (!ast->args[i].generic) { if (!type->args[j]->getFunc()) { types.push_back(getType(type->args[j])); names.push_back(ctx->cache->reverseIdentifierLookup[ast->args[i].name]); indices.push_back(i); } j++; } if (ast->hasAttr(Attr::CVarArg)) { types.pop_back(); names.pop_back(); indices.pop_back(); } auto irType = ctx->getModule()->unsafeGetFuncType( type->realizedName(), getType(type->args[0]), types, ast->hasAttr(Attr::CVarArg)); irType->setAstType(type->getFunc()); func->realize(irType, names); // TODO: refactor IR attribute API map attr; attr[".module"] = ast->attributes.module; for (auto &a : ast->attributes.customAttr) { // LOG("{} -> {}", ast->name, a); attr[a] = ""; } func->setAttribute(make_unique(attr)); for (int i = 0; i < names.size(); i++) func->getArgVar(names[i])->setSrcInfo(ast->args[indices[i]].getSrcInfo()); func->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); if (!ast->attributes.has(Attr::C) && !ast->attributes.has(Attr::Internal)) { ctx->addBlock(); for (auto i = 0; i < names.size(); i++) ctx->add(TranslateItem::Var, ast->args[indices[i]].name, func->getArgVar(names[i])); auto body = make(ast, "body"); ctx->bases.push_back(cast(func)); ctx->addSeries(body); transform(ast->suite); ctx->popSeries(); ctx->bases.pop_back(); cast(func)->setBody(body); ctx->popBlock(); } } void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func) { vector names; vector types; vector indices; for (int i = 0, j = 1; i < ast->args.size(); i++) if (!ast->args[i].generic) { types.push_back(getType(type->args[j])); names.push_back(ctx->cache->reverseIdentifierLookup[ast->args[i].name]); indices.push_back(i); j++; } auto irType = ctx->getModule()->unsafeGetFuncType(type->realizedName(), getType(type->args[0]), types); irType->setAstType(type->getFunc()); auto f = cast(func); f->realize(irType, names); // TODO: refactor IR attribute API map attr; attr[".module"] = ast->attributes.module; for (auto &a : ast->attributes.customAttr) attr[a] = ""; func->setAttribute(make_unique(attr)); for (int i = 0; i < names.size(); i++) func->getArgVar(names[i])->setSrcInfo(ast->args[indices[i]].getSrcInfo()); seqassert(ast->suite->firstInBlock() && ast->suite->firstInBlock()->getExpr() && ast->suite->firstInBlock()->getExpr()->expr->getString(), "LLVM function does not begin with a string"); std::istringstream sin( ast->suite->firstInBlock()->getExpr()->expr->getString()->getValue()); vector literals; auto &ss = ast->suite->getSuite()->stmts; for (int i = 1; i < ss.size(); i++) { if (auto *ei = ss[i]->getExpr()->expr->getInt()) { // static integer expression literals.emplace_back(ei->intValue); } else { seqassert(ss[i]->getExpr()->expr->isType() && ss[i]->getExpr()->expr->getType(), "invalid LLVM type argument: {}", ss[i]->getExpr()->toString()); literals.emplace_back(getType(ss[i]->getExpr()->expr->getType())); } } bool isDeclare = true; string declare; vector lines; for (string l; getline(sin, l);) { string lp = l; ltrim(lp); rtrim(lp); // Extract declares and constants. if (isDeclare && !startswith(lp, "declare ")) { bool isConst = lp.find("private constant") != string::npos; if (!isConst) { isDeclare = false; if (!lp.empty() && lp.back() != ':') lines.emplace_back("entry:"); } } if (isDeclare) declare += lp + "\n"; else lines.emplace_back(l); } f->setLLVMBody(join(lines, "\n")); f->setLLVMDeclarations(declare); f->setLLVMLiterals(literals); func->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); } } // namespace ast } // namespace seq