mirror of https://github.com/exaloop/codon.git
Add Python extension lowering pass
parent
5de12ee2f7
commit
cf5a193274
|
@ -192,6 +192,7 @@ set(CODON_HPPFILES
|
|||
codon/cir/transform/folding/rule.h
|
||||
codon/cir/transform/lowering/imperative.h
|
||||
codon/cir/transform/lowering/pipeline.h
|
||||
codon/cir/transform/lowering/pyextension.h
|
||||
codon/cir/transform/manager.h
|
||||
codon/cir/transform/parallel/openmp.h
|
||||
codon/cir/transform/parallel/schedule.h
|
||||
|
@ -300,6 +301,7 @@ set(CODON_CPPFILES
|
|||
codon/cir/transform/folding/folding.cpp
|
||||
codon/cir/transform/lowering/imperative.cpp
|
||||
codon/cir/transform/lowering/pipeline.cpp
|
||||
codon/cir/transform/lowering/pyextension.cpp
|
||||
codon/cir/transform/manager.cpp
|
||||
codon/cir/transform/parallel/openmp.cpp
|
||||
codon/cir/transform/parallel/schedule.cpp
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
#include "pyextension.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "codon/cir/util/cloning.h"
|
||||
#include "codon/cir/util/irtools.h"
|
||||
#include "codon/cir/util/matching.h"
|
||||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
namespace transform {
|
||||
namespace lowering {
|
||||
namespace {
|
||||
|
||||
const std::string EXPORT_ATTR = "std.internal.attributes.export";
|
||||
|
||||
Func *generateExtensionFunc(Func *f) {
|
||||
// PyObject *_PyCFunctionFast(PyObject *self,
|
||||
// PyObject *const *args,
|
||||
// Py_ssize_t nargs);
|
||||
|
||||
auto *M = f->getModule();
|
||||
auto *cobj = M->getPointerType(M->getByteType());
|
||||
auto *ext = M->Nr<BodiedFunc>("__py_extension");
|
||||
ext->realize(M->getFuncType(cobj, {cobj, M->getPointerType(cobj), M->getIntType()}),
|
||||
{"self", "args", "nargs"});
|
||||
auto *body = M->Nr<SeriesFlow>();
|
||||
ext->setBody(body);
|
||||
std::vector<Var *> extArgs(ext->arg_begin(), ext->arg_end());
|
||||
std::vector<Value *> vars;
|
||||
auto *args = extArgs[1];
|
||||
// auto *nargs = extArgs[2];
|
||||
|
||||
// TODO: check nargs
|
||||
|
||||
int idx = 0;
|
||||
for (auto it = f->arg_begin(); it != f->arg_end(); ++it) {
|
||||
auto *type = (*it)->getType();
|
||||
auto *fromPy = M->getOrRealizeMethod(type, "__from_py__", {cobj});
|
||||
seqassertn(fromPy, "__from_py__ method not found");
|
||||
auto *pyItem = util::call(fromPy, {(*M->Nr<VarValue>(args))[*M->getInt(idx++)]});
|
||||
vars.push_back(util::makeVar(pyItem, body, ext));
|
||||
}
|
||||
|
||||
auto *retType = util::getReturnType(f);
|
||||
auto *toPy = M->getOrRealizeMethod(retType, "__to_py__", {retType});
|
||||
seqassertn(toPy, "__to_py__ method not found");
|
||||
auto *retVal = util::call(toPy, {util::call(f, vars)});
|
||||
body->push_back(M->Nr<ReturnInstr>(retVal));
|
||||
return ext;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const std::string PythonExtensionLowering::KEY = "core-python-extension-lowering";
|
||||
|
||||
void PythonExtensionLowering::run(Module *module) {
|
||||
for (auto *var : *module) {
|
||||
if (auto *f = cast<BodiedFunc>(var)) {
|
||||
if (!util::hasAttribute(f, EXPORT_ATTR))
|
||||
continue;
|
||||
|
||||
std::cout << f->getName() << std::endl;
|
||||
std::cout << *generateExtensionFunc(f) << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace lowering
|
||||
} // namespace transform
|
||||
} // namespace ir
|
||||
} // namespace codon
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "codon/cir/transform/pass.h"
|
||||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
namespace transform {
|
||||
namespace lowering {
|
||||
|
||||
class PythonExtensionLowering : public Pass {
|
||||
public:
|
||||
static const std::string KEY;
|
||||
std::string getKey() const override { return KEY; }
|
||||
void run(Module *module) override;
|
||||
};
|
||||
|
||||
} // namespace lowering
|
||||
} // namespace transform
|
||||
} // namespace ir
|
||||
} // namespace codon
|
Loading…
Reference in New Issue