mirror of https://github.com/exaloop/codon.git
Reorganize API
parent
5920148d8d
commit
c08a2d7d17
|
@ -11,7 +11,6 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/cir/transform/lowering/pyextension.h"
|
||||
#include "codon/compiler/compiler.h"
|
||||
#include "codon/compiler/error.h"
|
||||
#include "codon/compiler/jit.h"
|
||||
|
@ -372,8 +371,7 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
|
|||
case BuildKind::PyExtension:
|
||||
compiler->getLLVMVisitor()->writeToPythonExtension(
|
||||
pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule,
|
||||
compiler->getPassManager()->getPythonExtensionPass()->getExtensionFunctions(),
|
||||
filename);
|
||||
compiler->getModule(), filename);
|
||||
break;
|
||||
case BuildKind::Detect:
|
||||
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);
|
||||
|
|
|
@ -39,6 +39,22 @@ std::ostream &MemberAttribute::doFormat(std::ostream &os) const {
|
|||
return os;
|
||||
}
|
||||
|
||||
const std::string PythonWrapperAttribute::AttributeName = "pythonWrapperAttribute";
|
||||
|
||||
std::unique_ptr<Attribute> PythonWrapperAttribute::clone(util::CloneVisitor &cv) const {
|
||||
return std::make_unique<PythonWrapperAttribute>(cast<Func>(cv.clone(original)));
|
||||
}
|
||||
|
||||
std::unique_ptr<Attribute>
|
||||
PythonWrapperAttribute::forceClone(util::CloneVisitor &cv) const {
|
||||
return std::make_unique<PythonWrapperAttribute>(cv.forceClone(original));
|
||||
}
|
||||
|
||||
std::ostream &PythonWrapperAttribute::doFormat(std::ostream &os) const {
|
||||
fmt::print(os, FMT_STRING("(pywrap {})"), original->referenceString());
|
||||
return os;
|
||||
}
|
||||
|
||||
const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute";
|
||||
|
||||
const std::string DocstringAttribute::AttributeName = "docstringAttribute";
|
||||
|
|
|
@ -135,6 +135,26 @@ private:
|
|||
std::ostream &doFormat(std::ostream &os) const override;
|
||||
};
|
||||
|
||||
/// Attribute used to mark Python wrappers of Codon functions
|
||||
struct PythonWrapperAttribute : public Attribute {
|
||||
static const std::string AttributeName;
|
||||
|
||||
/// the function being wrapped
|
||||
Func *original;
|
||||
|
||||
/// Constructs a PythonWrapperAttribute.
|
||||
/// @param original the function being wrapped
|
||||
explicit PythonWrapperAttribute(Func *original) : original(original) {}
|
||||
|
||||
bool needsClone() const override { return false; }
|
||||
|
||||
std::unique_ptr<Attribute> clone(util::CloneVisitor &cv) const override;
|
||||
std::unique_ptr<Attribute> forceClone(util::CloneVisitor &cv) const override;
|
||||
|
||||
private:
|
||||
std::ostream &doFormat(std::ostream &os) const override;
|
||||
};
|
||||
|
||||
/// Attribute attached to IR structures corresponding to tuple literals
|
||||
struct TupleLiteralAttribute : public Attribute {
|
||||
static const std::string AttributeName;
|
||||
|
|
|
@ -655,18 +655,24 @@ llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) {
|
|||
return wrap;
|
||||
}
|
||||
|
||||
void LLVMVisitor::writeToPythonExtension(
|
||||
const std::string &name, const std::vector<std::pair<Func *, Func *>> &funcs,
|
||||
const std::string &filename) {
|
||||
void LLVMVisitor::writeToPythonExtension(const std::string &name, const Module *module,
|
||||
const std::string &filename) {
|
||||
// Construct PyMethodDef array
|
||||
auto *ptr = B->getInt8PtrTy();
|
||||
auto *null = llvm::Constant::getNullValue(ptr);
|
||||
auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr);
|
||||
std::vector<llvm::Constant *> pyMethods;
|
||||
|
||||
for (auto &p : funcs) {
|
||||
auto *original = p.first;
|
||||
auto *generated = p.second;
|
||||
for (auto *var : *module) {
|
||||
auto *generated = cast<Func>(var);
|
||||
if (!generated)
|
||||
continue;
|
||||
|
||||
auto *pyWrapAttr = generated->getAttribute<PythonWrapperAttribute>();
|
||||
if (!pyWrapAttr)
|
||||
continue;
|
||||
|
||||
auto *original = pyWrapAttr->original;
|
||||
auto llvmName = getNameForFunction(generated);
|
||||
auto *llvmFunc = M->getFunction(llvmName);
|
||||
seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName);
|
||||
|
@ -735,17 +741,14 @@ void LLVMVisitor::writeToPythonExtension(
|
|||
auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
|
||||
|
||||
auto *docsConst = null;
|
||||
if (!funcs.empty()) {
|
||||
if (auto *docsAttr =
|
||||
funcs[0].first->getModule()->getAttribute<DocstringAttribute>()) {
|
||||
auto docs = docsAttr->docstring;
|
||||
auto *docsVar = new llvm::GlobalVariable(
|
||||
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
|
||||
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
|
||||
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
|
||||
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
|
||||
}
|
||||
if (auto *docsAttr = module->getAttribute<DocstringAttribute>()) {
|
||||
auto docs = docsAttr->docstring;
|
||||
auto *docsVar = new llvm::GlobalVariable(
|
||||
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
|
||||
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
|
||||
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
|
||||
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
|
||||
}
|
||||
|
||||
auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr);
|
||||
|
|
|
@ -345,12 +345,13 @@ public:
|
|||
bool library = false,
|
||||
const std::vector<std::string> &libs = {},
|
||||
const std::string &lflags = "");
|
||||
/// Writes module as Python extension shared object.
|
||||
/// Writes module as Python extension object. Exposes
|
||||
/// functions based on "PythonWrapperAttribute" attached
|
||||
/// to IR functions.
|
||||
/// @param name the module's name
|
||||
/// @param funcs extension functions
|
||||
/// @param module the IR module
|
||||
/// @param filename the file to write to
|
||||
void writeToPythonExtension(const std::string &name,
|
||||
const std::vector<std::pair<Func *, Func *>> &funcs,
|
||||
void writeToPythonExtension(const std::string &name, const Module *module,
|
||||
const std::string &filename);
|
||||
/// Runs optimization passes on module and writes the result
|
||||
/// to the specified file. The output type is determined by
|
||||
|
|
|
@ -94,7 +94,7 @@ void PythonExtensionLowering::run(Module *module) {
|
|||
|
||||
if (auto *g = generateExtensionFunc(f)) {
|
||||
LOG("[pyext] exporting {}", f->getName());
|
||||
extFuncs.emplace_back(f, g);
|
||||
g->setAttribute(std::make_unique<PythonWrapperAttribute>(f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,24 +13,11 @@ namespace transform {
|
|||
namespace lowering {
|
||||
|
||||
class PythonExtensionLowering : public Pass {
|
||||
private:
|
||||
/// vector of original function (1st) and generated
|
||||
/// extension wrapper (2nd)
|
||||
std::vector<std::pair<Func *, Func *>> extFuncs;
|
||||
|
||||
public:
|
||||
static const std::string KEY;
|
||||
std::string getKey() const override { return KEY; }
|
||||
|
||||
/// Constructs a PythonExtensionLowering pass.
|
||||
PythonExtensionLowering() : Pass(), extFuncs() {}
|
||||
|
||||
void run(Module *module) override;
|
||||
|
||||
/// @return extension function (original, generated) pairs
|
||||
std::vector<std::pair<Func *, Func *>> getExtensionFunctions() const {
|
||||
return extFuncs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace lowering
|
||||
|
|
|
@ -148,12 +148,6 @@ void PassManager::invalidate(const std::string &key) {
|
|||
}
|
||||
|
||||
void PassManager::registerStandardPasses(PassManager::Init init) {
|
||||
std::unique_ptr<lowering::PythonExtensionLowering> pyExtPass;
|
||||
if (pyExtension) {
|
||||
pyExtPass = std::make_unique<lowering::PythonExtensionLowering>();
|
||||
pyExtensionPass = pyExtPass.get();
|
||||
}
|
||||
|
||||
switch (init) {
|
||||
case Init::EMPTY:
|
||||
break;
|
||||
|
@ -162,7 +156,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
|
|||
registerPass(std::make_unique<lowering::ImperativeForFlowLowering>());
|
||||
registerPass(std::make_unique<parallel::OpenMPPass>());
|
||||
if (pyExtension)
|
||||
registerPass(std::move(pyExtPass));
|
||||
registerPass(std::make_unique<lowering::PythonExtensionLowering>());
|
||||
break;
|
||||
}
|
||||
case Init::RELEASE:
|
||||
|
@ -210,7 +204,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
|
|||
{cfgKey, globalKey});
|
||||
|
||||
if (pyExtension)
|
||||
registerPass(std::move(pyExtPass));
|
||||
registerPass(std::make_unique<lowering::PythonExtensionLowering>());
|
||||
|
||||
if (init != Init::JIT) {
|
||||
// Don't demote globals in JIT mode, since they might be used later
|
||||
|
|
|
@ -98,9 +98,6 @@ private:
|
|||
/// true if we are compiling as a Python extension
|
||||
bool pyExtension;
|
||||
|
||||
/// pointer to Python extension lowering pass, if applicable
|
||||
lowering::PythonExtensionLowering *pyExtensionPass;
|
||||
|
||||
public:
|
||||
/// PassManager initialization mode.
|
||||
enum Init {
|
||||
|
@ -113,8 +110,8 @@ public:
|
|||
explicit PassManager(Init init, std::vector<std::string> disabled = {},
|
||||
bool pyNumerics = false, bool pyExtension = false)
|
||||
: km(), passes(), analyses(), executionOrder(), results(),
|
||||
disabled(std::move(disabled)), pyNumerics(pyNumerics), pyExtension(pyExtension),
|
||||
pyExtensionPass(nullptr) {
|
||||
disabled(std::move(disabled)), pyNumerics(pyNumerics),
|
||||
pyExtension(pyExtension) {
|
||||
registerStandardPasses(init);
|
||||
}
|
||||
|
||||
|
@ -182,11 +179,6 @@ public:
|
|||
return std::find(disabled.begin(), disabled.end(), key) != disabled.end();
|
||||
}
|
||||
|
||||
/// @return the Python extension lowering pass, or null if none
|
||||
lowering::PythonExtensionLowering *getPythonExtensionPass() const {
|
||||
return pyExtensionPass;
|
||||
}
|
||||
|
||||
private:
|
||||
void runPass(Module *module, const std::string &name);
|
||||
void registerStandardPasses(Init init);
|
||||
|
|
Loading…
Reference in New Issue