mirror of https://github.com/exaloop/codon.git
Reorganize API
parent
5920148d8d
commit
c08a2d7d17
|
@ -11,7 +11,6 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "codon/cir/transform/lowering/pyextension.h"
|
|
||||||
#include "codon/compiler/compiler.h"
|
#include "codon/compiler/compiler.h"
|
||||||
#include "codon/compiler/error.h"
|
#include "codon/compiler/error.h"
|
||||||
#include "codon/compiler/jit.h"
|
#include "codon/compiler/jit.h"
|
||||||
|
@ -372,8 +371,7 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
|
||||||
case BuildKind::PyExtension:
|
case BuildKind::PyExtension:
|
||||||
compiler->getLLVMVisitor()->writeToPythonExtension(
|
compiler->getLLVMVisitor()->writeToPythonExtension(
|
||||||
pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule,
|
pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule,
|
||||||
compiler->getPassManager()->getPythonExtensionPass()->getExtensionFunctions(),
|
compiler->getModule(), filename);
|
||||||
filename);
|
|
||||||
break;
|
break;
|
||||||
case BuildKind::Detect:
|
case BuildKind::Detect:
|
||||||
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);
|
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);
|
||||||
|
|
|
@ -39,6 +39,22 @@ std::ostream &MemberAttribute::doFormat(std::ostream &os) const {
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::string PythonWrapperAttribute::AttributeName = "pythonWrapperAttribute";
|
||||||
|
|
||||||
|
std::unique_ptr<Attribute> PythonWrapperAttribute::clone(util::CloneVisitor &cv) const {
|
||||||
|
return std::make_unique<PythonWrapperAttribute>(cast<Func>(cv.clone(original)));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<Attribute>
|
||||||
|
PythonWrapperAttribute::forceClone(util::CloneVisitor &cv) const {
|
||||||
|
return std::make_unique<PythonWrapperAttribute>(cv.forceClone(original));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream &PythonWrapperAttribute::doFormat(std::ostream &os) const {
|
||||||
|
fmt::print(os, FMT_STRING("(pywrap {})"), original->referenceString());
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute";
|
const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute";
|
||||||
|
|
||||||
const std::string DocstringAttribute::AttributeName = "docstringAttribute";
|
const std::string DocstringAttribute::AttributeName = "docstringAttribute";
|
||||||
|
|
|
@ -135,6 +135,26 @@ private:
|
||||||
std::ostream &doFormat(std::ostream &os) const override;
|
std::ostream &doFormat(std::ostream &os) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Attribute used to mark Python wrappers of Codon functions
|
||||||
|
struct PythonWrapperAttribute : public Attribute {
|
||||||
|
static const std::string AttributeName;
|
||||||
|
|
||||||
|
/// the function being wrapped
|
||||||
|
Func *original;
|
||||||
|
|
||||||
|
/// Constructs a PythonWrapperAttribute.
|
||||||
|
/// @param original the function being wrapped
|
||||||
|
explicit PythonWrapperAttribute(Func *original) : original(original) {}
|
||||||
|
|
||||||
|
bool needsClone() const override { return false; }
|
||||||
|
|
||||||
|
std::unique_ptr<Attribute> clone(util::CloneVisitor &cv) const override;
|
||||||
|
std::unique_ptr<Attribute> forceClone(util::CloneVisitor &cv) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::ostream &doFormat(std::ostream &os) const override;
|
||||||
|
};
|
||||||
|
|
||||||
/// Attribute attached to IR structures corresponding to tuple literals
|
/// Attribute attached to IR structures corresponding to tuple literals
|
||||||
struct TupleLiteralAttribute : public Attribute {
|
struct TupleLiteralAttribute : public Attribute {
|
||||||
static const std::string AttributeName;
|
static const std::string AttributeName;
|
||||||
|
|
|
@ -655,18 +655,24 @@ llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) {
|
||||||
return wrap;
|
return wrap;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LLVMVisitor::writeToPythonExtension(
|
void LLVMVisitor::writeToPythonExtension(const std::string &name, const Module *module,
|
||||||
const std::string &name, const std::vector<std::pair<Func *, Func *>> &funcs,
|
const std::string &filename) {
|
||||||
const std::string &filename) {
|
|
||||||
// Construct PyMethodDef array
|
// Construct PyMethodDef array
|
||||||
auto *ptr = B->getInt8PtrTy();
|
auto *ptr = B->getInt8PtrTy();
|
||||||
auto *null = llvm::Constant::getNullValue(ptr);
|
auto *null = llvm::Constant::getNullValue(ptr);
|
||||||
auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr);
|
auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr);
|
||||||
std::vector<llvm::Constant *> pyMethods;
|
std::vector<llvm::Constant *> pyMethods;
|
||||||
|
|
||||||
for (auto &p : funcs) {
|
for (auto *var : *module) {
|
||||||
auto *original = p.first;
|
auto *generated = cast<Func>(var);
|
||||||
auto *generated = p.second;
|
if (!generated)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto *pyWrapAttr = generated->getAttribute<PythonWrapperAttribute>();
|
||||||
|
if (!pyWrapAttr)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto *original = pyWrapAttr->original;
|
||||||
auto llvmName = getNameForFunction(generated);
|
auto llvmName = getNameForFunction(generated);
|
||||||
auto *llvmFunc = M->getFunction(llvmName);
|
auto *llvmFunc = M->getFunction(llvmName);
|
||||||
seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName);
|
seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName);
|
||||||
|
@ -735,17 +741,14 @@ void LLVMVisitor::writeToPythonExtension(
|
||||||
auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
|
auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
|
||||||
|
|
||||||
auto *docsConst = null;
|
auto *docsConst = null;
|
||||||
if (!funcs.empty()) {
|
if (auto *docsAttr = module->getAttribute<DocstringAttribute>()) {
|
||||||
if (auto *docsAttr =
|
auto docs = docsAttr->docstring;
|
||||||
funcs[0].first->getModule()->getAttribute<DocstringAttribute>()) {
|
auto *docsVar = new llvm::GlobalVariable(
|
||||||
auto docs = docsAttr->docstring;
|
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
|
||||||
auto *docsVar = new llvm::GlobalVariable(
|
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||||
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
|
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
|
||||||
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
|
||||||
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
|
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
|
||||||
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
|
|
||||||
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr);
|
auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr);
|
||||||
|
|
|
@ -345,12 +345,13 @@ public:
|
||||||
bool library = false,
|
bool library = false,
|
||||||
const std::vector<std::string> &libs = {},
|
const std::vector<std::string> &libs = {},
|
||||||
const std::string &lflags = "");
|
const std::string &lflags = "");
|
||||||
/// Writes module as Python extension shared object.
|
/// Writes module as Python extension object. Exposes
|
||||||
|
/// functions based on "PythonWrapperAttribute" attached
|
||||||
|
/// to IR functions.
|
||||||
/// @param name the module's name
|
/// @param name the module's name
|
||||||
/// @param funcs extension functions
|
/// @param module the IR module
|
||||||
/// @param filename the file to write to
|
/// @param filename the file to write to
|
||||||
void writeToPythonExtension(const std::string &name,
|
void writeToPythonExtension(const std::string &name, const Module *module,
|
||||||
const std::vector<std::pair<Func *, Func *>> &funcs,
|
|
||||||
const std::string &filename);
|
const std::string &filename);
|
||||||
/// Runs optimization passes on module and writes the result
|
/// Runs optimization passes on module and writes the result
|
||||||
/// to the specified file. The output type is determined by
|
/// to the specified file. The output type is determined by
|
||||||
|
|
|
@ -94,7 +94,7 @@ void PythonExtensionLowering::run(Module *module) {
|
||||||
|
|
||||||
if (auto *g = generateExtensionFunc(f)) {
|
if (auto *g = generateExtensionFunc(f)) {
|
||||||
LOG("[pyext] exporting {}", f->getName());
|
LOG("[pyext] exporting {}", f->getName());
|
||||||
extFuncs.emplace_back(f, g);
|
g->setAttribute(std::make_unique<PythonWrapperAttribute>(f));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,24 +13,11 @@ namespace transform {
|
||||||
namespace lowering {
|
namespace lowering {
|
||||||
|
|
||||||
class PythonExtensionLowering : public Pass {
|
class PythonExtensionLowering : public Pass {
|
||||||
private:
|
|
||||||
/// vector of original function (1st) and generated
|
|
||||||
/// extension wrapper (2nd)
|
|
||||||
std::vector<std::pair<Func *, Func *>> extFuncs;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static const std::string KEY;
|
static const std::string KEY;
|
||||||
std::string getKey() const override { return KEY; }
|
std::string getKey() const override { return KEY; }
|
||||||
|
|
||||||
/// Constructs a PythonExtensionLowering pass.
|
|
||||||
PythonExtensionLowering() : Pass(), extFuncs() {}
|
|
||||||
|
|
||||||
void run(Module *module) override;
|
void run(Module *module) override;
|
||||||
|
|
||||||
/// @return extension function (original, generated) pairs
|
|
||||||
std::vector<std::pair<Func *, Func *>> getExtensionFunctions() const {
|
|
||||||
return extFuncs;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lowering
|
} // namespace lowering
|
||||||
|
|
|
@ -148,12 +148,6 @@ void PassManager::invalidate(const std::string &key) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void PassManager::registerStandardPasses(PassManager::Init init) {
|
void PassManager::registerStandardPasses(PassManager::Init init) {
|
||||||
std::unique_ptr<lowering::PythonExtensionLowering> pyExtPass;
|
|
||||||
if (pyExtension) {
|
|
||||||
pyExtPass = std::make_unique<lowering::PythonExtensionLowering>();
|
|
||||||
pyExtensionPass = pyExtPass.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (init) {
|
switch (init) {
|
||||||
case Init::EMPTY:
|
case Init::EMPTY:
|
||||||
break;
|
break;
|
||||||
|
@ -162,7 +156,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
|
||||||
registerPass(std::make_unique<lowering::ImperativeForFlowLowering>());
|
registerPass(std::make_unique<lowering::ImperativeForFlowLowering>());
|
||||||
registerPass(std::make_unique<parallel::OpenMPPass>());
|
registerPass(std::make_unique<parallel::OpenMPPass>());
|
||||||
if (pyExtension)
|
if (pyExtension)
|
||||||
registerPass(std::move(pyExtPass));
|
registerPass(std::make_unique<lowering::PythonExtensionLowering>());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Init::RELEASE:
|
case Init::RELEASE:
|
||||||
|
@ -210,7 +204,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
|
||||||
{cfgKey, globalKey});
|
{cfgKey, globalKey});
|
||||||
|
|
||||||
if (pyExtension)
|
if (pyExtension)
|
||||||
registerPass(std::move(pyExtPass));
|
registerPass(std::make_unique<lowering::PythonExtensionLowering>());
|
||||||
|
|
||||||
if (init != Init::JIT) {
|
if (init != Init::JIT) {
|
||||||
// Don't demote globals in JIT mode, since they might be used later
|
// Don't demote globals in JIT mode, since they might be used later
|
||||||
|
|
|
@ -98,9 +98,6 @@ private:
|
||||||
/// true if we are compiling as a Python extension
|
/// true if we are compiling as a Python extension
|
||||||
bool pyExtension;
|
bool pyExtension;
|
||||||
|
|
||||||
/// pointer to Python extension lowering pass, if applicable
|
|
||||||
lowering::PythonExtensionLowering *pyExtensionPass;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// PassManager initialization mode.
|
/// PassManager initialization mode.
|
||||||
enum Init {
|
enum Init {
|
||||||
|
@ -113,8 +110,8 @@ public:
|
||||||
explicit PassManager(Init init, std::vector<std::string> disabled = {},
|
explicit PassManager(Init init, std::vector<std::string> disabled = {},
|
||||||
bool pyNumerics = false, bool pyExtension = false)
|
bool pyNumerics = false, bool pyExtension = false)
|
||||||
: km(), passes(), analyses(), executionOrder(), results(),
|
: km(), passes(), analyses(), executionOrder(), results(),
|
||||||
disabled(std::move(disabled)), pyNumerics(pyNumerics), pyExtension(pyExtension),
|
disabled(std::move(disabled)), pyNumerics(pyNumerics),
|
||||||
pyExtensionPass(nullptr) {
|
pyExtension(pyExtension) {
|
||||||
registerStandardPasses(init);
|
registerStandardPasses(init);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,11 +179,6 @@ public:
|
||||||
return std::find(disabled.begin(), disabled.end(), key) != disabled.end();
|
return std::find(disabled.begin(), disabled.end(), key) != disabled.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @return the Python extension lowering pass, or null if none
|
|
||||||
lowering::PythonExtensionLowering *getPythonExtensionPass() const {
|
|
||||||
return pyExtensionPass;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runPass(Module *module, const std::string &name);
|
void runPass(Module *module, const std::string &name);
|
||||||
void registerStandardPasses(Init init);
|
void registerStandardPasses(Init init);
|
||||||
|
|
Loading…
Reference in New Issue