From 2285057005e22b052f648d556bfb76df24d68c44 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Sat, 28 Jan 2023 22:59:49 -0500 Subject: [PATCH] Add extension module codegen --- codon/app/main.cpp | 41 ++++-- codon/cir/llvm/llvisitor.cpp | 133 ++++++++++++++++++- codon/cir/llvm/llvisitor.h | 12 ++ codon/cir/transform/lowering/pyextension.cpp | 5 +- codon/cir/transform/lowering/pyextension.h | 17 +++ codon/cir/transform/manager.cpp | 11 ++ codon/cir/transform/manager.h | 21 ++- codon/compiler/compiler.cpp | 18 ++- codon/compiler/compiler.h | 8 +- 9 files changed, 236 insertions(+), 30 deletions(-) diff --git a/codon/app/main.cpp b/codon/app/main.cpp index 90adba77..e726cc32 100644 --- a/codon/app/main.cpp +++ b/codon/app/main.cpp @@ -4,12 +4,14 @@ #include #include #include +#include #include #include #include #include #include +#include "codon/cir/transform/lowering/pyextension.h" #include "codon/compiler/compiler.h" #include "codon/compiler/error.h" #include "codon/compiler/jit.h" @@ -83,7 +85,7 @@ void initLogFlags(const llvm::cl::opt &log) { codon::getLogger().parse(std::string(d)); } -enum BuildKind { LLVM, Bitcode, Object, Executable, Library, Detect }; +enum BuildKind { LLVM, Bitcode, Object, Executable, Library, PyExtension, Detect }; enum OptMode { Debug, Release }; enum Numerics { C, Python }; } // namespace @@ -109,8 +111,9 @@ int docMode(const std::vector &args, const std::string &argv0) { return EXIT_SUCCESS; } -std::unique_ptr processSource(const std::vector &args, - bool standalone) { +std::unique_ptr processSource( + const std::vector &args, bool standalone, + std::function pyExtension = [] { return false; }) { llvm::cl::opt input(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-")); auto regs = llvm::cl::getRegisteredOptions(); @@ -163,9 +166,9 @@ std::unique_ptr processSource(const std::vector & const bool isDebug = (optMode == OptMode::Debug); std::vector disabledOptsVec(disabledOpts); - auto compiler = std::make_unique(args[0], isDebug, disabledOptsVec, - /*isTest=*/false, - (numerics == Numerics::Python)); + auto compiler = std::make_unique( + args[0], isDebug, disabledOptsVec, + /*isTest=*/false, (numerics == Numerics::Python), pyExtension()); compiler->getLLVMVisitor()->setStandalone(standalone); // load plugins @@ -296,13 +299,15 @@ int buildMode(const std::vector &args, const std::string &argv0) { llvm::cl::desc("Pass given flags to linker")); llvm::cl::opt buildKind( llvm::cl::desc("output type"), - llvm::cl::values(clEnumValN(LLVM, "llvm", "Generate LLVM IR"), - clEnumValN(Bitcode, "bc", "Generate LLVM bitcode"), - clEnumValN(Object, "obj", "Generate native object file"), - clEnumValN(Executable, "exe", "Generate executable"), - clEnumValN(Library, "lib", "Generate shared library"), - clEnumValN(Detect, "detect", - "Detect output type based on output file extension")), + llvm::cl::values( + clEnumValN(LLVM, "llvm", "Generate LLVM IR"), + clEnumValN(Bitcode, "bc", "Generate LLVM bitcode"), + clEnumValN(Object, "obj", "Generate native object file"), + clEnumValN(Executable, "exe", "Generate executable"), + clEnumValN(Library, "lib", "Generate shared library"), + clEnumValN(PyExtension, "pyext", "Generate Python extension module"), + clEnumValN(Detect, "detect", + "Detect output type based on output file extension")), llvm::cl::init(Detect)); llvm::cl::opt output( "o", @@ -310,7 +315,8 @@ int buildMode(const std::vector &args, const std::string &argv0) { "Write compiled output to specified file. Supported extensions: " "none (executable), .o (object file), .ll (LLVM IR), .bc (LLVM bitcode)")); - auto compiler = processSource(args, /*standalone=*/true); + auto compiler = processSource(args, /*standalone=*/true, + [&] { return buildKind == BuildKind::PyExtension; }); if (!compiler) return EXIT_FAILURE; std::vector libsVec(libs); @@ -329,6 +335,7 @@ int buildMode(const std::vector &args, const std::string &argv0) { extension = ".o"; break; case BuildKind::Library: + case BuildKind::PyExtension: extension = isMacOS() ? ".dylib" : ".so"; break; case BuildKind::Executable: @@ -358,6 +365,12 @@ int buildMode(const std::vector &args, const std::string &argv0) { compiler->getLLVMVisitor()->writeToExecutable(filename, argv0, true, libsVec, lflags); break; + case BuildKind::PyExtension: + compiler->getLLVMVisitor()->writeToPythonExtension( + "mymodule", // TODO + compiler->getPassManager()->getPythonExtensionPass()->getExtensionFunctions(), + filename, argv0, libsVec, lflags); + break; case BuildKind::Detect: compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags); break; diff --git a/codon/cir/llvm/llvisitor.cpp b/codon/cir/llvm/llvisitor.cpp index f4050602..d9668ee2 100644 --- a/codon/cir/llvm/llvisitor.cpp +++ b/codon/cir/llvm/llvisitor.cpp @@ -404,9 +404,9 @@ void executeCommand(const std::vector &args) { void LLVMVisitor::setupGlobalCtorForSharedLibrary() { const std::string llvmCtor = "llvm.global_ctors"; auto *main = M->getFunction("main"); - main->setName(".main"); // avoid clash with other main if (M->getNamedValue(llvmCtor) || !main) return; + main->setName(".main"); // avoid clash with other main auto *ctorFuncTy = llvm::FunctionType::get(B->getVoidTy(), {}, /*isVarArg=*/false); auto *ctorEntryTy = llvm::StructType::get(B->getInt32Ty(), ctorFuncTy->getPointerTo(), @@ -541,6 +541,137 @@ void LLVMVisitor::writeToExecutable(const std::string &filename, llvm::sys::fs::remove(objFile); } +namespace { +// https://github.com/python/cpython/blob/main/Include/methodobject.h +constexpr int PYEXT_METH_VARARGS = 0x0001; +constexpr int PYEXT_METH_KEYWORDS = 0x0002; +constexpr int PYEXT_METH_NOARGS = 0x0004; +constexpr int PYEXT_METH_O = 0x0008; +constexpr int PYEXT_METH_CLASS = 0x0010; +constexpr int PYEXT_METH_STATIC = 0x0020; +constexpr int PYEXT_METH_COEXIST = 0x0040; +constexpr int PYEXT_METH_FASTCALL = 0x0080; +constexpr int PYEXT_METH_METHOD = 0x0200; +// https://github.com/python/cpython/blob/main/Include/modsupport.h +constexpr int PYEXT_PYTHON_ABI_VERSION = 3; +} // namespace + +void LLVMVisitor::writeToPythonExtension( + const std::string &name, const std::vector> &funcs, + const std::string &filename, const std::string &argv0, + const std::vector &libs, const std::string &lflags) { + // Construct PyMethodDef array + auto *ptr = B->getInt8PtrTy(); + auto *null = llvm::Constant::getNullValue(ptr); + auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr); + std::vector pyMethods; + + for (auto &p : funcs) { + auto *original = p.first; + auto *generated = p.second; + auto llvmName = getNameForFunction(generated); + auto *llvmFunc = M->getNamedValue(llvmName); + seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName); + + auto name = original->getUnmangledName(); + auto *nameVar = new llvm::GlobalVariable( + *M, llvm::ArrayType::get(B->getInt8Ty(), name.length() + 1), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantDataArray::getString(*context, name), ".pyext_func_name"); + nameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); + + auto *nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr); + auto *funcConst = llvm::ConstantExpr::getBitCast(llvmFunc, ptr); + auto *flagConst = B->getInt32(PYEXT_METH_FASTCALL); + auto *docsConst = null; + if (auto *docsAttr = original->getAttribute()) { + auto docs = docsAttr->docstring; + auto *docsVar = new llvm::GlobalVariable( + *M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring"); + docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); + docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr); + } + pyMethods.push_back(llvm::ConstantStruct::get(pyMethodDefType, nameConst, funcConst, + flagConst, docsConst)); + } + pyMethods.push_back( + llvm::ConstantStruct::get(pyMethodDefType, null, null, B->getInt32(0), null)); + + auto *pyMethodDefArrayType = llvm::ArrayType::get(pyMethodDefType, pyMethods.size()); + auto *pyMethodDefArray = new llvm::GlobalVariable( + *M, pyMethodDefArrayType, + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantArray::get(pyMethodDefArrayType, pyMethods), ".pyext_methods"); + + // Construct PyModuleDef array + auto *pyObjectType = llvm::StructType::get(B->getInt64Ty(), ptr); + auto *pyModuleDefBaseType = + llvm::StructType::get(pyObjectType, ptr, B->getInt64Ty(), ptr); + auto *pyModuleDefType = + llvm::StructType::get(pyModuleDefBaseType, ptr, ptr, B->getInt64Ty(), + pyMethodDefType->getPointerTo(), ptr, ptr, ptr, ptr); + + auto *pyObjectConst = llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), null); + auto *pyModuleDefBaseConst = llvm::ConstantStruct::get( + pyModuleDefBaseType, pyObjectConst, null, B->getInt64(0), null); + + auto *nameVar = new llvm::GlobalVariable( + *M, llvm::ArrayType::get(B->getInt8Ty(), name.length() + 1), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantDataArray::getString(*context, name), ".pyext_module_name"); + nameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); + auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr); + + auto *docsConst = null; + if (!funcs.empty()) { + if (auto *docsAttr = + funcs[0].first->getModule()->getAttribute()) { + auto docs = docsAttr->docstring; + auto *docsVar = new llvm::GlobalVariable( + *M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring"); + docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); + docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr); + } + } + + auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr); + auto *pyModuleDef = llvm::ConstantStruct::get( + pyModuleDefType, pyModuleDefBaseConst, nameConst, docsConst, B->getInt64(-1), + pyMethodArrayConst, null, null, null, null); + auto *pyModuleVar = + new llvm::GlobalVariable(*M, pyModuleDef->getType(), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + pyModuleDef, ".pyext_module"); + auto *pyModuleConst = llvm::ConstantExpr::getBitCast(pyModuleVar, ptr); + + // Construct initialization hook + auto pyModuleCreate = cast( + M->getOrInsertFunction("PyModule_Create2", ptr, ptr, B->getInt32Ty()) + .getCallee()); + pyModuleCreate->setDoesNotThrow(); + + auto *pyModuleInit = + cast(M->getOrInsertFunction("PyInit_" + name, ptr).getCallee()); + auto *entry = llvm::BasicBlock::Create(*context, "entry", pyModuleInit); + B->SetInsertPoint(entry); + if (auto *main = M->getFunction("main")) { + main->setName(".main"); + B->CreateCall({main->getFunctionType(), main}, + {B->getInt32(0), + llvm::ConstantPointerNull::get(B->getInt8PtrTy()->getPointerTo())}); + } + B->CreateRet(B->CreateCall(pyModuleCreate, + {pyModuleConst, B->getInt32(PYEXT_PYTHON_ABI_VERSION)})); + + // Generate shared object + // (This will not create a global ctor since we renamed the 'main' function above.) + writeToExecutable(filename, argv0, /*library=*/true, libs, lflags); +} + void LLVMVisitor::compile(const std::string &filename, const std::string &argv0, const std::vector &libs, const std::string &lflags) { diff --git a/codon/cir/llvm/llvisitor.h b/codon/cir/llvm/llvisitor.h index 499bf98a..9d253705 100644 --- a/codon/cir/llvm/llvisitor.h +++ b/codon/cir/llvm/llvisitor.h @@ -342,6 +342,18 @@ public: bool library = false, const std::vector &libs = {}, const std::string &lflags = ""); + /// Writes module as Python extension shared object. + /// @param name the module's name + /// @param funcs extension functions + /// @param filename the file to write to + /// @param argv0 compiler's argv[0] used to set rpath + /// @param libs library names to link + /// @param lflags extra flags to pass linker + void writeToPythonExtension(const std::string &name, + const std::vector> &funcs, + const std::string &filename, const std::string &argv0, + const std::vector &libs = {}, + const std::string &lflags = ""); /// Runs optimization passes on module and writes the result /// to the specified file. The output type is determined by /// the file extension (.ll for LLVM IR, .bc for LLVM bitcode diff --git a/codon/cir/transform/lowering/pyextension.cpp b/codon/cir/transform/lowering/pyextension.cpp index 49e9eb82..3adbe28a 100644 --- a/codon/cir/transform/lowering/pyextension.cpp +++ b/codon/cir/transform/lowering/pyextension.cpp @@ -23,7 +23,7 @@ Func *generateExtensionFunc(Func *f) { auto *M = f->getModule(); auto *cobj = M->getPointerType(M->getByteType()); - auto *ext = M->Nr("__py_extension"); + auto *ext = M->Nr("__.py_extension.__"); ext->realize(M->getFuncType(cobj, {cobj, M->getPointerType(cobj), M->getIntType()}), {"self", "args", "nargs"}); auto *body = M->Nr(); @@ -62,8 +62,7 @@ void PythonExtensionLowering::run(Module *module) { if (!util::hasAttribute(f, EXPORT_ATTR)) continue; - std::cout << f->getName() << std::endl; - std::cout << *generateExtensionFunc(f) << std::endl; + extFuncs.emplace_back(f, generateExtensionFunc(f)); } } } diff --git a/codon/cir/transform/lowering/pyextension.h b/codon/cir/transform/lowering/pyextension.h index 0e689fe8..e5b68480 100644 --- a/codon/cir/transform/lowering/pyextension.h +++ b/codon/cir/transform/lowering/pyextension.h @@ -2,6 +2,9 @@ #pragma once +#include +#include + #include "codon/cir/transform/pass.h" namespace codon { @@ -10,10 +13,24 @@ namespace transform { namespace lowering { class PythonExtensionLowering : public Pass { +private: + /// vector of original function (1st) and generated + /// extension wrapper (2nd) + std::vector> extFuncs; + public: static const std::string KEY; std::string getKey() const override { return KEY; } + + /// Constructs a PythonExtensionLowering pass. + PythonExtensionLowering() : Pass(), extFuncs() {} + void run(Module *module) override; + + /// @return extension function (original, generated) pairs + std::vector> getExtensionFunctions() const { + return extFuncs; + } }; } // namespace lowering diff --git a/codon/cir/transform/manager.cpp b/codon/cir/transform/manager.cpp index b7ebab46..3a65038a 100644 --- a/codon/cir/transform/manager.cpp +++ b/codon/cir/transform/manager.cpp @@ -148,6 +148,12 @@ void PassManager::invalidate(const std::string &key) { } void PassManager::registerStandardPasses(PassManager::Init init) { + std::unique_ptr pyExtPass; + if (pyExtension) { + pyExtPass = std::make_unique(); + pyExtensionPass = pyExtPass.get(); + } + switch (init) { case Init::EMPTY: break; @@ -155,6 +161,8 @@ void PassManager::registerStandardPasses(PassManager::Init init) { registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique()); + if (pyExtension) + registerPass(std::move(pyExtPass)); break; } case Init::RELEASE: @@ -201,6 +209,9 @@ void PassManager::registerStandardPasses(PassManager::Init init) { registerPass(std::make_unique(), /*insertBefore=*/"", {}, {cfgKey, globalKey}); + if (pyExtension) + registerPass(std::move(pyExtPass)); + if (init != Init::JIT) { // Don't demote globals in JIT mode, since they might be used later // by another user input. diff --git a/codon/cir/transform/manager.h b/codon/cir/transform/manager.h index 9ef74bee..3c32fa0d 100644 --- a/codon/cir/transform/manager.h +++ b/codon/cir/transform/manager.h @@ -11,6 +11,7 @@ #include "codon/cir/analyze/analysis.h" #include "codon/cir/module.h" +#include "codon/cir/transform/lowering/pyextension.h" #include "codon/cir/transform/pass.h" namespace codon { @@ -94,6 +95,12 @@ private: /// whether to use Python (vs. C) numeric semantics in passes bool pyNumerics; + /// true if we are compiling as a Python extension + bool pyExtension; + + /// pointer to Python extension lowering pass, if applicable + lowering::PythonExtensionLowering *pyExtensionPass; + public: /// PassManager initialization mode. enum Init { @@ -104,16 +111,17 @@ public: }; explicit PassManager(Init init, std::vector disabled = {}, - bool pyNumerics = false) + bool pyNumerics = false, bool pyExtension = false) : km(), passes(), analyses(), executionOrder(), results(), - disabled(std::move(disabled)), pyNumerics(pyNumerics) { + disabled(std::move(disabled)), pyNumerics(pyNumerics), pyExtension(pyExtension), + pyExtensionPass(nullptr) { registerStandardPasses(init); } explicit PassManager(bool debug = false, std::vector disabled = {}, - bool pyNumerics = false) + bool pyNumerics = false, bool pyExtension = false) : PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled), - pyNumerics) {} + pyNumerics, pyExtension) {} /// Checks if the given pass is included in this manager. /// @param key the pass key @@ -174,6 +182,11 @@ public: return std::find(disabled.begin(), disabled.end(), key) != disabled.end(); } + /// @return the Python extension lowering pass, or null if none + lowering::PythonExtensionLowering *getPythonExtensionPass() const { + return pyExtensionPass; + } + private: void runPass(Module *module, const std::string &name); void registerStandardPasses(Init init); diff --git a/codon/compiler/compiler.cpp b/codon/compiler/compiler.cpp index 54f7dd86..fdfde1e2 100644 --- a/codon/compiler/compiler.cpp +++ b/codon/compiler/compiler.cpp @@ -32,13 +32,13 @@ ir::transform::PassManager::Init getPassManagerInit(Compiler::Mode mode, bool is Compiler::Compiler(const std::string &argv0, Compiler::Mode mode, const std::vector &disabledPasses, bool isTest, - bool pyNumerics) - : argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), input(), - plm(std::make_unique(argv0)), + bool pyNumerics, bool pyExtension) + : argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), + pyExtension(pyExtension), input(), plm(std::make_unique(argv0)), cache(std::make_unique(argv0)), module(std::make_unique()), - pm(std::make_unique(getPassManagerInit(mode, isTest), - disabledPasses, pyNumerics)), + pm(std::make_unique( + getPassManagerInit(mode, isTest), disabledPasses, pyNumerics, pyExtension)), llvisitor(std::make_unique()) { cache->module = module.get(); cache->pythonCompat = pyNumerics; @@ -181,6 +181,14 @@ std::unordered_map Compiler::getEarlyDefines() { std::unordered_map earlyDefines; earlyDefines.emplace("__debug__", debug ? "1" : "0"); earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0"); + earlyDefines.emplace("__py_extension__", pyExtension ? "1" : "0"); + earlyDefines.emplace("__apple__", +#if __APPLE__ + "1" +#else + "0" +#endif + ); return earlyDefines; } diff --git a/codon/compiler/compiler.h b/codon/compiler/compiler.h index cc13bfa7..de188e67 100644 --- a/codon/compiler/compiler.h +++ b/codon/compiler/compiler.h @@ -28,6 +28,7 @@ private: std::string argv0; bool debug; bool pyNumerics; + bool pyExtension; std::string input; std::unique_ptr plm; std::unique_ptr cache; @@ -42,13 +43,14 @@ private: public: Compiler(const std::string &argv0, Mode mode, const std::vector &disabledPasses = {}, bool isTest = false, - bool pyNumerics = false); + bool pyNumerics = false, bool pyExtension = false); explicit Compiler(const std::string &argv0, bool debug = false, const std::vector &disabledPasses = {}, - bool isTest = false, bool pyNumerics = false) + bool isTest = false, bool pyNumerics = false, + bool pyExtension = false) : Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest, - pyNumerics) {} + pyNumerics, pyExtension) {} std::string getInput() const { return input; } PluginManager *getPluginManager() const { return plm.get(); }