From 12e8fe76660b2c203c134a8e9e1c3a25693b23f7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ibrahim=20Numanagi=C4=87?=
 <inumanag@users.noreply.github.com>
Date: Thu, 31 Mar 2022 01:22:26 -0700
Subject: [PATCH] @codon Python decorator and Python interop fixes (#19)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Codon decorator

* Move to extra/cython, add error handling

* Small fixes

* CR

* CR

* Fix cython CI

* Fix cython CI v2

* Fix cython CI v3

* Fix cython CI v4

* Fix cython CI v5

* Fix cython CI v6

* Fix cython CI v7

* Fix cython CI v8

* Fix cython CI v9

* Fix cython CI v10

* Fix cython CI v11

* CR

* Fix CI

* Fix CI

* Fix CI

* Fix CI

* Fix CI

Co-authored-by: Ishak Numanagić <ishak.numanagic@gmail.com>
---
 .github/actions/build-manylinux/entrypoint.sh |  11 +
 .github/workflows/ci.yml                      |  12 +-
 .gitignore                                    |   2 +
 codon/app/main.cpp                            |   2 +-
 codon/compiler/jit.cpp                        |  11 +-
 codon/compiler/jit.h                          |  26 +-
 codon/parser/visitors/simplify/simplify.cpp   |   2 +-
 extra/jupyter/jupyter.cpp                     |   2 +-
 extra/python/README.md                        |  14 +
 extra/python/pyproject.toml                   |   2 +
 extra/python/setup.py                         |  57 ++++
 extra/python/src/__init__.py                  |   3 +
 extra/python/src/decorator.py                 | 174 +++++++++++
 extra/python/src/jit.pxd                      |  16 +
 extra/python/src/jit.pyx                      |  31 ++
 stdlib/internal/dlopen.codon                  |   2 +-
 stdlib/internal/python.codon                  | 275 ++++++++++--------
 test/python/cython_jit.py                     |  87 ++++++
 18 files changed, 602 insertions(+), 127 deletions(-)
 create mode 100644 extra/python/README.md
 create mode 100644 extra/python/pyproject.toml
 create mode 100644 extra/python/setup.py
 create mode 100644 extra/python/src/__init__.py
 create mode 100644 extra/python/src/decorator.py
 create mode 100644 extra/python/src/jit.pxd
 create mode 100644 extra/python/src/jit.pyx
 create mode 100644 test/python/cython_jit.py

diff --git a/.github/actions/build-manylinux/entrypoint.sh b/.github/actions/build-manylinux/entrypoint.sh
index e8495ae6..4a3a69df 100755
--- a/.github/actions/build-manylinux/entrypoint.sh
+++ b/.github/actions/build-manylinux/entrypoint.sh
@@ -26,11 +26,22 @@ export LLVM_DIR=$(llvm/bin/llvm-config --cmakedir)
                       -DCMAKE_CXX_COMPILER=${CXX})
 cmake --build build --config Release -- VERBOSE=1
 
+# build cython
+export PATH=$PATH:$(pwd)/llvm/bin
+export LD_LIBRARY_PATH=$(pwd)/build:$LD_LIBRARY_PATH
+export CODON_INCLUDE_DIR=$(pwd)/build/include
+export CODON_LIB_DIR=$(pwd)/build
+python3 -m pip install cython
+python3 -m pip install -v extra/python
+
 # test
+export CODON_PATH=$(pwd)/stdlib
 ln -s build/libcodonrt.so .
 build/codon_test
 build/codon run test/core/helloworld.codon
 build/codon run test/core/exit.codon || if [[ $? -ne 42 ]]; then false; fi
+export PYTHONPATH=$(pwd):$PYTHONPATH
+python3 test/python/cython_jit.py
 
 # package
 export CODON_BUILD_ARCHIVE=codon-$(uname -s | awk '{print tolower($0)}')-$(uname -m).tar.gz
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 4c22050b..5805b3d5 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -151,6 +151,14 @@ jobs:
           CC: clang
           CXX: clang++
 
+      - name: Build Cython
+        run: |
+          python -m pip install -v extra/python
+        env:
+          CC: clang
+          CXX: clang++
+          PYTHONPATH: ./test/python
+
       - name: Test
         run: |
           echo $CODON_PYTHON
@@ -158,9 +166,11 @@ jobs:
           build/codon_test
           build/codon run test/core/helloworld.codon
           build/codon run test/core/exit.codon || if [[ $? -ne 42 ]]; then false; fi
+          python test/python/cython_jit.py
         env:
           CODON_PATH: ./stdlib
-          PYTHONPATH: ./test/python
+          PYTHONPATH: .:./test/python
+          LD_LIBRARY_PATH: ./build
 
       - name: Build Documentation
         run: |
diff --git a/.gitignore b/.gitignore
index 3bc55219..72101f09 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,6 +16,7 @@
 *.pyc
 build/
 build_*/
+extra/python/src/jit.cpp
 extra/jupyter/build/
 
 # Packages #
@@ -29,6 +30,7 @@ extra/jupyter/build/
 *.rar
 *.tar
 *.zip
