Reorganize API

pull/335/head
A. R. Shajii 2023-02-06 13:42:37 -05:00
parent 5920148d8d
commit c08a2d7d17
9 changed files with 67 additions and 56 deletions

View File

@ -11,7 +11,6 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "codon/cir/transform/lowering/pyextension.h"
#include "codon/compiler/compiler.h" #include "codon/compiler/compiler.h"
#include "codon/compiler/error.h" #include "codon/compiler/error.h"
#include "codon/compiler/jit.h" #include "codon/compiler/jit.h"
@ -372,8 +371,7 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
case BuildKind::PyExtension: case BuildKind::PyExtension:
compiler->getLLVMVisitor()->writeToPythonExtension( compiler->getLLVMVisitor()->writeToPythonExtension(
pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule, pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule,
compiler->getPassManager()->getPythonExtensionPass()->getExtensionFunctions(), compiler->getModule(), filename);
filename);
break; break;
case BuildKind::Detect: case BuildKind::Detect:
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags); compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);

View File

@ -39,6 +39,22 @@ std::ostream &MemberAttribute::doFormat(std::ostream &os) const {
return os; return os;
} }
const std::string PythonWrapperAttribute::AttributeName = "pythonWrapperAttribute";
std::unique_ptr<Attribute> PythonWrapperAttribute::clone(util::CloneVisitor &cv) const {
return std::make_unique<PythonWrapperAttribute>(cast<Func>(cv.clone(original)));
}
std::unique_ptr<Attribute>
PythonWrapperAttribute::forceClone(util::CloneVisitor &cv) const {
return std::make_unique<PythonWrapperAttribute>(cv.forceClone(original));
}
std::ostream &PythonWrapperAttribute::doFormat(std::ostream &os) const {
fmt::print(os, FMT_STRING("(pywrap {})"), original->referenceString());
return os;
}
const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute"; const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute";
const std::string DocstringAttribute::AttributeName = "docstringAttribute"; const std::string DocstringAttribute::AttributeName = "docstringAttribute";

View File

@ -135,6 +135,26 @@ private:
std::ostream &doFormat(std::ostream &os) const override; std::ostream &doFormat(std::ostream &os) const override;
}; };
/// Attribute used to mark Python wrappers of Codon functions
struct PythonWrapperAttribute : public Attribute {
static const std::string AttributeName;
/// the function being wrapped
Func *original;
/// Constructs a PythonWrapperAttribute.
/// @param original the function being wrapped
explicit PythonWrapperAttribute(Func *original) : original(original) {}
bool needsClone() const override { return false; }
std::unique_ptr<Attribute> clone(util::CloneVisitor &cv) const override;
std::unique_ptr<Attribute> forceClone(util::CloneVisitor &cv) const override;
private:
std::ostream &doFormat(std::ostream &os) const override;
};
/// Attribute attached to IR structures corresponding to tuple literals /// Attribute attached to IR structures corresponding to tuple literals
struct TupleLiteralAttribute : public Attribute { struct TupleLiteralAttribute : public Attribute {
static const std::string AttributeName; static const std::string AttributeName;

View File

@ -655,18 +655,24 @@ llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) {
return wrap; return wrap;
} }
void LLVMVisitor::writeToPythonExtension( void LLVMVisitor::writeToPythonExtension(const std::string &name, const Module *module,
const std::string &name, const std::vector<std::pair<Func *, Func *>> &funcs, const std::string &filename) {
const std::string &filename) {
// Construct PyMethodDef array // Construct PyMethodDef array
auto *ptr = B->getInt8PtrTy(); auto *ptr = B->getInt8PtrTy();
auto *null = llvm::Constant::getNullValue(ptr); auto *null = llvm::Constant::getNullValue(ptr);
auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr); auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr);
std::vector<llvm::Constant *> pyMethods; std::vector<llvm::Constant *> pyMethods;
for (auto &p : funcs) { for (auto *var : *module) {
auto *original = p.first; auto *generated = cast<Func>(var);
auto *generated = p.second; if (!generated)
continue;
auto *pyWrapAttr = generated->getAttribute<PythonWrapperAttribute>();
if (!pyWrapAttr)
continue;
auto *original = pyWrapAttr->original;
auto llvmName = getNameForFunction(generated); auto llvmName = getNameForFunction(generated);
auto *llvmFunc = M->getFunction(llvmName); auto *llvmFunc = M->getFunction(llvmName);
seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName); seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName);
@ -735,17 +741,14 @@ void LLVMVisitor::writeToPythonExtension(
auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr); auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
auto *docsConst = null; auto *docsConst = null;
if (!funcs.empty()) { if (auto *docsAttr = module->getAttribute<DocstringAttribute>()) {
if (auto *docsAttr = auto docs = docsAttr->docstring;
funcs[0].first->getModule()->getAttribute<DocstringAttribute>()) { auto *docsVar = new llvm::GlobalVariable(
auto docs = docsAttr->docstring; *M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
auto *docsVar = new llvm::GlobalVariable( /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1), llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring"); docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
}
} }
auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr); auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr);

View File

