1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Merge branch 'fn-dispatch' into auto-class-deduction

This commit is contained in:
Ibrahim Numanagić 2022-01-11 12:12:46 -08:00
commit d3fa986dbc
8 changed files with 247 additions and 16 deletions

View File

@ -133,6 +133,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
/// Realization lookup table that maps a realized class name to the corresponding
/// ClassRealization instance.
std::unordered_map<std::string, std::shared_ptr<ClassRealization>> realizations;
/// List of inherited class. We also keep the number of fields each of inherited
/// class.
std::vector<std::pair<std::string, int>> parentClasses;
Class() : ast(nullptr), originalAst(nullptr) {}
};

View File

@ -88,9 +88,10 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil
}
// Reserve the following static identifiers.
for (auto name : {"staticlen", "compile_error", "isinstance", "hasattr", "type",
"TypeVar", "Callable", "argv", "super"})
"TypeVar", "Callable", "argv", "super", "superf"})
stdlib->generateCanonicalName(name);
stdlib->add(SimplifyItem::Var, "super", "super", true);
stdlib->add(SimplifyItem::Var, "superf", "superf", true);
// This code must be placed in a preamble (these are not POD types but are
// referenced by the various preamble Function.N and Tuple.N stubs)

View File

@ -550,6 +550,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
typeAst = ctx->bases[ctx->bases.size() - 2].ast;
ctx->bases.back().selfName = name;
attr.set(".changedSelf");
attr.set(Attr::Method);
}
if (attr.has(Attr::C)) {
@ -796,6 +797,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
std::vector<std::unordered_map<std::string, ExprPtr>> substitutions;
std::vector<int> argSubstitutions;
std::unordered_set<std::string> seenMembers;
std::vector<int> baseASTsFields;
for (auto &baseClass : stmt->baseClasses) {
std::string bcName;
std::vector<ExprPtr> subs;
@ -836,6 +838,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
if (!extension)
ctx->cache->classes[canonicalName].fields.push_back({a.name, nullptr});
}
baseASTsFields.push_back(args.size());
}
// Add generics, if any, to the context.
@ -956,6 +959,9 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
ctx->moduleName.module);
ctx->cache->classes[canonicalName].ast =
N<ClassStmt>(canonicalName, args, N<SuiteStmt>(), attr);
for (int i = 0; i < baseASTs.size(); i++)
ctx->cache->classes[canonicalName].parentClasses.push_back(
{baseASTs[i]->name, baseASTsFields[i]});
std::vector<StmtPtr> fns;
ExprPtr codeType = ctx->bases.back().ast->clone();
std::vector<std::string> magics{};

View File

@ -301,6 +301,10 @@ private:
const std::vector<types::FuncTypePtr> &methods,
const std::vector<CallExpr::Arg> &args);
ExprPtr transformSuper(const CallExpr *expr);
std::vector<types::ClassTypePtr> getSuperTypes(const types::ClassTypePtr &cls);
private:
types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b,
bool undoOnSuccess = false);

View File