+**/**.egg-info
 
 # Logs and databases #
 ######################
diff --git a/codon/app/main.cpp b/codon/app/main.cpp
index 0387a8b1..d3dcb128 100644
--- a/codon/app/main.cpp
+++ b/codon/app/main.cpp
@@ -189,7 +189,7 @@ int runMode(const std::vector<const char *> &args) {
 
 namespace {
 std::string jitExec(codon::jit::JIT *jit, const std::string &code) {
-  auto result = jit->exec(code);
+  auto result = jit->execute(code);
   if (auto err = result.takeError()) {
     std::string output;
     llvm::handleAllErrors(
diff --git a/codon/compiler/jit.cpp b/codon/compiler/jit.cpp
index 881e5a82..0a70cb03 100644
--- a/codon/compiler/jit.cpp
+++ b/codon/compiler/jit.cpp
@@ -100,7 +100,7 @@ llvm::Expected<std::string> JIT::run(const ir::Func *input) {
   return getCapturedOutput();
 }
 
-llvm::Expected<std::string> JIT::exec(const std::string &code) {
+llvm::Expected<std::string> JIT::execute(const std::string &code) {
   auto *cache = compiler->getCache();
   ast::StmtPtr node = ast::parseCode(cache, JIT_FILENAME, code, /*startLine=*/0);
 
@@ -160,5 +160,14 @@ llvm::Expected<std::string> JIT::exec(const std::string &code) {
   }
 }
 
+JITResult JIT::executeSafe(const std::string &code) {
+  auto result = this->execute(code);
+  if (auto err = result.takeError()) {
+    auto errorInfo = llvm::toString(std::move(err));
+    return JITResult::error(errorInfo);
+  }
+  return JITResult::success(result.get());
+}
+
 } // namespace jit
 } // namespace codon
