1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Add extension module codegen

This commit is contained in:
A. R. Shajii 2023-01-28 22:59:49 -05:00
parent 947b9fe52b
commit 2285057005
9 changed files with 236 additions and 30 deletions

View File

@ -4,12 +4,14 @@
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <fstream> #include <fstream>
#include <functional>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "codon/cir/transform/lowering/pyextension.h"
#include "codon/compiler/compiler.h" #include "codon/compiler/compiler.h"
#include "codon/compiler/error.h" #include "codon/compiler/error.h"
#include "codon/compiler/jit.h" #include "codon/compiler/jit.h"
@ -83,7 +85,7 @@ void initLogFlags(const llvm::cl::opt<std::string> &log) {
codon::getLogger().parse(std::string(d)); codon::getLogger().parse(std::string(d));
} }
enum BuildKind { LLVM, Bitcode, Object, Executable, Library, Detect }; enum BuildKind { LLVM, Bitcode, Object, Executable, Library, PyExtension, Detect };
enum OptMode { Debug, Release }; enum OptMode { Debug, Release };
enum Numerics { C, Python }; enum Numerics { C, Python };
} // namespace } // namespace
@ -109,8 +111,9 @@ int docMode(const std::vector<const char *> &args, const std::string &argv0) {
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
std::unique_ptr<codon::Compiler> processSource(const std::vector<const char *> &args, std::unique_ptr<codon::Compiler> processSource(
bool standalone) { const std::vector<const char *> &args, bool standalone,
std::function<bool()> pyExtension = [] { return false; }) {
llvm::cl::opt<std::string> input(llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::opt<std::string> input(llvm::cl::Positional, llvm::cl::desc("<input file>"),
llvm::cl::init("-")); llvm::cl::init("-"));
auto regs = llvm::cl::getRegisteredOptions(); auto regs = llvm::cl::getRegisteredOptions();
@ -163,9 +166,9 @@ std::unique_ptr<codon::Compiler> processSource(const std::vector<const char *> &
const bool isDebug = (optMode == OptMode::Debug); const bool isDebug = (optMode == OptMode::Debug);
std::vector<std::string> disabledOptsVec(disabledOpts); std::vector<std::string> disabledOptsVec(disabledOpts);
auto compiler = std::make_unique<codon::Compiler>(args[0], isDebug, disabledOptsVec, auto compiler = std::make_unique<codon::Compiler>(
/*isTest=*/false, args[0], isDebug, disabledOptsVec,
(numerics == Numerics::Python)); /*isTest=*/false, (numerics == Numerics::Python), pyExtension());
compiler->getLLVMVisitor()->setStandalone(standalone); compiler->getLLVMVisitor()->setStandalone(standalone);
// load plugins // load plugins
@ -296,13 +299,15 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
llvm::cl::desc("Pass given flags to linker")); llvm::cl::desc("Pass given flags to linker"));
llvm::cl::opt<BuildKind> buildKind( llvm::cl::opt<BuildKind> buildKind(
llvm::cl::desc("output type"), llvm::cl::desc("output type"),
llvm::cl::values(clEnumValN(LLVM, "llvm", "Generate LLVM IR"), llvm::cl::values(
clEnumValN(Bitcode, "bc", "Generate LLVM bitcode"), clEnumValN(LLVM, "llvm", "Generate LLVM IR"),
clEnumValN(Object, "obj", "Generate native object file"), clEnumValN(Bitcode, "bc", "Generate LLVM bitcode"),
clEnumValN(Executable, "exe", "Generate executable"), clEnumValN(Object, "obj", "Generate native object file"),
clEnumValN(Library, "lib", "Generate shared library"), clEnumValN(Executable, "exe", "Generate executable"),
clEnumValN(Detect, "detect", clEnumValN(Library, "lib", "Generate shared library"),
"Detect output type based on output file extension")), clEnumValN(PyExtension, "pyext", "Generate Python extension module"),
clEnumValN(Detect, "detect",
"Detect output type based on output file extension")),
llvm::cl::init(Detect)); llvm::cl::init(Detect));
llvm::cl::opt<std::string> output( llvm::cl::opt<std::string> output(
"o", "o",
@ -310,7 +315,8 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
"Write compiled output to specified file. Supported extensions: " "Write compiled output to specified file. Supported extensions: "
"none (executable), .o (object file), .ll (LLVM IR), .bc (LLVM bitcode)")); "none (executable), .o (object file), .ll (LLVM IR), .bc (LLVM bitcode)"));
auto compiler = processSource(args, /*standalone=*/true); auto compiler = processSource(args, /*standalone=*/true,
[&] { return buildKind == BuildKind::PyExtension; });
if (!compiler) if (!compiler)
return EXIT_FAILURE; return EXIT_FAILURE;
std::vector<std::string> libsVec(libs); std::vector<std::string> libsVec(libs);
@ -329,6 +335,7 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
extension = ".o"; extension = ".o";
break; break;
case BuildKind::Library: case BuildKind::Library:
case BuildKind::PyExtension:
extension = isMacOS() ? ".dylib" : ".so"; extension = isMacOS() ? ".dylib" : ".so";
break; break;
case BuildKind::Executable: case BuildKind::Executable:
@ -358,6 +365,12 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
compiler->getLLVMVisitor()->writeToExecutable(filename, argv0, true, libsVec, compiler->getLLVMVisitor()->writeToExecutable(filename, argv0, true, libsVec,
lflags); lflags);
break; break;
case BuildKind::PyExtension:
compiler->getLLVMVisitor()->writeToPythonExtension(
"mymodule", // TODO
compiler->getPassManager()->getPythonExtensionPass()->getExtensionFunctions(),
filename, argv0, libsVec, lflags);
break;
case BuildKind::Detect: case BuildKind::Detect:
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags); compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);
break; break;

