Add support for @pycapture

pull/335/head
Ibrahim Numanagić 2023-02-10 18:21:48 -08:00
parent 946972df22
commit 92f9a274e7
7 changed files with 33 additions and 6 deletions

View File

@ -19,8 +19,14 @@ void SimplifyVisitor::visit(IdExpr *expr) {
return;
}
auto val = ctx->findDominatingBinding(expr->value);
if (!val)
if (!val && ctx->getBase()->pyCaptures) {
ctx->getBase()->pyCaptures->insert(expr->value);
resultExpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(expr->value));
return;
} else if (!val) {
E(Error::ID_NOT_FOUND, expr, expr->value);
}
// If we are accessing an outside variable, capture it or raise an error
auto captured = checkCapture(val);
@ -107,7 +113,11 @@ void SimplifyVisitor::visit(DotExpr *expr) {
std::reverse(chain.begin(), chain.end());
auto p = getImport(chain);
if (p.second->getModule() == "std.python") {
if (!p.second) {
seqassert(ctx->getBase()->pyCaptures, "unexpected py capture");
ctx->getBase()->pyCaptures->insert(chain[0]);
resultExpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(chain[0]));
} else if (p.second->getModule() == "std.python") {
resultExpr = transform(N<CallExpr>(
N<DotExpr>(N<DotExpr>(N<IdExpr>("internal"), "python"), "_get_identifier"),
N<StringExpr>(chain[p.first++])));
@ -238,7 +248,7 @@ SimplifyVisitor::getImport(const std::vector<std::string> &chain) {
for (auto i = chain.size(); i-- > importEnd;) {
if (fctx->getModule() == "std.python" && importEnd < chain.size()) {
// Special case: importing from Python.
// Fake SimplifyItem that inidcates std.python access
// Fake SimplifyItem that indicates std.python access
val = std::make_shared<SimplifyItem>(SimplifyItem::Var, "", "",
fctx->getModule(), std::vector<int>{});
return {importEnd, val};
@ -250,8 +260,11 @@ SimplifyVisitor::getImport(const std::vector<std::string> &chain) {
}
}
}
if (itemName.empty() && importName.empty())
if (itemName.empty() && importName.empty()) {
if (ctx->getBase()->pyCaptures)
return {1, nullptr};
E(Error::IMPORT_NO_MODULE, getSrcInfo(), chain[importEnd]);
}
if (itemName.empty())
E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd],
ctx->cache->imports[importName].moduleName);

View File

@ -25,7 +25,7 @@ SimplifyContext::SimplifyContext(std::string filename, Cache *cache)
SimplifyContext::Base::Base(std::string name, Attr *attributes)
: name(move(name)), attributes(attributes), deducedMembers(nullptr), selfName(),
captures(nullptr) {}
captures(nullptr), pyCaptures(nullptr) {}
void SimplifyContext::add(const std::string &name, const SimplifyContext::Item &var) {
auto v = find(name);

View File

@ -112,6 +112,10 @@ struct SimplifyContext : public Context<SimplifyItem> {
/// function after processing) and their types (indicating if they are a type, a
/// static or a variable).
std::unordered_map<std::string, std::pair<std::string, ExprPtr>> *captures;
/// Map of identifiers that are to be fetched from Python.
std::unordered_set<std::string> *pyCaptures;
/// Scope that defines the base.
std::vector<int> scope;

View File

@ -161,6 +161,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
StmtPtr suite = nullptr;
ExprPtr ret = nullptr;
std::unordered_map<std::string, std::pair<std::string, ExprPtr>> captures;
std::unordered_set<std::string> pyCaptures;
{
// Set up the base
SimplifyContext::BaseGuard br(ctx.get(), canonicalName);
@ -239,6 +240,8 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
} else {
if ((isEnclosedFunc || stmt->attributes.has(Attr::Capture)) && !isClassMember)
ctx->getBase()->captures = &captures;
if (stmt->attributes.has("std.internal.attributes.pycapture"))
ctx->getBase()->pyCaptures = &pyCaptures;
suite = SimplifyVisitor(ctx, preamble).transformConditionalScope(stmt->suite);
}
}

View File

@ -318,7 +318,6 @@ ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr,
// Case: transform `pyobj.member` to `pyobj._getattr("member")`
if (typ->is("pyobj")) {
LOG("-> /p {}", expr->toString());
return transform(
N<CallExpr>(N<DotExpr>(expr->expr, "_getattr"), N<StringExpr>(expr->member)));
}

View File

@ -32,6 +32,10 @@ def no_side_effect():
def nocapture():
pass
@__attribute__
def pycapture():
pass
@__attribute__
def derives():
pass

View File

@ -978,3 +978,7 @@ class Optional:
return Optional[T]()
else:
return Optional[T](T.__from_py__(o))
__pyenv__: Optional[pyobj] = None
def _____(): __pyenv__ # make it global!