diff --git a/codon/compiler/jit.h b/codon/compiler/jit.h
index ab3e42a7..61c92889 100644
--- a/codon/compiler/jit.h
+++ b/codon/compiler/jit.h
@@ -15,6 +15,29 @@
 namespace codon {
 namespace jit {
 
+struct JITResult {
+  std::string data;
+  bool isError;
+
+  JITResult():
+    data(""), isError(false) {}
+
+  JITResult(const std::string &data, bool isError):
+    data(data), isError(isError) {}
+
+  operator bool() {
+    return !this->isError;
+  }
+
+  static JITResult success(const std::string &output) {
+    return JITResult(output, false);
+  }
+
+  static JITResult error(const std::string &errorInfo) {
+    return JITResult(errorInfo, true);
+  }
+};
+
 class JIT {
 private:
   std::unique_ptr<Compiler> compiler;
@@ -29,7 +52,8 @@ public:
 
   llvm::Error init();
   llvm::Expected<std::string> run(const ir::Func *input);
-  llvm::Expected<std::string> exec(const std::string &code);
+  llvm::Expected<std::string> execute(const std::string &code);
+  JITResult executeSafe(const std::string &code);
 };
 
 } // namespace jit
diff --git a/codon/parser/visitors/simplify/simplify.cpp b/codon/parser/visitors/simplify/simplify.cpp
index 6c149a94..07f068e1 100644
--- a/codon/parser/visitors/simplify/simplify.cpp
+++ b/codon/parser/visitors/simplify/simplify.cpp
@@ -98,7 +98,7 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil
     stdlib->isStdlibLoading = true;
     stdlib->moduleName = {ImportFile::STDLIB, stdlibPath->path, "__init__"};
     auto baseTypeCode =
-        "@__internal__\n@tuple\nclass pyobj:\n  p: Ptr[byte]\n"
+        "@__internal__\nclass pyobj:\n  p: Ptr[byte]\n"
         "@__internal__\n@tuple\nclass str:\n  ptr: Ptr[byte]\n  len: int\n";
     SimplifyVisitor(stdlib, preamble)
         .transform(parseCode(stdlib->cache, stdlibPath->path, baseTypeCode));
diff --git a/extra/jupyter/jupyter.cpp b/extra/jupyter/jupyter.cpp
index eb068f28..f3ec268f 100644
--- a/extra/jupyter/jupyter.cpp
+++ b/extra/jupyter/jupyter.cpp
@@ -33,7 +33,7 @@ nl::json CodonJupyter::execute_request_impl(int execution_counter, const string
                                             bool silent, bool store_history,
                                             nl::json user_expressions,
                                             bool allow_stdin) {
-  auto result = jit->exec(code);
+  auto result = jit->execute(code);
   string failed;
   llvm::handleAllErrors(
       result.takeError(),
diff --git a/extra/python/README.md b/extra/python/README.md
new file mode 100644
index 00000000..8c2878cb
--- /dev/null
+++ b/extra/python/README.md
@@ -0,0 +1,14 @@
+To install:
+
+```bash
+$ pip install extra/python
+```
+
+To use:
+
+```python
+from codon import codon, JitError
+
+@codon
+def ...
+```
diff --git a/extra/python/pyproject.toml b/extra/python/pyproject.toml
new file mode 100644
index 00000000..c753eee1
--- /dev/null
+++ b/extra/python/pyproject.toml
@@ -0,0 +1,2 @@
+[build-system]
+requires = ["cython", "setuptools", "wheel"]
diff --git a/extra/python/setup.py b/extra/python/setup.py
new file mode 100644
index 00000000..e658a9c9
--- /dev/null
+++ b/extra/python/setup.py
@@ -0,0 +1,57 @@
+import os
+import subprocess
+
+from Cython.Distutils import build_ext
+from setuptools import setup
+from setuptools.extension import Extension
+
+
+def exists(executable):
+    ps = subprocess.run(["which", executable], stdout=subprocess.PIPE)
+    return ps.returncode == 0
+
+
+def get_output(*args):
+    ps = subprocess.run(args, stdout=subprocess.PIPE)
+    return ps.stdout.decode("utf8").strip()
+
+
+from_root = lambda relpath: os.path.realpath(f"{os.getcwd()}/../../{relpath}")
+
+llvm_config: str
+llvm_config_candidates = ["llvm-config-12", "llvm-config", from_root("llvm/bin/llvm-config")]
+for candidate in llvm_config_candidates:
+    if exists(candidate):
+        llvm_config = candidate
+        break
+else:
+    raise FileNotFoundError("Cannot find llvm-config; is llvm installed?")
+
+llvm_include_dir = get_output(llvm_config, "--includedir")
+llvm_lib_dir = get_output(llvm_config, "--libdir")
+
+codon_include_dir = os.environ.get("CODON_INCLUDE_DIR", from_root("build/include"))
+codon_lib_dir = os.environ.get("CODON_LIB_DIR", from_root("build"))
+
+print(f"<llvm>  {llvm_include_dir}, {llvm_lib_dir}")
+print(f"<codon> {codon_include_dir}, {codon_lib_dir}")
+
+jit_extension = Extension(
+    "codon_jit",
+    sources=["src/jit.pyx"],
+    libraries=["codonc", "codonrt"],
+    language="c++",
+    extra_compile_args=["-w", "-std=c++17"],
+    extra_link_args=[f"-Wl,-rpath,{codon_lib_dir}"],
+    include_dirs=[llvm_include_dir, codon_include_dir],
+    library_dirs=[llvm_lib_dir, codon_lib_dir],
+)
+
+setup(
+    name="codon",
+    version="0.1.0",
+    cmdclass={"build_ext": build_ext},
+    ext_modules=[jit_extension],
+    packages=["codon"],
+    package_dir={"codon": "src"}
+)
diff --git a/extra/python/src/__init__.py b/extra/python/src/__init__.py
new file mode 100644
index 00000000..30589239
--- /dev/null
+++ b/extra/python/src/__init__.py
@@ -0,0 +1,3 @@
+__all__ = ["codon", "JitError"]
+
+from .decorator import codon, JitError
diff --git a/extra/python/src/decorator.py b/extra/python/src/decorator.py
new file mode 100644
index 00000000..2a668c66
--- /dev/null
+++ b/extra/python/src/decorator.py
@@ -0,0 +1,174 @@
+import ctypes
+import inspect
+import importlib
+import importlib.util
+import sys
+
+from typing import List, Tuple
+
+sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
+
+from codon_jit import Jit, JitError
+
+
+separator = "__"
+
+
+# codon wrapper stubs
+
+
+def _wrapper_stub_init():
+    from internal.python import (
+        pyobj,
+        ensure_initialized,
+        Py_None,
+        PyImport_AddModule,
+        PyObject_GetAttrString,
+        PyObject_SetAttrString,
+        PyTuple_GetItem,
+    )
+
+    ensure_initialized(True)
+
+    module = PyImport_AddModule("__codon_interop__".c_str())
+
+
+def _wrapper_stub_header():
+    argt = PyObject_GetAttrString(module, "__codon_args__".c_str())
+
+
+def _wrapper_stub_footer_ret():
+    PyObject_SetAttrString(module, "__codon_ret__".c_str(), ret.p)
+
+
+def _wrapper_stub_footer_void():
+    pyobj.incref(Py_None)
+    PyObject_SetAttrString(module, "__codon_ret__".c_str(), Py_None)
+
+
+# helpers
+
+
+def _reset_jit():
+    global jit
+    jit = Jit()
+    lines = inspect.getsourcelines(_wrapper_stub_init)[0][1:]
+    jit.execute("".join([l[4:] for l in lines]))
+
+    return jit
+
+
+def _init():
+    spec = importlib.machinery.ModuleSpec("__codon_interop__", None)
+    module = importlib.util.module_from_spec(spec)
+    exec("__codon_args__ = ()\n__codon_ret__ = None", module.__dict__)
+    sys.modules["__codon_interop__"] = module
+    exec("import __codon_interop__")
+
+    return _reset_jit(), module
+
+
+jit, module = _init()
+
+
+def _obj_to_str(obj) -> str:
+    if inspect.isclass(obj):
+        lines = inspect.getsourcelines(obj)[0]
+        extra_spaces = lines[0].find("class")
+        obj_str = "".join(l[extra_spaces:] for l in lines)
+    elif callable(obj):
+        lines = inspect.getsourcelines(obj)[0]
+        extra_spaces = lines[0].find("@")
+        obj_str = "".join(l[extra_spaces:] for l in lines[1:])
+    else:
+        raise TypeError(f"Function or class expected, got {type(obj).__name__}.")
+    return obj_str.replace("_@par", "@par")
+
+
+def _obj_name(obj) -> str:
+    if inspect.isclass(obj) or callable(obj):
+        return obj.__name__
+    else:
+        raise TypeError(f"Function or class expected, got {type(obj).__name__}.")
+
+
+def _obj_name_full(obj) -> str:
+    obj_name = _obj_name(obj)
+    obj_name_stack = [obj_name]
+    frame = inspect.currentframe()
+    while frame.f_code.co_name != "codon":
+        frame = frame.f_back
+    frame = frame.f_back
+    while frame:
+        if frame.f_code.co_name == "<module>":
+            obj_name_stack += [frame.f_globals["__name__"]]
+            break
+        else:
+            obj_name_stack += [frame.f_code.co_name]
+        frame = frame.f_back
+    return obj_name, separator.join(reversed(obj_name_stack))
+
+
+def _parse_decorated(obj):
+    return _obj_name(obj), _obj_to_str(obj)
+
+
+def _get_type_info(obj) -> Tuple[List[str], str]:
+    sgn = inspect.signature(obj)
+    par = [p.annotation for p in sgn.parameters.values()]
+    ret = sgn.return_annotation
+    return par, ret
+
+
+def _type_str(typ) -> str:
+    if typ in (int, float, str, bool):
+        return typ.__name__
+    obj_str = str(typ)
+    return obj_str[7:] if obj_str.startswith("typing.") else obj_str
+
+
+def _build_wrapper(obj, obj_name) -> str:
+    arg_types, ret_type = _get_type_info(obj)
+    arg_count = len(arg_types)
+    wrap_name = f"{obj_name}{separator}wrapped"
+    wrap = [f"def {wrap_name}():\n"]
+    wrap += inspect.getsourcelines(_wrapper_stub_header)[0][1:]
+    wrap += [
+        f"    arg_{i} = {_type_str(arg_types[i])}.__from_py__(pyobj(PyTuple_GetItem(argt, {i})))\n"
+        for i in range(arg_count)
+    ]
+    args = ", ".join([f"arg_{i}" for i in range(arg_count)])
+    if ret_type != inspect._empty:
+        wrap += [f"    ret = {obj_name}({args}).__to_py__()\n"]
+        wrap += inspect.getsourcelines(_wrapper_stub_footer_ret)[0][1:]
+    else:
+        wrap += [f"    {obj_name}({args})\n"]
+        wrap += inspect.getsourcelines(_wrapper_stub_footer_void)[0][1:]
+    return wrap_name, "".join(wrap)
+
+
+# decorator
+
+
+def codon(obj):
+    try:
+        obj_name, obj_str = _parse_decorated(obj)
+        jit.execute(obj_str)
+
+        wrap_name, wrap_str = _build_wrapper(obj, obj_name)
+        jit.execute(wrap_str)
+    except JitError as e:
+        _reset_jit()
+        raise
+
+    def wrapped(*args, **kwargs):
+        try:
+            module.__codon_args__ = (*args, *kwargs.values())
+            stdout = jit.execute(f"{wrap_name}()")
+            print(stdout, end="")
+            return module.__codon_ret__
+        except JitError as e:
+            _reset_jit()
+            raise
+
+    return wrapped
diff --git a/extra/python/src/jit.pxd b/extra/python/src/jit.pxd
new file mode 100644
index 00000000..21c608cb
--- /dev/null
+++ b/extra/python/src/jit.pxd
@@ -0,0 +1,16 @@
+from libcpp.string cimport string
+
+
+cdef extern from "llvm/Support/Error.h" namespace "llvm":
+    cdef cppclass Error
+
+
+cdef extern from "codon/compiler/jit.h" namespace "codon::jit":
+    cdef cppclass JITResult:
+        string data
+        bint operator bool()
+
+    cdef cppclass JIT:
+        JIT(string)
+        Error init()
+        JITResult executeSafe(string)
diff --git a/extra/python/src/jit.pyx b/extra/python/src/jit.pyx
new file mode 100644
index 00000000..aa2546a8
--- /dev/null
+++ b/extra/python/src/jit.pyx
@@ -0,0 +1,31 @@
+# distutils: language=c++
+# cython: language_level=3
+# cython: c_string_type=unicode
+# cython: c_string_encoding=ascii
+
+from cython.operator import dereference as dref
+from libcpp.string cimport string
+
+from src.jit cimport JIT, JITResult
+
+
+class JitError(Exception):
+    pass
+
+
+cdef class Jit:
+    cdef JIT* jit
+
+    def __cinit__(self):
+        self.jit = new JIT(b"codon jit")
+        dref(self.jit).init()
+
+    def __dealloc__(self):
+        del self.jit
+
+    def execute(self, code: str) -> str:
+        result = dref(self.jit).executeSafe(code)
+        if <bint>result:
+            return result.data
+        else:
+            raise JitError(result.data)
diff --git a/stdlib/internal/dlopen.codon b/stdlib/internal/dlopen.codon
index b7cf65b2..ecd59ba9 100644
--- a/stdlib/internal/dlopen.codon
+++ b/stdlib/internal/dlopen.codon
@@ -31,7 +31,7 @@ def dlerror() -> str:
 
 
 def dlopen(name: str, flag: int = RTLD_NOW | RTLD_GLOBAL) -> cobj:
-    h = c_dlopen(cobj(0) if name == "" else name.c_str(), flag)
+    h = c_dlopen(cobj() if name == "" else name.c_str(), flag)
     if h == cobj():
         raise CError(dlerror())
     return h
diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon
index 6ef03a3c..91f9f908 100644
--- a/stdlib/internal/python.codon
+++ b/stdlib/internal/python.codon
@@ -1,44 +1,47 @@
 # (c) 2022 Exaloop Inc. All rights reserved.
 
 import os
+
 from internal.dlopen import *
 
-PyUnicode_AsEncodedString = Function[[cobj, cobj, cobj], cobj](cobj())
-PyBytes_AsString = Function[[cobj], cobj](cobj())
-PyErr_Fetch = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], void](cobj())
-PyObject_GetAttrString = Function[[cobj, cobj], cobj](cobj())
-PyObject_GetAttr = Function[[cobj, cobj], cobj](cobj())
-PyObject_Str = Function[[cobj], cobj](cobj())
-PyRun_SimpleString = Function[[cobj], void](cobj())
-Py_IncRef = Function[[cobj], void](cobj())
 Py_DecRef = Function[[cobj], void](cobj())
-PyObject_Call = Function[[cobj, cobj, cobj], cobj](cobj())
-PyObject_SetAttrString = Function[[cobj, cobj, cobj], cobj](cobj())
-PyObject_Length = Function[[cobj], int](cobj())
+Py_IncRef = Function[[cobj], void](cobj())
 Py_Initialize = Function[[], void](cobj())
-PyImport_ImportModule = Function[[cobj], cobj](cobj())
-PyLong_FromLong = Function[[int], cobj](cobj())
-PyLong_AsLong = Function[[cobj], int](cobj())
-PyFloat_FromDouble = Function[[float], cobj](cobj())
-PyFloat_AsDouble = Function[[cobj], float](cobj())
+Py_None = cobj()
 PyBool_FromLong = Function[[int], cobj](cobj())
-PyObject_IsTrue = Function[[cobj], int](cobj())
-PyUnicode_DecodeFSDefaultAndSize = Function[[cobj, int], cobj](cobj())
-PyTuple_New = Function[[int], cobj](cobj())
-PyTuple_SetItem = Function[[cobj, int, cobj], void](cobj())
-PyTuple_GetItem = Function[[cobj, int], cobj](cobj())
+PyBytes_AsString = Function[[cobj], cobj](cobj())
+PyDict_New = Function[[], cobj](cobj())
+PyDict_Next = Function[[cobj, Ptr[int], Ptr[cobj], Ptr[cobj]], int](cobj())
+PyDict_SetItem = Function[[cobj, cobj, cobj], cobj](cobj())
+PyErr_Fetch = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], void](cobj())
+PyFloat_AsDouble = Function[[cobj], float](cobj())
+PyFloat_FromDouble = Function[[float], cobj](cobj())
+PyImport_AddModule = Function[[cobj], cobj](cobj())
+PyImport_ImportModule = Function[[cobj], cobj](cobj())
+PyIter_Next = Function[[cobj], cobj](cobj())
+PyList_GetItem = Function[[cobj, int], cobj](cobj())
 PyList_New = Function[[int], cobj](cobj())
 PyList_SetItem = Function[[cobj, int, cobj], cobj](cobj())
