diff --git a/codon/compiler/compiler.cpp b/codon/compiler/compiler.cpp index 2c920274..cbb62c07 100644 --- a/codon/compiler/compiler.cpp +++ b/codon/compiler/compiler.cpp @@ -9,6 +9,8 @@ #include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/typecheck.h" +#include "codon/sir/util/irtools.h" +#include "codon/sir/util/operator.h" namespace codon { diff --git a/codon/compiler/jit.cpp b/codon/compiler/jit.cpp index 6e082d8e..34016480 100644 --- a/codon/compiler/jit.cpp +++ b/codon/compiler/jit.cpp @@ -108,7 +108,6 @@ llvm::Expected JIT::exec(const std::string &code) { auto sctx = cache->imports[MAIN_IMPORT].ctx; auto preamble = std::make_shared(); - ast::Cache bCache = *cache; ast::SimplifyContext bSimplify = *sctx; ast::TypeContext bType = *(cache->typeCtx); diff --git a/codon/parser/ast/expr.cpp b/codon/parser/ast/expr.cpp index ec5768a3..c4765443 100644 --- a/codon/parser/ast/expr.cpp +++ b/codon/parser/ast/expr.cpp @@ -46,9 +46,8 @@ std::string StaticValue::toString() const { return ""; if (!evaluated) return type == StaticValue::STRING ? "str" : "int"; - return type == StaticValue::STRING - ? "'" + escape(std::get(value)) + "'" - : std::to_string(std::get(value)); + return type == StaticValue::STRING ? "'" + escape(std::get(value)) + "'" + : std::to_string(std::get(value)); } int64_t StaticValue::getInt() const { seqassert(type == StaticValue::INT, "not an int"); @@ -398,11 +397,14 @@ StmtExpr::StmtExpr(std::shared_ptr stmt, std::shared_ptr stmt2, stmts.push_back(std::move(stmt2)); } StmtExpr::StmtExpr(const StmtExpr &expr) - : Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)) {} + : Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)), + attributes(expr.attributes) {} std::string StmtExpr::toString() const { return wrapType(format("stmt-expr ({}) {}", combine(stmts, " "), expr->toString())); } ACCEPT_IMPL(StmtExpr, ASTVisitor); +bool StmtExpr::hasAttr(const std::string &attr) const { return in(attributes, attr); } +void StmtExpr::setAttr(const std::string &attr) { attributes.insert(attr); } PtrExpr::PtrExpr(ExprPtr expr) : Expr(), expr(std::move(expr)) {} PtrExpr::PtrExpr(const PtrExpr &expr) : Expr(expr), expr(ast::clone(expr.expr)) {} diff --git a/codon/parser/ast/expr.h b/codon/parser/ast/expr.h index df5716c4..220d56f7 100644 --- a/codon/parser/ast/expr.h +++ b/codon/parser/ast/expr.h @@ -600,6 +600,8 @@ struct RangeExpr : public Expr { struct StmtExpr : public Expr { std::vector> stmts; ExprPtr expr; + /// Set of attributes. + std::set attributes; StmtExpr(std::vector> stmts, ExprPtr expr); StmtExpr(std::shared_ptr stmt, ExprPtr expr); @@ -610,6 +612,10 @@ struct StmtExpr : public Expr { ACCEPT(ASTVisitor); const StmtExpr *getStmtExpr() const override { return this; } + + /// Attribute helpers + bool hasAttr(const std::string &attr) const; + void setAttr(const std::string &attr); }; /// Pointer expression (__ptr__(expr)). diff --git a/codon/parser/ast/stmt.cpp b/codon/parser/ast/stmt.cpp index 86f0ad64..e3440b83 100644 --- a/codon/parser/ast/stmt.cpp +++ b/codon/parser/ast/stmt.cpp @@ -42,6 +42,7 @@ std::string SuiteStmt::toString(int indent) const { } ACCEPT_IMPL(SuiteStmt, ASTVisitor); void SuiteStmt::flatten(StmtPtr s, std::vector &stmts) { + // WARNING: does not preserve attributes! if (!s) return; auto suite = const_cast(s->getSuite()); diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 63f7161f..feb234bf 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -170,7 +170,6 @@ struct Cache : public std::enable_shared_from_this { /// corresponding Function instance. std::unordered_map functions; - struct Overload { /// Canonical name of an overload (e.g. Foo.__init__.1). std::string name; @@ -189,6 +188,9 @@ struct Cache : public std::enable_shared_from_this { std::shared_ptr codegenCtx; /// Set of function realizations that are to be translated to IR. std::set> pendingRealizations; + /// Mapping of partial record names to function pointers and corresponding masks. + std::unordered_map>> + partials; /// Custom operators std::unordered_map(N(N(clone(var), "append"), clone(it))))); } } - resultExpr = N(stmts, transform(var)); + auto e = N(stmts, transform(var)); + e->setAttr(ir::ListLiteralAttribute::AttributeName); + resultExpr = e; ctx->popBlock(); } @@ -207,7 +210,9 @@ void SimplifyVisitor::visit(SetExpr *expr) { stmts.push_back(transform( N(N(N(clone(var), "add"), clone(it))))); } - resultExpr = N(stmts, transform(var)); + auto e = N(stmts, transform(var)); + e->setAttr(ir::SetLiteralAttribute::AttributeName); + resultExpr = e; ctx->popBlock(); } @@ -229,7 +234,9 @@ void SimplifyVisitor::visit(DictExpr *expr) { stmts.push_back(transform(N(N( N(clone(var), "__setitem__"), clone(it.key), clone(it.value))))); } - resultExpr = N(stmts, transform(var)); + auto e = N(stmts, transform(var)); + e->setAttr(ir::DictLiteralAttribute::AttributeName); + resultExpr = e; ctx->popBlock(); } @@ -386,14 +393,13 @@ void SimplifyVisitor::visit(IndexExpr *expr) { // IndexExpr[i1, ..., iN] is internally stored as IndexExpr[TupleExpr[i1, ..., iN]] // for N > 1, so make sure to check that case. - std::vector it; if (auto t = index->getTuple()) for (auto &i : t->items) it.push_back(i); else it.push_back(index); - for (auto &i: it) { + for (auto &i : it) { if (auto es = i->getStar()) i = N(transform(es->what)); else if (auto ek = CAST(i, KeywordStarExpr)) @@ -689,7 +695,9 @@ void SimplifyVisitor::visit(StmtExpr *expr) { for (auto &s : expr->stmts) stmts.emplace_back(transform(s)); auto e = transform(expr->expr); - resultExpr = N(stmts, e); + auto s = N(stmts, e); + s->attributes = expr->attributes; + resultExpr = s; } /**************************************************************************************/ @@ -726,8 +734,8 @@ ExprPtr SimplifyVisitor::transformInt(const std::string &value, } /// Custom suffix sfx: use int.__suffix_sfx__(str) call. /// NOTE: you cannot neither use binary (0bXXX) format here. - return transform(N(N("int", format("__suffix_{}__", suffix)), - N(val))); + return transform( + N(N("int", format("__suffix_{}__", suffix)), N(val))); } ExprPtr SimplifyVisitor::transformFloat(const std::string &value, diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 7e0ba953..f1346ce1 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -209,6 +209,10 @@ void TranslateVisitor::visit(StmtExpr *expr) { transform(s); ctx->popSeries(); result = make(expr, bodySeries, transform(expr->expr)); + for (auto &a : expr->attributes) { + // if (a == ir::ListLiteralAttribute::AttributeName) + // result->setAttribute(ir::ListLiteralAttribute); + } } /************************************************************************************/ diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index 91e0aa74..18820b16 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -304,7 +304,6 @@ private: ExprPtr transformSuper(const CallExpr *expr); std::vector getSuperTypes(const types::ClassTypePtr &cls); - private: types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b, bool undoOnSuccess = false); diff --git a/codon/parser/visitors/typecheck/typecheck_expr.cpp b/codon/parser/visitors/typecheck/typecheck_expr.cpp index 16f00d64..899dc4cc 100644 --- a/codon/parser/visitors/typecheck/typecheck_expr.cpp +++ b/codon/parser/visitors/typecheck/typecheck_expr.cpp @@ -10,6 +10,7 @@ #include "codon/parser/common.h" #include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/typecheck/typecheck.h" +#include "codon/sir/attribute.h" using fmt::format; @@ -741,14 +742,7 @@ ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &e if (!tuple->getRecord()) return nullptr; if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_PARTIAL)) - // in(std::set{"Ptr", "pyobj", "str", "Array"}, tuple->name)) - // Ptr, pyobj and str are internal types and have only one overloaded __getitem__ return nullptr; - // if (in(ctx->cache->classes[tuple->name].methods, "__getitem__")) { - // ctx->cache->overloads[ctx->cache->classes[tuple->name].methods["__getitem__"]] - // .size() != 1) - // return nullptr; - // } // Extract a static integer value from a compatible expression. auto getInt = [&](int64_t *o, const ExprPtr &e) { @@ -1123,30 +1117,32 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in N(N(N(clone(var), "__init__"), expr->args))), clone(var))); } - } else if (auto pc = callee->getPartial()) { - ExprPtr var = N(partialVar = ctx->cache->getTemporaryVar("pt")); - expr->expr = transform(N(N(clone(var), expr->expr), - N(pc->func->ast->name))); - calleeFn = expr->expr->type->getFunc(); - for (int i = 0, j = 0; i < pc->known.size(); i++) - if (pc->func->ast->args[i].generic) { - if (pc->known[i]) - unify(calleeFn->funcGenerics[j].type, pc->func->funcGenerics[j].type); - j++; - } - known = pc->known; - seqassert(calleeFn, "not a function: {}", expr->expr->type->toString()); - } else if (!callee->getFunc()) { - // Case 3: callee is not a named function. Route it through a __call__ method. - ExprPtr newCall = N(N(expr->expr, "__call__"), expr->args); - return transform(newCall, false, allowVoidExpr); + } else { + auto pc = callee->getPartial(); + if (pc) { + ExprPtr var = N(partialVar = ctx->cache->getTemporaryVar("pt")); + expr->expr = transform(N(N(clone(var), expr->expr), + N(pc->func->ast->name))); + calleeFn = expr->expr->type->getFunc(); + for (int i = 0, j = 0; i < pc->known.size(); i++) + if (pc->func->ast->args[i].generic) { + if (pc->known[i]) + unify(calleeFn->funcGenerics[j].type, pc->func->funcGenerics[j].type); + j++; + } + known = pc->known; + seqassert(calleeFn, "not a function: {}", expr->expr->type->toString()); + } else if (!callee->getFunc()) { + // Case 3: callee is not a named function. Route it through a __call__ method. + ExprPtr newCall = N(N(expr->expr, "__call__"), expr->args); + return transform(newCall, false, allowVoidExpr); + } } // Handle named and default arguments std::vector args; std::vector typeArgs; int typeArgCount = 0; - // bool isPartial = false; int ellipsisStage = -1; auto newMask = std::vector(calleeFn->ast->args.size(), 1); auto getPartialArg = [&](int pi) { @@ -1385,12 +1381,14 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in N(N(partialTypeName), newArgs)), N(var)); } + const_cast(call->getStmtExpr()) + ->setAttr(ir::PartialFunctionAttribute::AttributeName); call = transform(call, false, allowVoidExpr); - seqassert(call->type->getRecord() && - startswith(call->type->getRecord()->name, partialTypeName) && - !call->type->getPartial(), - "bad partial transformation"); - call->type = N(call->type->getRecord(), calleeFn, newMask); + // seqassert(call->type->getRecord() && + // startswith(call->type->getRecord()->name, partialTypeName) && + // !call->type->getPartial(), + // "bad partial transformation"); + // call->type = N(call->type->getRecord(), calleeFn, newMask); return call; } else { // Case 2. Normal function call. @@ -1427,7 +1425,7 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) auto t = expr->args[1].value->type; auto hierarchy = getSuperTypes(typ->getClass()); - for (auto &tx: hierarchy) { + for (auto &tx : hierarchy) { auto unifyOK = tx->unify(t.get(), nullptr) >= 0; if (unifyOK) { return {true, transform(N(true))}; @@ -1638,9 +1636,12 @@ std::string TypecheckVisitor::generatePartialStub(const std::vector &mask, else if (!fn->ast->args[i].generic) tupleSize++; auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name); - if (!ctx->find(typeName)) + if (!ctx->find(typeName)) { + ctx->cache->partials[typeName] = { + std::static_pointer_cast(fn->shared_from_this()), mask}; // 2 for .starArgs and .kwstarArgs (empty tuples if fn does not have them) generateTupleStub(tupleSize + 2, typeName, {}, false); + } return typeName; } @@ -1712,12 +1713,14 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) { N(N(partialTypeName), N(), N(N(kwName)))), N(var)); + const_cast(call->getStmtExpr()) + ->setAttr(ir::PartialFunctionAttribute::AttributeName); call = transform(call, false, allowVoidExpr); - seqassert(call->type->getRecord() && - startswith(call->type->getRecord()->name, partialTypeName) && - !call->type->getPartial(), - "bad partial transformation"); - call->type = N(call->type->getRecord(), fn, mask); + // seqassert(call->type->getRecord() && + // startswith(call->type->getRecord()->name, partialTypeName) && + // !call->type->getPartial(), + // "bad partial transformation"); + // call->type = N(call->type->getRecord(), fn, mask); return call; } @@ -2023,7 +2026,7 @@ std::vector TypecheckVisitor::getSuperTypes(const ClassTypePtr &cl return result; result.push_back(cls); int start = 0; - for (auto &cand: ctx->cache->classes[cls->name].parentClasses) { + for (auto &cand : ctx->cache->classes[cls->name].parentClasses) { auto name = cand.first; int fields = cand.second; auto val = ctx->find(name); @@ -2037,7 +2040,7 @@ std::vector TypecheckVisitor::getSuperTypes(const ClassTypePtr &cl unify(t, ft); } start += fields; - for (auto &t: getSuperTypes(ftyp)) + for (auto &t : getSuperTypes(ftyp)) result.push_back(t); } return result; diff --git a/codon/parser/visitors/typecheck/typecheck_stmt.cpp b/codon/parser/visitors/typecheck/typecheck_stmt.cpp index 42ece717..c7aa05e9 100644 --- a/codon/parser/visitors/typecheck/typecheck_stmt.cpp +++ b/codon/parser/visitors/typecheck/typecheck_stmt.cpp @@ -508,6 +508,13 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { else typ = std::make_shared( stmt->name, ctx->cache->reverseIdentifierLookup[stmt->name]); + if (stmt->isRecord() && startswith(stmt->name, TYPE_PARTIAL)) { + seqassert(in(ctx->cache->partials, stmt->name), + "invalid partial initialization: {}", stmt->name); + typ = std::make_shared(typ->getRecord(), + ctx->cache->partials[stmt->name].first, + ctx->cache->partials[stmt->name].second); + } typ->setSrcInfo(stmt->getSrcInfo()); ctx->add(TypecheckItem::Type, stmt->name, typ); ctx->bases[0].visitedAsts[stmt->name] = {TypecheckItem::Type, typ}; diff --git a/codon/sir/attribute.cpp b/codon/sir/attribute.cpp index b27541e5..d131c993 100644 --- a/codon/sir/attribute.cpp +++ b/codon/sir/attribute.cpp @@ -1,5 +1,8 @@ -#include "value.h" +#include "attribute.h" +#include "codon/sir/func.h" +#include "codon/sir/util/cloning.h" +#include "codon/sir/value.h" #include "codon/util/fmt/ostream.h" namespace codon { @@ -36,5 +39,136 @@ std::ostream &MemberAttribute::doFormat(std::ostream &os) const { const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute"; +const std::string TupleLiteralAttribute::AttributeName = "tupleLiteralAttribute"; + +std::unique_ptr TupleLiteralAttribute::clone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto *val : elements) + elementsCloned.push_back(cv.clone(val)); + return std::make_unique(elementsCloned); +} + +std::unique_ptr +TupleLiteralAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto *val : elements) + elementsCloned.push_back(cv.forceClone(val)); + return std::make_unique(elementsCloned); +} + +std::ostream &TupleLiteralAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto *val : elements) + strings.push_back(fmt::format(FMT_STRING("{}"), *val)); + fmt::print(os, FMT_STRING("({})"), fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + +const std::string ListLiteralAttribute::AttributeName = "listLiteralAttribute"; + +std::unique_ptr ListLiteralAttribute::clone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto *val : elements) + elementsCloned.push_back(cv.clone(val)); + return std::make_unique(elementsCloned); +} + +std::unique_ptr +ListLiteralAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto *val : elements) + elementsCloned.push_back(cv.forceClone(val)); + return std::make_unique(elementsCloned); +} + +std::ostream &ListLiteralAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto *val : elements) + strings.push_back(fmt::format(FMT_STRING("{}"), *val)); + fmt::print(os, FMT_STRING("[{}]"), fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + +const std::string SetLiteralAttribute::AttributeName = "setLiteralAttribute"; + +std::unique_ptr SetLiteralAttribute::clone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto *val : elements) + elementsCloned.push_back(cv.clone(val)); + return std::make_unique(elementsCloned); +} + +std::unique_ptr +SetLiteralAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto *val : elements) + elementsCloned.push_back(cv.forceClone(val)); + return std::make_unique(elementsCloned); +} + +std::ostream &SetLiteralAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto *val : elements) + strings.push_back(fmt::format(FMT_STRING("{}"), *val)); + fmt::print(os, FMT_STRING("set([{}])"), + fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + +const std::string DictLiteralAttribute::AttributeName = "dictLiteralAttribute"; + +std::unique_ptr DictLiteralAttribute::clone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto &val : elements) + elementsCloned.push_back({cv.clone(val.key), cv.clone(val.value)}); + return std::make_unique(elementsCloned); +} + +std::unique_ptr +DictLiteralAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto &val : elements) + elementsCloned.push_back({cv.forceClone(val.key), cv.forceClone(val.value)}); + return std::make_unique(elementsCloned); +} + +std::ostream &DictLiteralAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto &val : elements) + strings.push_back(fmt::format(FMT_STRING("{}:{}"), *val.key, *val.value)); + fmt::print(os, FMT_STRING("dict([{}])"), + fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + +const std::string PartialFunctionAttribute::AttributeName = "partialFunctionAttribute"; + +std::unique_ptr +PartialFunctionAttribute::clone(util::CloneVisitor &cv) const { + std::vector argsCloned; + for (auto *val : args) + argsCloned.push_back(cv.clone(val)); + return std::make_unique(cast(cv.clone(func)), + argsCloned); +} + +std::unique_ptr +PartialFunctionAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector argsCloned; + for (auto *val : args) + argsCloned.push_back(cv.forceClone(val)); + return std::make_unique(cast(cv.forceClone(func)), + argsCloned); +} + +std::ostream &PartialFunctionAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto *val : args) + strings.push_back(val ? fmt::format(FMT_STRING("{}"), *val) : "..."); + fmt::print(os, FMT_STRING("{}({})"), func->getName(), + fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + } // namespace ir } // namespace codon diff --git a/codon/sir/attribute.h b/codon/sir/attribute.h index 48beb5b7..cf34d072 100644 --- a/codon/sir/attribute.h +++ b/codon/sir/attribute.h @@ -14,6 +14,13 @@ namespace codon { namespace ir { +class Func; +class Value; + +namespace util { +class CloneVisitor; +} + /// Base for SIR attributes. struct Attribute { virtual ~Attribute() noexcept = default; @@ -26,14 +33,15 @@ struct Attribute { } /// @return a clone of the attribute - std::unique_ptr clone() const { - return std::unique_ptr(doClone()); + virtual std::unique_ptr clone(util::CloneVisitor &cv) const = 0; + + /// @return a clone of the attribute + virtual std::unique_ptr forceClone(util::CloneVisitor &cv) const { + return clone(cv); } private: virtual std::ostream &doFormat(std::ostream &os) const = 0; - - virtual Attribute *doClone() const = 0; }; /// Attribute containing SrcInfo @@ -48,10 +56,12 @@ struct SrcInfoAttribute : public Attribute { /// @param info the source info explicit SrcInfoAttribute(codon::SrcInfo info) : info(std::move(info)) {} + std::unique_ptr clone(util::CloneVisitor &cv) const override { + return std::make_unique(*this); + } + private: std::ostream &doFormat(std::ostream &os) const override { return os << info; } - - Attribute *doClone() const override { return new SrcInfoAttribute(*this); } }; /// Attribute containing function information @@ -76,10 +86,12 @@ struct KeyValueAttribute : public Attribute { /// string if none std::string get(const std::string &key) const; + std::unique_ptr clone(util::CloneVisitor &cv) const override { + return std::make_unique(*this); + } + private: std::ostream &doFormat(std::ostream &os) const override; - - Attribute *doClone() const override { return new KeyValueAttribute(*this); } }; /// Attribute containing type member information @@ -95,10 +107,106 @@ struct MemberAttribute : public Attribute { explicit MemberAttribute(std::map memberSrcInfo) : memberSrcInfo(std::move(memberSrcInfo)) {} + std::unique_ptr clone(util::CloneVisitor &cv) const override { + return std::make_unique(*this); + } + private: std::ostream &doFormat(std::ostream &os) const override; +}; - Attribute *doClone() const override { return new MemberAttribute(*this); } +/// Attribute attached to IR structures corresponding to tuple literals +struct TupleLiteralAttribute : public Attribute { + static const std::string AttributeName; + + /// values contained in tuple literal + std::vector elements; + + explicit TupleLiteralAttribute(std::vector elements) + : elements(std::move(elements)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + +/// Attribute attached to IR structures corresponding to list literals +struct ListLiteralAttribute : public Attribute { + static const std::string AttributeName; + + /// values contained in list literal + std::vector elements; + + explicit ListLiteralAttribute(std::vector elements) + : elements(std::move(elements)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + +/// Attribute attached to IR structures corresponding to set literals +struct SetLiteralAttribute : public Attribute { + static const std::string AttributeName; + + /// values contained in set literal + std::vector elements; + + explicit SetLiteralAttribute(std::vector elements) + : elements(std::move(elements)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + +/// Attribute attached to IR structures corresponding to dict literals +struct DictLiteralAttribute : public Attribute { + struct KeyValuePair { + Value *key; + Value *value; + }; + + static const std::string AttributeName; + + /// keys and values contained in dict literal + std::vector elements; + + explicit DictLiteralAttribute(std::vector elements) + : elements(std::move(elements)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + +/// Attribute attached to IR structures corresponding to partial functions +struct PartialFunctionAttribute : public Attribute { + static const std::string AttributeName; + + /// function being called + Func *func; + + /// partial arguments, or null if none + /// e.g. "f(a, ..., b)" has elements [a, null, b] + std::vector args; + + PartialFunctionAttribute(Func *func, std::vector args) + : func(func), args(std::move(args)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; }; } // namespace ir diff --git a/codon/sir/util/cloning.h b/codon/sir/util/cloning.h index c62ce17f..e3aaf01f 100644 --- a/codon/sir/util/cloning.h +++ b/codon/sir/util/cloning.h @@ -82,7 +82,7 @@ public: for (auto it = other->attributes_begin(); it != other->attributes_end(); ++it) { const auto *attr = other->getAttribute(*it); if (attr->needsClone()) { - ctx[id]->setAttribute(attr->clone(), *it); + ctx[id]->setAttribute(attr->clone(*this), *it); } } } @@ -125,7 +125,7 @@ public: for (auto it = other->attributes_begin(); it != other->attributes_end(); ++it) { const auto *attr = other->getAttribute(*it); if (attr->needsClone()) { - ctx[id]->setAttribute(attr->clone(), *it); + ctx[id]->setAttribute(attr->forceClone(*this), *it); } } }