Various bug fixes (#185)

* Fix #183

* Fix #162; Fix #135

* Fix #155

* Fix #191

* Fix #187

* Fix #189

* Fix vtable init; Fix failing tests on Linux

* Fix #190

* Fix #156

* Fix union routing

* Format

---------

Co-authored-by: A. R. Shajii <ars@ars.me>
pull/218/head
Ibrahim Numanagić 2023-02-05 15:53:15 -08:00 committed by GitHub
parent 28ebb2e84d
commit 5f13644751
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 260 additions and 54 deletions

View File

@ -282,8 +282,20 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
// Expression to be used if function binding is modified by captures or decorators
ExprPtr finalExpr = nullptr;
// If there are captures, replace `fn` with `fn(cap1=cap1, cap2=cap2, ...)`
if (!captures.empty())
if (!captures.empty()) {
finalExpr = N<CallExpr>(N<IdExpr>(stmt->name), partialArgs);
// Add updated self reference in case function is recursive!
auto pa = partialArgs;
for (auto &a : pa) {
if (!a.name.empty())
a.value = N<IdExpr>(a.name);
else
a.value = clone(a.value);
}
f->suite = N<SuiteStmt>(
N<AssignStmt>(N<IdExpr>(rootName), N<CallExpr>(N<IdExpr>(rootName), pa)),
suite);
}
// Parse remaining decorators
for (auto i = stmt->decorators.size(); i-- > 0;) {

View File

@ -170,7 +170,10 @@ void TranslateVisitor::visit(StringExpr *expr) {
void TranslateVisitor::visit(IdExpr *expr) {
auto val = ctx->find(expr->value);
seqassert(val, "cannot find '{}'", expr->value);
if (auto *v = val->getVar())
if (expr->value == "__vtable_size__")
result = make<ir::IntConst>(expr, ctx->cache->classRealizationCnt + 2,
getType(expr->getType()));
else if (auto *v = val->getVar())
result = make<ir::VarValue>(expr, v);
else if (auto *f = val->getFunc())
result = make<ir::VarValue>(expr, f);

View File

@ -783,16 +783,31 @@ ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) {
.type->getStatic()
->evaluate()
.getString();
std::vector<TypePtr> args{typ};
std::vector<std::pair<std::string, TypePtr>> args{{"", typ}};
if (expr->expr->isId("hasattr:0")) {
// Case: the first hasattr overload allows passing argument types via *args
auto tup = expr->args[1].value->getTuple();
seqassert(tup, "not a tuple");
for (auto &a : tup->items) {
transformType(a);
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.push_back(a->getType());
args.push_back({"", a->getType()});
}
auto kwtup = expr->args[2].value->origExpr->getCall();
seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(),
"expected call: {}", expr->args[2].value->origExpr);
auto kw = expr->args[2].value->origExpr->getCall();
auto kwCls =
in(ctx->cache->classes, expr->args[2].value->getType()->getClass()->name);
seqassert(kwCls, "cannot find {}",
expr->args[2].value->getType()->getClass()->name);
for (size_t i = 0; i < kw->args.size(); i++) {
auto &a = kw->args[i].value;
transform(a);
if (!a->getType()->getClass())
return nullptr;
args.push_back({kwCls->fields[i].name, a->getType()});
}
}

View File

@ -382,11 +382,11 @@ StmtPtr TypecheckVisitor::prepareVTables() {
// def class_init_vtables():
// return __internal__.class_make_n_vtables(<NUM_REALIZATIONS> + 1)
auto &initAllVT = ctx->cache->functions[rep];
auto suite = N<SuiteStmt>(
N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.class_make_n_vtables:0"),
N<IntExpr>(ctx->cache->classRealizationCnt + 1))));
auto suite = N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
N<IdExpr>("__internal__.class_make_n_vtables:0"), N<IdExpr>("__vtable_size__"))));
initAllVT.ast->suite = suite;
auto typ = initAllVT.realizations.begin()->second->type;
LOG_REALIZE("[poly] {} : {}", typ, *suite);
typ->ast = initAllVT.ast.get();
auto fx = realizeFunc(typ.get(), true);
@ -402,30 +402,36 @@ StmtPtr TypecheckVisitor::prepareVTables() {
suite = N<SuiteStmt>();
for (auto &[_, cls] : ctx->cache->classes) {
for (auto &[r, real] : cls.realizations) {
size_t vtSz = 0;
for (auto &[base, vtable] : real->vtables) {
if (!vtable.ir)
vtSz += vtable.table.size();
}
auto var = initFn.ast->args[0].name;
// p.__setitem__(real.ID) = Ptr[cobj](real.vtables.size() + 2)
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<DotExpr>(N<IdExpr>(var), "__setitem__"), N<IntExpr>(real->id),
N<CallExpr>(NT<InstantiateExpr>(NT<IdExpr>("Ptr"),
std::vector<ExprPtr>{NT<IdExpr>("cobj")}),
N<IntExpr>(vtSz + 2)))));
// __internal__.class_set_typeinfo(p[real.ID], real.ID)
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<IdExpr>("__internal__.class_set_typeinfo:0"),
N<IndexExpr>(N<IdExpr>(var), N<IntExpr>(real->id)), N<IntExpr>(real->id))));
vtSz = 0;
for (auto &[base, vtable] : real->vtables) {
if (!vtable.ir) {
auto var = initFn.ast->args[0].name;
// p.__setitem__(real.ID) = Ptr[cobj](real.vtables.size() + 2)
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<DotExpr>(N<IdExpr>(var), "__setitem__"), N<IntExpr>(real->id),
N<CallExpr>(NT<InstantiateExpr>(NT<IdExpr>("Ptr"),
std::vector<ExprPtr>{NT<IdExpr>("cobj")}),
N<IntExpr>(vtable.table.size() + 2)))));
// __internal__.class_set_typeinfo(p[real.ID], real.ID)
suite->stmts.push_back(N<ExprStmt>(
N<CallExpr>(N<IdExpr>("__internal__.class_set_typeinfo:0"),
N<IndexExpr>(N<IdExpr>(var), N<IntExpr>(real->id)),
N<IntExpr>(real->id))));
for (auto &[k, v] : vtable.table) {
auto &[fn, id] = v;
std::vector<ExprPtr> ids;
for (auto &t : fn->getArgTypes())
ids.push_back(NT<IdExpr>(t->realizedName()));
// p[real.ID].__setitem__(f.ID, Function[<TYPE_F>](f).__raw__())
LOG_REALIZE("[poly] vtable[{}][{}] = {}", real->id, vtSz + id, fn);
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<DotExpr>(N<IndexExpr>(N<IdExpr>(var), N<IntExpr>(real->id)),
"__setitem__"),
N<IntExpr>(id),
N<IntExpr>(vtSz + id),
N<CallExpr>(N<DotExpr>(
N<CallExpr>(
NT<InstantiateExpr>(
@ -438,12 +444,14 @@ StmtPtr TypecheckVisitor::prepareVTables() {
N<IdExpr>(fn->realizedName())),
"__raw__")))));
}
vtSz += vtable.table.size();
}
}
}
}
initFn.ast->suite = suite;
typ = initFn.realizations.begin()->second->type;
LOG_REALIZE("[poly] {} : {}", typ, suite->toString(2));
typ->ast = initFn.ast.get();
realizeFunc(typ.get(), true);
@ -469,6 +477,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
N<DotExpr>(N<IdExpr>(clsTyp->realizedName()), "__vtable_id__"))));
}
LOG_REALIZE("[poly] {} : {}", t, *suite);
initObjFns.ast->suite = suite;
t->ast = initObjFns.ast.get();
realizeFunc(t.get(), true);
@ -502,6 +511,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
N<DotExpr>(NT<InstantiateExpr>(
NT<IdExpr>(format("{}{}", TYPE_TUPLE, types.size())), types),
"__elemsize__"));
LOG_REALIZE("[poly] {} : {}", t, *suite);
initDist.ast->suite = suite;
t->ast = initDist.ast.get();
realizeFunc(t.get(), true);
@ -802,8 +812,8 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
type->getArgTypes()[0]->getHeterogenousTuple()) {
// Special case: do not realize auto-generated heterogenous __getitem__
E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable");
} else if (startswith(ast->name, "Function.__call__")) {
// Special case: Function.__call__
} else if (startswith(ast->name, "Function.__call_internal__")) {
// Special case: Function.__call_internal__
/// TODO: move to IR one day
std::vector<StmtPtr> items;
items.push_back(nullptr);
@ -826,6 +836,14 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
ll.push_back(format("ret {{}} %{}", as.size()));
items[0] = N<ExprStmt>(N<StringExpr>(combine2(ll, "\n")));
ast->suite = N<SuiteStmt>(items);
} else if (startswith(ast->name, "Union.__new__:0")) {
auto unionType = type->funcParent->getUnion();
seqassert(unionType, "expected union, got {}", type->funcParent);
StmtPtr suite = N<ReturnStmt>(N<CallExpr>(
N<IdExpr>("__internal__.new_union:0"), N<IdExpr>(type->ast->args[0].name),
N<IdExpr>(unionType->realizedTypeName())));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.new_union:0")) {
// Special case: __internal__.new_union
// def __internal__.new_union(value, U[T0, ..., TN]):
@ -910,21 +928,29 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
auto suite = N<SuiteStmt>();
int tag = 0;
for (auto &t : unionTypes) {
auto callee =
N<DotExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"),
N<IdExpr>(selfVar), NT<IdExpr>(t->realizedName())),
fnName);
auto args = N<StarExpr>(N<IdExpr>(ast->args[2].name.substr(1)));
auto kwargs = N<KeywordStarExpr>(N<IdExpr>(ast->args[3].name.substr(2)));
std::vector<CallExpr::Arg> callArgs;
ExprPtr check =
N<CallExpr>(N<IdExpr>("hasattr"), NT<IdExpr>(t->realizedName()),
N<StringExpr>(fnName), args->clone(), kwargs->clone());
suite->stmts.push_back(N<IfStmt>(
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"),
N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag)),
N<ReturnStmt>(N<CallExpr>(
N<DotExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"),
N<IdExpr>(selfVar), NT<IdExpr>(t->realizedName())),
fnName),
N<StarExpr>(N<IdExpr>(ast->args[2].name.substr(1))),
N<KeywordStarExpr>(N<IdExpr>(ast->args[3].name.substr(2)))))));
N<BinaryExpr>(
check, "&&",
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"),
N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag))),
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(callee, args, kwargs)))));
tag++;
}
suite->stmts.push_back(
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
N<StringExpr>("invalid union call"))));
// suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>()));
unify(type->getRetType(), ctx->instantiate(ctx->getType("Union")));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union_first:0")) {