-PyList_GetItem = Function[[cobj, int], cobj](cobj())
-PySet_New = Function[[cobj], cobj](cobj())
-PySet_Add = Function[[cobj, cobj], cobj](cobj())
-PyDict_New = Function[[], cobj](cobj())
-PyDict_SetItem = Function[[cobj, cobj, cobj], cobj](cobj())
-PyDict_Next = Function[[cobj, Ptr[int], Ptr[cobj], Ptr[cobj]], int](cobj())
+PyLong_AsLong = Function[[cobj], int](cobj())
+PyLong_FromLong = Function[[int], cobj](cobj())
+PyObject_Call = Function[[cobj, cobj, cobj], cobj](cobj())
+PyObject_GetAttr = Function[[cobj, cobj], cobj](cobj())
+PyObject_GetAttrString = Function[[cobj, cobj], cobj](cobj())
 PyObject_GetIter = Function[[cobj], cobj](cobj())
-PyIter_Next = Function[[cobj], cobj](cobj())
 PyObject_HasAttrString = Function[[cobj, cobj], int](cobj())
-PyImport_AddModule = Function[[cobj], cobj](cobj())
+PyObject_IsTrue = Function[[cobj], int](cobj())
+PyObject_Length = Function[[cobj], int](cobj())
+PyObject_SetAttrString = Function[[cobj, cobj, cobj], cobj](cobj())
+PyObject_Str = Function[[cobj], cobj](cobj())
+PyRun_SimpleString = Function[[cobj], void](cobj())
+PySet_Add = Function[[cobj, cobj], cobj](cobj())
+PySet_New = Function[[cobj], cobj](cobj())
+PyTuple_GetItem = Function[[cobj, int], cobj](cobj())
+PyTuple_New = Function[[int], cobj](cobj())
+PyTuple_SetItem = Function[[cobj, int, cobj], void](cobj())
+PyUnicode_AsEncodedString = Function[[cobj, cobj, cobj], cobj](cobj())
+PyUnicode_DecodeFSDefaultAndSize = Function[[cobj, int], cobj](cobj())
+PyUnicode_FromString = Function[[cobj], cobj](cobj())
 
 _PY_MODULE_CACHE = Dict[str, pyobj]()
 