@ -1038,20 +1038,22 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
seenNames.insert(i.name);
}
if (expr->expr->isId("super")) {
if (expr->expr->isId("superf")) {
if (ctx->bases.back().supers.empty())
error("no matching super methods are available");
error("no matching superf methods are available");
auto parentCls = ctx->bases.back().type->getFunc()->funcParent;
auto m =
findMatchingMethods(parentCls ? CAST(parentCls, types::ClassType) : nullptr,
ctx->bases.back().supers, expr->args);
if (m.empty())
error("no matching super methods are available");
error("no matching superf methods are available");
// LOG("found {} <- {}", ctx->bases.back().type->getFunc()->toString(),
// m[0]->toString());
ExprPtr e = N<CallExpr>(N<IdExpr>(m[0]->ast->name), expr->args);
return transform(e, false, true);
}
if (expr->expr->isId("super"))
return transformSuper(expr);
bool isPartial = !expr->args.empty() && expr->args.back().value->getEllipsis() &&
!expr->args.back().value->getEllipsis()->isPipeArg &&
@ -1421,8 +1423,15 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
expr->args[1].value =
transformType(expr->args[1].value, /*disableActivation*/ true);
auto t = expr->args[1].value->type;
auto unifyOK = typ->unify(t.get(), nullptr) >= 0;
return {true, transform(N<BoolExpr>(unifyOK))};
auto hierarchy = getSuperTypes(typ->getClass());
for (auto &tx: hierarchy) {
auto unifyOK = tx->unify(t.get(), nullptr) >= 0;
if (unifyOK) {
return {true, transform(N<BoolExpr>(true))};
}
}
return {true, transform(N<BoolExpr>(false))};
}
}
} else if (val == "staticlen") {
@ -1946,5 +1955,91 @@ types::FuncTypePtr TypecheckVisitor::findDispatch(const std::string &fn) {
return typ;
}
ExprPtr TypecheckVisitor::transformSuper(const CallExpr *expr) {
// For now, we just support casting to the _FIRST_ overload (i.e. empty super())
if (!expr->args.empty())
error("super does not take arguments");
if (ctx->bases.empty() || !ctx->bases.back().type)
error("no parent classes available");
auto fptyp = ctx->bases.back().type->getFunc();
if (!fptyp || !fptyp->ast->hasAttr(Attr::Method))
error("no parent classes available");
if (fptyp->args.size() < 2)
error("no parent classes available");
ClassTypePtr typ = fptyp->args[1]->getClass();
auto &cands = ctx->cache->classes[typ->name].parentClasses;
if (cands.empty())
error("no parent classes available");
// if (typ->getRecord())
// error("cannot use super on tuple types");
// find parent typ
// unify top N args with parent typ args
// realize & do bitcast
// call bitcast() . method
auto name = cands[0].first;
int fields = cands[0].second;
auto val = ctx->find(name);
seqassert(val, "cannot find '{}'", name);
auto ftyp = ctx->instantiate(expr, val->type)->getClass();
if (typ->getRecord()) {
std::vector<ExprPtr> members;
for (int i = 0; i < fields; i++)
members.push_back(N<DotExpr>(N<IdExpr>(fptyp->ast->args[0].name),
ctx->cache->classes[typ->name].fields[i].name));
ExprPtr e = transform(
N<CallExpr>(N<IdExpr>(format(TYPE_TUPLE "{}", members.size())), members));
unify(e->type, ftyp);
e->type = ftyp;
return e;
} else {
for (int i = 0; i < fields; i++) {
auto t = ctx->cache->classes[typ->name].fields[i].type;
t = ctx->instantiate(expr, t, typ.get());
auto ft = ctx->cache->classes[name].fields[i].type;
ft = ctx->instantiate(expr, ft, ftyp.get());
unify(t, ft);
}
ExprPtr typExpr = N<IdExpr>(name);
typExpr->setType(ftyp);
auto self = fptyp->ast->args[0].name;
ExprPtr e = transform(
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "to_class_ptr"),
N<CallExpr>(N<DotExpr>(N<IdExpr>(self), "__raw__")), typExpr));
return e;
}
}
std::vector<ClassTypePtr> TypecheckVisitor::getSuperTypes(const ClassTypePtr &cls) {
std::vector<ClassTypePtr> result;
if (!cls)
return result;
result.push_back(cls);
int start = 0;
for (auto &cand: ctx->cache->classes[cls->name].parentClasses) {
auto name = cand.first;
int fields = cand.second;
auto val = ctx->find(name);
seqassert(val, "cannot find '{}'", name);
auto ftyp = ctx->instantiate(nullptr, val->type)->getClass();
for (int i = start; i < fields; i++) {
auto t = ctx->cache->classes[cls->name].fields[i].type;
t = ctx->instantiate(nullptr, t, cls.get());
auto ft = ctx->cache->classes[name].fields[i].type;
ft = ctx->instantiate(nullptr, ft, ftyp.get());
unify(t, ft);
}
start += fields;
for (auto &t: getSuperTypes(ftyp))
result.push_back(t);
}
return result;
}
} // namespace ast
} // namespace codon

View File

@ -125,6 +125,24 @@ class __internal__:
def opt_ref_invert[T](what: Optional[T]) -> T:
ret i8* %what
@pure
@llvm
def to_class_ptr[T](ptr: Ptr[byte]) -> T:
%0 = bitcast i8* %ptr to {=T}
ret {=T} %0
@pure
def _tuple_offsetof(x, field: Static[int]):
@llvm
def _llvm_offsetof(T: type, idx: Static[int], TE: type) -> int:
%a = alloca {=T}
%b = getelementptr inbounds {=T}, {=T}* %a, i64 0, i32 {=idx}
%base = ptrtoint {=T}* %a to i64
%elem = ptrtoint {=TE}* %b to i64
%offset = sub i64 %elem, %base
ret i64 %offset
return _llvm_offsetof(type(x), field, type(x[field]))
def raw_type_str(p: Ptr[byte], name: str) -> str:
pstr = p.__repr__()
# '<[name] at [pstr]>'

