pyextension.h support for toplevel functions

pull/335/head
Ibrahim Numanagić 2023-02-09 18:01:49 -08:00
parent 14ea7127c7
commit c03d2e12cf
6 changed files with 92 additions and 92 deletions

View File

@ -19,7 +19,7 @@ namespace codon::ast {
Cache::Cache(std::string argv0)
: generatedSrcInfoCount(0), unboundCount(256), varCount(0), age(0),
argv0(std::move(argv0)), typeCtx(nullptr), codegenCtx(nullptr), isJit(false),
jitCell(0) {}
jitCell(0), pythonExt(false), pyModule(nullptr) {}
std::string Cache::getTemporaryVar(const std::string &prefix, char sigil) {
return fmt::format("{}{}_{}", sigil ? fmt::format("{}_", sigil) : "", prefix,
@ -233,9 +233,77 @@ std::vector<ExprPtr> Cache::mergeC3(std::vector<std::vector<ExprPtr>> &seqs) {
void Cache::populatePythonModule() {
LOG("[py] ====== module generation =======");
#define N std::make_shared
auto getFn = [&](const std::string &canonicalName,
const std::string &className = "") -> ir::Func * {
auto &fna = functions[canonicalName].ast;
std::vector<Param> params;
std::vector<ExprPtr> args;
auto sctx = imports[MAIN_IMPORT].ctx;
bool isMethod = className.empty() ? false : fna->hasAttr(Attr::Method);
auto name =
fmt::format("{}{}", className.empty() ? "" : className + ".", canonicalName);
LOG("[py] {}: {} => {}", isMethod ? "method" : "classm", name, isMethod);
params = {Param{sctx->generateCanonicalName("self"), N<IdExpr>("cobj")},
Param{sctx->generateCanonicalName("args"), N<IdExpr>("cobj")}};
if (fna->args.size() > 1 + isMethod) {
params.back().type = N<InstantiateExpr>(N<IdExpr>("Ptr"), params.back().type);
params.push_back(Param{sctx->generateCanonicalName("nargs"), N<IdExpr>("int")});
}
ExprPtr po = N<IdExpr>(params[0].name);
if (!className.empty())
po = N<CallExpr>(N<DotExpr>(N<IdExpr>(className), "__from_py__"), po);
if (isMethod)
args.push_back(po);
if (fna->args.size() > 1 + isMethod) {
for (size_t ai = isMethod; ai < fna->args.size(); ai++) {
ExprPtr po = N<IndexExpr>(N<IdExpr>(params[1].name), N<IntExpr>(ai - isMethod));
if (fna->args[ai].type) {
po = N<CallExpr>(N<DotExpr>(fna->args[ai].type->clone(), "__from_py__"), po);
} else {
po = N<CallExpr>(N<IdExpr>("pyobj"), po);
}
args.push_back(po);
}
} else if (fna->args.size() == 1 + isMethod) {
ExprPtr po = N<IdExpr>(params[1].name);
if (fna->args[isMethod].type) {
po = N<CallExpr>(N<DotExpr>(fna->args[isMethod].type->clone(), "__from_py__"), po);
} else {
po = N<CallExpr>(N<IdExpr>("pyobj"), po);
}
args.push_back(po);
}
auto stubName = sctx->generateCanonicalName(fmt::format("_py.{}", name));
auto node = N<FunctionStmt>(
stubName, N<IdExpr>("cobj"), params,
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
N<DotExpr>(N<CallExpr>(N<IdExpr>(canonicalName), args), "__to_py__")))),
Attr({Attr::ForceRealize}));
functions[node->name].ast = node;
int oldAge = typeCtx->age;
typeCtx->age = 99999;
auto tv = TypecheckVisitor(typeCtx);
auto tnode = tv.transform(node);
seqassertn(tnode, "blah");
seqassertn(typeCtx->forceFind(stubName) && typeCtx->forceFind(stubName)->type,
"bad type");
auto rtv = tv.realize(typeCtx->forceFind(stubName)->type);
seqassertn(rtv, "realization of {} failed", stubName);
auto pr = pendingRealizations; // copy it as it might be modified
for (auto &fn : pr)
TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone());
auto f = functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir;
typeCtx->age = oldAge;
return f;
};
// if (!pythonExt)
// return;
if (!pyModule)
pyModule = std::make_shared<ir::PyModule>();
#define N std::make_shared
using namespace ast;
// def wrapper(self: cobj, arg: cobj) -> cobj
@ -248,66 +316,9 @@ void Cache::populatePythonModule() {
auto &fna = functions[fnn].ast;
if (fna->hasAttr("autogenerated"))
continue;
std::vector<Param> params;
std::vector<ExprPtr> args;
auto sctx = imports[MAIN_IMPORT].ctx;
bool isMethod = fna->hasAttr(Attr::Method);
LOG("[py] {}: {}.{} => {}", isMethod ? "method" : "classm", cn, n, fnn);
params = {Param{sctx->generateCanonicalName("self"), N<IdExpr>("cobj")},
Param{sctx->generateCanonicalName("args"), N<IdExpr>("cobj")}};
if (fna->args.size() > 2) {
params.back().type = N<InstantiateExpr>(N<IdExpr>("Ptr"), params.back().type);
params.push_back(
Param{sctx->generateCanonicalName("nargs"), N<IdExpr>("int")});
}
ExprPtr po = N<IdExpr>(params[0].name);
po = N<CallExpr>(N<DotExpr>(N<IdExpr>(cn), "__from_py__"), po);
if (isMethod)
args.push_back(po);
if (fna->args.size() > 1 + isMethod) {
for (size_t ai = isMethod; ai < fna->args.size(); ai++) {
ExprPtr po =
N<IndexExpr>(N<IdExpr>(params[1].name), N<IntExpr>(ai - isMethod));
if (fna->args[ai].type) {
po = N<CallExpr>(N<DotExpr>(fna->args[ai].type->clone(), "__from_py__"),
po);
} else {
po = N<CallExpr>(N<IdExpr>("pyobj"), po);
}
args.push_back(po);
}
} else if (fna->args.size() == 1 + isMethod) {
ExprPtr po = N<IdExpr>(params[1].name);
if (fna->args[1].type) {
po = N<CallExpr>(N<DotExpr>(fna->args[1].type->clone(), "__from_py__"), po);
} else {
po = N<CallExpr>(N<IdExpr>("pyobj"), po);
}
args.push_back(po);
}
auto stubName = sctx->generateCanonicalName(fmt::format("_py.{}.{}", cn, n));
auto node =
N<FunctionStmt>(stubName, N<IdExpr>("cobj"), params,
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(N<DotExpr>(
N<CallExpr>(N<IdExpr>(fnn), args), "__to_py__")))),
Attr({Attr::ForceRealize}));
functions[node->name].ast = node;
int oldAge = typeCtx->age;
typeCtx->age = 99999;
auto tv = TypecheckVisitor(typeCtx);
auto tnode = tv.transform(node);
auto rtv = tv.realize(typeCtx->forceFind(stubName)->type);
seqassertn(rtv, "realization of {} failed", stubName);
auto pr = pendingRealizations; // copy it as it might be modified
for (auto &fn : pr)
TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone());
auto f =
functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir;
typeCtx->age = oldAge;
auto f = getFn(fnn, cn);
if (!f)
continue;
if (n == "__repr__") {
py.repr = f;
} else if (n == "__add__") {
@ -406,8 +417,9 @@ void Cache::populatePythonModule() {
py.init = f;
} else {
py.methods.push_back(ir::PyFunction{n, fna->getDocstr(), f,
isMethod ? ir::PyFunction::Type::METHOD
: ir::PyFunction::Type::CLASS});
fna->hasAttr(Attr::Method)
? ir::PyFunction::Type::METHOD
: ir::PyFunction::Type::CLASS});
}
// LOG(">| [{}] {}", functions[stubName].realizations.size(), *f);
}
@ -428,6 +440,15 @@ void Cache::populatePythonModule() {
pyModule->types.push_back(py);
}
#undef N
for (const auto &[fn, f] : functions)
if (f.isToplevel) {
auto fnn = overloads[f.rootName].back().name; // last overload
LOG("[py] functn {} => {}", rev(fn), fnn);
auto ir = getFn(fnn);
pyModule->functions.push_back(ir::PyFunction{rev(fn), f.ast->getDocstr(), ir,
ir::PyFunction::Type::TOPLEVEL});
}
}
} // namespace codon::ast

