mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
Fix Python calls; add staticenumerate
This commit is contained in:
parent
938ab8dee4
commit
66f9e4fd23
@ -330,8 +330,6 @@ void Cache::populatePythonModule() {
|
||||
auto fna = functions[canonicalName].ast;
|
||||
bool isMethod = fna->hasAttr(Attr::Method);
|
||||
std::string call = pyWrap + ".wrap_multiple";
|
||||
if (isMethod)
|
||||
call += "_method";
|
||||
bool isMagic = false;
|
||||
if (startswith(n, "__") && endswith(n, "__")) {
|
||||
if (auto i = in(classes[pyWrap].methods,
|
||||
@ -346,6 +344,7 @@ void Cache::populatePythonModule() {
|
||||
auto generics = std::vector<types::TypePtr>{tc};
|
||||
if (!isMagic) {
|
||||
generics.push_back(std::make_shared<types::StaticType>(this, n));
|
||||
generics.push_back(std::make_shared<types::StaticType>(this, (int)isMethod));
|
||||
}
|
||||
auto f = realizeIR(functions[fnName].type, generics);
|
||||
if (!f)
|
||||
@ -451,6 +450,7 @@ void Cache::populatePythonModule() {
|
||||
: ir::PyFunction::Type::CLASS,
|
||||
// always use FASTCALL for now; works even for 0- or 1- arg methods
|
||||
2});
|
||||
py.methods.back().keywords = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,6 +78,9 @@ SimplifyVisitor::transformTupleGenerator(const std::vector<CallExpr::Arg> &args)
|
||||
E(Error::CALL_TUPLE_COMPREHENSION, args[0].value);
|
||||
auto var = clone(g->loops[0].vars);
|
||||
auto ex = clone(g->expr);
|
||||
|
||||
ctx->enterConditionalBlock();
|
||||
ctx->getBase()->loops.push_back({"", ctx->scope.blocks, {}});
|
||||
if (auto i = var->getId()) {
|
||||
ctx->addVar(i->value, ctx->generateCanonicalName(i->value), var->getSrcInfo());
|
||||
var = transform(var);
|
||||
@ -89,6 +92,11 @@ SimplifyVisitor::transformTupleGenerator(const std::vector<CallExpr::Arg> &args)
|
||||
auto head = transform(N<AssignStmt>(clone(g->loops[0].vars), clone(var)));
|
||||
ex = N<StmtExpr>(head, transform(ex));
|
||||
}
|
||||
ctx->leaveConditionalBlock();
|
||||
// Dominate loop variables
|
||||
for (auto &var : ctx->getBase()->getLoop()->seenVars)
|
||||
ctx->findDominatingBinding(var);
|
||||
ctx->getBase()->loops.pop_back();
|
||||
return N<GeneratorExpr>(
|
||||
GeneratorExpr::Generator, ex,
|
||||
std::vector<GeneratorBody>{{var, transform(g->loops[0].gen), {}}});
|
||||
|
@ -966,9 +966,30 @@ ExprPtr TypecheckVisitor::transformInternalStaticFn(CallExpr *expr) {
|
||||
auto n = fn->ast->args[i].name;
|
||||
trimStars(n);
|
||||
n = ctx->cache->rev(n);
|
||||
v.push_back(N<TupleExpr>(std::vector<ExprPtr>{N<IntExpr>(i), N<StringExpr>(n)}));
|
||||
v.push_back(N<StringExpr>(n));
|
||||
}
|
||||
return transform(N<TupleExpr>(v));
|
||||
} else if (expr->expr->isId("std.internal.static.fn_has_default")) {
|
||||
expr->staticValue.type = StaticValue::INT;
|
||||
auto fn = expr->args[0].value->type->getFunc();
|
||||
if (!fn)
|
||||
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
|
||||
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
|
||||
seqassert(idx, "expected a static integer");
|
||||
auto &args = fn->ast->args;
|
||||
if (*idx < 0 || *idx >= args.size())
|
||||
error("argument out of bounds");
|
||||
return transform(N<IntExpr>(args[*idx].defaultValue != nullptr));
|
||||
} else if (expr->expr->isId("std.internal.static.fn_get_default")) {
|
||||
auto fn = expr->args[0].value->type->getFunc();
|
||||
if (!fn)
|
||||
error("expected a function, got '{}'", expr->args[0].value->type->prettyString());
|
||||
auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type);
|
||||
seqassert(idx, "expected a static integer");
|
||||
auto &args = fn->ast->args;
|
||||
if (*idx < 0 || *idx >= args.size())
|
||||
error("argument out of bounds");
|
||||
return transform(args[*idx].defaultValue);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -237,7 +237,7 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) {
|
||||
} else {
|
||||
error("bad call to fn_overloads");
|
||||
}
|
||||
} else if (iter && startswith(iter->value, "std.internal.static.fn_args")) {
|
||||
} else if (iter && startswith(iter->value, "std.internal.builtin.staticenumerate")) {
|
||||
auto &suiteVec = stmt->suite->getSuite()->stmts;
|
||||
int validI = 0;
|
||||
for (; validI < suiteVec.size(); validI++) {
|
||||
@ -252,17 +252,15 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) {
|
||||
|
||||
if (auto fna = ctx->getFunctionArgs(iter->type)) {
|
||||
auto [generics, args] = *fna;
|
||||
auto typ = ctx->extractFunction(args[0]);
|
||||
auto typ = args[0]->getRecord();
|
||||
if (!typ)
|
||||
error("fn_args needs a function");
|
||||
for (size_t i = 0; i < typ->ast->args.size(); i++) {
|
||||
error("fn_args needs a tuple");
|
||||
for (size_t i = 0; i < typ->args.size(); i++) {
|
||||
suiteVec[0]->getAssign()->rhs = N<IntExpr>(i);
|
||||
suiteVec[0]->getAssign()->type =
|
||||
NT<IndexExpr>(NT<IdExpr>("Static"), NT<IdExpr>("int"));
|
||||
suiteVec[1]->getAssign()->rhs =
|
||||
N<StringExpr>(ctx->cache->rev(typ->ast->args[i].name));
|
||||
suiteVec[1]->getAssign()->type =
|
||||
NT<IndexExpr>(NT<IdExpr>("Static"), NT<IdExpr>("str"));
|
||||
N<IndexExpr>(stmt->iter->getCall()->args[0].value->clone(), N<IntExpr>(i));
|
||||
block->stmts.push_back(fn("", nullptr));
|
||||
}
|
||||
} else {
|
||||
@ -282,7 +280,6 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) {
|
||||
transform(N<SuiteStmt>(N<AssignStmt>(N<IdExpr>(loopVar), N<BoolExpr>(true)),
|
||||
N<WhileStmt>(N<IdExpr>(loopVar), block)));
|
||||
ctx->blockLevel--;
|
||||
// LOG("-> {} :: {}", getSrcInfo(), loop->toString(2));
|
||||
return loop;
|
||||
}
|
||||
|
||||
|
@ -221,6 +221,11 @@ def enumerate(x, start: int = 0):
|
||||
yield (i, a)
|
||||
i += 1
|
||||
|
||||
def staticenumerate(tup):
|
||||
i = -1
|
||||
return tuple(((i := i + 1), t) for t in tup)
|
||||
i
|
||||
|
||||
def echo(x):
|
||||
"""
|
||||
Print and return argument
|
||||
|
@ -1443,6 +1443,17 @@ def _____(): __pyenv__ # make it global!
|
||||
|
||||
|
||||
import internal.static as _S
|
||||
|
||||
|
||||
class _PyWrapError(Static[PyError]):
|
||||
def __init__(self, message: str, pytype: pyobj = pyobj(cobj(), steal=True)):
|
||||
super().__init__("_PyWrapError", message)
|
||||
self.pytype = pytype
|
||||
|
||||
def __init__(self, e: PyError):
|
||||
self.__init__("_PyWrapError", e.message, e.pytype)
|
||||
|
||||
|
||||
class _PyWrap:
|
||||
def _wrap_arg(arg: cobj):
|
||||
return pyobj(arg, steal=True)
|
||||
@ -1451,7 +1462,7 @@ class _PyWrap:
|
||||
if _S.fn_can_call(fn, *args):
|
||||
try:
|
||||
return map(fn, args)
|
||||
except PyError:
|
||||
except PyError as e:
|
||||
pass
|
||||
raise PyError("cannot dispatch " + F)
|
||||
|
||||
@ -1524,7 +1535,7 @@ class _PyWrap:
|
||||
a = tuple(
|
||||
_PyWrap._wrap_arg(obj) if i == 0 else
|
||||
(kwds[n] if kwds and n in kwds else args[(ai := ai + 1)])
|
||||
for i, n in _S.fn_args(fn)
|
||||
for i, n in staticenumerate(_S.fn_args(fn))
|
||||
)
|
||||
if ai + 1 != args.__len__():
|
||||
continue
|
||||
@ -1545,7 +1556,7 @@ class _PyWrap:
|
||||
a = tuple(
|
||||
_PyWrap._wrap_arg(obj) if i == 0 else
|
||||
(kwds[n] if kwds and n in kwds else args[(ai := ai + 1)])
|
||||
for i, n in _S.fn_args(fn)
|
||||
for i, n in staticenumerate(_S.fn_args(fn))
|
||||
)
|
||||
if ai + 1 != args.__len__():
|
||||
continue
|
||||
@ -1640,43 +1651,40 @@ class _PyWrap:
|
||||
# print('[c] iter')
|
||||
return _PyWrap.IterWrap._init(obj, T)
|
||||
|
||||
def wrap_multiple_method(obj: cobj, args: Ptr[cobj], nargs: int, T: type, F: Static[str]):
|
||||
# print(f'[c] method: {T.__class__.__name__} {F} {obj} {args} {nargs}')
|
||||
def _err() -> pyobj:
|
||||
raise PyError("argument mismatch")
|
||||
|
||||
def wrap_multiple(
|
||||
obj: cobj, _args: cobj, nargs: int, _kwds: cobj, T: type, F: Static[str],
|
||||
M: Static[int] = 1
|
||||
):
|
||||
args = _PyWrap._wrap_arg(_args)
|
||||
kwds = _PyWrap._wrap_arg(_kwds) if _kwds != cobj() else None
|
||||
for fn in _S.fn_overloads(T, F):
|
||||
try:
|
||||
ai = -1
|
||||
an = tuple(
|
||||
_PyWrap._wrap_arg(obj) if i == 0 else
|
||||
(_PyWrap._wrap_arg(args[i]) if i < nargs else _err())
|
||||
for i, _ in _S.fn_args(fn)
|
||||
)
|
||||
if len(an) != nargs + 1:
|
||||
_err()
|
||||
if _S.fn_can_call(fn, *an):
|
||||
return fn(*an).__to_py__()
|
||||
except PyError:
|
||||
pass
|
||||
raise PyError("cannot dispatch " + F)
|
||||
a = [pyobj(cobj(), steal=True) for _ in _S.fn_args(fn)]
|
||||
ai, ki = 0, 0
|
||||
for i, n in staticenumerate(_S.fn_args(fn)):
|
||||
if M == 1 and i == 0:
|
||||
print('le method')
|
||||
a[i] = _PyWrap._wrap_arg(obj)
|
||||
continue
|
||||
if ai < nargs:
|
||||
a[i] = args[ai]
|
||||
ai += 1
|
||||
elif kwds and n in kwds:
|
||||
a[i] = kwds[n]
|
||||
ki += 1
|
||||
else:
|
||||
if _S.fn_has_default(fn, i):
|
||||
a[i] = _S.fn_get_default(fn, i)
|
||||
else:
|
||||
raise PyError("missing argument")
|
||||
if ai < nargs:
|
||||
raise PyError("too many *args")
|
||||
if kwds and ki < kwds.__len__():
|
||||
raise PyError("too many **kwargs")
|
||||
ta = tuple(unwrap(a[i]) for i, _ in staticenumerate(_S.fn_args(fn)))
|
||||
|
||||
def wrap_multiple(obj: cobj, args: Ptr[cobj], nargs: int, T: type, F: Static[str]):
|
||||
# print(f'[c] nonmethod: {T.__class__.__name__} {F} {obj} {args} {nargs}')
|
||||
def _err() -> pyobj:
|
||||
raise PyError("argument mismatch")
|
||||
|
||||
for fn in _S.fn_overloads(T, F):
|
||||
try:
|
||||
ai = -1
|
||||
an = tuple(
|
||||
_PyWrap._wrap_arg(args[i]) if i < nargs else _err()
|
||||
for i, _ in _S.fn_args(fn)
|
||||
)
|
||||
if len(an) != nargs:
|
||||
_err()
|
||||
if _S.fn_can_call(fn, *an):
|
||||
return fn(*an).__to_py__()
|
||||
if _S.fn_can_call(fn, *ta):
|
||||
return fn(*ta).__to_py__()
|
||||
except PyError:
|
||||
pass
|
||||
raise PyError("cannot dispatch " + F)
|
||||
@ -1707,3 +1715,5 @@ class _PyWrap:
|
||||
if obj.head.pytype != pytype:
|
||||
raise TypeError("Python object has incompatible type")
|
||||
return obj.data
|
||||
|
||||
|
||||
|
@ -15,5 +15,11 @@ def fn_arg_get_type(F, i: Static[int]):
|
||||
def fn_can_call(F, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def fn_has_default(F, i: Static[int]):
|
||||
pass
|
||||
|
||||
def fn_get_default(F, i: Static[int]):
|
||||
pass
|
||||
|
||||
def class_args(T: type):
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user