diff --git a/codon/cir/llvm/llvisitor.cpp b/codon/cir/llvm/llvisitor.cpp index 48150480..23a8cf2d 100644 --- a/codon/cir/llvm/llvisitor.cpp +++ b/codon/cir/llvm/llvisitor.cpp @@ -556,6 +556,75 @@ constexpr int PYEXT_METH_METHOD = 0x0200; constexpr int PYEXT_PYTHON_ABI_VERSION = 1013; } // namespace +llvm::Function *LLVMVisitor::createPyTryCatchWrapper(llvm::Function *func) { + auto *wrap = + cast(M->getOrInsertFunction((func->getName() + ".tc_wrap").str(), + func->getFunctionType()) + .getCallee()); + wrap->setPersonalityFn(llvm::cast(makePersonalityFunc().getCallee())); + auto *entry = llvm::BasicBlock::Create(*context, "entry", wrap); + auto *normal = llvm::BasicBlock::Create(*context, "normal", wrap); + auto *unwind = llvm::BasicBlock::Create(*context, "unwind", wrap); + + B->SetInsertPoint(entry); + std::vector args; + for (auto &arg : wrap->args()) { + args.push_back(&arg); + } + auto *result = B->CreateInvoke(func, normal, unwind, args); + + B->SetInsertPoint(normal); + B->CreateRet(result); + + B->SetInsertPoint(unwind); + auto *caughtResult = B->CreateLandingPad(getPadType(), 1); + caughtResult->setCleanup(true); + caughtResult->addClause(getTypeIdxVar(nullptr)); + auto *unwindType = llvm::StructType::get(B->getInt64Ty()); // header only + auto *unwindException = B->CreateExtractValue(caughtResult, 0); + auto *unwindExceptionClass = B->CreateLoad( + B->getInt64Ty(), + B->CreateStructGEP( + unwindType, B->CreatePointerCast(unwindException, unwindType->getPointerTo()), + 0)); + unwindException = B->CreateExtractValue(caughtResult, 0); + auto *excType = llvm::StructType::get(getTypeInfoType(), B->getInt8PtrTy()); + auto *excVal = + B->CreatePointerCast(B->CreateConstGEP1_64(B->getInt8Ty(), unwindException, + (uint64_t)seq_exc_offset()), + excType->getPointerTo()); + auto *loadedExc = B->CreateLoad(excType, excVal); + auto *objPtr = B->CreateExtractValue(loadedExc, 1); + + auto *strType = llvm::StructType::get(B->getInt64Ty(), B->getInt8PtrTy()); + auto *excHeader = llvm::StructType::get(strType, strType); + auto *header = B->CreateLoad(excHeader, objPtr); + auto *msg = B->CreateExtractValue(header, 1); + auto *msgLen = B->CreateExtractValue(msg, 0); + auto *msgPtr = B->CreateExtractValue(msg, 1); + + // copy msg into new null-terminated buffer + auto alloc = makeAllocFunc(/*atomic=*/true); + auto *buf = B->CreateCall(alloc, B->CreateAdd(msgLen, B->getInt64(1))); + B->CreateMemCpy(buf, {}, msgPtr, {}, msgLen); + auto *last = B->CreateInBoundsGEP(B->getInt8Ty(), buf, msgLen); + B->CreateStore(B->getInt8(0), last); + + auto *pyErrSetString = M->getNamedValue("PyErr_SetString"); + seqassertn(pyErrSetString, "'PyErr_SetString' not found in module"); + auto *pyExcRuntimeError = M->getNamedValue("PyExc_RuntimeError"); + seqassertn(pyExcRuntimeError, "'PyExc_RuntimeError' not found in module"); + B->CreateCall(llvm::FunctionCallee( + llvm::FunctionType::get(B->getVoidTy(), + {B->getInt8PtrTy(), B->getInt8PtrTy()}, + /*isVarArg=*/false), + B->CreateLoad(B->getInt8PtrTy(), pyErrSetString)), + {B->CreateLoad(B->getInt8PtrTy(), pyExcRuntimeError), buf}); + B->CreateRet(llvm::Constant::getNullValue(wrap->getReturnType())); + + return wrap; +} + void LLVMVisitor::writeToPythonExtension( const std::string &name, const std::vector> &funcs, const std::string &filename, const std::string &argv0, @@ -570,8 +639,9 @@ void LLVMVisitor::writeToPythonExtension( auto *original = p.first; auto *generated = p.second; auto llvmName = getNameForFunction(generated); - auto *llvmFunc = M->getNamedValue(llvmName); + auto *llvmFunc = M->getFunction(llvmName); seqassertn(llvmFunc, "function {} not found in LLVM module", llvmName); + llvmFunc = createPyTryCatchWrapper(llvmFunc); auto name = original->getUnmangledName(); auto *nameVar = new llvm::GlobalVariable( @@ -658,7 +728,7 @@ void LLVMVisitor::writeToPythonExtension( auto *pyModuleConst = llvm::ConstantExpr::getBitCast(pyModuleVar, ptr); // Construct initialization hook - auto pyModuleCreate = cast( + auto *pyModuleCreate = cast( M->getOrInsertFunction("PyModule_Create2", ptr, ptr, B->getInt32Ty()) .getCallee()); pyModuleCreate->setDoesNotThrow(); diff --git a/codon/cir/llvm/llvisitor.h b/codon/cir/llvm/llvisitor.h index 9d253705..a3bec9c2 100644 --- a/codon/cir/llvm/llvisitor.h +++ b/codon/cir/llvm/llvisitor.h @@ -198,6 +198,9 @@ private: // Shared library setup void setupGlobalCtorForSharedLibrary(); + // Python extension setup + llvm::Function *createPyTryCatchWrapper(llvm::Function *func); + // LLVM passes void runLLVMPipeline(); diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon index 58755839..f6e88578 100644 --- a/stdlib/internal/python.codon +++ b/stdlib/internal/python.codon @@ -107,7 +107,7 @@ PyObject_IsInstance = Function[[cobj, cobj], i32](cobj()) Py_None = cobj() Py_True = cobj() Py_False = cobj() -PyExc_TypeError = cobj() +PyExc_RuntimeError = cobj() Py_LT = 0 Py_LE = 1 Py_EQ = 2 @@ -246,7 +246,7 @@ def init_dl_handles(py_handle: cobj): global Py_None global Py_True global Py_False - global PyExc_TypeError + global PyExc_RuntimeError Py_DecRef = dlsym(py_handle, "Py_DecRef") Py_IncRef = dlsym(py_handle, "Py_IncRef") @@ -342,7 +342,7 @@ def init_dl_handles(py_handle: cobj): Py_None = dlsym(py_handle, "_Py_NoneStruct") Py_True = dlsym(py_handle, "_Py_TrueStruct") Py_False = dlsym(py_handle, "_Py_FalseStruct") - PyExc_TypeError = Ptr[cobj](dlsym(py_handle, "PyExc_TypeError"))[0] + PyExc_RuntimeError = Ptr[cobj](dlsym(py_handle, "PyExc_RuntimeError"))[0] def setup_python(python_loaded: bool): global _PY_INITIALIZED @@ -717,7 +717,7 @@ def _isinstance(what: pyobj, typ: pyobj) -> bool: def _extension_bad_args(got: int, expected: int) -> bool: if got != expected: - PyErr_SetString(PyExc_TypeError, + PyErr_SetString(PyExc_RuntimeError, f"expected {expected} arguments, but got {got}".c_str()) return True return False