View File

@ -600,6 +600,30 @@ print hasattr(int, "__getitem__")
print hasattr([1, 2], "__getitem__", str)
#: False
#%% isinstance_inheritance,barebones
class AX[T]:
a: T
def __init__(self, a: T):
self.a = a
class Side:
def __init__(self):
pass
class BX[T,U](AX[T], Side):
b: U
def __init__(self, a: T, b: U):
super().__init__(a)
self.b = b
class CX[T,U](BX[T,U]):
c: int
def __init__(self, a: T, b: U):
super().__init__(a, b)
self.c = 1
c = CX('a', False)
print isinstance(c, CX), isinstance(c, BX), isinstance(c, AX), isinstance(c, Side)
#: True True True True
print isinstance(c, BX[str, bool]), isinstance(c, BX[str, str]), isinstance(c, AX[int])
#: True False False
#%% staticlen_err,barebones
print staticlen([1, 2]) #! List[int] is not a tuple type
@ -679,3 +703,83 @@ def foo(x: Callable[[1,2], 3]): pass #! unexpected static type
#%% static_unify_2,barebones
def foo(x: List[1]): pass #! cannot unify T and 1
#%% super,barebones
class A[T]:
a: T
def __init__(self, t: T):
self.a = t
def foo(self):
return f'A:{self.a}'
class B(A[str]):
b: int
def __init__(self):
super().__init__('s')
self.b = 6
def baz(self):
return f'{super().foo()}::{self.b}'
b = B()
print b.foo() #: A:s
print b.baz() #: A:s::6
class AX[T]:
a: T
def __init__(self, a: T):
self.a = a
def foo(self):
return f'[AX:{self.a}]'
class BX[T,U](AX[T]):
b: U
def __init__(self, a: T, b: U):
print super().__class__
super().__init__(a)
self.b = b
def foo(self):
return f'[BX:{super().foo()}:{self.b}]'
class CX[T,U](BX[T,U]):
c: int
def __init__(self, a: T, b: U):
print super().__class__
super().__init__(a, b)
self.c = 1
def foo(self):
return f'CX:{super().foo()}:{self.c}'
c = CX('a', False)
print c.__class__, c.foo()
#: BX[str,bool]
#: AX[str]
#: CX[str,bool] CX:[BX:[AX:a]:False]:1
#%% super_tuple,barebones
@tuple
class A[T]:
a: T
x: int
def __new__(a: T) -> A[T]:
return (a, 1)
def foo(self):
return f'A:{self.a}'
@tuple
class B(A[str]):
b: int
def __new__() -> B:
return (*(A('s')), 6)
def baz(self):
return f'{super().foo()}::{self.b}'
b = B()
print b.foo() #: A:s
print b.baz() #: A:s::6
#%% super_error,barebones
class A:
def __init__(self):
super().__init__()
a = A()
#! no parent classes available
#! while realizing A.__init__:1 (arguments A.__init__:1[A])
#%% super_error_2,barebones
super().foo(1) #! no parent classes available

View File

@ -299,19 +299,19 @@ def foo2(x):
foo2(1) #: 2
foo2('s') #: 1
#%% super,barebones
#%% superf,barebones
class Foo:
def foo(a):
# super(a)
# superf(a)
print 'foo-1', a
def foo(a: int):
super(a)
superf(a)
print 'foo-2', a
def foo(a: str):
super(a)
superf(a)
print 'foo-3', a
def foo(a):
super(a)
superf(a)
print 'foo-4', a
Foo.foo(1)
#: foo-1 1
@ -324,21 +324,21 @@ class Bear:
@extend
class Bear:
def woof(x):
return super(x) + f' bear w--f {x}'
return superf(x) + f' bear w--f {x}'
print Bear.woof('!')
#: bear woof ! bear w--f !
class PolarBear(Bear):
def woof():
return 'polar ' + super('@')
return 'polar ' + superf('@')
print PolarBear.woof()
#: polar bear woof @ bear w--f @
#%% super_error,barebones
#%% superf_error,barebones
class Foo:
def foo(a):
super(a)
superf(a)
print 'foo-1', a
Foo.foo(1)
#! no matching super methods are available
#! no matching superf methods are available
#! while realizing Foo.foo:0