@ -345,12 +345,13 @@ public:
bool library = false, bool library = false,
const std::vector<std::string> &libs = {}, const std::vector<std::string> &libs = {},
const std::string &lflags = ""); const std::string &lflags = "");
/// Writes module as Python extension shared object. /// Writes module as Python extension object. Exposes
/// functions based on "PythonWrapperAttribute" attached
/// to IR functions.
/// @param name the module's name /// @param name the module's name
/// @param funcs extension functions /// @param module the IR module
/// @param filename the file to write to /// @param filename the file to write to
void writeToPythonExtension(const std::string &name, void writeToPythonExtension(const std::string &name, const Module *module,
const std::vector<std::pair<Func *, Func *>> &funcs,
const std::string &filename); const std::string &filename);
/// Runs optimization passes on module and writes the result /// Runs optimization passes on module and writes the result
/// to the specified file. The output type is determined by /// to the specified file. The output type is determined by

View File

@ -94,7 +94,7 @@ void PythonExtensionLowering::run(Module *module) {
if (auto *g = generateExtensionFunc(f)) { if (auto *g = generateExtensionFunc(f)) {
LOG("[pyext] exporting {}", f->getName()); LOG("[pyext] exporting {}", f->getName());
extFuncs.emplace_back(f, g); g->setAttribute(std::make_unique<PythonWrapperAttribute>(f));
} }
} }
} }

View File

@ -13,24 +13,11 @@ namespace transform {
namespace lowering { namespace lowering {
class PythonExtensionLowering : public Pass { class PythonExtensionLowering : public Pass {
private:
/// vector of original function (1st) and generated
/// extension wrapper (2nd)
std::vector<std::pair<Func *, Func *>> extFuncs;
public: public:
static const std::string KEY; static const std::string KEY;
std::string getKey() const override { return KEY; } std::string getKey() const override { return KEY; }
/// Constructs a PythonExtensionLowering pass.
PythonExtensionLowering() : Pass(), extFuncs() {}
void run(Module *module) override; void run(Module *module) override;
/// @return extension function (original, generated) pairs
std::vector<std::pair<Func *, Func *>> getExtensionFunctions() const {
return extFuncs;
}
}; };
} // namespace lowering } // namespace lowering

View File

@ -148,12 +148,6 @@ void PassManager::invalidate(const std::string &key) {
} }
void PassManager::registerStandardPasses(PassManager::Init init) { void PassManager::registerStandardPasses(PassManager::Init init) {
std::unique_ptr<lowering::PythonExtensionLowering> pyExtPass;
if (pyExtension) {
pyExtPass = std::make_unique<lowering::PythonExtensionLowering>();
pyExtensionPass = pyExtPass.get();
}
switch (init) { switch (init) {
case Init::EMPTY: case Init::EMPTY:
break; break;
@ -162,7 +156,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
registerPass(std::make_unique<lowering::ImperativeForFlowLowering>()); registerPass(std::make_unique<lowering::ImperativeForFlowLowering>());
registerPass(std::make_unique<parallel::OpenMPPass>()); registerPass(std::make_unique<parallel::OpenMPPass>());
if (pyExtension) if (pyExtension)
registerPass(std::move(pyExtPass)); registerPass(std::make_unique<lowering::PythonExtensionLowering>());
break; break;
} }
case Init::RELEASE: case Init::RELEASE:
@ -210,7 +204,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
{cfgKey, globalKey}); {cfgKey, globalKey});
if (pyExtension) if (pyExtension)
registerPass(std::move(pyExtPass)); registerPass(std::make_unique<lowering::PythonExtensionLowering>());
if (init != Init::JIT) { if (init != Init::JIT) {
// Don't demote globals in JIT mode, since they might be used later // Don't demote globals in JIT mode, since they might be used later

View File

@ -98,9 +98,6 @@ private:
/// true if we are compiling as a Python extension /// true if we are compiling as a Python extension
bool pyExtension; bool pyExtension;
/// pointer to Python extension lowering pass, if applicable
lowering::PythonExtensionLowering *pyExtensionPass;
public: public:
/// PassManager initialization mode. /// PassManager initialization mode.
enum Init { enum Init {
@ -113,8 +110,8 @@ public:
explicit PassManager(Init init, std::vector<std::string> disabled = {}, explicit PassManager(Init init, std::vector<std::string> disabled = {},
bool pyNumerics = false, bool pyExtension = false) bool pyNumerics = false, bool pyExtension = false)
: km(), passes(), analyses(), executionOrder(), results(), : km(), passes(), analyses(), executionOrder(), results(),
disabled(std::move(disabled)), pyNumerics(pyNumerics), pyExtension(pyExtension), disabled(std::move(disabled)), pyNumerics(pyNumerics),
pyExtensionPass(nullptr) { pyExtension(pyExtension) {
registerStandardPasses(init); registerStandardPasses(init);
} }
@ -182,11 +179,6 @@ public:
return std::find(disabled.begin(), disabled.end(), key) != disabled.end(); return std::find(disabled.begin(), disabled.end(), key) != disabled.end();
} }
/// @return the Python extension lowering pass, or null if none
lowering::PythonExtensionLowering *getPythonExtensionPass() const {
return pyExtensionPass;
}
private: private:
void runPass(Module *module, const std::string &name); void runPass(Module *module, const std::string &name);
void registerStandardPasses(Init init); void registerStandardPasses(Init init);