View File

@ -404,9 +404,9 @@ void executeCommand(const std::vector<std::string> &args) {
void LLVMVisitor::setupGlobalCtorForSharedLibrary() { void LLVMVisitor::setupGlobalCtorForSharedLibrary() {
const std::string llvmCtor = "llvm.global_ctors"; const std::string llvmCtor = "llvm.global_ctors";
auto *main = M->getFunction("main"); auto *main = M->getFunction("main");
main->setName(".main"); // avoid clash with other main
if (M->getNamedValue(llvmCtor) || !main) if (M->getNamedValue(llvmCtor) || !main)
return; return;
main->setName(".main"); // avoid clash with other main
auto *ctorFuncTy = llvm::FunctionType::get(B->getVoidTy(), {}, /*isVarArg=*/false); auto *ctorFuncTy = llvm::FunctionType::get(B->getVoidTy(), {}, /*isVarArg=*/false);
auto *ctorEntryTy = llvm::StructType::get(B->getInt32Ty(), ctorFuncTy->getPointerTo(), auto *ctorEntryTy = llvm::StructType::get(B->getInt32Ty(), ctorFuncTy->getPointerTo(),
@ -541,6 +541,137 @@ void LLVMVisitor::writeToExecutable(const std::string &filename,
llvm::sys::fs::remove(objFile); llvm::sys::fs::remove(objFile);
} }
namespace {
// https://github.com/python/cpython/blob/main/Include/methodobject.h
constexpr int PYEXT_METH_VARARGS = 0x0001;
constexpr int PYEXT_METH_KEYWORDS = 0x0002;
constexpr int PYEXT_METH_NOARGS = 0x0004;
constexpr int PYEXT_METH_O = 0x0008;
constexpr int PYEXT_METH_CLASS = 0x0010;
constexpr int PYEXT_METH_STATIC = 0x0020;
constexpr int PYEXT_METH_COEXIST = 0x0040;
constexpr int PYEXT_METH_FASTCALL = 0x0080;
constexpr int PYEXT_METH_METHOD = 0x0200;
// https://github.com/python/cpython/blob/main/Include/modsupport.h
constexpr int PYEXT_PYTHON_ABI_VERSION = 3;
} // namespace
void LLVMVisitor::writeToPythonExtension(
const std::string &name, const std::vector<std::pair<Func *, Func *>> &funcs,
const std::string &filename, const std::string &argv0,
const std::vector<std::string> &libs, const std::string &lflags) {
// Construct PyMethodDef array
auto *ptr = B->getInt8PtrTy();
auto *null = llvm::Constant::getNullValue(ptr);
auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr);
std::vector<llvm::Constant *> pyMethods;
for (auto &p : funcs) {
auto *original = p.first;
auto *generated = p.second;
auto llvmName = getNameForFunction(generated);
auto *llvmFunc = M->getNamedValue(llvmName);
seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName);
auto name = original->getUnmangledName();
auto *nameVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), name.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, name), ".pyext_func_name");
nameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
auto *nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
auto *funcConst = llvm::ConstantExpr::getBitCast(llvmFunc, ptr);
auto *flagConst = B->getInt32(PYEXT_METH_FASTCALL);
auto *docsConst = null;
if (auto *docsAttr = original->getAttribute<DocstringAttribute>()) {
auto docs = docsAttr->docstring;
auto *docsVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
}
pyMethods.push_back(llvm::ConstantStruct::get(pyMethodDefType, nameConst, funcConst,
flagConst, docsConst));
}
pyMethods.push_back(
llvm::ConstantStruct::get(pyMethodDefType, null, null, B->getInt32(0), null));
auto *pyMethodDefArrayType = llvm::ArrayType::get(pyMethodDefType, pyMethods.size());
auto *pyMethodDefArray = new llvm::GlobalVariable(
*M, pyMethodDefArrayType,
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantArray::get(pyMethodDefArrayType, pyMethods), ".pyext_methods");
// Construct PyModuleDef array
auto *pyObjectType = llvm::StructType::get(B->getInt64Ty(), ptr);
auto *pyModuleDefBaseType =
llvm::StructType::get(pyObjectType, ptr, B->getInt64Ty(), ptr);
auto *pyModuleDefType =
llvm::StructType::get(pyModuleDefBaseType, ptr, ptr, B->getInt64Ty(),
pyMethodDefType->getPointerTo(), ptr, ptr, ptr, ptr);
auto *pyObjectConst = llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), null);
auto *pyModuleDefBaseConst = llvm::ConstantStruct::get(
pyModuleDefBaseType, pyObjectConst, null, B->getInt64(0), null);
auto *nameVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), name.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, name), ".pyext_module_name");
nameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
auto *docsConst = null;
if (!funcs.empty()) {
if (auto *docsAttr =
funcs[0].first->getModule()->getAttribute<DocstringAttribute>()) {
auto docs = docsAttr->docstring;
auto *docsVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
}
}
auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr);
auto *pyModuleDef = llvm::ConstantStruct::get(
pyModuleDefType, pyModuleDefBaseConst, nameConst, docsConst, B->getInt64(-1),
pyMethodArrayConst, null, null, null, null);
auto *pyModuleVar =
new llvm::GlobalVariable(*M, pyModuleDef->getType(),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
pyModuleDef, ".pyext_module");
auto *pyModuleConst = llvm::ConstantExpr::getBitCast(pyModuleVar, ptr);
// Construct initialization hook
auto pyModuleCreate = cast<llvm::Function>(
M->getOrInsertFunction("PyModule_Create2", ptr, ptr, B->getInt32Ty())
.getCallee());
pyModuleCreate->setDoesNotThrow();
auto *pyModuleInit =
cast<llvm::Function>(M->getOrInsertFunction("PyInit_" + name, ptr).getCallee());
auto *entry = llvm::BasicBlock::Create(*context, "entry", pyModuleInit);
B->SetInsertPoint(entry);
if (auto *main = M->getFunction("main")) {
main->setName(".main");
B->CreateCall({main->getFunctionType(), main},
{B->getInt32(0),
llvm::ConstantPointerNull::get(B->getInt8PtrTy()->getPointerTo())});
}
B->CreateRet(B->CreateCall(pyModuleCreate,
{pyModuleConst, B->getInt32(PYEXT_PYTHON_ABI_VERSION)}));
// Generate shared object
// (This will not create a global ctor since we renamed the 'main' function above.)
writeToExecutable(filename, argv0, /*library=*/true, libs, lflags);
}
void LLVMVisitor::compile(const std::string &filename, const std::string &argv0, void LLVMVisitor::compile(const std::string &filename, const std::string &argv0,
const std::vector<std::string> &libs, const std::vector<std::string> &libs,
const std::string &lflags) { const std::string &lflags) {

View File

@ -342,6 +342,18 @@ public:
bool library = false, bool library = false,
const std::vector<std::string> &libs = {}, const std::vector<std::string> &libs = {},
const std::string &lflags = ""); const std::string &lflags = "");
/// Writes module as Python extension shared object.
/// @param name the module's name
/// @param funcs extension functions
/// @param filename the file to write to
/// @param argv0 compiler's argv[0] used to set rpath
/// @param libs library names to link
/// @param lflags extra flags to pass linker
void writeToPythonExtension(const std::string &name,
const std::vector<std::pair<Func *, Func *>> &funcs,
const std::string &filename, const std::string &argv0,
const std::vector<std::string> &libs = {},
const std::string &lflags = "");
/// Runs optimization passes on module and writes the result /// Runs optimization passes on module and writes the result
/// to the specified file. The output type is determined by /// to the specified file. The output type is determined by
/// the file extension (.ll for LLVM IR, .bc for LLVM bitcode /// the file extension (.ll for LLVM IR, .bc for LLVM bitcode