@@ -77,102 +80,125 @@ def __codon_repr__(fig):
 _PY_INITIALIZED = False
 
 
-def init():
+def init_dl_handles(py_handle: cobj):
+    global Py_DecRef
+    global Py_IncRef
+    global Py_Initialize
+    global Py_None
+    global PyBool_FromLong
+    global PyBytes_AsString
+    global PyDict_New
+    global PyDict_Next
+    global PyDict_SetItem
+    global PyErr_Fetch
+    global PyFloat_AsDouble
+    global PyFloat_FromDouble
+    global PyImport_AddModule
+    global PyImport_ImportModule
+    global PyIter_Next
+    global PyList_GetItem
+    global PyList_New
+    global PyList_SetItem
+    global PyLong_AsLong
+    global PyLong_FromLong
+    global PyObject_Call
+    global PyObject_GetAttr
+    global PyObject_GetAttrString
+    global PyObject_GetIter
+    global PyObject_HasAttrString
+    global PyObject_IsTrue
+    global PyObject_Length
+    global PyObject_SetAttrString
+    global PyObject_Str
+    global PyRun_SimpleString
+    global PySet_Add
+    global PySet_New
+    global PyTuple_GetItem
+    global PyTuple_New
+    global PyTuple_SetItem
+    global PyUnicode_AsEncodedString
+    global PyUnicode_DecodeFSDefaultAndSize
+    global PyUnicode_FromString
+    Py_DecRef = dlsym(py_handle, "Py_DecRef")
+    Py_IncRef = dlsym(py_handle, "Py_IncRef")
+    Py_Initialize = dlsym(py_handle, "Py_Initialize")
+    Py_None = dlsym(py_handle, "_Py_NoneStruct")
+    PyBool_FromLong = dlsym(py_handle, "PyBool_FromLong")
+    PyBytes_AsString = dlsym(py_handle, "PyBytes_AsString")
+    PyDict_New = dlsym(py_handle, "PyDict_New")
+    PyDict_Next = dlsym(py_handle, "PyDict_Next")
+    PyDict_SetItem = dlsym(py_handle, "PyDict_SetItem")
+    PyErr_Fetch = dlsym(py_handle, "PyErr_Fetch")
+    PyFloat_AsDouble = dlsym(py_handle, "PyFloat_AsDouble")
+    PyFloat_FromDouble = dlsym(py_handle, "PyFloat_FromDouble")
+    PyImport_AddModule = dlsym(py_handle, "PyImport_AddModule")
+    PyImport_ImportModule = dlsym(py_handle, "PyImport_ImportModule")
+    PyIter_Next = dlsym(py_handle, "PyIter_Next")
+    PyList_GetItem = dlsym(py_handle, "PyList_GetItem")
+    PyList_New = dlsym(py_handle, "PyList_New")
+    PyList_SetItem = dlsym(py_handle, "PyList_SetItem")
+    PyLong_AsLong = dlsym(py_handle, "PyLong_AsLong")
+    PyLong_FromLong = dlsym(py_handle, "PyLong_FromLong")
+    PyObject_Call = dlsym(py_handle, "PyObject_Call")
+    PyObject_GetAttr = dlsym(py_handle, "PyObject_GetAttr")
+    PyObject_GetAttrString = dlsym(py_handle, "PyObject_GetAttrString")
+    PyObject_GetIter = dlsym(py_handle, "PyObject_GetIter")
+    PyObject_HasAttrString = dlsym(py_handle, "PyObject_HasAttrString")
+    PyObject_IsTrue = dlsym(py_handle, "PyObject_IsTrue")
+    PyObject_Length = dlsym(py_handle, "PyObject_Length")
+    PyObject_SetAttrString = dlsym(py_handle, "PyObject_SetAttrString")
+    PyObject_Str = dlsym(py_handle, "PyObject_Str")
+    PyRun_SimpleString = dlsym(py_handle, "PyRun_SimpleString")
+    PySet_Add = dlsym(py_handle, "PySet_Add")
+    PySet_New = dlsym(py_handle, "PySet_New")
+    PyTuple_GetItem = dlsym(py_handle, "PyTuple_GetItem")
+    PyTuple_New = dlsym(py_handle, "PyTuple_New")
+    PyTuple_SetItem = dlsym(py_handle, "PyTuple_SetItem")
+    PyUnicode_AsEncodedString = dlsym(py_handle, "PyUnicode_AsEncodedString")
+    PyUnicode_DecodeFSDefaultAndSize = dlsym(py_handle, "PyUnicode_DecodeFSDefaultAndSize")
+    PyUnicode_FromString = dlsym(py_handle, "PyUnicode_FromString")
+
+
+def init(python_loaded: bool = False):
     global _PY_INITIALIZED
     if _PY_INITIALIZED:
         return
 
