mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
Add extension module codegen
This commit is contained in:
parent
947b9fe52b
commit
2285057005
@ -4,12 +4,14 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "codon/cir/transform/lowering/pyextension.h"
|
||||||
#include "codon/compiler/compiler.h"
|
#include "codon/compiler/compiler.h"
|
||||||
#include "codon/compiler/error.h"
|
#include "codon/compiler/error.h"
|
||||||
#include "codon/compiler/jit.h"
|
#include "codon/compiler/jit.h"
|
||||||
@ -83,7 +85,7 @@ void initLogFlags(const llvm::cl::opt<std::string> &log) {
|
|||||||
codon::getLogger().parse(std::string(d));
|
codon::getLogger().parse(std::string(d));
|
||||||
}
|
}
|
||||||
|
|
||||||
enum BuildKind { LLVM, Bitcode, Object, Executable, Library, Detect };
|
enum BuildKind { LLVM, Bitcode, Object, Executable, Library, PyExtension, Detect };
|
||||||
enum OptMode { Debug, Release };
|
enum OptMode { Debug, Release };
|
||||||
enum Numerics { C, Python };
|
enum Numerics { C, Python };
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -109,8 +111,9 @@ int docMode(const std::vector<const char *> &args, const std::string &argv0) {
|
|||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<codon::Compiler> processSource(const std::vector<const char *> &args,
|
std::unique_ptr<codon::Compiler> processSource(
|
||||||
bool standalone) {
|
const std::vector<const char *> &args, bool standalone,
|
||||||
|
std::function<bool()> pyExtension = [] { return false; }) {
|
||||||
llvm::cl::opt<std::string> input(llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
llvm::cl::opt<std::string> input(llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
||||||
llvm::cl::init("-"));
|
llvm::cl::init("-"));
|
||||||
auto regs = llvm::cl::getRegisteredOptions();
|
auto regs = llvm::cl::getRegisteredOptions();
|
||||||
@ -163,9 +166,9 @@ std::unique_ptr<codon::Compiler> processSource(const std::vector<const char *> &
|
|||||||
|
|
||||||
const bool isDebug = (optMode == OptMode::Debug);
|
const bool isDebug = (optMode == OptMode::Debug);
|
||||||
std::vector<std::string> disabledOptsVec(disabledOpts);
|
std::vector<std::string> disabledOptsVec(disabledOpts);
|
||||||
auto compiler = std::make_unique<codon::Compiler>(args[0], isDebug, disabledOptsVec,
|
auto compiler = std::make_unique<codon::Compiler>(
|
||||||
/*isTest=*/false,
|
args[0], isDebug, disabledOptsVec,
|
||||||
(numerics == Numerics::Python));
|
/*isTest=*/false, (numerics == Numerics::Python), pyExtension());
|
||||||
compiler->getLLVMVisitor()->setStandalone(standalone);
|
compiler->getLLVMVisitor()->setStandalone(standalone);
|
||||||
|
|
||||||
// load plugins
|
// load plugins
|
||||||
@ -296,13 +299,15 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
|
|||||||
llvm::cl::desc("Pass given flags to linker"));
|
llvm::cl::desc("Pass given flags to linker"));
|
||||||
llvm::cl::opt<BuildKind> buildKind(
|
llvm::cl::opt<BuildKind> buildKind(
|
||||||
llvm::cl::desc("output type"),
|
llvm::cl::desc("output type"),
|
||||||
llvm::cl::values(clEnumValN(LLVM, "llvm", "Generate LLVM IR"),
|
llvm::cl::values(
|
||||||
clEnumValN(Bitcode, "bc", "Generate LLVM bitcode"),
|
clEnumValN(LLVM, "llvm", "Generate LLVM IR"),
|
||||||
clEnumValN(Object, "obj", "Generate native object file"),
|
clEnumValN(Bitcode, "bc", "Generate LLVM bitcode"),
|
||||||
clEnumValN(Executable, "exe", "Generate executable"),
|
clEnumValN(Object, "obj", "Generate native object file"),
|
||||||
clEnumValN(Library, "lib", "Generate shared library"),
|
clEnumValN(Executable, "exe", "Generate executable"),
|
||||||
clEnumValN(Detect, "detect",
|
clEnumValN(Library, "lib", "Generate shared library"),
|
||||||
"Detect output type based on output file extension")),
|
clEnumValN(PyExtension, "pyext", "Generate Python extension module"),
|
||||||
|
clEnumValN(Detect, "detect",
|
||||||
|
"Detect output type based on output file extension")),
|
||||||
llvm::cl::init(Detect));
|
llvm::cl::init(Detect));
|
||||||
llvm::cl::opt<std::string> output(
|
llvm::cl::opt<std::string> output(
|
||||||
"o",
|
"o",
|
||||||
@ -310,7 +315,8 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
|
|||||||
"Write compiled output to specified file. Supported extensions: "
|
"Write compiled output to specified file. Supported extensions: "
|
||||||
"none (executable), .o (object file), .ll (LLVM IR), .bc (LLVM bitcode)"));
|
"none (executable), .o (object file), .ll (LLVM IR), .bc (LLVM bitcode)"));
|
||||||
|
|
||||||
auto compiler = processSource(args, /*standalone=*/true);
|
auto compiler = processSource(args, /*standalone=*/true,
|
||||||
|
[&] { return buildKind == BuildKind::PyExtension; });
|
||||||
if (!compiler)
|
if (!compiler)
|
||||||
return EXIT_FAILURE;
|
return EXIT_FAILURE;
|
||||||
std::vector<std::string> libsVec(libs);
|
std::vector<std::string> libsVec(libs);
|
||||||
@ -329,6 +335,7 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
|
|||||||
extension = ".o";
|
extension = ".o";
|
||||||
break;
|
break;
|
||||||
case BuildKind::Library:
|
case BuildKind::Library:
|
||||||
|
case BuildKind::PyExtension:
|
||||||
extension = isMacOS() ? ".dylib" : ".so";
|
extension = isMacOS() ? ".dylib" : ".so";
|
||||||
break;
|
break;
|
||||||
case BuildKind::Executable:
|
case BuildKind::Executable:
|
||||||
@ -358,6 +365,12 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
|
|||||||
compiler->getLLVMVisitor()->writeToExecutable(filename, argv0, true, libsVec,
|
compiler->getLLVMVisitor()->writeToExecutable(filename, argv0, true, libsVec,
|
||||||
lflags);
|
lflags);
|
||||||
break;
|
break;
|
||||||
|
case BuildKind::PyExtension:
|
||||||
|
compiler->getLLVMVisitor()->writeToPythonExtension(
|
||||||
|
"mymodule", // TODO
|
||||||
|
compiler->getPassManager()->getPythonExtensionPass()->getExtensionFunctions(),
|
||||||
|
filename, argv0, libsVec, lflags);
|
||||||
|
break;
|
||||||
case BuildKind::Detect:
|
case BuildKind::Detect:
|
||||||
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);
|
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);
|
||||||
break;
|
break;
|
||||||
|
@ -404,9 +404,9 @@ void executeCommand(const std::vector<std::string> &args) {
|
|||||||
void LLVMVisitor::setupGlobalCtorForSharedLibrary() {
|
void LLVMVisitor::setupGlobalCtorForSharedLibrary() {
|
||||||
const std::string llvmCtor = "llvm.global_ctors";
|
const std::string llvmCtor = "llvm.global_ctors";
|
||||||
auto *main = M->getFunction("main");
|
auto *main = M->getFunction("main");
|
||||||
main->setName(".main"); // avoid clash with other main
|
|
||||||
if (M->getNamedValue(llvmCtor) || !main)
|
if (M->getNamedValue(llvmCtor) || !main)
|
||||||
return;
|
return;
|
||||||
|
main->setName(".main"); // avoid clash with other main
|
||||||
|
|
||||||
auto *ctorFuncTy = llvm::FunctionType::get(B->getVoidTy(), {}, /*isVarArg=*/false);
|
auto *ctorFuncTy = llvm::FunctionType::get(B->getVoidTy(), {}, /*isVarArg=*/false);
|
||||||
auto *ctorEntryTy = llvm::StructType::get(B->getInt32Ty(), ctorFuncTy->getPointerTo(),
|
auto *ctorEntryTy = llvm::StructType::get(B->getInt32Ty(), ctorFuncTy->getPointerTo(),
|
||||||
@ -541,6 +541,137 @@ void LLVMVisitor::writeToExecutable(const std::string &filename,
|
|||||||
llvm::sys::fs::remove(objFile);
|
llvm::sys::fs::remove(objFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// https://github.com/python/cpython/blob/main/Include/methodobject.h
|
||||||
|
constexpr int PYEXT_METH_VARARGS = 0x0001;
|
||||||
|
constexpr int PYEXT_METH_KEYWORDS = 0x0002;
|
||||||
|
constexpr int PYEXT_METH_NOARGS = 0x0004;
|
||||||
|
constexpr int PYEXT_METH_O = 0x0008;
|
||||||
|
constexpr int PYEXT_METH_CLASS = 0x0010;
|
||||||
|
constexpr int PYEXT_METH_STATIC = 0x0020;
|
||||||
|
constexpr int PYEXT_METH_COEXIST = 0x0040;
|
||||||
|
constexpr int PYEXT_METH_FASTCALL = 0x0080;
|
||||||
|
constexpr int PYEXT_METH_METHOD = 0x0200;
|
||||||
|
// https://github.com/python/cpython/blob/main/Include/modsupport.h
|
||||||
|
constexpr int PYEXT_PYTHON_ABI_VERSION = 3;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void LLVMVisitor::writeToPythonExtension(
|
||||||
|
const std::string &name, const std::vector<std::pair<Func *, Func *>> &funcs,
|
||||||
|
const std::string &filename, const std::string &argv0,
|
||||||
|
const std::vector<std::string> &libs, const std::string &lflags) {
|
||||||
|
// Construct PyMethodDef array
|
||||||
|
auto *ptr = B->getInt8PtrTy();
|
||||||
|
auto *null = llvm::Constant::getNullValue(ptr);
|
||||||
|
auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr);
|
||||||
|
std::vector<llvm::Constant *> pyMethods;
|
||||||
|
|
||||||
|
for (auto &p : funcs) {
|
||||||
|
auto *original = p.first;
|
||||||
|
auto *generated = p.second;
|
||||||
|
auto llvmName = getNameForFunction(generated);
|
||||||
|
auto *llvmFunc = M->getNamedValue(llvmName);
|
||||||
|
seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName);
|
||||||
|
|
||||||
|
auto name = original->getUnmangledName();
|
||||||
|
auto *nameVar = new llvm::GlobalVariable(
|
||||||
|
*M, llvm::ArrayType::get(B->getInt8Ty(), name.length() + 1),
|
||||||
|
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||||
|
llvm::ConstantDataArray::getString(*context, name), ".pyext_func_name");
|
||||||
|
nameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
|
||||||
|
|
||||||
|
auto *nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
|
||||||
|
auto *funcConst = llvm::ConstantExpr::getBitCast(llvmFunc, ptr);
|
||||||
|
auto *flagConst = B->getInt32(PYEXT_METH_FASTCALL);
|
||||||
|
auto *docsConst = null;
|
||||||
|
if (auto *docsAttr = original->getAttribute<DocstringAttribute>()) {
|
||||||
|
auto docs = docsAttr->docstring;
|
||||||
|
auto *docsVar = new llvm::GlobalVariable(
|
||||||
|
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
|
||||||
|
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||||
|
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
|
||||||
|
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
|
||||||
|
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
|
||||||
|
}
|
||||||
|
pyMethods.push_back(llvm::ConstantStruct::get(pyMethodDefType, nameConst, funcConst,
|
||||||
|
flagConst, docsConst));
|
||||||
|
}
|
||||||
|
pyMethods.push_back(
|
||||||
|
llvm::ConstantStruct::get(pyMethodDefType, null, null, B->getInt32(0), null));
|
||||||
|
|
||||||
|
auto *pyMethodDefArrayType = llvm::ArrayType::get(pyMethodDefType, pyMethods.size());
|
||||||
|
auto *pyMethodDefArray = new llvm::GlobalVariable(
|
||||||
|
*M, pyMethodDefArrayType,
|
||||||
|
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||||
|
llvm::ConstantArray::get(pyMethodDefArrayType, pyMethods), ".pyext_methods");
|
||||||
|
|
||||||
|
// Construct PyModuleDef array
|
||||||
|
auto *pyObjectType = llvm::StructType::get(B->getInt64Ty(), ptr);
|
||||||
|
auto *pyModuleDefBaseType =
|
||||||
|
llvm::StructType::get(pyObjectType, ptr, B->getInt64Ty(), ptr);
|
||||||
|
auto *pyModuleDefType =
|
||||||
|
llvm::StructType::get(pyModuleDefBaseType, ptr, ptr, B->getInt64Ty(),
|
||||||
|
pyMethodDefType->getPointerTo(), ptr, ptr, ptr, ptr);
|
||||||
|
|
||||||
|
auto *pyObjectConst = llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), null);
|
||||||
|
auto *pyModuleDefBaseConst = llvm::ConstantStruct::get(
|
||||||
|
pyModuleDefBaseType, pyObjectConst, null, B->getInt64(0), null);
|
||||||
|
|
||||||
|
auto *nameVar = new llvm::GlobalVariable(
|
||||||
|
*M, llvm::ArrayType::get(B->getInt8Ty(), name.length() + 1),
|
||||||
|
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||||
|
llvm::ConstantDataArray::getString(*context, name), ".pyext_module_name");
|
||||||
|
nameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
|
||||||
|
auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
|
||||||
|
|
||||||
|
auto *docsConst = null;
|
||||||
|
if (!funcs.empty()) {
|
||||||
|
if (auto *docsAttr =
|
||||||
|
funcs[0].first->getModule()->getAttribute<DocstringAttribute>()) {
|
||||||
|
auto docs = docsAttr->docstring;
|
||||||
|
auto *docsVar = new llvm::GlobalVariable(
|
||||||
|
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
|
||||||
|
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||||
|
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
|
||||||
|
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
|
||||||
|
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr);
|
||||||
|
auto *pyModuleDef = llvm::ConstantStruct::get(
|
||||||
|
pyModuleDefType, pyModuleDefBaseConst, nameConst, docsConst, B->getInt64(-1),
|
||||||
|
pyMethodArrayConst, null, null, null, null);
|
||||||
|
auto *pyModuleVar =
|
||||||
|
new llvm::GlobalVariable(*M, pyModuleDef->getType(),
|
||||||
|
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||||
|
pyModuleDef, ".pyext_module");
|
||||||
|
auto *pyModuleConst = llvm::ConstantExpr::getBitCast(pyModuleVar, ptr);
|
||||||
|
|
||||||
|
// Construct initialization hook
|
||||||
|
auto pyModuleCreate = cast<llvm::Function>(
|
||||||
|
M->getOrInsertFunction("PyModule_Create2", ptr, ptr, B->getInt32Ty())
|
||||||
|
.getCallee());
|
||||||
|
pyModuleCreate->setDoesNotThrow();
|
||||||
|
|
||||||
|
auto *pyModuleInit =
|
||||||
|
cast<llvm::Function>(M->getOrInsertFunction("PyInit_" + name, ptr).getCallee());
|
||||||
|
auto *entry = llvm::BasicBlock::Create(*context, "entry", pyModuleInit);
|
||||||
|
B->SetInsertPoint(entry);
|
||||||
|
if (auto *main = M->getFunction("main")) {
|
||||||
|
main->setName(".main");
|
||||||
|
B->CreateCall({main->getFunctionType(), main},
|
||||||
|
{B->getInt32(0),
|
||||||
|
llvm::ConstantPointerNull::get(B->getInt8PtrTy()->getPointerTo())});
|
||||||
|
}
|
||||||
|
B->CreateRet(B->CreateCall(pyModuleCreate,
|
||||||
|
{pyModuleConst, B->getInt32(PYEXT_PYTHON_ABI_VERSION)}));
|
||||||
|
|
||||||
|
// Generate shared object
|
||||||
|
// (This will not create a global ctor since we renamed the 'main' function above.)
|
||||||
|
writeToExecutable(filename, argv0, /*library=*/true, libs, lflags);
|
||||||
|
}
|
||||||
|
|
||||||
void LLVMVisitor::compile(const std::string &filename, const std::string &argv0,
|
void LLVMVisitor::compile(const std::string &filename, const std::string &argv0,
|
||||||
const std::vector<std::string> &libs,
|
const std::vector<std::string> &libs,
|
||||||
const std::string &lflags) {
|
const std::string &lflags) {
|
||||||
|
@ -342,6 +342,18 @@ public:
|
|||||||
bool library = false,
|
bool library = false,
|
||||||
const std::vector<std::string> &libs = {},
|
const std::vector<std::string> &libs = {},
|
||||||
const std::string &lflags = "");
|
const std::string &lflags = "");
|
||||||
|
/// Writes module as Python extension shared object.
|
||||||
|
/// @param name the module's name
|
||||||
|
/// @param funcs extension functions
|
||||||
|
/// @param filename the file to write to
|
||||||
|
/// @param argv0 compiler's argv[0] used to set rpath
|
||||||
|
/// @param libs library names to link
|
||||||
|
/// @param lflags extra flags to pass linker
|
||||||
|
void writeToPythonExtension(const std::string &name,
|
||||||
|
const std::vector<std::pair<Func *, Func *>> &funcs,
|
||||||
|
const std::string &filename, const std::string &argv0,
|
||||||
|
const std::vector<std::string> &libs = {},
|
||||||
|
const std::string &lflags = "");
|
||||||
/// Runs optimization passes on module and writes the result
|
/// Runs optimization passes on module and writes the result
|
||||||
/// to the specified file. The output type is determined by
|
/// to the specified file. The output type is determined by
|
||||||
/// the file extension (.ll for LLVM IR, .bc for LLVM bitcode
|
/// the file extension (.ll for LLVM IR, .bc for LLVM bitcode
|
||||||
|
@ -23,7 +23,7 @@ Func *generateExtensionFunc(Func *f) {
|
|||||||
|
|
||||||
auto *M = f->getModule();
|
auto *M = f->getModule();
|
||||||
auto *cobj = M->getPointerType(M->getByteType());
|
auto *cobj = M->getPointerType(M->getByteType());
|
||||||
auto *ext = M->Nr<BodiedFunc>("__py_extension");
|
auto *ext = M->Nr<BodiedFunc>("__.py_extension.__");
|
||||||
ext->realize(M->getFuncType(cobj, {cobj, M->getPointerType(cobj), M->getIntType()}),
|
ext->realize(M->getFuncType(cobj, {cobj, M->getPointerType(cobj), M->getIntType()}),
|
||||||
{"self", "args", "nargs"});
|
{"self", "args", "nargs"});
|
||||||
auto *body = M->Nr<SeriesFlow>();
|
auto *body = M->Nr<SeriesFlow>();
|
||||||
@ -62,8 +62,7 @@ void PythonExtensionLowering::run(Module *module) {
|
|||||||
if (!util::hasAttribute(f, EXPORT_ATTR))
|
if (!util::hasAttribute(f, EXPORT_ATTR))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
std::cout << f->getName() << std::endl;
|
extFuncs.emplace_back(f, generateExtensionFunc(f));
|
||||||
std::cout << *generateExtensionFunc(f) << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "codon/cir/transform/pass.h"
|
#include "codon/cir/transform/pass.h"
|
||||||
|
|
||||||
namespace codon {
|
namespace codon {
|
||||||
@ -10,10 +13,24 @@ namespace transform {
|
|||||||
namespace lowering {
|
namespace lowering {
|
||||||
|
|
||||||
class PythonExtensionLowering : public Pass {
|
class PythonExtensionLowering : public Pass {
|
||||||
|
private:
|
||||||
|
/// vector of original function (1st) and generated
|
||||||
|
/// extension wrapper (2nd)
|
||||||
|
std::vector<std::pair<Func *, Func *>> extFuncs;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static const std::string KEY;
|
static const std::string KEY;
|
||||||
std::string getKey() const override { return KEY; }
|
std::string getKey() const override { return KEY; }
|
||||||
|
|
||||||
|
/// Constructs a PythonExtensionLowering pass.
|
||||||
|
PythonExtensionLowering() : Pass(), extFuncs() {}
|
||||||
|
|
||||||
void run(Module *module) override;
|
void run(Module *module) override;
|
||||||
|
|
||||||
|
/// @return extension function (original, generated) pairs
|
||||||
|
std::vector<std::pair<Func *, Func *>> getExtensionFunctions() const {
|
||||||
|
return extFuncs;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lowering
|
} // namespace lowering
|
||||||
|
@ -148,6 +148,12 @@ void PassManager::invalidate(const std::string &key) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PassManager::registerStandardPasses(PassManager::Init init) {
|
void PassManager::registerStandardPasses(PassManager::Init init) {
|
||||||
|
std::unique_ptr<lowering::PythonExtensionLowering> pyExtPass;
|
||||||
|
if (pyExtension) {
|
||||||
|
pyExtPass = std::make_unique<lowering::PythonExtensionLowering>();
|
||||||
|
pyExtensionPass = pyExtPass.get();
|
||||||
|
}
|
||||||
|
|
||||||
switch (init) {
|
switch (init) {
|
||||||
case Init::EMPTY:
|
case Init::EMPTY:
|
||||||
break;
|
break;
|
||||||
@ -155,6 +161,8 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
|
|||||||
registerPass(std::make_unique<lowering::PipelineLowering>());
|
registerPass(std::make_unique<lowering::PipelineLowering>());
|
||||||
registerPass(std::make_unique<lowering::ImperativeForFlowLowering>());
|
registerPass(std::make_unique<lowering::ImperativeForFlowLowering>());
|
||||||
registerPass(std::make_unique<parallel::OpenMPPass>());
|
registerPass(std::make_unique<parallel::OpenMPPass>());
|
||||||
|
if (pyExtension)
|
||||||
|
registerPass(std::move(pyExtPass));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Init::RELEASE:
|
case Init::RELEASE:
|
||||||
@ -201,6 +209,9 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
|
|||||||
registerPass(std::make_unique<parallel::OpenMPPass>(), /*insertBefore=*/"", {},
|
registerPass(std::make_unique<parallel::OpenMPPass>(), /*insertBefore=*/"", {},
|
||||||
{cfgKey, globalKey});
|
{cfgKey, globalKey});
|
||||||
|
|
||||||
|
if (pyExtension)
|
||||||
|
registerPass(std::move(pyExtPass));
|
||||||
|
|
||||||
if (init != Init::JIT) {
|
if (init != Init::JIT) {
|
||||||
// Don't demote globals in JIT mode, since they might be used later
|
// Don't demote globals in JIT mode, since they might be used later
|
||||||
// by another user input.
|
// by another user input.
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
|
|
||||||
#include "codon/cir/analyze/analysis.h"
|
#include "codon/cir/analyze/analysis.h"
|
||||||
#include "codon/cir/module.h"
|
#include "codon/cir/module.h"
|
||||||
|
#include "codon/cir/transform/lowering/pyextension.h"
|
||||||
#include "codon/cir/transform/pass.h"
|
#include "codon/cir/transform/pass.h"
|
||||||
|
|
||||||
namespace codon {
|
namespace codon {
|
||||||
@ -94,6 +95,12 @@ private:
|
|||||||
/// whether to use Python (vs. C) numeric semantics in passes
|
/// whether to use Python (vs. C) numeric semantics in passes
|
||||||
bool pyNumerics;
|
bool pyNumerics;
|
||||||
|
|
||||||
|
/// true if we are compiling as a Python extension
|
||||||
|
bool pyExtension;
|
||||||
|
|
||||||
|
/// pointer to Python extension lowering pass, if applicable
|
||||||
|
lowering::PythonExtensionLowering *pyExtensionPass;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// PassManager initialization mode.
|
/// PassManager initialization mode.
|
||||||
enum Init {
|
enum Init {
|
||||||
@ -104,16 +111,17 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
explicit PassManager(Init init, std::vector<std::string> disabled = {},
|
explicit PassManager(Init init, std::vector<std::string> disabled = {},
|
||||||
bool pyNumerics = false)
|
bool pyNumerics = false, bool pyExtension = false)
|
||||||
: km(), passes(), analyses(), executionOrder(), results(),
|
: km(), passes(), analyses(), executionOrder(), results(),
|
||||||
disabled(std::move(disabled)), pyNumerics(pyNumerics) {
|
disabled(std::move(disabled)), pyNumerics(pyNumerics), pyExtension(pyExtension),
|
||||||
|
pyExtensionPass(nullptr) {
|
||||||
registerStandardPasses(init);
|
registerStandardPasses(init);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit PassManager(bool debug = false, std::vector<std::string> disabled = {},
|
explicit PassManager(bool debug = false, std::vector<std::string> disabled = {},
|
||||||
bool pyNumerics = false)
|
bool pyNumerics = false, bool pyExtension = false)
|
||||||
: PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled),
|
: PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled),
|
||||||
pyNumerics) {}
|
pyNumerics, pyExtension) {}
|
||||||
|
|
||||||
/// Checks if the given pass is included in this manager.
|
/// Checks if the given pass is included in this manager.
|
||||||
/// @param key the pass key
|
/// @param key the pass key
|
||||||
@ -174,6 +182,11 @@ public:
|
|||||||
return std::find(disabled.begin(), disabled.end(), key) != disabled.end();
|
return std::find(disabled.begin(), disabled.end(), key) != disabled.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @return the Python extension lowering pass, or null if none
|
||||||
|
lowering::PythonExtensionLowering *getPythonExtensionPass() const {
|
||||||
|
return pyExtensionPass;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runPass(Module *module, const std::string &name);
|
void runPass(Module *module, const std::string &name);
|
||||||
void registerStandardPasses(Init init);
|
void registerStandardPasses(Init init);
|
||||||
|
@ -32,13 +32,13 @@ ir::transform::PassManager::Init getPassManagerInit(Compiler::Mode mode, bool is
|
|||||||
|
|
||||||
Compiler::Compiler(const std::string &argv0, Compiler::Mode mode,
|
Compiler::Compiler(const std::string &argv0, Compiler::Mode mode,
|
||||||
const std::vector<std::string> &disabledPasses, bool isTest,
|
const std::vector<std::string> &disabledPasses, bool isTest,
|
||||||
bool pyNumerics)
|
bool pyNumerics, bool pyExtension)
|
||||||
: argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), input(),
|
: argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics),
|
||||||
plm(std::make_unique<PluginManager>(argv0)),
|
pyExtension(pyExtension), input(), plm(std::make_unique<PluginManager>(argv0)),
|
||||||
cache(std::make_unique<ast::Cache>(argv0)),
|
cache(std::make_unique<ast::Cache>(argv0)),
|
||||||
module(std::make_unique<ir::Module>()),
|
module(std::make_unique<ir::Module>()),
|
||||||
pm(std::make_unique<ir::transform::PassManager>(getPassManagerInit(mode, isTest),
|
pm(std::make_unique<ir::transform::PassManager>(
|
||||||
disabledPasses, pyNumerics)),
|
getPassManagerInit(mode, isTest), disabledPasses, pyNumerics, pyExtension)),
|
||||||
llvisitor(std::make_unique<ir::LLVMVisitor>()) {
|
llvisitor(std::make_unique<ir::LLVMVisitor>()) {
|
||||||
cache->module = module.get();
|
cache->module = module.get();
|
||||||
cache->pythonCompat = pyNumerics;
|
cache->pythonCompat = pyNumerics;
|
||||||
@ -181,6 +181,14 @@ std::unordered_map<std::string, std::string> Compiler::getEarlyDefines() {
|
|||||||
std::unordered_map<std::string, std::string> earlyDefines;
|
std::unordered_map<std::string, std::string> earlyDefines;
|
||||||
earlyDefines.emplace("__debug__", debug ? "1" : "0");
|
earlyDefines.emplace("__debug__", debug ? "1" : "0");
|
||||||
earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0");
|
earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0");
|
||||||
|
earlyDefines.emplace("__py_extension__", pyExtension ? "1" : "0");
|
||||||
|
earlyDefines.emplace("__apple__",
|
||||||
|
#if __APPLE__
|
||||||
|
"1"
|
||||||
|
#else
|
||||||
|
"0"
|
||||||
|
#endif
|
||||||
|
);
|
||||||
return earlyDefines;
|
return earlyDefines;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ private:
|
|||||||
std::string argv0;
|
std::string argv0;
|
||||||
bool debug;
|
bool debug;
|
||||||
bool pyNumerics;
|
bool pyNumerics;
|
||||||
|
bool pyExtension;
|
||||||
std::string input;
|
std::string input;
|
||||||
std::unique_ptr<PluginManager> plm;
|
std::unique_ptr<PluginManager> plm;
|
||||||
std::unique_ptr<ast::Cache> cache;
|
std::unique_ptr<ast::Cache> cache;
|
||||||
@ -42,13 +43,14 @@ private:
|
|||||||
public:
|
public:
|
||||||
Compiler(const std::string &argv0, Mode mode,
|
Compiler(const std::string &argv0, Mode mode,
|
||||||
const std::vector<std::string> &disabledPasses = {}, bool isTest = false,
|
const std::vector<std::string> &disabledPasses = {}, bool isTest = false,
|
||||||
bool pyNumerics = false);
|
bool pyNumerics = false, bool pyExtension = false);
|
||||||
|
|
||||||
explicit Compiler(const std::string &argv0, bool debug = false,
|
explicit Compiler(const std::string &argv0, bool debug = false,
|
||||||
const std::vector<std::string> &disabledPasses = {},
|
const std::vector<std::string> &disabledPasses = {},
|
||||||
bool isTest = false, bool pyNumerics = false)
|
bool isTest = false, bool pyNumerics = false,
|
||||||
|
bool pyExtension = false)
|
||||||
: Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest,
|
: Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest,
|
||||||
pyNumerics) {}
|
pyNumerics, pyExtension) {}
|
||||||
|
|
||||||
std::string getInput() const { return input; }
|
std::string getInput() const { return input; }
|
||||||
PluginManager *getPluginManager() const { return plm.get(); }
|
PluginManager *getPluginManager() const { return plm.get(); }
|
||||||
|
Loading…
x
Reference in New Issue
Block a user