From c08a2d7d17976941f75172be52862d726164f010 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Mon, 6 Feb 2023 13:42:37 -0500 Subject: [PATCH] Reorganize API --- codon/app/main.cpp | 4 +-- codon/cir/attribute.cpp | 16 +++++++++ codon/cir/attribute.h | 20 +++++++++++ codon/cir/llvm/llvisitor.cpp | 37 +++++++++++--------- codon/cir/llvm/llvisitor.h | 9 ++--- codon/cir/transform/lowering/pyextension.cpp | 2 +- codon/cir/transform/lowering/pyextension.h | 13 ------- codon/cir/transform/manager.cpp | 10 ++---- codon/cir/transform/manager.h | 12 ++----- 9 files changed, 67 insertions(+), 56 deletions(-) diff --git a/codon/app/main.cpp b/codon/app/main.cpp index 75d827df..48917038 100644 --- a/codon/app/main.cpp +++ b/codon/app/main.cpp @@ -11,7 +11,6 @@ #include #include -#include "codon/cir/transform/lowering/pyextension.h" #include "codon/compiler/compiler.h" #include "codon/compiler/error.h" #include "codon/compiler/jit.h" @@ -372,8 +371,7 @@ int buildMode(const std::vector &args, const std::string &argv0) { case BuildKind::PyExtension: compiler->getLLVMVisitor()->writeToPythonExtension( pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule, - compiler->getPassManager()->getPythonExtensionPass()->getExtensionFunctions(), - filename); + compiler->getModule(), filename); break; case BuildKind::Detect: compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags); diff --git a/codon/cir/attribute.cpp b/codon/cir/attribute.cpp index 0f104e99..9a496698 100644 --- a/codon/cir/attribute.cpp +++ b/codon/cir/attribute.cpp @@ -39,6 +39,22 @@ std::ostream &MemberAttribute::doFormat(std::ostream &os) const { return os; } +const std::string PythonWrapperAttribute::AttributeName = "pythonWrapperAttribute"; + +std::unique_ptr PythonWrapperAttribute::clone(util::CloneVisitor &cv) const { + return std::make_unique(cast(cv.clone(original))); +} + +std::unique_ptr +PythonWrapperAttribute::forceClone(util::CloneVisitor &cv) const { + return std::make_unique(cv.forceClone(original)); +} + +std::ostream &PythonWrapperAttribute::doFormat(std::ostream &os) const { + fmt::print(os, FMT_STRING("(pywrap {})"), original->referenceString()); + return os; +} + const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute"; const std::string DocstringAttribute::AttributeName = "docstringAttribute"; diff --git a/codon/cir/attribute.h b/codon/cir/attribute.h index 2fc8a841..77482282 100644 --- a/codon/cir/attribute.h +++ b/codon/cir/attribute.h @@ -135,6 +135,26 @@ private: std::ostream &doFormat(std::ostream &os) const override; }; +/// Attribute used to mark Python wrappers of Codon functions +struct PythonWrapperAttribute : public Attribute { + static const std::string AttributeName; + + /// the function being wrapped + Func *original; + + /// Constructs a PythonWrapperAttribute. + /// @param original the function being wrapped + explicit PythonWrapperAttribute(Func *original) : original(original) {} + + bool needsClone() const override { return false; } + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + /// Attribute attached to IR structures corresponding to tuple literals struct TupleLiteralAttribute : public Attribute { static const std::string AttributeName; diff --git a/codon/cir/llvm/llvisitor.cpp b/codon/cir/llvm/llvisitor.cpp index 4f8bcda4..b0f8151a 100644 --- a/codon/cir/llvm/llvisitor.cpp +++ b/codon/cir/llvm/llvisitor.cpp @@ -655,18 +655,24 @@ llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) { return wrap; } -void LLVMVisitor::writeToPythonExtension( - const std::string &name, const std::vector> &funcs, - const std::string &filename) { +void LLVMVisitor::writeToPythonExtension(const std::string &name, const Module *module, + const std::string &filename) { // Construct PyMethodDef array auto *ptr = B->getInt8PtrTy(); auto *null = llvm::Constant::getNullValue(ptr); auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr); std::vector pyMethods; - for (auto &p : funcs) { - auto *original = p.first; - auto *generated = p.second; + for (auto *var : *module) { + auto *generated = cast(var); + if (!generated) + continue; + + auto *pyWrapAttr = generated->getAttribute(); + if (!pyWrapAttr) + continue; + + auto *original = pyWrapAttr->original; auto llvmName = getNameForFunction(generated); auto *llvmFunc = M->getFunction(llvmName); seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName); @@ -735,17 +741,14 @@ void LLVMVisitor::writeToPythonExtension( auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr); auto *docsConst = null; - if (!funcs.empty()) { - if (auto *docsAttr = - funcs[0].first->getModule()->getAttribute()) { - auto docs = docsAttr->docstring; - auto *docsVar = new llvm::GlobalVariable( - *M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, - llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring"); - docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); - docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr); - } + if (auto *docsAttr = module->getAttribute()) { + auto docs = docsAttr->docstring; + auto *docsVar = new llvm::GlobalVariable( + *M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring"); + docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); + docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr); } auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr); diff --git a/codon/cir/llvm/llvisitor.h b/codon/cir/llvm/llvisitor.h index b4de5549..7c08612a 100644 --- a/codon/cir/llvm/llvisitor.h +++ b/codon/cir/llvm/llvisitor.h @@ -345,12 +345,13 @@ public: bool library = false, const std::vector &libs = {}, const std::string &lflags = ""); - /// Writes module as Python extension shared object. + /// Writes module as Python extension object. Exposes + /// functions based on "PythonWrapperAttribute" attached + /// to IR functions. /// @param name the module's name - /// @param funcs extension functions + /// @param module the IR module /// @param filename the file to write to - void writeToPythonExtension(const std::string &name, - const std::vector> &funcs, + void writeToPythonExtension(const std::string &name, const Module *module, const std::string &filename); /// Runs optimization passes on module and writes the result /// to the specified file. The output type is determined by diff --git a/codon/cir/transform/lowering/pyextension.cpp b/codon/cir/transform/lowering/pyextension.cpp index 8287ca8f..a114f298 100644 --- a/codon/cir/transform/lowering/pyextension.cpp +++ b/codon/cir/transform/lowering/pyextension.cpp @@ -94,7 +94,7 @@ void PythonExtensionLowering::run(Module *module) { if (auto *g = generateExtensionFunc(f)) { LOG("[pyext] exporting {}", f->getName()); - extFuncs.emplace_back(f, g); + g->setAttribute(std::make_unique(f)); } } } diff --git a/codon/cir/transform/lowering/pyextension.h b/codon/cir/transform/lowering/pyextension.h index e5b68480..a90e5cdc 100644 --- a/codon/cir/transform/lowering/pyextension.h +++ b/codon/cir/transform/lowering/pyextension.h @@ -13,24 +13,11 @@ namespace transform { namespace lowering { class PythonExtensionLowering : public Pass { -private: - /// vector of original function (1st) and generated - /// extension wrapper (2nd) - std::vector> extFuncs; - public: static const std::string KEY; std::string getKey() const override { return KEY; } - /// Constructs a PythonExtensionLowering pass. - PythonExtensionLowering() : Pass(), extFuncs() {} - void run(Module *module) override; - - /// @return extension function (original, generated) pairs - std::vector> getExtensionFunctions() const { - return extFuncs; - } }; } // namespace lowering diff --git a/codon/cir/transform/manager.cpp b/codon/cir/transform/manager.cpp index 3a65038a..a641a4d8 100644 --- a/codon/cir/transform/manager.cpp +++ b/codon/cir/transform/manager.cpp @@ -148,12 +148,6 @@ void PassManager::invalidate(const std::string &key) { } void PassManager::registerStandardPasses(PassManager::Init init) { - std::unique_ptr pyExtPass; - if (pyExtension) { - pyExtPass = std::make_unique(); - pyExtensionPass = pyExtPass.get(); - } - switch (init) { case Init::EMPTY: break; @@ -162,7 +156,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) { registerPass(std::make_unique()); registerPass(std::make_unique()); if (pyExtension) - registerPass(std::move(pyExtPass)); + registerPass(std::make_unique()); break; } case Init::RELEASE: @@ -210,7 +204,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) { {cfgKey, globalKey}); if (pyExtension) - registerPass(std::move(pyExtPass)); + registerPass(std::make_unique()); if (init != Init::JIT) { // Don't demote globals in JIT mode, since they might be used later diff --git a/codon/cir/transform/manager.h b/codon/cir/transform/manager.h index 3c32fa0d..6174c024 100644 --- a/codon/cir/transform/manager.h +++ b/codon/cir/transform/manager.h @@ -98,9 +98,6 @@ private: /// true if we are compiling as a Python extension bool pyExtension; - /// pointer to Python extension lowering pass, if applicable - lowering::PythonExtensionLowering *pyExtensionPass; - public: /// PassManager initialization mode. enum Init { @@ -113,8 +110,8 @@ public: explicit PassManager(Init init, std::vector disabled = {}, bool pyNumerics = false, bool pyExtension = false) : km(), passes(), analyses(), executionOrder(), results(), - disabled(std::move(disabled)), pyNumerics(pyNumerics), pyExtension(pyExtension), - pyExtensionPass(nullptr) { + disabled(std::move(disabled)), pyNumerics(pyNumerics), + pyExtension(pyExtension) { registerStandardPasses(init); } @@ -182,11 +179,6 @@ public: return std::find(disabled.begin(), disabled.end(), key) != disabled.end(); } - /// @return the Python extension lowering pass, or null if none - lowering::PythonExtensionLowering *getPythonExtensionPass() const { - return pyExtensionPass; - } - private: void runPass(Module *module, const std::string &name); void registerStandardPasses(Init init);