1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Fix Python calls; add staticenumerate

This commit is contained in:
Ibrahim Numanagić 2023-03-14 22:32:34 -07:00
parent 938ab8dee4
commit 66f9e4fd23
7 changed files with 95 additions and 48 deletions

View File

@ -330,8 +330,6 @@ void Cache::populatePythonModule() {
auto fna = functions[canonicalName].ast;
bool isMethod = fna->hasAttr(Attr::Method);
std::string call = pyWrap + ".wrap_multiple";
if (isMethod)
call += "_method";
bool isMagic = false;
if (startswith(n, "__") && endswith(n, "__")) {
if (auto i = in(classes[pyWrap].methods,
@ -346,6 +344,7 @@ void Cache::populatePythonModule() {
auto generics = std::vector<types::TypePtr>{tc};
if (!isMagic) {
generics.push_back(std::make_shared<types::StaticType>(this, n));
generics.push_back(std::make_shared<types::StaticType>(this, (int)isMethod));
}
auto f = realizeIR(functions[fnName].type, generics);
if (!f)
@ -451,6 +450,7 @@ void Cache::populatePythonModule() {
: ir::PyFunction::Type::CLASS,
// always use FASTCALL for now; works even for 0- or 1- arg methods
2});
py.methods.back().keywords = true;
}
}

View File

@ -78,6 +78,9 @@ SimplifyVisitor::transformTupleGenerator(const std::vector<CallExpr::Arg> &args)
E(Error::CALL_TUPLE_COMPREHENSION, args[0].value);
auto var = clone(g->loops[0].vars);
auto ex = clone(g->expr);
ctx->enterConditionalBlock();
ctx->getBase()->loops.push_back({"", ctx->scope.blocks, {}});
if (auto i = var->getId()) {
ctx->addVar(i->value, ctx->generateCanonicalName(i->value), var->getSrcInfo());
var = transform(var);
@ -89,6 +92,11 @@ SimplifyVisitor::transformTupleGenerator(const std::vector<CallExpr::Arg> &args)
auto head = transform(N<AssignStmt>(clone(g->loops[0].vars), clone(var)));
ex = N<StmtExpr>(head, transform(ex));
}
ctx->leaveConditionalBlock();
// Dominate loop variables
for (auto &var : ctx->getBase()->getLoop()->seenVars)
ctx->findDominatingBinding(var);
ctx->getBase()->loops.pop_back();
return N<GeneratorExpr>(
GeneratorExpr::Generator, ex,
std::vector<GeneratorBody>{{var, transform(g->loops[0].gen), {}}});

View File

@ -966,9 +966,30 @@ ExprPtr TypecheckVisitor::transformInternalStaticFn(CallExpr *expr) {
auto n = fn->ast->args[i].name;
trimStars(n);
n = ctx->cache->rev(n);
v.push_back(N<TupleExpr>(std::vector<ExprPtr>{N<IntExpr>(i), N<StringExpr>(n)}));
v.push_back(N<StringExpr>(n));
}
return transform(N<TupleExpr>(v));
} else if (expr->expr->isId("std.internal.static.fn_has_default")) {
expr->staticValue.type = StaticValue::INT;
auto fn = expr->args[0].value->type->getFunc();
if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
seqassert(idx, "expected a static integer");
auto &args = fn->ast->args;
if (*idx < 0 || *idx >= args.size())
error("argument out of bounds");
return transform(N<IntExpr>(args[*idx].defaultValue != nullptr));
} else if (expr->expr->isId("std.internal.static.fn_get_default")) {
auto fn = expr->args[0].value->type->getFunc();
if (!fn)
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
seqassert(idx, "expected a static integer");
auto &args = fn->ast->args;
if (*idx < 0 || *idx >= args.size())
error("argument out of bounds");
return transform(args[*idx].defaultValue);
} else {
return nullptr;
}

View File

@ -237,7 +237,7 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) {
} else {
error("bad call to fn_overloads");
}
} else if (iter && startswith(iter->value, "std.internal.static.fn_args")) {
} else if (iter && startswith(iter->value, "std.internal.builtin.staticenumerate")) {
auto &suiteVec = stmt->suite->getSuite()->stmts;
int validI = 0;
for (; validI < suiteVec.size(); validI++) {
@ -252,17 +252,15 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) {
if (auto fna = ctx->getFunctionArgs(iter->type)) {
auto [generics, args] = *fna;
auto typ = ctx->extractFunction(args[0]);
auto typ = args[0]->getRecord();
if (!typ)
error("fn_args needs a function");
for (size_t i = 0; i < typ->ast->args.size(); i++) {
error("fn_args needs a tuple");
for (size_t i = 0; i < typ->args.size(); i++) {
suiteVec[0]->getAssign()->rhs = N<IntExpr>(i);
suiteVec[0]->getAssign()->type =
NT<IndexExpr>(NT<IdExpr>("Static"), NT<IdExpr>("int"));
suiteVec[1]->getAssign()->rhs =
N<StringExpr>(ctx->cache->rev(typ->ast->args[i].name));
suiteVec[1]->getAssign()->type =
NT<IndexExpr>(NT<IdExpr>("Static"), NT<IdExpr>("str"));
N<IndexExpr>(stmt->iter->getCall()->args[0].value->clone(), N<IntExpr>(i));
block->stmts.push_back(fn("", nullptr));
}
} else {
@ -282,7 +280,6 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) {
transform(N<SuiteStmt>(N<AssignStmt>(N<IdExpr>(loopVar), N<BoolExpr>(true)),
N<WhileStmt>(N<IdExpr>(loopVar), block)));
ctx->blockLevel--;
// LOG("-> {} :: {}", getSrcInfo(), loop->toString(2));
return loop;
}

View File

@ -221,6 +221,11 @@ def enumerate(x, start: int = 0):
yield (i, a)
i += 1
def staticenumerate(tup):
i = -1
return tuple(((i := i + 1), t) for t in tup)
i
def echo(x):
"""
Print and return argument

View File

@ -1443,6 +1443,17 @@ def _____(): __pyenv__ # make it global!
import internal.static as _S
class _PyWrapError(Static[PyError]):
def __init__(self, message: str, pytype: pyobj = pyobj(cobj(), steal=True)):
super().__init__("_PyWrapError", message)
self.pytype = pytype
def __init__(self, e: PyError):
self.__init__("_PyWrapError", e.message, e.pytype)
class _PyWrap:
def _wrap_arg(arg: cobj):
return pyobj(arg, steal=True)
@ -1451,7 +1462,7 @@ class _PyWrap:
if _S.fn_can_call(fn, *args):
try:
return map(fn, args)
except PyError:
except PyError as e:
pass
raise PyError("cannot dispatch " + F)
@ -1524,7 +1535,7 @@ class _PyWrap:
a = tuple(
_PyWrap._wrap_arg(obj) if i == 0 else
(kwds[n] if kwds and n in kwds else args[(ai := ai + 1)])
for i, n in _S.fn_args(fn)
for i, n in staticenumerate(_S.fn_args(fn))
)
if ai + 1 != args.__len__():
continue
@ -1545,7 +1556,7 @@ class _PyWrap:
a = tuple(
_PyWrap._wrap_arg(obj) if i == 0 else
(kwds[n] if kwds and n in kwds else args[(ai := ai + 1)])
for i, n in _S.fn_args(fn)
for i, n in staticenumerate(_S.fn_args(fn))
)
if ai + 1 != args.__len__():
continue
@ -1640,43 +1651,40 @@ class _PyWrap:
# print('[c] iter')
return _PyWrap.IterWrap._init(obj, T)
def wrap_multiple_method(obj: cobj, args: Ptr[cobj], nargs: int, T: type, F: Static[str]):
# print(f'[c] method: {T.__class__.__name__} {F} {obj} {args} {nargs}')
def _err() -> pyobj:
raise PyError("argument mismatch")
def wrap_multiple(
obj: cobj, _args: cobj, nargs: int, _kwds: cobj, T: type, F: Static[str],
M: Static[int] = 1
):
args = _PyWrap._wrap_arg(_args)
kwds = _PyWrap._wrap_arg(_kwds) if _kwds != cobj() else None
for fn in _S.fn_overloads(T, F):
try:
ai = -1
an = tuple(
_PyWrap._wrap_arg(obj) if i == 0 else
(_PyWrap._wrap_arg(args[i]) if i < nargs else _err())
for i, _ in _S.fn_args(fn)
)
if len(an) != nargs + 1:
_err()
if _S.fn_can_call(fn, *an):
return fn(*an).__to_py__()
except PyError:
pass
raise PyError("cannot dispatch " + F)
a = [pyobj(cobj(), steal=True) for _ in _S.fn_args(fn)]
ai, ki = 0, 0
for i, n in staticenumerate(_S.fn_args(fn)):
if M == 1 and i == 0:
print('le method')
a[i] = _PyWrap._wrap_arg(obj)
continue
if ai < nargs:
a[i] = args[ai]
ai += 1
elif kwds and n in kwds:
a[i] = kwds[n]
ki += 1
else:
if _S.fn_has_default(fn, i):
a[i] = _S.fn_get_default(fn, i)
else:
raise PyError("missing argument")
if ai < nargs:
raise PyError("too many *args")
if kwds and ki < kwds.__len__():
raise PyError("too many **kwargs")
ta = tuple(unwrap(a[i]) for i, _ in staticenumerate(_S.fn_args(fn)))
def wrap_multiple(obj: cobj, args: Ptr[cobj], nargs: int, T: type, F: Static[str]):
# print(f'[c] nonmethod: {T.__class__.__name__} {F} {obj} {args} {nargs}')
def _err() -> pyobj:
raise PyError("argument mismatch")
for fn in _S.fn_overloads(T, F):
try:
ai = -1
an = tuple(
_PyWrap._wrap_arg(args[i]) if i < nargs else _err()
for i, _ in _S.fn_args(fn)
)
if len(an) != nargs:
_err()
if _S.fn_can_call(fn, *an):
return fn(*an).__to_py__()
if _S.fn_can_call(fn, *ta):
return fn(*ta).__to_py__()
except PyError:
pass
raise PyError("cannot dispatch " + F)
@ -1707,3 +1715,5 @@ class _PyWrap:
if obj.head.pytype != pytype:
raise TypeError("Python object has incompatible type")
return obj.data

View File

@ -15,5 +15,11 @@ def fn_arg_get_type(F, i: Static[int]):
def fn_can_call(F, *args, **kwargs):
pass
def fn_has_default(F, i: Static[int]):
pass
def fn_get_default(F, i: Static[int]):
pass
def class_args(T: type):
pass