View File

@ -657,8 +657,12 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) {
if (!lt->is("pyobj") && rt->is("pyobj")) {
// Special case: `obj op pyobj` -> `rhs.__rmagic__(lhs)` on lhs
// Assumes that pyobj implements all left and right magics
return transform(N<CallExpr>(N<DotExpr>(expr->rexpr, format("__{}__", rightMagic)),
expr->lexpr));
auto l = ctx->cache->getTemporaryVar("l"), r = ctx->cache->getTemporaryVar("r");
return transform(
N<StmtExpr>(N<AssignStmt>(N<IdExpr>(l), expr->lexpr),
N<AssignStmt>(N<IdExpr>(r), expr->rexpr),
N<CallExpr>(N<DotExpr>(N<IdExpr>(r), format("__{}__", rightMagic)),
N<IdExpr>(l))));
}
if (lt->getUnion()) {
// Special case: `union op obj` -> `union.__magic__(rhs)`
@ -667,19 +671,24 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) {
}
// Normal operations: check if `lhs.__magic__(lhs, rhs)` exists
auto method = findBestMethod(lt, format("__{}__", magic), {expr->lexpr, expr->rexpr});
// Right-side magics: check if `rhs.__rmagic__(rhs, lhs)` exists
if (!method && (method = findBestMethod(rt, format("__{}__", rightMagic),
{expr->rexpr, expr->lexpr}))) {
swap(expr->lexpr, expr->rexpr);
}
if (method) {
if (auto method =
findBestMethod(lt, format("__{}__", magic), {expr->lexpr, expr->rexpr})) {
// Normal case: `__magic__(lhs, rhs)`
return transform(
N<CallExpr>(N<IdExpr>(method->ast->name), expr->lexpr, expr->rexpr));
}
// Right-side magics: check if `rhs.__rmagic__(rhs, lhs)` exists
if (auto method = findBestMethod(rt, format("__{}__", rightMagic),
{expr->rexpr, expr->lexpr})) {
auto l = ctx->cache->getTemporaryVar("l"), r = ctx->cache->getTemporaryVar("r");
return transform(N<StmtExpr>(
N<AssignStmt>(N<IdExpr>(l), expr->lexpr),
N<AssignStmt>(N<IdExpr>(r), expr->rexpr),
N<CallExpr>(N<IdExpr>(method->ast->name), N<IdExpr>(r), N<IdExpr>(l))));
}
// 145
return nullptr;
}
@ -745,14 +754,18 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
sliceAdjustIndices(sz, &start, &stop, step);
// Generate a sub-tuple
auto var = N<IdExpr>(ctx->cache->getTemporaryVar("tup"));
auto ass = N<AssignStmt>(var, expr);
std::vector<ExprPtr> te;
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) {
if (i < 0 || i >= sz)
E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i);
te.push_back(N<DotExpr>(clone(expr), classItem->fields[i].name));
te.push_back(N<DotExpr>(clone(var), classItem->fields[i].name));
}
return {true, transform(N<CallExpr>(
N<DotExpr>(format(TYPE_TUPLE "{}", te.size()), "__new__"), te))};
ExprPtr e = transform(N<StmtExpr>(
std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(format(TYPE_TUPLE "{}", te.size()), "__new__"), te)));
return {true, e};
}
return {false, nullptr};

