Add Python extension lowering pass

pull/335/head
A. R. Shajii 2023-01-27 22:57:59 -05:00
parent 5de12ee2f7
commit cf5a193274
3 changed files with 98 additions and 0 deletions

View File

@ -192,6 +192,7 @@ set(CODON_HPPFILES
codon/cir/transform/folding/rule.h
codon/cir/transform/lowering/imperative.h
codon/cir/transform/lowering/pipeline.h
codon/cir/transform/lowering/pyextension.h
codon/cir/transform/manager.h
codon/cir/transform/parallel/openmp.h
codon/cir/transform/parallel/schedule.h
@ -300,6 +301,7 @@ set(CODON_CPPFILES
codon/cir/transform/folding/folding.cpp
codon/cir/transform/lowering/imperative.cpp
codon/cir/transform/lowering/pipeline.cpp
codon/cir/transform/lowering/pyextension.cpp
codon/cir/transform/manager.cpp
codon/cir/transform/parallel/openmp.cpp
codon/cir/transform/parallel/schedule.cpp

View File

@ -0,0 +1,74 @@
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
#include "pyextension.h"
#include <algorithm>
#include "codon/cir/util/cloning.h"
#include "codon/cir/util/irtools.h"
#include "codon/cir/util/matching.h"
namespace codon {
namespace ir {
namespace transform {
namespace lowering {
namespace {
const std::string EXPORT_ATTR = "std.internal.attributes.export";
Func *generateExtensionFunc(Func *f) {
// PyObject *_PyCFunctionFast(PyObject *self,
// PyObject *const *args,
// Py_ssize_t nargs);
auto *M = f->getModule();
auto *cobj = M->getPointerType(M->getByteType());
auto *ext = M->Nr<BodiedFunc>("__py_extension");
ext->realize(M->getFuncType(cobj, {cobj, M->getPointerType(cobj), M->getIntType()}),
{"self", "args", "nargs"});
auto *body = M->Nr<SeriesFlow>();
ext->setBody(body);
std::vector<Var *> extArgs(ext->arg_begin(), ext->arg_end());
std::vector<Value *> vars;
auto *args = extArgs[1];
// auto *nargs = extArgs[2];
// TODO: check nargs
int idx = 0;
for (auto it = f->arg_begin(); it != f->arg_end(); ++it) {
auto *type = (*it)->getType();
auto *fromPy = M->getOrRealizeMethod(type, "__from_py__", {cobj});
seqassertn(fromPy, "__from_py__ method not found");
auto *pyItem = util::call(fromPy, {(*M->Nr<VarValue>(args))[*M->getInt(idx++)]});
vars.push_back(util::makeVar(pyItem, body, ext));
}
auto *retType = util::getReturnType(f);
auto *toPy = M->getOrRealizeMethod(retType, "__to_py__", {retType});
seqassertn(toPy, "__to_py__ method not found");
auto *retVal = util::call(toPy, {util::call(f, vars)});
body->push_back(M->Nr<ReturnInstr>(retVal));
return ext;
}
} // namespace
const std::string PythonExtensionLowering::KEY = "core-python-extension-lowering";
void PythonExtensionLowering::run(Module *module) {
for (auto *var : *module) {
if (auto *f = cast<BodiedFunc>(var)) {
if (!util::hasAttribute(f, EXPORT_ATTR))
continue;
std::cout << f->getName() << std::endl;
std::cout << *generateExtensionFunc(f) << std::endl;
}
}
}
} // namespace lowering
} // namespace transform
} // namespace ir
} // namespace codon

View File

@ -0,0 +1,22 @@
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
#pragma once
#include "codon/cir/transform/pass.h"
namespace codon {
namespace ir {
namespace transform {
namespace lowering {
class PythonExtensionLowering : public Pass {
public:
static const std::string KEY;
std::string getKey() const override { return KEY; }
void run(Module *module) override;
};
} // namespace lowering
} // namespace transform
} // namespace ir
} // namespace codon