mirror of https://github.com/exaloop/codon.git
Bugfixes 2023-08 (#440)
* Fix type argument overload issue; Fix Cython version for CI * Add __contains__ for kwargs * Add get() for kwargs * Add static <<, >> and unary ~ * Fix CI * Fix OpenMP "ordered" clause * Fix static ~ * Fix Cython 3 issues * Fix Python MANIFEST.in --------- Co-authored-by: A. R. Shajii <ars@ars.me>auto-jit
parent
7198a0971a
commit
750bb28c9c
|
@ -151,19 +151,35 @@ std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name,
|
|||
StmtPtr stmt = N<ClassStmt>(ctx->cache->generateSrcInfo(), typeName, args, nullptr,
|
||||
std::vector<ExprPtr>{N<IdExpr>("tuple")});
|
||||
|
||||
// Add getItem for KwArgs:
|
||||
// Add helpers for KwArgs:
|
||||
// `def __getitem__(self, key: Static[str]): return getattr(self, key)`
|
||||
// `def __contains__(self, key: Static[str]): return hasattr(self, key)`
|
||||
auto getItem = N<FunctionStmt>(
|
||||
"__getitem__", nullptr,
|
||||
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
|
||||
N<IdExpr>("str"))}},
|
||||
N<SuiteStmt>(N<ReturnStmt>(
|
||||
N<CallExpr>(N<IdExpr>("getattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
|
||||
auto contains = N<FunctionStmt>(
|
||||
"__contains__", nullptr,
|
||||
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
|
||||
N<IdExpr>("str"))}},
|
||||
N<SuiteStmt>(N<ReturnStmt>(
|
||||
N<CallExpr>(N<IdExpr>("hasattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
|
||||
auto getDef = N<FunctionStmt>(
|
||||
"get", nullptr,
|
||||
std::vector<Param>{
|
||||
Param{"self"},
|
||||
Param{"key", N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("str"))},
|
||||
Param{"default", nullptr, N<CallExpr>(N<IdExpr>("NoneType"))}},
|
||||
N<SuiteStmt>(N<ReturnStmt>(
|
||||
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "kwargs_get"),
|
||||
N<IdExpr>("self"), N<IdExpr>("key"), N<IdExpr>("default")))));
|
||||
if (startswith(typeName, TYPE_KWTUPLE))
|
||||
stmt->getClass()->suite = getItem;
|
||||
stmt->getClass()->suite = N<SuiteStmt>(getItem, contains, getDef);
|
||||
|
||||
// Add getItem for KwArgs:
|
||||
// `def __repr__(self,): return __magic__.repr_partial(self)`
|
||||
// Add repr for KwArgs:
|
||||
// `def __repr__(self): return __magic__.repr_partial(self)`
|
||||
auto repr = N<FunctionStmt>(
|
||||
"__repr__", nullptr, std::vector<Param>{Param{"self"}},
|
||||
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
|
||||
|
|
|
@ -22,7 +22,8 @@ void TypecheckVisitor::visit(UnaryExpr *expr) {
|
|||
transform(expr->expr);
|
||||
|
||||
static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
|
||||
staticOps = {{StaticValue::INT, {"-", "+", "!"}}, {StaticValue::STRING, {"@"}}};
|
||||
staticOps = {{StaticValue::INT, {"-", "+", "!", "~"}},
|
||||
{StaticValue::STRING, {"@"}}};
|
||||
// Handle static expressions
|
||||
if (expr->expr->isStatic() && in(staticOps[expr->expr->staticValue.type], expr->op)) {
|
||||
resultExpr = evaluateStaticUnary(expr);
|
||||
|
@ -62,7 +63,7 @@ void TypecheckVisitor::visit(BinaryExpr *expr) {
|
|||
static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
|
||||
staticOps = {{StaticValue::INT,
|
||||
{"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//",
|
||||
"%", "&", "|", "^"}},
|
||||
"%", "&", "|", "^", ">>", "<<"}},
|
||||
{StaticValue::STRING, {"==", "!=", "+"}}};
|
||||
if (expr->lexpr->isStatic() && expr->rexpr->isStatic() &&
|
||||
expr->lexpr->staticValue.type == expr->rexpr->staticValue.type &&
|
||||
|
@ -370,13 +371,15 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
|
|||
}
|
||||
|
||||
// Case: static integers
|
||||
if (expr->op == "-" || expr->op == "+" || expr->op == "!") {
|
||||
if (expr->op == "-" || expr->op == "+" || expr->op == "!" || expr->op == "~") {
|
||||
if (expr->expr->staticValue.evaluated) {
|
||||
int64_t value = expr->expr->staticValue.getInt();
|
||||
if (expr->op == "+")
|
||||
;
|
||||
else if (expr->op == "-")
|
||||
value = -value;
|
||||
else if (expr->op == "~")
|
||||
value = ~value;
|
||||
else
|
||||
value = !bool(value);
|
||||
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
|
||||
|
@ -484,6 +487,10 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
|
|||
lvalue = lvalue & rvalue;
|
||||
else if (expr->op == "|")
|
||||
lvalue = lvalue | rvalue;
|
||||
else if (expr->op == ">>")
|
||||
lvalue = lvalue >> rvalue;
|
||||
else if (expr->op == "<<")
|
||||
lvalue = lvalue << rvalue;
|
||||
else if (expr->op == "//")
|
||||
lvalue = divMod(ctx, lvalue, rvalue).first;
|
||||
else if (expr->op == "%")
|
||||
|
|
|
@ -253,7 +253,7 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
|
|||
auto score = ctx->reorderNamedArgs(
|
||||
fn.get(), args,
|
||||
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
|
||||
for (int si = 0; si < slots.size(); si++) {
|
||||
for (int si = 0, gi = 0; si < slots.size(); si++) {
|
||||
if (fn->ast->args[si].status == Param::Generic) {
|
||||
if (slots[si].empty()) {
|
||||
// is this "real" type?
|
||||
|
@ -263,8 +263,13 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
|
|||
}
|
||||
reordered.push_back({nullptr, 0});
|
||||
} else {
|
||||
seqassert(gi < fn->funcGenerics.size(), "bad fn");
|
||||
if (!fn->funcGenerics[gi].type->isStaticType() &&
|
||||
!args[slots[si][0]].value->isType())
|
||||
return -1;
|
||||
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
|
||||
}
|
||||
gi++;
|
||||
} else if (si == s || si == k || slots[si].size() != 1) {
|
||||
// Ignore *args, *kwargs and default arguments
|
||||
reordered.push_back({nullptr, 0});
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
include codon/*.pxd
|
|
@ -216,7 +216,7 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
|
|||
file=sys.stderr,
|
||||
)
|
||||
return _jit.run_wrapper(
|
||||
obj_name, types, f.__module__, pyvars, args, 1 if debug else 0
|
||||
obj_name, list(types), f.__module__, list(pyvars), args, 1 if debug else 0
|
||||
)
|
||||
except JITError:
|
||||
_reset_jit()
|
||||
|
|
|
@ -65,7 +65,7 @@ else:
|
|||
|
||||
jit_extension = Extension(
|
||||
"codon.codon_jit",
|
||||
sources=["codon/jit.pyx", "codon/jit.pxd"],
|
||||
sources=["codon/jit.pyx"],
|
||||
libraries=libraries,
|
||||
language="c++",
|
||||
extra_compile_args=["-w"],
|
||||
|
|
|
@ -435,6 +435,12 @@ class __internal__:
|
|||
e.col = col
|
||||
return e
|
||||
|
||||
def kwargs_get(kw, key: Static[str], default):
|
||||
if hasattr(kw, key):
|
||||
return getattr(kw, key)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
@extend
|
||||
class __magic__:
|
||||
|
|
|
@ -136,8 +136,8 @@ def _master_end(loc_ref: Ptr[Ident], gtid: int):
|
|||
__kmpc_end_master(loc_ref, i32(gtid))
|
||||
|
||||
def _ordered_begin(loc_ref: Ptr[Ident], gtid: int):
|
||||
from C import __kmpc_ordered(Ptr[Ident], i32) -> i32
|
||||
return int(__kmpc_ordered(loc_ref, i32(gtid)))
|
||||
from C import __kmpc_ordered(Ptr[Ident], i32)
|
||||
__kmpc_ordered(loc_ref, i32(gtid))
|
||||
|
||||
def _ordered_end(loc_ref: Ptr[Ident], gtid: int):
|
||||
from C import __kmpc_end_ordered(Ptr[Ident], i32)
|
||||
|
@ -781,11 +781,11 @@ def ordered(func):
|
|||
def _wrapper(*args, **kwargs):
|
||||
gtid = get_thread_num()
|
||||
loc = _default_loc()
|
||||
if _ordered_begin(loc, gtid) != 0:
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
finally:
|
||||
_ordered_end(loc, gtid)
|
||||
_ordered_begin(loc, gtid)
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
finally:
|
||||
_ordered_end(loc, gtid)
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
|
|
@ -1231,6 +1231,21 @@ def foo(x):
|
|||
print foo('hi') #: (3, 2)
|
||||
print foo('hi', 1) #: (2, 'hi_1')
|
||||
|
||||
|
||||
def fox(a: int, b: int, c: int, dtype: type = int):
|
||||
print('fox 1:', a, b, c)
|
||||
|
||||
@overload
|
||||
def fox(a: int, b: int, dtype: type = int):
|
||||
print('fox 2:', a, b, dtype.__class__.__name__)
|
||||
|
||||
fox(1, 2, float)
|
||||
#: fox 2: 1 2 float
|
||||
fox(1, 2)
|
||||
#: fox 2: 1 2 int
|
||||
fox(1, 2, 3)
|
||||
#: fox 1: 1 2 3
|
||||
|
||||
#%% fn_shadow,barebones
|
||||
def foo(x):
|
||||
return 1, x
|
||||
|
|
|
@ -889,6 +889,20 @@ def test_omp_collapse():
|
|||
|
||||
assert A6 == B6
|
||||
|
||||
@test
|
||||
def test_omp_ordered(N: int = 1000):
|
||||
@omp.ordered
|
||||
def f(A, i):
|
||||
A.append(i)
|
||||
|
||||
A = []
|
||||
|
||||
@par(schedule='dynamic', chunk_size=1, num_threads=2, ordered=True)
|
||||
for i in range(N):
|
||||
f(A, i)
|
||||
|
||||
assert A == list(range(N))
|
||||
|
||||
test_omp_api()
|
||||
test_omp_schedules()
|
||||
test_omp_ranges()
|
||||
|
@ -901,3 +915,4 @@ test_omp_transform(111.1, 222.2, 333.3)
|
|||
test_omp_nested()
|
||||
test_omp_corner_cases()
|
||||
test_omp_collapse()
|
||||
test_omp_ordered()
|
||||
|
|
Loading…
Reference in New Issue