Auto-convert Codon exceptions to Python exceptions

pull/335/head
A. R. Shajii 2023-01-30 15:32:27 -05:00
parent 95f28e11a6
commit 999e42664e
3 changed files with 79 additions and 6 deletions

View File

@ -556,6 +556,75 @@ constexpr int PYEXT_METH_METHOD = 0x0200;
constexpr int PYEXT_PYTHON_ABI_VERSION = 1013;
} // namespace
llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) {
auto *wrap =
cast<llvm::Function>(M->getOrInsertFunction((func->getName() + ".tc_wrap").str(),
func->getFunctionType())
.getCallee());
wrap->setPersonalityFn(llvm::cast<llvm::Constant>(makePersonalityFunc().getCallee()));
auto *entry = llvm::BasicBlock::Create(*context, "entry", wrap);
auto *normal = llvm::BasicBlock::Create(*context, "normal", wrap);
auto *unwind = llvm::BasicBlock::Create(*context, "unwind", wrap);
B->SetInsertPoint(entry);
std::vector<llvm::Value *> args;
for (auto &arg : wrap->args()) {
args.push_back(&arg);
}
auto *result = B->CreateInvoke(func, normal, unwind, args);
B->SetInsertPoint(normal);
B->CreateRet(result);
B->SetInsertPoint(unwind);
auto *caughtResult = B->CreateLandingPad(getPadType(), 1);
caughtResult->setCleanup(true);
caughtResult->addClause(getTypeIdxVar(nullptr));
auto *unwindType = llvm::StructType::get(B->getInt64Ty()); // header only
auto *unwindException = B->CreateExtractValue(caughtResult, 0);
auto *unwindExceptionClass = B->CreateLoad(
B->getInt64Ty(),
B->CreateStructGEP(
unwindType, B->CreatePointerCast(unwindException, unwindType->getPointerTo()),
0));
unwindException = B->CreateExtractValue(caughtResult, 0);
auto *excType = llvm::StructType::get(getTypeInfoType(), B->getInt8PtrTy());
auto *excVal =
B->CreatePointerCast(B->CreateConstGEP1_64(B->getInt8Ty(), unwindException,
(uint64_t)seq_exc_offset()),
excType->getPointerTo());
auto *loadedExc = B->CreateLoad(excType, excVal);
auto *objPtr = B->CreateExtractValue(loadedExc, 1);
auto *strType = llvm::StructType::get(B->getInt64Ty(), B->getInt8PtrTy());
auto *excHeader = llvm::StructType::get(strType, strType);
auto *header = B->CreateLoad(excHeader, objPtr);
auto *msg = B->CreateExtractValue(header, 1);
auto *msgLen = B->CreateExtractValue(msg, 0);
auto *msgPtr = B->CreateExtractValue(msg, 1);
// copy msg into new null-terminated buffer
auto alloc = makeAllocFunc(/*atomic=*/true);
auto *buf = B->CreateCall(alloc, B->CreateAdd(msgLen, B->getInt64(1)));
B->CreateMemCpy(buf, {}, msgPtr, {}, msgLen);
auto *last = B->CreateInBoundsGEP(B->getInt8Ty(), buf, msgLen);
B->CreateStore(B->getInt8(0), last);
auto *pyErrSetString = M->getNamedValue("PyErr_SetString");
seqassertn(pyErrSetString, "'PyErr_SetString' not found in module");
auto *pyExcRuntimeError = M->getNamedValue("PyExc_RuntimeError");
seqassertn(pyExcRuntimeError, "'PyExc_RuntimeError' not found in module");
B->CreateCall(llvm::FunctionCallee(
llvm::FunctionType::get(B->getVoidTy(),
{B->getInt8PtrTy(), B->getInt8PtrTy()},
/*isVarArg=*/false),
B->CreateLoad(B->getInt8PtrTy(), pyErrSetString)),
{B->CreateLoad(B->getInt8PtrTy(), pyExcRuntimeError), buf});
B->CreateRet(llvm::Constant::getNullValue(wrap->getReturnType()));
return wrap;
}
void LLVMVisitor::writeToPythonExtension(
const std::string &name, const std::vector<std::pair<Func *, Func *>> &funcs,
const std::string &filename, const std::string &argv0,
@ -570,8 +639,9 @@ void LLVMVisitor::writeToPythonExtension(
auto *original = p.first;
auto *generated = p.second;
auto llvmName = getNameForFunction(generated);
auto *llvmFunc = M->getNamedValue(llvmName);
auto *llvmFunc = M->getFunction(llvmName);
seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName);
llvmFunc = createPyTryCatchWrapper(llvmFunc);
auto name = original->getUnmangledName();
auto *nameVar = new llvm::GlobalVariable(
@ -658,7 +728,7 @@ void LLVMVisitor::writeToPythonExtension(
auto *pyModuleConst = llvm::ConstantExpr::getBitCast(pyModuleVar, ptr);
// Construct initialization hook
auto pyModuleCreate = cast<llvm::Function>(
auto *pyModuleCreate = cast<llvm::Function>(
M->getOrInsertFunction("PyModule_Create2", ptr, ptr, B->getInt32Ty())
.getCallee());
pyModuleCreate->setDoesNotThrow();

View File

@ -198,6 +198,9 @@ private:
// Shared library setup
void setupGlobalCtorForSharedLibrary();
// Python extension setup
llvm::Function *createPyTryCatchWrapper(llvm::Function *func);
// LLVM passes
void runLLVMPipeline();

View File

@ -107,7 +107,7 @@ PyObject_IsInstance = Function[[cobj, cobj], i32](cobj())
Py_None = cobj()
Py_True = cobj()
Py_False = cobj()
PyExc_TypeError = cobj()
PyExc_RuntimeError = cobj()
Py_LT = 0
Py_LE = 1
Py_EQ = 2
@ -246,7 +246,7 @@ def init_dl_handles(py_handle: cobj):
global Py_None
global Py_True
global Py_False
global PyExc_TypeError
global PyExc_RuntimeError
Py_DecRef = dlsym(py_handle, "Py_DecRef")
Py_IncRef = dlsym(py_handle, "Py_IncRef")
@ -342,7 +342,7 @@ def init_dl_handles(py_handle: cobj):
Py_None = dlsym(py_handle, "_Py_NoneStruct")
Py_True = dlsym(py_handle, "_Py_TrueStruct")
Py_False = dlsym(py_handle, "_Py_FalseStruct")
PyExc_TypeError = Ptr[cobj](dlsym(py_handle, "PyExc_TypeError"))[0]
PyExc_RuntimeError = Ptr[cobj](dlsym(py_handle, "PyExc_RuntimeError"))[0]
def setup_python(python_loaded: bool):
global _PY_INITIALIZED
@ -717,7 +717,7 @@ def _isinstance(what: pyobj, typ: pyobj) -> bool:
def _extension_bad_args(got: int, expected: int) -> bool:
if got != expected:
PyErr_SetString(PyExc_TypeError,
PyErr_SetString(PyExc_RuntimeError,
f"expected {expected} arguments, but got {got}".c_str())
return True
return False