Refactor CallExpr routing [wip]

typecheck-v2
Ibrahim Numanagić 2024-08-12 08:00:17 -07:00
parent 1ef300b5b8
commit 94b623d174
6 changed files with 40 additions and 71 deletions

View File

@ -195,7 +195,7 @@ Func *Module::getOrRealizeFunc(const std::string &funcName,
return cache->realizeFunction(func, arg, gens); return cache->realizeFunction(func, arg, gens);
} catch (const exc::ParserException &e) { } catch (const exc::ParserException &e) {
for (int i = 0; i < e.messages.size(); i++) for (int i = 0; i < e.messages.size(); i++)
LOG_IR("getOrRealizeFunc parser error at {}: {}", e.locations[i], e.messages[i]); LOG("getOrRealizeFunc parser error at {}: {}", e.locations[i], e.messages[i]);
return nullptr; return nullptr;
} }
} }

View File

@ -424,7 +424,7 @@ types::FuncTypePtr TypecheckVisitor::getDispatch(const std::string &fn) {
ctx->cache->functions[name].ast = ast; ctx->cache->functions[name].ast = ast;
ctx->cache->functions[name].type = typ; ctx->cache->functions[name].type = typ;
ast->setDone(); ast->setDone();
prependStmts->push_back(ast); // prependStmts->push_back(ast);
return typ; return typ;
} }

View File

@ -27,10 +27,11 @@ void TypecheckVisitor::visit(LambdaExpr *expr) {
params.emplace_back(s); params.emplace_back(s);
auto f = auto f =
N<FunctionStmt>(name, nullptr, params, N<SuiteStmt>(N<ReturnStmt>(expr->expr))); N<FunctionStmt>(name, nullptr, params, N<SuiteStmt>(N<ReturnStmt>(expr->expr)));
transform(f);
if (auto a = expr->getAttribute(Attr::Bindings)) if (auto a = expr->getAttribute(Attr::Bindings))
f->setAttribute(Attr::Bindings, a->clone()); f->setAttribute(Attr::Bindings, a->clone());
resultExpr = resultExpr =
transform(N<StmtExpr>(f, N<CallExpr>(N<IdExpr>(name), N<EllipsisExpr>()))); transform(N<CallExpr>(N<IdExpr>(name), N<EllipsisExpr>()));
} }
/// Unify the function return type with `Generator[?]`. /// Unify the function return type with `Generator[?]`.

View File

@ -376,6 +376,9 @@ class Partial:
%1 = insertvalue { {=T}, {=K} } %0, {=K} %kwargs, 1 %1 = insertvalue { {=T}, {=K} } %0, {=K} %kwargs, 1
ret { {=T}, {=K} } %1 ret { {=T}, {=K} } %1
# def __new__(M: Static[str], F: type) -> Partial[M, Tuple[Tuple], NamedTuple[0,Tuple], F]:
# return Partial.__new__((), NamedTuple[0,Tuple](), M=M, F=F)
def __repr__(self): def __repr__(self):
return __magic__.repr_partial(self) return __magic__.repr_partial(self)
@ -385,3 +388,7 @@ class Partial:
@property @property
def __fn_name__(self): def __fn_name__(self):
return F.T.__name__ return F.T.__name__
def __raw__(self):
# TODO: better error message
return F.T.__raw__()

View File

@ -868,3 +868,9 @@ a = 1.0
b = 2.0 b = 2.0
c = fox(a, b) c = fox(a, b)
print(math.log(c) / 2) #: 0 print(math.log(c) / 2) #: 0
#%% repeated_lambda,barebones
def acc(i, func=lambda a, b: a + b):
return i + func(i, i)
print acc(1) #: 3
print acc('i') #: iii

View File

