pull/185/head
Ibrahim Numanagić 2023-01-30 17:26:47 -08:00
parent 88316988f0
commit 8ea7993302
3 changed files with 16 additions and 1 deletions

View File

@ -838,6 +838,15 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
ll.push_back(format("ret {{}} %{}", as.size()));
items[0] = N<ExprStmt>(N<StringExpr>(combine2(ll, "\n")));
ast->suite = N<SuiteStmt>(items);
} else if (startswith(ast->name, "Union.__new__:0")) {
auto unionType = type->funcParent->getUnion();
seqassert(unionType, "expected union, got {}", type->funcParent);
StmtPtr suite = N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("__internal__.new_union:0"),
N<IdExpr>(type->ast->args[0].name),
N<IdExpr>(unionType->realizedTypeName())));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.new_union:0")) {
// Special case: __internal__.new_union
// def __internal__.new_union(value, U[T0, ..., TN]):
@ -876,6 +885,7 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
}
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<IdExpr>("compile_error"), N<StringExpr>("invalid union constructor"))));
LOG("-> {}", suite->toString(2));
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union:0")) {
// Special case: __internal__.get_union

View File

@ -108,7 +108,9 @@ class Ref[T]:
@__internal__
@tuple
class Union[TU]:
pass
# compiler-generated
def __new__(val):
TU
# dummy
@__internal__

View File

@ -1477,6 +1477,9 @@ x : Union[A,B,C] = A()
print x.foo(), x.foo().__class__.__name__
#: 1 Union[List[bool],int,str]
xx = Union[int, str](0)
print(xx) #: 0
#%% generator_capture_nonglobal,barebones
# Issue #49
def foo(iter):