-    LD = os.getenv("CODON_PYTHON", default=f"libpython.{dlext()}")
-    hnd = dlopen(LD, RTLD_LOCAL | RTLD_NOW)
+    py_handle: cobj
+    if python_loaded:
+        py_handle = dlopen("", RTLD_LOCAL | RTLD_NOW)
+    else:
+        LD = os.getenv("CODON_PYTHON", default="libpython." + dlext())
+        py_handle = dlopen(LD, RTLD_LOCAL | RTLD_NOW)
 
-    global PyUnicode_AsEncodedString
-    PyUnicode_AsEncodedString = dlsym(hnd, "PyUnicode_AsEncodedString")
-    global PyBytes_AsString
-    PyBytes_AsString = dlsym(hnd, "PyBytes_AsString")
-    global PyErr_Fetch
-    PyErr_Fetch = dlsym(hnd, "PyErr_Fetch")
-    global PyObject_GetAttrString
-    PyObject_GetAttrString = dlsym(hnd, "PyObject_GetAttrString")
-    global PyObject_GetAttr
-    PyObject_GetAttr = dlsym(hnd, "PyObject_GetAttr")
-    global PyObject_Str
-    PyObject_Str = dlsym(hnd, "PyObject_Str")
-    global PyRun_SimpleString
-    PyRun_SimpleString = dlsym(hnd, "PyRun_SimpleString")
-    global Py_IncRef
-    Py_IncRef = dlsym(hnd, "Py_IncRef")
-    global Py_DecRef
-    Py_DecRef = dlsym(hnd, "Py_DecRef")
-    global PyObject_Call
-    PyObject_Call = dlsym(hnd, "PyObject_Call")
-    global PyObject_SetAttrString
-    PyObject_SetAttrString = dlsym(hnd, "PyObject_SetAttrString")
-    global PyObject_Length
-    PyObject_Length = dlsym(hnd, "PyObject_Length")
-    global Py_Initialize
-    Py_Initialize = dlsym(hnd, "Py_Initialize")
-    global PyImport_ImportModule
-    PyImport_ImportModule = dlsym(hnd, "PyImport_ImportModule")
-    global PyLong_FromLong
-    PyLong_FromLong = dlsym(hnd, "PyLong_FromLong")
-    global PyLong_AsLong
-    PyLong_AsLong = dlsym(hnd, "PyLong_AsLong")
-    global PyFloat_FromDouble
-    PyFloat_FromDouble = dlsym(hnd, "PyFloat_FromDouble")
-    global PyFloat_AsDouble
-    PyFloat_AsDouble = dlsym(hnd, "PyFloat_AsDouble")
-    global PyBool_FromLong
-    PyBool_FromLong = dlsym(hnd, "PyBool_FromLong")
-    global PyObject_IsTrue
-    PyObject_IsTrue = dlsym(hnd, "PyObject_IsTrue")
-    global PyUnicode_DecodeFSDefaultAndSize
-    PyUnicode_DecodeFSDefaultAndSize = dlsym(hnd, "PyUnicode_DecodeFSDefaultAndSize")
-    global PyTuple_New
-    PyTuple_New = dlsym(hnd, "PyTuple_New")
-    global PyTuple_SetItem
-    PyTuple_SetItem = dlsym(hnd, "PyTuple_SetItem")
-    global PyTuple_GetItem
-    PyTuple_GetItem = dlsym(hnd, "PyTuple_GetItem")
-    global PyList_New
-    PyList_New = dlsym(hnd, "PyList_New")
-    global PyList_SetItem
-    PyList_SetItem = dlsym(hnd, "PyList_SetItem")
-    global PyList_GetItem
-    PyList_GetItem = dlsym(hnd, "PyList_GetItem")
-    global PySet_New
-    PySet_New = dlsym(hnd, "PySet_New")
-    global PySet_Add
-    PySet_Add = dlsym(hnd, "PySet_Add")
-    global PyDict_New
-    PyDict_New = dlsym(hnd, "PyDict_New")
-    global PyDict_SetItem
-    PyDict_SetItem = dlsym(hnd, "PyDict_SetItem")
-    global PyDict_Next
-    PyDict_Next = dlsym(hnd, "PyDict_Next")
-    global PyObject_GetIter
-    PyObject_GetIter = dlsym(hnd, "PyObject_GetIter")
-    global PyIter_Next
-    PyIter_Next = dlsym(hnd, "PyIter_Next")
-    global PyObject_HasAttrString
-    PyObject_HasAttrString = dlsym(hnd, "PyObject_HasAttrString")
-    global PyImport_AddModule
-    PyImport_AddModule = dlsym(hnd, "PyImport_AddModule")
+    init_dl_handles(py_handle)
+
+    if not python_loaded:
+        Py_Initialize()
 