View File

@ -23,7 +23,7 @@ Func *generateExtensionFunc(Func *f) {
auto *M = f->getModule(); auto *M = f->getModule();
auto *cobj = M->getPointerType(M->getByteType()); auto *cobj = M->getPointerType(M->getByteType());
auto *ext = M->Nr<BodiedFunc>("__py_extension"); auto *ext = M->Nr<BodiedFunc>("__.py_extension.__");
ext->realize(M->getFuncType(cobj, {cobj, M->getPointerType(cobj), M->getIntType()}), ext->realize(M->getFuncType(cobj, {cobj, M->getPointerType(cobj), M->getIntType()}),
{"self", "args", "nargs"}); {"self", "args", "nargs"});
auto *body = M->Nr<SeriesFlow>(); auto *body = M->Nr<SeriesFlow>();
@ -62,8 +62,7 @@ void PythonExtensionLowering::run(Module *module) {
if (!util::hasAttribute(f, EXPORT_ATTR)) if (!util::hasAttribute(f, EXPORT_ATTR))
continue; continue;
std::cout << f->getName() << std::endl; extFuncs.emplace_back(f, generateExtensionFunc(f));
std::cout << *generateExtensionFunc(f) << std::endl;
} }
} }
} }

View File

@ -2,6 +2,9 @@
#pragma once #pragma once
#include <utility>
#include <vector>
#include "codon/cir/transform/pass.h" #include "codon/cir/transform/pass.h"
namespace codon { namespace codon {
@ -10,10 +13,24 @@ namespace transform {
namespace lowering { namespace lowering {
class PythonExtensionLowering : public Pass { class PythonExtensionLowering : public Pass {
private:
/// vector of original function (1st) and generated
/// extension wrapper (2nd)
std::vector<std::pair<Func *, Func *>> extFuncs;
public: public:
static const std::string KEY; static const std::string KEY;
std::string getKey() const override { return KEY; } std::string getKey() const override { return KEY; }
/// Constructs a PythonExtensionLowering pass.
PythonExtensionLowering() : Pass(), extFuncs() {}
void run(Module *module) override; void run(Module *module) override;
/// @return extension function (original, generated) pairs
std::vector<std::pair<Func *, Func *>> getExtensionFunctions() const {
return extFuncs;
}
}; };
} // namespace lowering } // namespace lowering

View File

@ -148,6 +148,12 @@ void PassManager::invalidate(const std::string &key) {
} }
void PassManager::registerStandardPasses(PassManager::Init init) { void PassManager::registerStandardPasses(PassManager::Init init) {
std::unique_ptr<lowering::PythonExtensionLowering> pyExtPass;
if (pyExtension) {
pyExtPass = std::make_unique<lowering::PythonExtensionLowering>();
pyExtensionPass = pyExtPass.get();
}
switch (init) { switch (init) {
case Init::EMPTY: case Init::EMPTY:
break; break;
@ -155,6 +161,8 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
registerPass(std::make_unique<lowering::PipelineLowering>()); registerPass(std::make_unique<lowering::PipelineLowering>());
registerPass(std::make_unique<lowering::ImperativeForFlowLowering>()); registerPass(std::make_unique<lowering::ImperativeForFlowLowering>());
registerPass(std::make_unique<parallel::OpenMPPass>()); registerPass(std::make_unique<parallel::OpenMPPass>());
if (pyExtension)
registerPass(std::move(pyExtPass));
break; break;
} }
case Init::RELEASE: case Init::RELEASE:
@ -201,6 +209,9 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
registerPass(std::make_unique<parallel::OpenMPPass>(), /*insertBefore=*/"", {}, registerPass(std::make_unique<parallel::OpenMPPass>(), /*insertBefore=*/"", {},
{cfgKey, globalKey}); {cfgKey, globalKey});
if (pyExtension)
registerPass(std::move(pyExtPass));
if (init != Init::JIT) { if (init != Init::JIT) {
// Don't demote globals in JIT mode, since they might be used later // Don't demote globals in JIT mode, since they might be used later
// by another user input. // by another user input.

View File

@ -11,6 +11,7 @@
#include "codon/cir/analyze/analysis.h" #include "codon/cir/analyze/analysis.h"
#include "codon/cir/module.h" #include "codon/cir/module.h"
#include "codon/cir/transform/lowering/pyextension.h"
#include "codon/cir/transform/pass.h" #include "codon/cir/transform/pass.h"
namespace codon { namespace codon {
@ -94,6 +95,12 @@ private:
/// whether to use Python (vs. C) numeric semantics in passes /// whether to use Python (vs. C) numeric semantics in passes
bool pyNumerics; bool pyNumerics;
/// true if we are compiling as a Python extension
bool pyExtension;
/// pointer to Python extension lowering pass, if applicable
lowering::PythonExtensionLowering *pyExtensionPass;
public: public:
/// PassManager initialization mode. /// PassManager initialization mode.
enum Init { enum Init {
@ -104,16 +111,17 @@ public:
}; };
explicit PassManager(Init init, std::vector<std::string> disabled = {}, explicit PassManager(Init init, std::vector<std::string> disabled = {},
bool pyNumerics = false) bool pyNumerics = false, bool pyExtension = false)
: km(), passes(), analyses(), executionOrder(), results(), : km(), passes(), analyses(), executionOrder(), results(),
disabled(std::move(disabled)), pyNumerics(pyNumerics) { disabled(std::move(disabled)), pyNumerics(pyNumerics), pyExtension(pyExtension),
pyExtensionPass(nullptr) {
registerStandardPasses(init); registerStandardPasses(init);
} }
explicit PassManager(bool debug = false, std::vector<std::string> disabled = {}, explicit PassManager(bool debug = false, std::vector<std::string> disabled = {},
bool pyNumerics = false) bool pyNumerics = false, bool pyExtension = false)
: PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled), : PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled),
pyNumerics) {} pyNumerics, pyExtension) {}
/// Checks if the given pass is included in this manager. /// Checks if the given pass is included in this manager.
/// @param key the pass key /// @param key the pass key
@ -174,6 +182,11 @@ public:
return std::find(disabled.begin(), disabled.end(), key) != disabled.end(); return std::find(disabled.begin(), disabled.end(), key) != disabled.end();
} }
/// @return the Python extension lowering pass, or null if none
lowering::PythonExtensionLowering *getPythonExtensionPass() const {
return pyExtensionPass;
}
private: private:
void runPass(Module *module, const std::string &name); void runPass(Module *module, const std::string &name);
void registerStandardPasses(Init init); void registerStandardPasses(Init init);

View File

@ -32,13 +32,13 @@ ir::transform::PassManager::Init getPassManagerInit(Compiler::Mode mode, bool is
Compiler::Compiler(const std::string &argv0, Compiler::Mode mode, Compiler::Compiler(const std::string &argv0, Compiler::Mode mode,
const std::vector<std::string> &disabledPasses, bool isTest, const std::vector<std::string> &disabledPasses, bool isTest,
bool pyNumerics) bool pyNumerics, bool pyExtension)
: argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), input(), : argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics),
plm(std::make_unique<PluginManager>(argv0)), pyExtension(pyExtension), input(), plm(std::make_unique<PluginManager>(argv0)),
cache(std::make_unique<ast::Cache>(argv0)), cache(std::make_unique<ast::Cache>(argv0)),
module(std::make_unique<ir::Module>()), module(std::make_unique<ir::Module>()),
pm(std::make_unique<ir::transform::PassManager>(getPassManagerInit(mode, isTest), pm(std::make_unique<ir::transform::PassManager>(
disabledPasses, pyNumerics)), getPassManagerInit(mode, isTest), disabledPasses, pyNumerics, pyExtension)),
llvisitor(std::make_unique<ir::LLVMVisitor>()) { llvisitor(std::make_unique<ir::LLVMVisitor>()) {
cache->module = module.get(); cache->module = module.get();
cache->pythonCompat = pyNumerics; cache->pythonCompat = pyNumerics;
@ -181,6 +181,14 @@ std::unordered_map<std::string, std::string> Compiler::getEarlyDefines() {
std::unordered_map<std::string, std::string> earlyDefines; std::unordered_map<std::string, std::string> earlyDefines;
earlyDefines.emplace("__debug__", debug ? "1" : "0"); earlyDefines.emplace("__debug__", debug ? "1" : "0");
earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0"); earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0");
earlyDefines.emplace("__py_extension__", pyExtension ? "1" : "0");
earlyDefines.emplace("__apple__",
#if __APPLE__
"1"
#else
"0"
#endif
);
return earlyDefines; return earlyDefines;
} }

View File

@ -28,6 +28,7 @@ private:
std::string argv0; std::string argv0;
bool debug; bool debug;
bool pyNumerics; bool pyNumerics;
bool pyExtension;
std::string input; std::string input;
std::unique_ptr<PluginManager> plm; std::unique_ptr<PluginManager> plm;
std::unique_ptr<ast::Cache> cache; std::unique_ptr<ast::Cache> cache;
@ -42,13 +43,14 @@ private:
public: public:
Compiler(const std::string &argv0, Mode mode, Compiler(const std::string &argv0, Mode mode,
const std::vector<std::string> &disabledPasses = {}, bool isTest = false, const std::vector<std::string> &disabledPasses = {}, bool isTest = false,
bool pyNumerics = false); bool pyNumerics = false, bool pyExtension = false);
explicit Compiler(const std::string &argv0, bool debug = false, explicit Compiler(const std::string &argv0, bool debug = false,
const std::vector<std::string> &disabledPasses = {}, const std::vector<std::string> &disabledPasses = {},
bool isTest = false, bool pyNumerics = false) bool isTest = false, bool pyNumerics = false,
bool pyExtension = false)
: Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest, : Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest,
pyNumerics) {} pyNumerics, pyExtension) {}
std::string getInput() const { return input; } std::string getInput() const { return input; }
PluginManager *getPluginManager() const { return plm.get(); } PluginManager *getPluginManager() const { return plm.get(); }