View File

@ -191,9 +191,10 @@ struct Cache : public std::enable_shared_from_this<Cache> {
types::FuncTypePtr type;
/// Module information
std::string module;
std::string rootName = "";
bool isToplevel = false;
Function() : ast(nullptr), origAst(nullptr), type(nullptr) {}
Function() : ast(nullptr), origAst(nullptr), type(nullptr), rootName(""), isToplevel(false) {}
};
/// Function lookup table that maps a canonical function identifier to the
/// corresponding Function instance.

View File

@ -166,8 +166,6 @@ struct SimplifyContext : public Context<SimplifyItem> {
bool allowTypeOf;
/// Set if all assignments should not be dominated later on.
bool avoidDomination = false;
/// Canonical names of functions that should be exported (e.g., for Python use)
std::unordered_set<std::string> makeExport;
public:
SimplifyContext(std::string filename, Cache *cache);

View File

@ -278,7 +278,9 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
ctx->cache->functions[canonicalName].ast = f;
ctx->cache->functions[canonicalName].origAst =
std::static_pointer_cast<FunctionStmt>(stmt->clone());
ctx->cache->functions[canonicalName].module = ctx->getModule();
ctx->cache->functions[canonicalName].isToplevel =
ctx->getModule().empty() && ctx->isGlobal();
ctx->cache->functions[canonicalName].rootName = rootName;
// Expression to be used if function binding is modified by captures or decorators
ExprPtr finalExpr = nullptr;
@ -315,10 +317,6 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
} else {
resultStmt = f;
}
if (ctx->isGlobal() && ctx->getModule().empty()) {
ctx->makeExport.insert(f->name);
}
}
/// Make a capturing anonymous function with the provided suite and argument names.