@ -540,9 +540,6 @@ def test_accumulate_from_cpython():
assert list(accumulate(List[int](), initial=100)) == [100] assert list(accumulate(List[int](), initial=100)) == [100]
test_accumulate_from_cpython()
@test @test
def test_chain_from_cpython(): def test_chain_from_cpython():
assert list(chain("abc", "def")) == list("abcdef") assert list(chain("abc", "def")) == list("abcdef")
@ -551,9 +548,6 @@ def test_chain_from_cpython():
assert list(take(4, chain("abc", "def"))) == list("abcd") assert list(take(4, chain("abc", "def"))) == list("abcd")
test_chain_from_cpython()
@test @test
def test_chain_from_iterable_from_cpython(): def test_chain_from_iterable_from_cpython():
assert list(chain.from_iterable(["abc", "def"])) == list("abcdef") assert list(chain.from_iterable(["abc", "def"])) == list("abcdef")
@ -562,9 +556,6 @@ def test_chain_from_iterable_from_cpython():
assert take(4, chain.from_iterable(["abc", "def"])) == list("abcd") assert take(4, chain.from_iterable(["abc", "def"])) == list("abcd")
test_chain_from_iterable_from_cpython()
@test @test
def test_combinations_from_cpython(): def test_combinations_from_cpython():
f = lambda x: x # hack to get non-static argument f = lambda x: x # hack to get non-static argument
@ -645,7 +636,6 @@ def test_combinations_from_cpython():
e for e in values if e in c e for e in values if e in c
] # comb is a subsequence of the input iterable ] # comb is a subsequence of the input iterable
test_combinations_from_cpython() # takes long time to typecheck
@test @test
@ -741,9 +731,6 @@ def test_combinations_with_replacement_from_cpython():
] # comb is a subsequence of the input iterable ] # comb is a subsequence of the input iterable
test_combinations_with_replacement_from_cpython()
@test @test
def test_permutations_from_cpython(): def test_permutations_from_cpython():
f = lambda x: x # hack to get non-static argument f = lambda x: x # hack to get non-static argument
@ -813,9 +800,6 @@ def test_permutations_from_cpython():
assert result == list(permutations(values, r)) assert result == list(permutations(values, r))
test_permutations_from_cpython()
@extend @extend
class List: class List:
def __lt__(self, other: List[T]): def __lt__(self, other: List[T]):
@ -952,9 +936,6 @@ def test_combinatorics_from_cpython():
assert comb == sorted(set(cwr) & set(perm)) # comb: both a cwr and a perm assert comb == sorted(set(cwr) & set(perm)) # comb: both a cwr and a perm
test_combinatorics_from_cpython() # TODO: takes FOREVER to typecheck
@test @test
def test_compress_from_cpython(): def test_compress_from_cpython():
assert list(compress(data="ABCDEF", selectors=[1, 0, 1, 0, 1, 1])) == list("ACEF") assert list(compress(data="ABCDEF", selectors=[1, 0, 1, 0, 1, 1])) == list("ACEF")
@ -969,9 +950,6 @@ def test_compress_from_cpython():
assert list(compress(data, selectors)) == [1, 3, 5] * n assert list(compress(data, selectors)) == [1, 3, 5] * n
test_compress_from_cpython()
@test @test
def test_count_from_cpython(): def test_count_from_cpython():
assert lzip("abc", count()) == [("a", 0), ("b", 1), ("c", 2)] assert lzip("abc", count()) == [("a", 0), ("b", 1), ("c", 2)]
@ -982,9 +960,6 @@ def test_count_from_cpython():
assert take(3, count(3.25)) == [3.25, 4.25, 5.25] assert take(3, count(3.25)) == [3.25, 4.25, 5.25]
test_count_from_cpython()
@test @test
def test_count_with_stride_from_cpython(): def test_count_with_stride_from_cpython():
assert lzip("abc", count(2, 3)) == [("a", 2), ("b", 5), ("c", 8)] assert lzip("abc", count(2, 3)) == [("a", 2), ("b", 5), ("c", 8)]
@ -996,9 +971,6 @@ def test_count_with_stride_from_cpython():
assert take(3, count(2.0, 1.25)) == [2.0, 3.25, 4.5] assert take(3, count(2.0, 1.25)) == [2.0, 3.25, 4.5]
test_count_with_stride_from_cpython()
@test @test
def test_cycle_from_cpython(): def test_cycle_from_cpython():
assert take(10, cycle("abc")) == list("abcabcabca") assert take(10, cycle("abc")) == list("abcabcabca")
@ -1006,9 +978,6 @@ def test_cycle_from_cpython():
assert list(islice(cycle(gen3()), 10)) == [0, 1, 2, 0, 1, 2, 0, 1, 2, 0] assert list(islice(cycle(gen3()), 10)) == [0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
test_cycle_from_cpython()
@test @test
def test_groupby_from_cpython(): def test_groupby_from_cpython():
# Check whether it accepts arguments correctly # Check whether it accepts arguments correctly
@ -1075,9 +1044,6 @@ def test_groupby_from_cpython():
assert r == [(5, "a"), (2, "r"), (2, "b")] assert r == [(5, "a"), (2, "r"), (2, "b")]
test_groupby_from_cpython()
@test @test
def test_filter_from_cpython(): def test_filter_from_cpython():
assert list(filter(isEven, range(6))) == [0, 2, 4] assert list(filter(isEven, range(6))) == [0, 2, 4]
@ -1086,9 +1052,6 @@ def test_filter_from_cpython():
assert take(4, filter(isEven, count())) == [0, 2, 4, 6] assert take(4, filter(isEven, count())) == [0, 2, 4, 6]
test_filter_from_cpython()
@test @test
def test_filterfalse_from_cpython(): def test_filterfalse_from_cpython():
assert list(filterfalse(isEven, range(6))) == [1, 3, 5] assert list(filterfalse(isEven, range(6))) == [1, 3, 5]
@ -1097,9 +1060,6 @@ def test_filterfalse_from_cpython():
assert take(4, filterfalse(isEven, count())) == [1, 3, 5, 7] assert take(4, filterfalse(isEven, count())) == [1, 3, 5, 7]
test_filterfalse_from_cpython()
@test @test
def test_zip_from_cpython(): def test_zip_from_cpython():
ans = [(x, y) for x, y in zip("abc", count())] ans = [(x, y) for x, y in zip("abc", count())]
@ -1112,9 +1072,6 @@ def test_zip_from_cpython():
assert [pair for pair in zip("abc", "def")] == lzip("abc", "def") assert [pair for pair in zip("abc", "def")] == lzip("abc", "def")
test_zip_from_cpython()
@test @test
def test_ziplongest_from_cpython(): def test_ziplongest_from_cpython():
for args in ( for args in (
@ -1142,9 +1099,6 @@ def test_ziplongest_from_cpython():
) )
test_ziplongest_from_cpython()
@test @test
def test_product_from_cpython(): def test_product_from_cpython():
for args, result in ( for args, result in (
@ -1175,9 +1129,6 @@ def test_product_from_cpython():
) )
test_product_from_cpython()
@test @test
def test_repeat_from_cpython(): def test_repeat_from_cpython():
assert list(repeat(object="a", times=3)) == ["a", "a", "a"] assert list(repeat(object="a", times=3)) == ["a", "a", "a"]
@ -1188,9 +1139,6 @@ def test_repeat_from_cpython():
assert list(repeat("a", -3)) == [] assert list(repeat("a", -3)) == []
test_repeat_from_cpython()
@test @test
def test_map_from_cpython(): def test_map_from_cpython():
power = lambda a, b: a ** b power = lambda a, b: a ** b
@ -1201,9 +1149,6 @@ def test_map_from_cpython():
assert list(map(tupleize, List[int]())) == [] assert list(map(tupleize, List[int]())) == []
test_map_from_cpython()
@test @test
def test_starmap_from_cpython(): def test_starmap_from_cpython():
power = lambda a, b: a ** b power = lambda a, b: a ** b
@ -1213,9 +1158,6 @@ def test_starmap_from_cpython():
assert list(starmap(power, [(4, 5)])) == [4 ** 5] assert list(starmap(power, [(4, 5)])) == [4 ** 5]
test_starmap_from_cpython()
@test @test
def test_islice_from_cpython(): def test_islice_from_cpython():
for args in ( # islice(args) should agree with range(args) for args in ( # islice(args) should agree with range(args)
@ -1243,9 +1185,6 @@ def test_islice_from_cpython():
assert list(islice(range(10), 1, None, 2)) == list(range(1, 10, 2)) assert list(islice(range(10), 1, None, 2)) == list(range(1, 10, 2))
test_islice_from_cpython()
@test @test
def test_takewhile_from_cpython(): def test_takewhile_from_cpython():
data = [1, 3, 5, 20, 2, 4, 6, 8] data = [1, 3, 5, 20, 2, 4, 6, 8]
@ -1255,9 +1194,6 @@ def test_takewhile_from_cpython():
assert list(t) == [1, 1, 1] assert list(t) == [1, 1, 1]
test_takewhile_from_cpython()
@test @test
def test_dropwhile_from_cpython(): def test_dropwhile_from_cpython():
data = [1, 3, 5, 20, 2, 4, 6, 8] data = [1, 3, 5, 20, 2, 4, 6, 8]
@ -1265,9 +1201,6 @@ def test_dropwhile_from_cpython():
assert list(dropwhile(underten, List[int]())) == [] assert list(dropwhile(underten, List[int]())) == []
test_dropwhile_from_cpython()
@test @test
def test_tee_from_cpython(): def test_tee_from_cpython():
import random import random
@ -1315,5 +1248,27 @@ def test_tee_from_cpython():
assert list(a) == list(range(100, 2000)) assert list(a) == list(range(100, 2000))
assert list(c) == list(range(2, 2000)) assert list(c) == list(range(2, 2000))
test_accumulate_from_cpython()
test_chain_from_cpython()
test_chain_from_iterable_from_cpython()
test_combinations_from_cpython() # takes long time to typecheck
test_combinations_with_replacement_from_cpython()
test_permutations_from_cpython()
test_combinatorics_from_cpython() # TODO: takes FOREVER to typecheck
test_compress_from_cpython()
test_count_from_cpython()
test_count_with_stride_from_cpython()
test_cycle_from_cpython()
test_groupby_from_cpython()
test_filter_from_cpython()
test_filterfalse_from_cpython()
test_zip_from_cpython()
test_ziplongest_from_cpython()
test_product_from_cpython()
test_repeat_from_cpython()
test_map_from_cpython()
test_starmap_from_cpython()
test_islice_from_cpython()
test_takewhile_from_cpython()
test_dropwhile_from_cpython()
test_tee_from_cpython() test_tee_from_cpython()