View File

@ -200,6 +200,21 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(const ClassTypePtr &typ,
return m.empty() ? nullptr : m[0];
}
/// Select the best method indicated of an object that matches the given argument
/// types. See @c findMatchingMethods for details.
types::FuncTypePtr TypecheckVisitor::findBestMethod(
const ClassTypePtr &typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args) {
std::vector<CallExpr::Arg> callArgs;
for (auto &[n, a] : args) {
callArgs.push_back({n, std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a);
}
auto methods = ctx->findMethod(typ->name, member, false);
auto m = findMatchingMethods(typ, methods, callArgs);
return m.empty() ? nullptr : m[0];
}
/// Select the best method among the provided methods given the list of arguments.
/// See @c reorderNamedArgs for details.
std::vector<types::FuncTypePtr>

View File

@ -210,6 +210,9 @@ private:
types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ,
const std::string &member,
const std::vector<ExprPtr> &args);
types::FuncTypePtr
findBestMethod(const types::ClassTypePtr &typ, const std::string &member,
const std::vector<std::pair<std::string, types::TypePtr>> &args);
std::vector<types::FuncTypePtr>
findMatchingMethods(const types::ClassTypePtr &typ,
const std::vector<types::FuncTypePtr> &methods,

View File

@ -3,7 +3,7 @@
# distutils: language=c++
# cython: language_level=3
# cython: c_string_type=unicode
# cython: c_string_encoding=ascii
# cython: c_string_encoding=utf8
from libcpp.string cimport string
from libcpp.vector cimport vector

View File

@ -108,7 +108,9 @@ class Ref[T]:
@__internal__
@tuple
class Union[TU]:
pass
# compiler-generated
def __new__(val):
TU
# dummy
@__internal__
@ -153,7 +155,7 @@ def isinstance(obj, what):
def overload():
pass
def hasattr(obj, attr: Static[str], *args):
def hasattr(obj, attr: Static[str], *args, **kwargs):
"""Special handling"""
pass

View File

@ -11,6 +11,8 @@ from C import seq_print(str)
from C import exit(int)
from C import malloc(int) -> cobj as c_malloc
__vtable_size__ = 0
@extend
class __internal__:
@pure
@ -438,8 +440,11 @@ class Function:
return __internal__.raw_type_str(self.__raw__(), "function")
@llvm
def __call__(self, *args) -> TR:
def __call_internal__(self: Function[T, TR], args: T) -> TR:
noop # compiler will populate this one
def __call__(self, *args) -> TR:
return Function.__call_internal__(self, args)
__vtables__ = __internal__.class_init_vtables()
def _____(): __vtables__ # make it global!

View File

@ -119,9 +119,13 @@ class Int:
@pure
@llvm
def __floordiv__(self, other: Int[N]) -> Int[N]:
def _floordiv(self, other: Int[N]) -> Int[N]:
%0 = sdiv i{=N} %self, %other
ret i{=N} %0
def __floordiv__(self, other: Int[N]) -> Int[N]:
if N > 128:
compile_error("division is not supported on Int[N] when N > 128")
return self._floordiv(other)
@pure
@llvm
@ -133,9 +137,13 @@ class Int:
@pure
@llvm
def __mod__(self, other: Int[N]) -> Int[N]:
def _mod(self, other: Int[N]) -> Int[N]:
%0 = srem i{=N} %self, %other
ret i{=N} %0
def __mod__(self, other: Int[N]) -> Int[N]:
if N > 128:
compile_error("modulus is not supported on Int[N] when N > 128")
return self._mod(other)
def __divmod__(self, other: Int[N]) -> Tuple[Int[N], Int[N]]:
d = self // other
@ -344,9 +352,13 @@ class UInt:
@pure
@llvm
def __floordiv__(self, other: UInt[N]) -> UInt[N]:
def _floordiv(self, other: UInt[N]) -> UInt[N]:
%0 = udiv i{=N} %self, %other
ret i{=N} %0
def __floordiv__(self, other: UInt[N]) -> UInt[N]:
if N > 128:
compile_error("division is not supported on UInt[N] when N > 128")
return self._floordiv(other)
@pure
@llvm
@ -358,9 +370,13 @@ class UInt:
@pure
@llvm
def __mod__(self, other: UInt[N]) -> UInt[N]:
def _mod(self, other: UInt[N]) -> UInt[N]:
%0 = urem i{=N} %self, %other
ret i{=N} %0
def __mod__(self, other: UInt[N]) -> UInt[N]:
if N > 128:
compile_error("modulus is not supported on UInt[N] when N > 128")
return self._mod(other)
def __divmod__(self, other: UInt[N]) -> Tuple[UInt[N], UInt[N]]:
return (self // other, self % other)

View File

@ -806,6 +806,18 @@ print X(1) + Y(2) #: 5
print Y(1) + X(2) #: 4
class A:
def __radd__(self, n: int):
return 0
def f():
print('f')
return 1
def g():
print('g')
return A()
f() + g()
#: f
#: g
#%% magic_2,barebones
@tuple
@ -1232,3 +1244,14 @@ def foo():
foo()
#! name 'x' is not defined
#! name 'b' is not defined
#%% capture_recursive,barebones
def f(x: int) -> int:
z = 2 * x
def g(y: int) -> int:
if y == 0:
return 1
else:
return g(y - 1) * z
return g(4)
print(f(3)) #: 1296

View File

@ -208,6 +208,22 @@ print a[1] #: 2s
print a[0:2], a[:2], a[1:] #: (1, '2s') (1, '2s') ('2s', 3.3)
print a[0:3:2], a[-1:] #: (1, 3.3) (3.3)
#%% static_index_side,barebones
def foo(a):
print(a)
return a
print (foo(2), foo(1))[::-1]
#: 2
#: 1
#: (1, 2)
print (foo(1), foo(2), foo(3), foo(4))[2]
#: 1
#: 2
#: 3
#: 4
#: 3
#%% static_index_lenient,barebones
a = (1, 2)
print a[3:5] #: ()

View File

@ -1477,6 +1477,20 @@ x : Union[A,B,C] = A()
print x.foo(), x.foo().__class__.__name__
#: 1 Union[List[bool],int,str]
xx = Union[int, str](0)
print(xx) #: 0
#%% union_error,barebones
a: Union[int, str] = 123
print(123 == a) #: True
print(a == 123) #: True
try:
a = "foo"
print(a == 123)
except TypeError:
print("oops", a) #: oops 'foo'
#%% generator_capture_nonglobal,barebones
# Issue #49
def foo(iter):
@ -1775,3 +1789,46 @@ class Div(BinOp):
expr : Expr = Mul(Const(3), Add(Const(10), Const(5)))
print(expr.eval()) #: 45
#%% polymorphism_4
class A(object):
a: int
def __init__(self, a: int):
self.a = a
def test_a(self, n: int):
print("test_a:A", n)
def test(self, n: int):
print("test:A", n)
def test2(self, n: int):
print("test2:A", n)
class B(A):
b: int
def __init__(self, a: int, b: int):
super().__init__(a)
self.b = b
def test(self, n: int):
print("test:B", n)
def test2(self, n: int):
print("test2:B", n)
class C(B):
pass
b = B(1, 2)
b.test_a(1)
b.test(1)
#: test_a:A 1
#: test:B 1
a: A = b
a.test(1)
a.test2(2)
#: test:B 1
#: test2:B 2