#pragma once #include "sir/sir.h" #include "visitor.h" namespace codon { namespace ir { namespace util { class CloneVisitor : public ConstVisitor { private: /// the clone context std::unordered_map ctx; /// the result Node *result = nullptr; /// the module Module *module; /// true if break/continue loops should be cloned bool cloneLoop; public: /// Constructs a clone visitor. /// @param module the module /// @param cloneLoop true if break/continue loops should be cloned explicit CloneVisitor(Module *module, bool cloneLoop = true) : module(module) {} virtual ~CloneVisitor() noexcept = default; void visit(const Var *v) override; void visit(const BodiedFunc *v) override; void visit(const ExternalFunc *v) override; void visit(const InternalFunc *v) override; void visit(const LLVMFunc *v) override; void visit(const VarValue *v) override; void visit(const PointerValue *v) override; void visit(const SeriesFlow *v) override; void visit(const IfFlow *v) override; void visit(const WhileFlow *v) override; void visit(const ForFlow *v) override; void visit(const ImperativeForFlow *v) override; void visit(const TryCatchFlow *v) override; void visit(const PipelineFlow *v) override; void visit(const dsl::CustomFlow *v) override; void visit(const IntConst *v) override; void visit(const FloatConst *v) override; void visit(const BoolConst *v) override; void visit(const StringConst *v) override; void visit(const dsl::CustomConst *v) override; void visit(const AssignInstr *v) override; void visit(const ExtractInstr *v) override; void visit(const InsertInstr *v) override; void visit(const CallInstr *v) override; void visit(const StackAllocInstr *v) override; void visit(const TypePropertyInstr *v) override; void visit(const YieldInInstr *v) override; void visit(const TernaryInstr *v) override; void visit(const BreakInstr *v) override; void visit(const ContinueInstr *v) override; void visit(const ReturnInstr *v) override; void visit(const YieldInstr *v) override; void visit(const ThrowInstr *v) override; void visit(const FlowInstr *v) override; void visit(const dsl::CustomInstr *v) override; /// Clones a value, returning the previous value if other has already been cloned. /// @param other the original /// @return the clone Value *clone(const Value *other) { if (!other) return nullptr; auto id = other->getId(); if (ctx.find(id) == ctx.end()) { other->accept(*this); ctx[id] = result; for (auto it = other->attributes_begin(); it != other->attributes_end(); ++it) { const auto *attr = other->getAttribute(*it); if (attr->needsClone()) { ctx[id]->setAttribute(attr->clone(), *it); } } } return cast(ctx[id]); } /// Returns the original unless the variable has been force cloned. /// @param other the original /// @return the original or the previous clone Var *clone(const Var *other) { if (!other) return nullptr; auto id = other->getId(); if (ctx.find(id) != ctx.end()) return cast(ctx[id]); return const_cast(other); } /// Clones a flow, returning the previous value if other has already been cloned. /// @param other the original /// @return the clone Flow *clone(const Flow *other) { return cast(clone(static_cast(other))); } /// Forces a clone. No difference for values but ensures that variables are actually /// cloned. /// @param other the original /// @return the clone template NodeType *forceClone(const NodeType *other) { if (!other) return nullptr; auto id = other->getId(); if (ctx.find(id) == ctx.end()) { other->accept(*this); ctx[id] = result; for (auto it = other->attributes_begin(); it != other->attributes_end(); ++it) { const auto *attr = other->getAttribute(*it); if (attr->needsClone()) { ctx[id]->setAttribute(attr->clone(), *it); } } } return cast(ctx[id]); } /// Remaps a clone. /// @param original the original /// @param newVal the clone template void forceRemap(const NodeType *original, const NodeType *newVal) { ctx[original->getId()] = const_cast(newVal); } PipelineFlow::Stage clone(const PipelineFlow::Stage &other) { std::vector args; for (const auto *a : other) args.push_back(clone(a)); return {clone(other.getCallee()), std::move(args), other.isGenerator(), other.isParallel()}; } private: template NodeType *Nt(const NodeType *source, Args... args) { return module->N(source, std::forward(args)..., source->getName()); } }; } // namespace util } // namespace ir } // namespace codon