View File

@ -106,25 +106,6 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil
ctx->scope.stmts[ctx->scope.blocks.back()].begin(),
ctx->scope.stmts[ctx->scope.blocks.back()].end());
suite->stmts.push_back(n);
auto exports = ctx->makeExport;
for (auto &fn : exports) {
std::vector<Param> args;
std::vector<ExprPtr> callArgs;
auto t = N<IdExpr>("pyobj");
t->markType();
for (auto &a : ctx->cache->functions[fn].ast->args) {
args.push_back(
Param{ctx->cache->rev(a.name), a.type ? a.type->clone() : t->clone()});
callArgs.push_back(N<IdExpr>(ctx->cache->rev(a.name)));
}
// TODO: what to do in case of overrides?
auto ast = N<FunctionStmt>("._py_" + ctx->cache->rev(fn), nullptr, args,
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
N<IdExpr>(ctx->cache->rev(fn)), callArgs))),
Attr({Attr::Export}));
suite->stmts.push_back(SimplifyVisitor(ctx, preamble).transform(ast));
}
#undef N
if (!ctx->cache->errors.empty())

View File

@ -318,6 +318,7 @@ ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr,
// Case: transform `pyobj.member` to `pyobj._getattr("member")`
if (typ->is("pyobj")) {
LOG("-> /p {}", expr->toString());
return transform(
N<CallExpr>(N<DotExpr>(expr->expr, "_getattr"), N<StringExpr>(expr->member)));
}