Update pyextension codegen (WIP)

pull/335/head
A. R. Shajii 2023-02-10 11:28:04 -05:00
parent c467645aec
commit 7d3f62c014
3 changed files with 50 additions and 63 deletions

View File

@ -369,9 +369,10 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
lflags);
break;
case BuildKind::PyExtension:
compiler->getLLVMVisitor()->writeToPythonExtension(
pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule,
compiler->getModule(), filename);
compiler->getCache()->pyModule->name =
pyModule.empty() ? llvm::sys::path::stem(compiler->getInput()).str() : pyModule;
compiler->getLLVMVisitor()->writeToPythonExtension(*compiler->getCache()->pyModule,
filename);
break;
case BuildKind::Detect:
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);

View File

@ -655,65 +655,63 @@ llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) {
return wrap;
}
void LLVMVisitor::writeToPythonExtension(const std::string &name, const Module *module,
void LLVMVisitor::writeToPythonExtension(const PyModule &pymod,
const std::string &filename) {
// Construct PyMethodDef array
// Setup LLVM types & constants
auto *i64 = B->getInt64Ty();
auto *i32 = B->getInt32Ty();
auto *ptr = B->getInt8PtrTy();
auto *pyMethodDefType = llvm::StructType::create("PyMethodDef", ptr, ptr, i32, ptr);
auto *pyObjectType = llvm::StructType::create("PyObject", i64, ptr);
auto *pyModuleDefBaseType =
llvm::StructType::create("PyMethodDefBase", pyObjectType, ptr, i64, ptr);
auto *pyModuleDefType =
llvm::StructType::create("PyModuleDef", pyModuleDefBaseType, ptr, ptr, i64,
pyMethodDefType->getPointerTo(), ptr, ptr, ptr, ptr);
auto *zero64 = B->getInt64(0);
auto *zero32 = B->getInt32(0);
auto *null = llvm::Constant::getNullValue(ptr);
auto *pyMethodDefType = llvm::StructType::get(ptr, ptr, B->getInt32Ty(), ptr);
// Handle functions
std::vector<llvm::Constant *> pyMethods;
for (auto *var : *module) {
auto *generated = cast<Func>(var);
if (!generated)
continue;
auto *pyWrapAttr = generated->getAttribute<PythonWrapperAttribute>();
if (!pyWrapAttr)
continue;
auto *original = pyWrapAttr->original;
auto llvmName = getNameForFunction(generated);
for (auto &pyfunc : pymod.functions) {
auto llvmName = getNameForFunction(pyfunc.func);
auto *llvmFunc = M->getFunction(llvmName);
seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName);
llvmFunc = createPyTryCatchWrapper(llvmFunc);
auto name = original->getUnmangledName();
if (ast::startswith(name, "._py_"))
name = name.substr(5);
auto *nameVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), name.length() + 1),
*M, llvm::ArrayType::get(B->getInt8Ty(), pyfunc.name.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, name), ".pyext_func_name");
llvm::ConstantDataArray::getString(*context, pyfunc.name), ".pyext_func_name");
nameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
auto *nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
auto *funcConst = llvm::ConstantExpr::getBitCast(llvmFunc, ptr);
auto numArgs = std::distance(original->arg_begin(), original->arg_end());
int flag = 0;
if (numArgs == 0) {
switch (pyfunc.nargs) {
case 0:
flag = PYEXT_METH_NOARGS;
} else if (numArgs == 1) {
case 1:
flag = PYEXT_METH_O;
} else {
default:
flag = PYEXT_METH_FASTCALL;
}
auto *flagConst = B->getInt32(flag);
auto *docsConst = null;
if (auto *docsAttr = original->getAttribute<DocstringAttribute>()) {
auto docs = docsAttr->docstring;
if (pyfunc.doc.empty()) {
auto *docsVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
*M, llvm::ArrayType::get(B->getInt8Ty(), pyfunc.doc.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
llvm::ConstantDataArray::getString(*context, pyfunc.doc), ".pyext_docstring");
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
}
pyMethods.push_back(llvm::ConstantStruct::get(pyMethodDefType, nameConst, funcConst,
pyMethods.push_back(llvm::ConstantStruct::get(pyMethodDefType, nameVar, llvmFunc,
flagConst, docsConst));
}
pyMethods.push_back(
llvm::ConstantStruct::get(pyMethodDefType, null, null, B->getInt32(0), null));
llvm::ConstantStruct::get(pyMethodDefType, null, null, zero32, null));
auto *pyMethodDefArrayType = llvm::ArrayType::get(pyMethodDefType, pyMethods.size());
auto *pyMethodDefArray = new llvm::GlobalVariable(
@ -722,31 +720,23 @@ void LLVMVisitor::writeToPythonExtension(const std::string &name, const Module *
llvm::ConstantArray::get(pyMethodDefArrayType, pyMethods), ".pyext_methods");
// Construct PyModuleDef array
auto *pyObjectType = llvm::StructType::get(B->getInt64Ty(), ptr);
auto *pyModuleDefBaseType =
llvm::StructType::get(pyObjectType, ptr, B->getInt64Ty(), ptr);
auto *pyModuleDefType =
llvm::StructType::get(pyModuleDefBaseType, ptr, ptr, B->getInt64Ty(),
pyMethodDefType->getPointerTo(), ptr, ptr, ptr, ptr);
auto *pyObjectConst = llvm::ConstantStruct::get(pyObjectType, B->getInt64(1), null);
auto *pyModuleDefBaseConst = llvm::ConstantStruct::get(
pyModuleDefBaseType, pyObjectConst, null, B->getInt64(0), null);
auto *nameVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), name.length() + 1),
*M, llvm::ArrayType::get(B->getInt8Ty(), pymod.name.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, name), ".pyext_module_name");
llvm::ConstantDataArray::getString(*context, pymod.name), ".pyext_module_name");
nameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
auto nameConst = llvm::ConstantExpr::getBitCast(nameVar, ptr);
auto *docsConst = null;
if (auto *docsAttr = module->getAttribute<DocstringAttribute>()) {
auto docs = docsAttr->docstring;
if (!pymod.doc.empty()) {
auto *docsVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(B->getInt8Ty(), docs.length() + 1),
*M, llvm::ArrayType::get(B->getInt8Ty(), pymod.doc.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, docs), ".pyext_docstring");
llvm::ConstantDataArray::getString(*context, pymod.doc), ".pyext_docstring");
docsVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
docsConst = llvm::ConstantExpr::getBitCast(docsVar, ptr);
}
@ -762,24 +752,23 @@ void LLVMVisitor::writeToPythonExtension(const std::string &name, const Module *
auto *pyModuleConst = llvm::ConstantExpr::getBitCast(pyModuleVar, ptr);
// Construct initialization hook
auto *pyModuleCreate = cast<llvm::Function>(
M->getOrInsertFunction("PyModule_Create2", ptr, ptr, B->getInt32Ty())
.getCallee());
auto *pyModuleCreate = llvm::cast<llvm::Function>(
M->getOrInsertFunction("PyModule_Create2", ptr, ptr, i32).getCallee());
pyModuleCreate->setDoesNotThrow();
auto *pyModuleInit =
cast<llvm::Function>(M->getOrInsertFunction("PyInit_" + name, ptr).getCallee());
auto *pyModuleInit = llvm::cast<llvm::Function>(
M->getOrInsertFunction("PyInit_" + pymod.name, ptr).getCallee());
auto *entry = llvm::BasicBlock::Create(*context, "entry", pyModuleInit);
B->SetInsertPoint(entry);
if (auto *main = M->getFunction("main")) {
main->setName(".main");
B->CreateCall({main->getFunctionType(), main},
{B->getInt32(0),
llvm::ConstantPointerNull::get(B->getInt8PtrTy()->getPointerTo())});
B->CreateCall({main->getFunctionType(), main}, {zero32, null});
}
B->CreateRet(B->CreateCall(pyModuleCreate,
{pyModuleConst, B->getInt32(PYEXT_PYTHON_ABI_VERSION)}));
// TODO: Codegen types, methods, etc.
writeToObjectFile(filename);
}

View File

@ -4,6 +4,7 @@
#include "codon/cir/cir.h"
#include "codon/cir/llvm/llvm.h"
#include "codon/cir/pyextension.h"
#include "codon/dsl/plugins.h"
#include "codon/util/common.h"
@ -345,14 +346,10 @@ public:
bool library = false,
const std::vector<std::string> &libs = {},
const std::string &lflags = "");
/// Writes module as Python extension object. Exposes
/// functions based on "PythonWrapperAttribute" attached
/// to IR functions.
/// @param name the module's name
/// @param module the IR module
/// Writes module as Python extension object.
/// @param pymod extension module
/// @param filename the file to write to
void writeToPythonExtension(const std::string &name, const Module *module,
const std::string &filename);
void writeToPythonExtension(const PyModule &pymod, const std::string &filename);
/// Runs optimization passes on module and writes the result
/// to the specified file. The output type is determined by
/// the file extension (.ll for LLVM IR, .bc for LLVM bitcode