mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
Bugfixes 2023-08 (#440)
* Fix type argument overload issue; Fix Cython version for CI * Add __contains__ for kwargs * Add get() for kwargs * Add static <<, >> and unary ~ * Fix CI * Fix OpenMP "ordered" clause * Fix static ~ * Fix Cython 3 issues * Fix Python MANIFEST.in --------- Co-authored-by: A. R. Shajii <ars@ars.me>
This commit is contained in:
parent
7198a0971a
commit
750bb28c9c
@ -151,19 +151,35 @@ std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name,
|
|||||||
StmtPtr stmt = N<ClassStmt>(ctx->cache->generateSrcInfo(), typeName, args, nullptr,
|
StmtPtr stmt = N<ClassStmt>(ctx->cache->generateSrcInfo(), typeName, args, nullptr,
|
||||||
std::vector<ExprPtr>{N<IdExpr>("tuple")});
|
std::vector<ExprPtr>{N<IdExpr>("tuple")});
|
||||||
|
|
||||||
// Add getItem for KwArgs:
|
// Add helpers for KwArgs:
|
||||||
// `def __getitem__(self, key: Static[str]): return getattr(self, key)`
|
// `def __getitem__(self, key: Static[str]): return getattr(self, key)`
|
||||||
|
// `def __contains__(self, key: Static[str]): return hasattr(self, key)`
|
||||||
auto getItem = N<FunctionStmt>(
|
auto getItem = N<FunctionStmt>(
|
||||||
"__getitem__", nullptr,
|
"__getitem__", nullptr,
|
||||||
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
|
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
|
||||||
N<IdExpr>("str"))}},
|
N<IdExpr>("str"))}},
|
||||||
N<SuiteStmt>(N<ReturnStmt>(
|
N<SuiteStmt>(N<ReturnStmt>(
|
||||||
N<CallExpr>(N<IdExpr>("getattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
|
N<CallExpr>(N<IdExpr>("getattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
|
||||||
|
auto contains = N<FunctionStmt>(
|
||||||
|
"__contains__", nullptr,
|
||||||
|
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
|
||||||
|
N<IdExpr>("str"))}},
|
||||||
|
N<SuiteStmt>(N<ReturnStmt>(
|
||||||
|
N<CallExpr>(N<IdExpr>("hasattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
|
||||||
|
auto getDef = N<FunctionStmt>(
|
||||||
|
"get", nullptr,
|
||||||
|
std::vector<Param>{
|
||||||
|
Param{"self"},
|
||||||
|
Param{"key", N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("str"))},
|
||||||
|
Param{"default", nullptr, N<CallExpr>(N<IdExpr>("NoneType"))}},
|
||||||
|
N<SuiteStmt>(N<ReturnStmt>(
|
||||||
|
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "kwargs_get"),
|
||||||
|
N<IdExpr>("self"), N<IdExpr>("key"), N<IdExpr>("default")))));
|
||||||
if (startswith(typeName, TYPE_KWTUPLE))
|
if (startswith(typeName, TYPE_KWTUPLE))
|
||||||
stmt->getClass()->suite = getItem;
|
stmt->getClass()->suite = N<SuiteStmt>(getItem, contains, getDef);
|
||||||
|
|
||||||
// Add getItem for KwArgs:
|
// Add repr for KwArgs:
|
||||||
// `def __repr__(self,): return __magic__.repr_partial(self)`
|
// `def __repr__(self): return __magic__.repr_partial(self)`
|
||||||
auto repr = N<FunctionStmt>(
|
auto repr = N<FunctionStmt>(
|
||||||
"__repr__", nullptr, std::vector<Param>{Param{"self"}},
|
"__repr__", nullptr, std::vector<Param>{Param{"self"}},
|
||||||
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
|
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
|
||||||
|
@ -22,7 +22,8 @@ void TypecheckVisitor::visit(UnaryExpr *expr) {
|
|||||||
transform(expr->expr);
|
transform(expr->expr);
|
||||||
|
|
||||||
static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
|
static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
|
||||||
staticOps = {{StaticValue::INT, {"-", "+", "!"}}, {StaticValue::STRING, {"@"}}};
|
staticOps = {{StaticValue::INT, {"-", "+", "!", "~"}},
|
||||||
|
{StaticValue::STRING, {"@"}}};
|
||||||
// Handle static expressions
|
// Handle static expressions
|
||||||
if (expr->expr->isStatic() && in(staticOps[expr->expr->staticValue.type], expr->op)) {
|
if (expr->expr->isStatic() && in(staticOps[expr->expr->staticValue.type], expr->op)) {
|
||||||
resultExpr = evaluateStaticUnary(expr);
|
resultExpr = evaluateStaticUnary(expr);
|
||||||
@ -62,7 +63,7 @@ void TypecheckVisitor::visit(BinaryExpr *expr) {
|
|||||||
static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
|
static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
|
||||||
staticOps = {{StaticValue::INT,
|
staticOps = {{StaticValue::INT,
|
||||||
{"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//",
|
{"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//",
|
||||||
"%", "&", "|", "^"}},
|
"%", "&", "|", "^", ">>", "<<"}},
|
||||||
{StaticValue::STRING, {"==", "!=", "+"}}};
|
{StaticValue::STRING, {"==", "!=", "+"}}};
|
||||||
if (expr->lexpr->isStatic() && expr->rexpr->isStatic() &&
|
if (expr->lexpr->isStatic() && expr->rexpr->isStatic() &&
|
||||||
expr->lexpr->staticValue.type == expr->rexpr->staticValue.type &&
|
expr->lexpr->staticValue.type == expr->rexpr->staticValue.type &&
|
||||||
@ -370,13 +371,15 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Case: static integers
|
// Case: static integers
|
||||||
if (expr->op == "-" || expr->op == "+" || expr->op == "!") {
|
if (expr->op == "-" || expr->op == "+" || expr->op == "!" || expr->op == "~") {
|
||||||
if (expr->expr->staticValue.evaluated) {
|
if (expr->expr->staticValue.evaluated) {
|
||||||
int64_t value = expr->expr->staticValue.getInt();
|
int64_t value = expr->expr->staticValue.getInt();
|
||||||
if (expr->op == "+")
|
if (expr->op == "+")
|
||||||
;
|
;
|
||||||
else if (expr->op == "-")
|
else if (expr->op == "-")
|
||||||
value = -value;
|
value = -value;
|
||||||
|
else if (expr->op == "~")
|
||||||
|
value = ~value;
|
||||||
else
|
else
|
||||||
value = !bool(value);
|
value = !bool(value);
|
||||||
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
|
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
|
||||||
@ -484,6 +487,10 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
|
|||||||
lvalue = lvalue & rvalue;
|
lvalue = lvalue & rvalue;
|
||||||
else if (expr->op == "|")
|
else if (expr->op == "|")
|
||||||
lvalue = lvalue | rvalue;
|
lvalue = lvalue | rvalue;
|
||||||
|
else if (expr->op == ">>")
|
||||||
|
lvalue = lvalue >> rvalue;
|
||||||
|
else if (expr->op == "<<")
|
||||||
|
lvalue = lvalue << rvalue;
|
||||||
else if (expr->op == "//")
|
else if (expr->op == "//")
|
||||||
lvalue = divMod(ctx, lvalue, rvalue).first;
|
lvalue = divMod(ctx, lvalue, rvalue).first;
|
||||||
else if (expr->op == "%")
|
else if (expr->op == "%")
|
||||||
|
@ -253,7 +253,7 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
|
|||||||
auto score = ctx->reorderNamedArgs(
|
auto score = ctx->reorderNamedArgs(
|
||||||
fn.get(), args,
|
fn.get(), args,
|
||||||
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
|
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
|
||||||
for (int si = 0; si < slots.size(); si++) {
|
for (int si = 0, gi = 0; si < slots.size(); si++) {
|
||||||
if (fn->ast->args[si].status == Param::Generic) {
|
if (fn->ast->args[si].status == Param::Generic) {
|
||||||
if (slots[si].empty()) {
|
if (slots[si].empty()) {
|
||||||
// is this "real" type?
|
// is this "real" type?
|
||||||
@ -263,8 +263,13 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
|
|||||||
}
|
}
|
||||||
reordered.push_back({nullptr, 0});
|
reordered.push_back({nullptr, 0});
|
||||||
} else {
|
} else {
|
||||||
|
seqassert(gi < fn->funcGenerics.size(), "bad fn");
|
||||||
|
if (!fn->funcGenerics[gi].type->isStaticType() &&
|
||||||
|
!args[slots[si][0]].value->isType())
|
||||||
|
return -1;
|
||||||
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
|
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
|
||||||
}
|
}
|
||||||
|
gi++;
|
||||||
} else if (si == s || si == k || slots[si].size() != 1) {
|
} else if (si == s || si == k || slots[si].size() != 1) {
|
||||||
// Ignore *args, *kwargs and default arguments
|
// Ignore *args, *kwargs and default arguments
|
||||||
reordered.push_back({nullptr, 0});
|
reordered.push_back({nullptr, 0});
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
include codon/*.pxd
|
@ -216,7 +216,7 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
|
|||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return _jit.run_wrapper(
|
return _jit.run_wrapper(
|
||||||
obj_name, types, f.__module__, pyvars, args, 1 if debug else 0
|
obj_name, list(types), f.__module__, list(pyvars), args, 1 if debug else 0
|
||||||
)
|
)
|
||||||
except JITError:
|
except JITError:
|
||||||
_reset_jit()
|
_reset_jit()
|
||||||
|
@ -65,7 +65,7 @@ else:
|
|||||||
|
|
||||||
jit_extension = Extension(
|
jit_extension = Extension(
|
||||||
"codon.codon_jit",
|
"codon.codon_jit",
|
||||||
sources=["codon/jit.pyx", "codon/jit.pxd"],
|
sources=["codon/jit.pyx"],
|
||||||
libraries=libraries,
|
libraries=libraries,
|
||||||
language="c++",
|
language="c++",
|
||||||
extra_compile_args=["-w"],
|
extra_compile_args=["-w"],
|
||||||
|
@ -435,6 +435,12 @@ class __internal__:
|
|||||||
e.col = col
|
e.col = col
|
||||||
return e
|
return e
|
||||||
|
|
||||||
|
def kwargs_get(kw, key: Static[str], default):
|
||||||
|
if hasattr(kw, key):
|
||||||
|
return getattr(kw, key)
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
@extend
|
@extend
|
||||||
class __magic__:
|
class __magic__:
|
||||||
|
@ -136,8 +136,8 @@ def _master_end(loc_ref: Ptr[Ident], gtid: int):
|
|||||||
__kmpc_end_master(loc_ref, i32(gtid))
|
__kmpc_end_master(loc_ref, i32(gtid))
|
||||||
|
|
||||||
def _ordered_begin(loc_ref: Ptr[Ident], gtid: int):
|
def _ordered_begin(loc_ref: Ptr[Ident], gtid: int):
|
||||||
from C import __kmpc_ordered(Ptr[Ident], i32) -> i32
|
from C import __kmpc_ordered(Ptr[Ident], i32)
|
||||||
return int(__kmpc_ordered(loc_ref, i32(gtid)))
|
__kmpc_ordered(loc_ref, i32(gtid))
|
||||||
|
|
||||||
def _ordered_end(loc_ref: Ptr[Ident], gtid: int):
|
def _ordered_end(loc_ref: Ptr[Ident], gtid: int):
|
||||||
from C import __kmpc_end_ordered(Ptr[Ident], i32)
|
from C import __kmpc_end_ordered(Ptr[Ident], i32)
|
||||||
@ -781,11 +781,11 @@ def ordered(func):
|
|||||||
def _wrapper(*args, **kwargs):
|
def _wrapper(*args, **kwargs):
|
||||||
gtid = get_thread_num()
|
gtid = get_thread_num()
|
||||||
loc = _default_loc()
|
loc = _default_loc()
|
||||||
if _ordered_begin(loc, gtid) != 0:
|
_ordered_begin(loc, gtid)
|
||||||
try:
|
try:
|
||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
_ordered_end(loc, gtid)
|
_ordered_end(loc, gtid)
|
||||||
|
|
||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
||||||
|
@ -1231,6 +1231,21 @@ def foo(x):
|
|||||||
print foo('hi') #: (3, 2)
|
print foo('hi') #: (3, 2)
|
||||||
print foo('hi', 1) #: (2, 'hi_1')
|
print foo('hi', 1) #: (2, 'hi_1')
|
||||||
|
|
||||||
|
|
||||||
|
def fox(a: int, b: int, c: int, dtype: type = int):
|
||||||
|
print('fox 1:', a, b, c)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def fox(a: int, b: int, dtype: type = int):
|
||||||
|
print('fox 2:', a, b, dtype.__class__.__name__)
|
||||||
|
|
||||||
|
fox(1, 2, float)
|
||||||
|
#: fox 2: 1 2 float
|
||||||
|
fox(1, 2)
|
||||||
|
#: fox 2: 1 2 int
|
||||||
|
fox(1, 2, 3)
|
||||||
|
#: fox 1: 1 2 3
|
||||||
|
|
||||||
#%% fn_shadow,barebones
|
#%% fn_shadow,barebones
|
||||||
def foo(x):
|
def foo(x):
|
||||||
return 1, x
|
return 1, x
|
||||||
|
@ -889,6 +889,20 @@ def test_omp_collapse():
|
|||||||
|
|
||||||
assert A6 == B6
|
assert A6 == B6
|
||||||
|
|
||||||
|
@test
|
||||||
|
def test_omp_ordered(N: int = 1000):
|
||||||
|
@omp.ordered
|
||||||
|
def f(A, i):
|
||||||
|
A.append(i)
|
||||||
|
|
||||||
|
A = []
|
||||||
|
|
||||||
|
@par(schedule='dynamic', chunk_size=1, num_threads=2, ordered=True)
|
||||||
|
for i in range(N):
|
||||||
|
f(A, i)
|
||||||
|
|
||||||
|
assert A == list(range(N))
|
||||||
|
|
||||||
test_omp_api()
|
test_omp_api()
|
||||||
test_omp_schedules()
|
test_omp_schedules()
|
||||||
test_omp_ranges()
|
test_omp_ranges()
|
||||||
@ -901,3 +915,4 @@ test_omp_transform(111.1, 222.2, 333.3)
|
|||||||
test_omp_nested()
|
test_omp_nested()
|
||||||
test_omp_corner_cases()
|
test_omp_corner_cases()
|
||||||
test_omp_collapse()
|
test_omp_collapse()
|
||||||
|
test_omp_ordered()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user