-    Py_Initialize()
     PyRun_SimpleString(_PY_INIT.c_str())
     _PY_INITIALIZED = True
 
 
-def ensure_initialized():
+def ensure_initialized(python_loaded: bool = False):
     if not _PY_INITIALIZED:
-        init()
-        # raise ValueError("Python not initialized; make sure to 'import python'")
+        init(python_loaded)
 
 
 @extend
 class pyobj:
-    def __new__(p: Ptr[byte]) -> pyobj:
-        return (p,)
+    @__internal__
+    def __new__() -> pyobj:
+        pass
+
+    def __raw__(self) -> Ptr[byte]:
+        return __internal__.class_raw(self)
+
+    def __init__(self, p: Ptr[byte]):
+        self.p = p
+
+    def __del__(self):
+        self.decref()
 
     def _getattr(self, name: str) -> pyobj:
         return pyobj.exc_wrap(pyobj(PyObject_GetAttrString(self.p, name.c_str())))
@@ -203,8 +229,7 @@ class pyobj:
             raise ValueError("Python object is not iterable")
         while i := PyIter_Next(it):
             yield pyobj(pyobj.exc_wrap(i))
-            pyobj(i).decref()
-        pyobj(it).decref()
+        pyobj.decref(it)
         pyobj.exc_check()
 
     def to_str(self, errors: str, empty: str = "") -> str:
