Reorganize API

pull/335/head
A. R. Shajii 2023-02-06 13:42:37 -05:00
parent 5920148d8d
commit c08a2d7d17
9 changed files with 67 additions and 56 deletions

View File

@ -11,7 +11,6 @@
#include <unordered_map>
#include <vector>
#include "codon/cir/transform/lowering/pyextension.h"
#include "codon/compiler/compiler.h"
#include "codon/compiler/error.h"
#include "codon/compiler/jit.h"
@ -372,8 +371,7 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
case BuildKind::PyExtension:
compiler->getLLVMVisitor()->writeToPythonExtension(
pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule,
compiler->getPassManager()->getPythonExtensionPass()->getExtensionFunctions(),
filename);
compiler->getModule(), filename);
break;
case BuildKind::Detect:
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);

View File

@ -39,6 +39,22 @@ std::ostream &MemberAttribute::doFormat(std::ostream &os) const {
return os;
}
const std::string PythonWrapperAttribute::AttributeName = "pythonWrapperAttribute";
std::unique_ptr<Attribute> PythonWrapperAttribute::clone(util::CloneVisitor &cv) const {
return std::make_unique<PythonWrapperAttribute>(cast<Func>(cv.clone(original)));
}
std::unique_ptr<Attribute>
PythonWrapperAttribute::forceClone(util::CloneVisitor &cv) const {
return std::make_unique<PythonWrapperAttribute>(cv.forceClone(original));
}
std::ostream &PythonWrapperAttribute::doFormat(std::ostream &os) const {
fmt::print(os, FMT_STRING("(pywrap {})"), original->referenceString());
return os;
}
const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute";
const std::string DocstringAttribute::AttributeName = "docstringAttribute";

View File

