diff --git a/CMakeLists.txt b/CMakeLists.txt index da3f0243..d468575f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -192,6 +192,7 @@ set(CODON_HPPFILES codon/cir/transform/folding/rule.h codon/cir/transform/lowering/imperative.h codon/cir/transform/lowering/pipeline.h + codon/cir/transform/lowering/pyextension.h codon/cir/transform/manager.h codon/cir/transform/parallel/openmp.h codon/cir/transform/parallel/schedule.h @@ -300,6 +301,7 @@ set(CODON_CPPFILES codon/cir/transform/folding/folding.cpp codon/cir/transform/lowering/imperative.cpp codon/cir/transform/lowering/pipeline.cpp + codon/cir/transform/lowering/pyextension.cpp codon/cir/transform/manager.cpp codon/cir/transform/parallel/openmp.cpp codon/cir/transform/parallel/schedule.cpp diff --git a/codon/cir/transform/lowering/pyextension.cpp b/codon/cir/transform/lowering/pyextension.cpp new file mode 100644 index 00000000..49e9eb82 --- /dev/null +++ b/codon/cir/transform/lowering/pyextension.cpp @@ -0,0 +1,74 @@ +// Copyright (C) 2022-2023 Exaloop Inc. + +#include "pyextension.h" + +#include + +#include "codon/cir/util/cloning.h" +#include "codon/cir/util/irtools.h" +#include "codon/cir/util/matching.h" + +namespace codon { +namespace ir { +namespace transform { +namespace lowering { +namespace { + +const std::string EXPORT_ATTR = "std.internal.attributes.export"; + +Func *generateExtensionFunc(Func *f) { + // PyObject *_PyCFunctionFast(PyObject *self, + // PyObject *const *args, + // Py_ssize_t nargs); + + auto *M = f->getModule(); + auto *cobj = M->getPointerType(M->getByteType()); + auto *ext = M->Nr("__py_extension"); + ext->realize(M->getFuncType(cobj, {cobj, M->getPointerType(cobj), M->getIntType()}), + {"self", "args", "nargs"}); + auto *body = M->Nr(); + ext->setBody(body); + std::vector extArgs(ext->arg_begin(), ext->arg_end()); + std::vector vars; + auto *args = extArgs[1]; + // auto *nargs = extArgs[2]; + + // TODO: check nargs + + int idx = 0; + for (auto it = f->arg_begin(); it != f->arg_end(); ++it) { + auto *type = (*it)->getType(); + auto *fromPy = M->getOrRealizeMethod(type, "__from_py__", {cobj}); + seqassertn(fromPy, "__from_py__ method not found"); + auto *pyItem = util::call(fromPy, {(*M->Nr(args))[*M->getInt(idx++)]}); + vars.push_back(util::makeVar(pyItem, body, ext)); + } + + auto *retType = util::getReturnType(f); + auto *toPy = M->getOrRealizeMethod(retType, "__to_py__", {retType}); + seqassertn(toPy, "__to_py__ method not found"); + auto *retVal = util::call(toPy, {util::call(f, vars)}); + body->push_back(M->Nr(retVal)); + return ext; +} + +} // namespace + +const std::string PythonExtensionLowering::KEY = "core-python-extension-lowering"; + +void PythonExtensionLowering::run(Module *module) { + for (auto *var : *module) { + if (auto *f = cast(var)) { + if (!util::hasAttribute(f, EXPORT_ATTR)) + continue; + + std::cout << f->getName() << std::endl; + std::cout << *generateExtensionFunc(f) << std::endl; + } + } +} + +} // namespace lowering +} // namespace transform +} // namespace ir +} // namespace codon diff --git a/codon/cir/transform/lowering/pyextension.h b/codon/cir/transform/lowering/pyextension.h new file mode 100644 index 00000000..0e689fe8 --- /dev/null +++ b/codon/cir/transform/lowering/pyextension.h @@ -0,0 +1,22 @@ +// Copyright (C) 2022-2023 Exaloop Inc. + +#pragma once + +#include "codon/cir/transform/pass.h" + +namespace codon { +namespace ir { +namespace transform { +namespace lowering { + +class PythonExtensionLowering : public Pass { +public: + static const std::string KEY; + std::string getKey() const override { return KEY; } + void run(Module *module) override; +}; + +} // namespace lowering +} // namespace transform +} // namespace ir +} // namespace codon