@@ -212,7 +237,7 @@ class pyobj:
         if obj == cobj():
             return empty
         bts = PyBytes_AsString(obj)
-        pyobj(obj).decref()
+        pyobj.decref(obj)
         return str.from_ptr(bts)
 
     def exc_check():
@@ -221,14 +246,12 @@ class pyobj:
         if ptype != cobj():
             py_msg = PyObject_Str(pvalue) if pvalue != cobj() else pvalue
             msg = pyobj(py_msg).to_str("ignore", "<empty Python message>")
-            typ = pyobj.to_str(
-                pyobj(PyObject_GetAttrString(ptype, "__name__".c_str())), "ignore"
-            )
+            typ = pyobj(PyObject_GetAttrString(ptype, "__name__".c_str())).to_str("ignore")
 
-            pyobj(ptype).decref()
-            pyobj(pvalue).decref()
-            pyobj(ptraceback).decref()
-            pyobj(py_msg).decref()
+            pyobj.decref(ptype)
+            pyobj.decref(pvalue)
+            pyobj.decref(ptraceback)
+            pyobj.decref(py_msg)
 
             raise PyError(msg, typ)
 
@@ -239,9 +262,21 @@ class pyobj:
     def incref(self):
         Py_IncRef(self.p)
 
+    def incref(obj: pyobj):
+        Py_IncRef(obj.p)
+
+    def incref(ptr: Ptr[byte]):
+        Py_IncRef(ptr)
+
     def decref(self):
         Py_DecRef(self.p)
 
+    def decref(obj: pyobj):
+        Py_DecRef(obj.p)
+
+    def decref(ptr: Ptr[byte]):
+        Py_DecRef(ptr)
+
     def __call__(self, *args, **kwargs) -> pyobj:
         names = iter(kwargs.__dict__())
         kws = dict[str, pyobj]()
diff --git a/test/python/cython_jit.py b/test/python/cython_jit.py
new file mode 100644
index 00000000..ede06865
--- /dev/null
+++ b/test/python/cython_jit.py
@@ -0,0 +1,87 @@
+import sys
+from io import StringIO
+from typing import Dict, List, Tuple
+
+from codon import codon, JitError
+
+
+# test stdout
+
+
+def test_stdout():
+    @codon
+    def run():
+        print("hello world!")
+
+    try:
+        output = StringIO()
+        sys.stdout = output
+        run()
+        assert output.getvalue() == "hello world!\n"
+    finally:
+        sys.stdout = sys.__stdout__
+
+
+test_stdout()
+
+
+# test error handling
+
+
+def test_error_handling():
+    @codon
+    def run() -> int:
+        return "not int"
+
+    try:
+        r = run()
+    except JitError:
+        assert True
+    except BaseException:
+        assert False
+    else:
+        assert False
+
+
+test_error_handling()
+
+
+# test type validity
+
+
+def test_return_type():
+    @codon
+    def run() -> Tuple[int, str, float, List[int], Dict[str, int]]:
+        return (1, "str", 2.45, [1, 2, 3], {"a": 1, "b": 2})
+
+    r = run()
+    assert type(r) == tuple
+    assert type(r[0]) == int
+    assert type(r[1]) == str
+    assert type(r[2]) == float
+    assert type(r[3]) == list
+    assert len(r[3]) == 3
+    assert type(r[3][0]) == int
+    assert type(r[4]) == dict
+    assert len(r[4].items()) == 2
+    assert type(next(iter(r[4].keys()))) == str
+    assert type(next(iter(r[4].values()))) == int
+
+
+test_return_type()
+
+
+def test_param_types():
+    @codon
+    def run(a: int, b: Tuple[int, int], c: List[int], d: Dict[str, int]) -> int:
+        s = 0
+        for v in [a, *b, *c, *d.values()]:
+            s += v
+        return s
+
+    r = run(1, (2, 3), [4, 5, 6], dict(a=7, b=8, c=9))
+    assert type(r) == int
+    assert r == 45
+
+
+test_param_types()