@ -135,6 +135,26 @@ private:
std::ostream &doFormat(std::ostream &os) const override;
};
/// Attribute used to mark Python wrappers of Codon functions
struct PythonWrapperAttribute : public Attribute {
static const std::string AttributeName;
/// the function being wrapped
Func *original;
/// Constructs a PythonWrapperAttribute.
/// @param original the function being wrapped
explicit PythonWrapperAttribute(Func *original) : original(original) {}
bool needsClone() const override { return false; }
std::unique_ptr<Attribute> clone(util::CloneVisitor &cv) const override;
std::unique_ptr<Attribute> forceClone(util::CloneVisitor &cv) const override;
private:
std::ostream &doFormat(std::ostream &os) const override;
};
/// Attribute attached to IR structures corresponding to tuple literals
struct TupleLiteralAttribute : public Attribute {
static const std::string AttributeName;

View File

@ -655,18 +655,24 @@ llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) {
return wrap;
}
void LLVMVisitor::writeToPythonExtension(
const std::string &name, const std::vector<std::pair<Func *, Func *>> &funcs,
const std::string &filename) {
void LLVMVisitor::writeToPythonExtension(const std::string &name, const Module *module,
const std::string &filename) {
// Construct PyMethodDef array
auto *ptr = B->getInt8PtrTy();
auto *null = llvm::Constant::getNullValue(ptr);
auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr);
std::vector<llvm::Constant *> pyMethods;
for (auto &p : funcs) {
auto *original = p.first;
auto *generated = p.second;
for (auto *var : *module) {
auto *generated = cast<Func>(var);
if (!generated)
continue;
auto *pyWrapAttr = generated->getAttribute<PythonWrapperAttribute>();
if (!pyWrapAttr)
continue;
auto *original = pyWrapAttr->original;
auto llvmName = getNameForFunction(generated);
auto *llvmFunc = M->getFunction(llvmName);
seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName);
@ -735,17 +741,14 @@ void LLVMVisitor::writeToPythonExtension(
auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
auto *docsConst = null;
if (!funcs.empty()) {
if (auto *docsAttr =
funcs[0].first->getModule()->getAttribute<DocstringAttribute>()) {
auto docs = docsAttr->docstring;
auto *docsVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
}
if (auto *docsAttr = module->getAttribute<DocstringAttribute>()) {
auto docs = docsAttr->docstring;
auto *docsVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
}
auto *pyMethodArrayConst = llvm::ConstantExpr::getBitCast(pyMethodDefArray, ptr);

View File

@ -345,12 +345,13 @@ public:
bool library = false,
const std::vector<std::string> &libs = {},
const std::string &lflags = "");
/// Writes module as Python extension shared object.
/// Writes module as Python extension object. Exposes
/// functions based on "PythonWrapperAttribute" attached
/// to IR functions.
/// @param name the module's name
/// @param funcs extension functions
/// @param module the IR module
/// @param filename the file to write to
void writeToPythonExtension(const std::string &name,
const std::vector<std::pair<Func *, Func *>> &funcs,
void writeToPythonExtension(const std::string &name, const Module *module,
const std::string &filename);
/// Runs optimization passes on module and writes the result
/// to the specified file. The output type is determined by

View File

@ -94,7 +94,7 @@ void PythonExtensionLowering::run(Module *module) {
if (auto *g = generateExtensionFunc(f)) {
LOG("[pyext] exporting {}", f->getName());
extFuncs.emplace_back(f, g);
g->setAttribute(std::make_unique<PythonWrapperAttribute>(f));
}
}
}

View File

@ -13,24 +13,11 @@ namespace transform {
namespace lowering {
class PythonExtensionLowering : public Pass {
private:
/// vector of original function (1st) and generated
/// extension wrapper (2nd)
std::vector<std::pair<Func *, Func *>> extFuncs;
public:
static const std::string KEY;
std::string getKey() const override { return KEY; }
/// Constructs a PythonExtensionLowering pass.
PythonExtensionLowering() : Pass(), extFuncs() {}
void run(Module *module) override;
/// @return extension function (original, generated) pairs
std::vector<std::pair<Func *, Func *>> getExtensionFunctions() const {
return extFuncs;
}
};
} // namespace lowering

View File

@ -148,12 +148,6 @@ void PassManager::invalidate(const std::string &key) {
}
void PassManager::registerStandardPasses(PassManager::Init init) {
std::unique_ptr<lowering::PythonExtensionLowering> pyExtPass;
if (pyExtension) {
pyExtPass = std::make_unique<lowering::PythonExtensionLowering>();
pyExtensionPass = pyExtPass.get();
}
switch (init) {
case Init::EMPTY:
break;
@ -162,7 +156,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
registerPass(std::make_unique<lowering::ImperativeForFlowLowering>());
registerPass(std::make_unique<parallel::OpenMPPass>());
if (pyExtension)
registerPass(std::move(pyExtPass));
registerPass(std::make_unique<lowering::PythonExtensionLowering>());
break;
}
case Init::RELEASE:
@ -210,7 +204,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
{cfgKey, globalKey});
if (pyExtension)
registerPass(std::move(pyExtPass));
registerPass(std::make_unique<lowering::PythonExtensionLowering>());
if (init != Init::JIT) {
// Don't demote globals in JIT mode, since they might be used later

View File

@ -98,9 +98,6 @@ private:
/// true if we are compiling as a Python extension
bool pyExtension;
/// pointer to Python extension lowering pass, if applicable
lowering::PythonExtensionLowering *pyExtensionPass;
public:
/// PassManager initialization mode.
enum Init {
@ -113,8 +110,8 @@ public:
explicit PassManager(Init init, std::vector<std::string> disabled = {},
bool pyNumerics = false, bool pyExtension = false)
: km(), passes(), analyses(), executionOrder(), results(),
disabled(std::move(disabled)), pyNumerics(pyNumerics), pyExtension(pyExtension),
pyExtensionPass(nullptr) {
disabled(std::move(disabled)), pyNumerics(pyNumerics),
pyExtension(pyExtension) {
registerStandardPasses(init);
}
@ -182,11 +179,6 @@ public:
return std::find(disabled.begin(), disabled.end(), key) != disabled.end();
}
/// @return the Python extension lowering pass, or null if none
lowering::PythonExtensionLowering *getPythonExtensionPass() const {
return pyExtensionPass;
}
private:
void runPass(Module *module, const std::string &name);
void registerStandardPasses(Init init);