Refactor CallExpr routing

typecheck-v2
Ibrahim Numanagić 2024-08-08 21:26:06 -07:00
parent 6315dcc3c9
commit f02f6371fc
12 changed files with 426 additions and 425 deletions

View File

@ -45,6 +45,13 @@ std::string Expr::wrapType(const std::string &sexpr) const {
type && !done ? format(" #:type \"{}\"", type->debugString(2)) : "");
return s;
}
Expr *Expr::operator<<(const types::TypePtr &t) {
seqassert(type, "lhs is nullptr");
if ((*type) << t) {
E(Error::TYPE_UNIFY, getSrcInfo(), type->prettyString(), t->prettyString());
}
return this;
}
Param::Param(std::string name, Expr *type, Expr *defaultValue, int status)
: name(std::move(name)), type(type), defaultValue(defaultValue) {

View File

@ -57,6 +57,8 @@ struct Expr : public AcceptorExtend<Expr, ASTNode> {
static const char NodeId;
SERIALIZE(Expr, BASE(ASTNode), /*type,*/ done, origExpr);
Expr *operator<<(const types::TypePtr &t);
protected:
/// Add a type to S-expression string.
std::string wrapType(const std::string &sexpr) const;

View File

@ -56,4 +56,16 @@ char Type::isStaticType() {
return 0;
}
Type *Type::operator<<(const TypePtr &t) {
seqassert(t, "rhs is nullptr");
types::Type::Unification undo;
if (unify(t.get(), &undo) >= 0) {
return this;
} else {
undo.undo();
return nullptr;
}
}
} // namespace codon::ast::types

View File

@ -118,6 +118,8 @@ public:
virtual bool is(const std::string &s);
char isStaticType();
Type *operator<<(const std::shared_ptr<Type> &t);
protected:
Cache *cache;
explicit Type(const std::shared_ptr<Type> &);

View File

@ -43,7 +43,7 @@ private:
/// The absolute path of the current module.
std::string filename;
/// SrcInfo stack used for obtaining source information of the current expression.
std::vector<SrcInfo> srcInfos;
std::vector<ASTNode *> nodeStack;
public:
explicit Context(std::string filename) : filename(std::move(filename)) {
@ -120,9 +120,14 @@ protected:
public:
/* SrcInfo helpers */
void pushSrcInfo(SrcInfo s) { srcInfos.emplace_back(std::move(s)); }
void popSrcInfo() { srcInfos.pop_back(); }
SrcInfo getSrcInfo() const { return srcInfos.back(); }
void pushNode(ASTNode *n) { nodeStack.emplace_back(n); }
void popNode() { nodeStack.pop_back(); }
ASTNode *getLastNode() const { return nodeStack.back(); }
ASTNode *getParentNode() const {
assert(nodeStack.size() > 1);
return nodeStack[nodeStack.size() - 2];
}
SrcInfo getSrcInfo() const { return nodeStack.back()->getSrcInfo(); }
};
} // namespace codon::ast

View File

@ -27,7 +27,6 @@ void TypecheckVisitor::visit(IdExpr *expr) {
}
auto o = in(ctx->cache->overloads, val->canonicalName);
if (expr->getType()->getUnbound() && o && o->size() > 1) {
// LOG("dispatch: {}", val->canonicalName);
val = ctx->forceFind(getDispatch(val->canonicalName)->ast->name);
}
@ -72,8 +71,144 @@ void TypecheckVisitor::visit(IdExpr *expr) {
/// `a.B.c` -> canonical name of `c` in class `a.B`
/// `python.foo` -> internal.python._get_identifier("foo")
/// Other cases are handled during the type checking.
/// See @c transformDot for details.
void TypecheckVisitor::visit(DotExpr *expr) { resultExpr = transformDot(expr); }
/// Transform a dot expression. Select the best method overload if possible.
/// @example
/// `obj.__class__` -> `type(obj)`
/// `cls.__name__` -> `"class"` (same for functions)
/// `obj.method` -> `cls.method(obj, ...)` or
/// `cls.method(obj)` if method has `@property` attribute
/// @c getClassMember examples:
/// `obj.GENERIC` -> `GENERIC` (IdExpr with generic/static value)
/// `optional.member` -> `unwrap(optional).member`
/// `pyobj.member` -> `pyobj._getattr("member")`
/// @return nullptr if no transformation was made
/// See @c getClassMember and @c getBestOverload
void TypecheckVisitor::visit(DotExpr *expr) {
// First flatten the imports:
// transform Dot(Dot(a, b), c...) to {a, b, c, ...}
CallExpr *parentCall = cast<CallExpr>(ctx->getParentNode());
std::vector<std::string> chain;
Expr *root = expr;
for (; cast<DotExpr>(root); root = cast<DotExpr>(root)->getExpr())
chain.push_back(cast<DotExpr>(root)->getMember());
Expr *nexpr = expr;
if (auto id = cast<IdExpr>(root)) {
// Case: a.bar.baz
chain.push_back(id->getValue());
std::reverse(chain.begin(), chain.end());
auto [pos, val] = getImport(chain);
if (!val) {
seqassert(ctx->getBase()->pyCaptures, "unexpected py capture");
ctx->getBase()->pyCaptures->insert(chain[0]);
nexpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(chain[0]));
} else if (val->getModule() == "std.python") {
nexpr = transform(N<CallExpr>(
N<DotExpr>(N<DotExpr>(N<IdExpr>("internal"), "python"), "_get_identifier"),
N<StringExpr>(chain[pos++])));
} else if (val->getModule() == ctx->getModule() && pos == 1) {
nexpr = transform(N<IdExpr>(chain[0]), true);
} else {
nexpr = N<IdExpr>(val->canonicalName);
}
while (pos < chain.size())
nexpr = N<DotExpr>(nexpr, chain[pos++]);
}
if (!cast<DotExpr>(nexpr)) {
resultExpr = transform(nexpr);
return;
} else {
expr->expr = cast<DotExpr>(nexpr)->getExpr();
expr->member = cast<DotExpr>(nexpr)->getMember();
}
// Special case: obj.__class__
if (expr->getMember() == "__class__") {
/// TODO: prevent cls.__class__ and type(cls)
resultExpr = transform(N<CallExpr>(N<IdExpr>("type"), expr->getExpr()));
return;
}
expr->expr = transform(expr->getExpr());
// Special case: fn.__name__
// Should go before cls.__name__ to allow printing generic functions
if (ctx->getType(expr->getExpr()->getType())->getFunc() &&
expr->getMember() == "__name__") {
resultExpr = transform(
N<StringExpr>(ctx->getType(expr->getExpr()->getType())->prettyString()));
return;
}
// Special case: fn.__llvm_name__ or obj.__llvm_name__
if (expr->getMember() == "__llvm_name__") {
if (realize(expr->getExpr()->getType()))
resultExpr = transform(N<StringExpr>(expr->getExpr()->getType()->realizedName()));
return;
}
// Special case: cls.__name__
if (expr->getExpr()->getType()->is("type") && expr->getMember() == "__name__") {
if (realize(expr->getExpr()->getType()))
resultExpr = transform(
N<StringExpr>(ctx->getType(expr->getExpr()->getType())->prettyString()));
return;
}
// Special case: expr.__is_static__
if (expr->getMember() == "__is_static__") {
if (expr->getExpr()->isDone())
resultExpr =
transform(N<BoolExpr>(bool(expr->getExpr()->getType()->getStatic())));
return;
}
// Special case: cls.__id__
if (expr->getExpr()->getType()->is("type") && expr->getMember() == "__id__") {
if (auto c = realize(getType(expr->getExpr())))
resultExpr =
transform(N<IntExpr>(ctx->cache->getClass(c->getClass())
->realizations[c->getClass()->realizedName()]
->id));
return;
}
// Ensure that the type is known (otherwise wait until it becomes known)
auto typ = getType(expr->getExpr()) ? getType(expr->getExpr())->getClass() : nullptr;
if (!typ)
return;
// Check if this is a method or member access
auto methods = ctx->findMethod(typ.get(), expr->getMember());
if (methods.empty()) {
resultExpr = getClassMember(expr);
} else {
auto bestMethod =
methods.size() > 1
? getDispatch(
ctx->cache->functions[methods.front()->ast->getName()].rootName)
: methods.front();
Expr *e = N<IdExpr>(bestMethod->ast->getName());
e->setType(unify(expr->getType(), ctx->instantiate(bestMethod, typ)));;
if (expr->getExpr()->getType()->is("type")) {
// Static access: `cls.method`
} else if (parentCall && !bestMethod->ast->hasAttribute(Attr::StaticMethod) &&
!bestMethod->ast->hasAttribute(Attr::Property)) {
// Instance access: `obj.method`
parentCall->items.insert(parentCall->items.begin(), expr->getExpr());
} else {
// Instance access: `obj.method`
// Transform y.method to a partial call `type(obj).method(args, ...)`
std::vector<Expr *> methodArgs;
// Do not add self if a method is marked with @staticmethod
if (!bestMethod->ast->hasAttribute(Attr::StaticMethod))
methodArgs.push_back(expr->getExpr());
// If a method is marked with @property, just call it directly
if (!bestMethod->ast->hasAttribute(Attr::Property))
methodArgs.push_back(N<EllipsisExpr>(EllipsisExpr::PARTIAL));
e = N<CallExpr>(e, methodArgs);
}
resultExpr = transform(e);
}
}
/// Access identifiers from outside of the current function/class scope.
/// Either use them as-is (globals), capture them if allowed (nonlocals),
@ -125,32 +260,6 @@ bool TypecheckVisitor::checkCapture(const TypeContext::Item &val) {
if (crossCaptureBoundary)
E(Error::ID_CANNOT_CAPTURE, getSrcInfo(), ctx->cache->rev(val->canonicalName));
// Case: a nonlocal variable that has not been marked with `nonlocal` statement
// and capturing is enabled
// auto captures = ctx->getBase()->captures;
// if (captures && !in(*captures, val->canonicalName)) {
// // Captures are transformed to function arguments; generate new name for that
// // argument
// Expr * typ = nullptr;
// if (val->isType())
// typ = N<IdExpr>("type");
// if (auto st = val->isStatic())
// typ = N<IndexExpr>(N<IdExpr>("Static"),
// N<IdExpr>(st == StaticValue::INT ? "int" : "str"));
// auto [newName, _] = (*captures)[val->canonicalName] = {
// ctx->generateCanonicalName(val->canonicalName), typ};
// ctx->cache->reverseIdentifierLookup[newName] = newName;
// // Add newly generated argument to the context
// std::shared_ptr<TypecheckItem> newVal = nullptr;
// if (val->isType())
// newVal = ctx->addType(ctx->cache->rev(val->canonicalName), newName, val->type);
// else
// newVal = ctx->addVar(ctx->cache->rev(val->canonicalName), newName, val->type);
// newVal->baseName = ctx->getBaseName();
// newVal->canShadow = false; // todo)) needed here? remove noshadow on fn
// boundaries? newVal->scope = ctx->getBase()->scope; return true;
// }
// Case: a nonlocal variable that has not been marked with `nonlocal` statement
// and capturing is *not* enabled
E(Error::ID_NONLOCAL, getSrcInfo(), ctx->cache->rev(val->canonicalName));
@ -291,8 +400,10 @@ types::FuncTypePtr TypecheckVisitor::getDispatch(const std::string &fn) {
auto baseType = getFuncTypeBase(2);
auto typ = std::make_shared<FuncType>(baseType, ast, 0);
typ->funcParent = ctx->cache->functions[overloads[0]].type->funcParent;
typ = std::static_pointer_cast<FuncType>(typ->generalize(ctx->typecheckLevel - 1));
ctx->addFunc(name, name, typ);
// LOG("-[D]-> {} / {}", typ->debugString(2), typ->funcParent->debugString(2));
overloads.insert(overloads.begin(), name);
ctx->cache->functions[name].ast = ast;
@ -302,186 +413,6 @@ types::FuncTypePtr TypecheckVisitor::getDispatch(const std::string &fn) {
return typ;
}
/// Transform a dot expression. Select the best method overload if possible.
/// @param args (optional) list of class method arguments used to select the best
/// overload. nullptr if not available.
/// @example
/// `obj.__class__` -> `type(obj)`
/// `cls.__name__` -> `"class"` (same for functions)
/// `obj.method` -> `cls.method(obj, ...)` or
/// `cls.method(obj)` if method has `@property` attribute
/// @c getClassMember examples:
/// `obj.GENERIC` -> `GENERIC` (IdExpr with generic/static value)
/// `optional.member` -> `unwrap(optional).member`
/// `pyobj.member` -> `pyobj._getattr("member")`
/// @return nullptr if no transformation was made
/// See @c getClassMember and @c getBestOverload
Expr *TypecheckVisitor::transformDot(DotExpr *expr, std::vector<CallArg> *args) {
// First flatten the imports:
// transform Dot(Dot(a, b), c...) to {a, b, c, ...}
if (!expr->getType())
expr->setType(ctx->getUnbound());
std::vector<std::string> chain;
Expr *root = expr;
for (; cast<DotExpr>(root); root = cast<DotExpr>(root)->expr)
chain.push_back(cast<DotExpr>(root)->member);
Expr *nexpr = expr;
if (auto id = cast<IdExpr>(root)) {
// Case: a.bar.baz
chain.push_back(id->getValue());
std::reverse(chain.begin(), chain.end());
auto [pos, val] = getImport(chain);
if (!val) {
seqassert(ctx->getBase()->pyCaptures, "unexpected py capture");
ctx->getBase()->pyCaptures->insert(chain[0]);
nexpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(chain[0]));
} else if (val->getModule() == "std.python") {
nexpr = transform(N<CallExpr>(
N<DotExpr>(N<DotExpr>(N<IdExpr>("internal"), "python"), "_get_identifier"),
N<StringExpr>(chain[pos++])));
} else if (val->getModule() == ctx->getModule() && pos == 1) {
nexpr = transform(N<IdExpr>(chain[0]), true);
} else {
nexpr = N<IdExpr>(val->canonicalName);
}
while (pos < chain.size())
nexpr = N<DotExpr>(nexpr, chain[pos++]);
}
if (!cast<DotExpr>(nexpr)) {
if (args) {
nexpr = transform(nexpr);
if (auto id = cast<IdExpr>(nexpr))
if (endswith(id->getValue(), ":dispatch"))
if (auto bestMethod = getBestOverload(id, args)) {
auto t = id->getType();
nexpr = N<IdExpr>(bestMethod->ast->name);
nexpr->setType(ctx->instantiate(bestMethod));
}
return nexpr;
} else {
return transform(nexpr);
}
} else {
expr->expr = cast<DotExpr>(nexpr)->expr;
expr->member = cast<DotExpr>(nexpr)->member;
}
// Special case: obj.__class__
if (expr->member == "__class__") {
/// TODO: prevent cls.__class__ and type(cls)
return N<CallExpr>(N<IdExpr>("type"), expr->expr);
}
expr->expr = transform(expr->expr);
// Special case: fn.__name__
// Should go before cls.__name__ to allow printing generic functions
if (ctx->getType(expr->expr->getType())->getFunc() && expr->member == "__name__") {
return transform(
N<StringExpr>(ctx->getType(expr->expr->getType())->prettyString()));
}
// Special case: fn.__llvm_name__ or obj.__llvm_name__
if (expr->member == "__llvm_name__") {
if (realize(expr->expr->getType()))
return transform(N<StringExpr>(expr->expr->getType()->realizedName()));
return nullptr;
}
// Special case: cls.__name__
if (expr->expr->getType()->is("type") && expr->member == "__name__") {
if (realize(expr->expr->getType()))
return transform(
N<StringExpr>(ctx->getType(expr->expr->getType())->prettyString()));
return nullptr;
}
// Special case: expr.__is_static__
if (expr->member == "__is_static__") {
if (expr->expr->isDone())
return transform(N<BoolExpr>(bool(expr->expr->getType()->getStatic())));
return nullptr;
}
// Special case: cls.__id__
if (expr->expr->getType()->is("type") && expr->member == "__id__") {
if (auto c = realize(getType(expr->expr))) {
return transform(N<IntExpr>(ctx->cache->getClass(c->getClass())
->realizations[c->getClass()->realizedName()]
->id));
}
return nullptr;
}
// Ensure that the type is known (otherwise wait until it becomes known)
if (!getType(expr->expr))
return nullptr;
auto typ = getType(expr->expr)->getClass();
if (!typ)
return nullptr;
// Check if this is a method or member access
if (ctx->findMethod(typ.get(), expr->member).empty())
return getClassMember(expr, args);
auto bestMethod = getBestOverload(expr, args);
if (args) {
unify(expr->getType(), ctx->instantiate(bestMethod, typ));
// A function is deemed virtual if it is marked as such and
// if a base class has a RTTI
auto cls = ctx->cache->getClass(typ);
bool isVirtual = in(cls->virtuals, expr->member);
isVirtual &= cls->rtti;
isVirtual &= !expr->expr->getType()->is("type");
if (isVirtual && !bestMethod->ast->hasAttribute(Attr::StaticMethod) &&
!bestMethod->ast->hasAttribute(Attr::Property)) {
// Special case: route the call through a vtable
if (realize(expr->getType())) {
auto fn = expr->getType()->getFunc();
auto vid = getRealizationID(typ.get(), fn.get());
// Function[Tuple[TArg1, TArg2, ...], TRet]
std::vector<Expr *> ids;
for (auto &t : fn->getArgTypes())
ids.push_back(N<IdExpr>(t->realizedName()));
auto fnType = N<InstantiateExpr>(
N<IdExpr>("Function"),
std::vector<Expr *>{N<InstantiateExpr>(N<IdExpr>(TYPE_TUPLE), ids),
N<IdExpr>(fn->getRetType()->realizedName())});
// Function[Tuple[TArg1, TArg2, ...],TRet](
// __internal__.class_get_rtti_vtable(expr)[T[VIRTUAL_ID]]
// )
auto e = N<CallExpr>(
fnType, N<IndexExpr>(N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"),
"class_get_rtti_vtable"),
expr->expr),
N<IntExpr>(vid)));
return transform(e);
}
}
}
// Check if a method is a static or an instance method and transform accordingly
if (expr->expr->getType()->is("type") || args) {
// Static access: `cls.method`
Expr *e = N<IdExpr>(bestMethod->ast->name);
e->setType(unify(expr->getType(), ctx->instantiate(bestMethod, typ)));
return transform(e); // Realize if needed
} else {
// Instance access: `obj.method`
// Transform y.method to a partial call `type(obj).method(args, ...)`
std::vector<Expr *> methodArgs;
// Do not add self if a method is marked with @staticmethod
if (!bestMethod->ast->hasAttribute(Attr::StaticMethod))
methodArgs.push_back(expr->expr);
// If a method is marked with @property, just call it directly
if (!bestMethod->ast->hasAttribute(Attr::Property))
methodArgs.push_back(N<EllipsisExpr>(EllipsisExpr::PARTIAL));
auto e = transform(N<CallExpr>(N<IdExpr>(bestMethod->ast->name), methodArgs));
unify(expr->getType(), e->getType());
return e;
}
}
/// Select the requested class member.
/// @param args (optional) list of class method arguments used to select the best
/// overload if the member is optional. nullptr if not available.
@ -489,7 +420,7 @@ Expr *TypecheckVisitor::transformDot(DotExpr *expr, std::vector<CallArg> *args)
/// `obj.GENERIC` -> `GENERIC` (IdExpr with generic/static value)
/// `optional.member` -> `unwrap(optional).member`
/// `pyobj.member` -> `pyobj._getattr("member")`
Expr *TypecheckVisitor::getClassMember(DotExpr *expr, std::vector<CallArg> *args) {
Expr *TypecheckVisitor::getClassMember(DotExpr *expr) {
auto typ = getType(expr->expr)->getClass();
seqassert(typ, "not a class");
@ -547,11 +478,8 @@ Expr *TypecheckVisitor::getClassMember(DotExpr *expr, std::vector<CallArg> *args
// Case: transform `optional.member` to `unwrap(optional).member`
if (typ->is(TYPE_OPTIONAL)) {
auto dot = N<DotExpr>(transform(N<CallExpr>(N<IdExpr>(FN_UNWRAP), expr->expr)),
expr->member);
dot->setType(ctx->getUnbound()); // as dot is not transformed
if (auto d = transformDot(dot, args))
return d;
auto dot = transform(N<DotExpr>(
transform(N<CallExpr>(N<IdExpr>(FN_UNWRAP), expr->expr)), expr->member));
return dot;
}
@ -594,121 +522,107 @@ TypePtr TypecheckVisitor::findSpecialMember(const std::string &member) {
/// @param methods List of available methods.
/// @param args (optional) list of class method arguments used to select the best
/// overload if the member is optional. nullptr if not available.
FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr, std::vector<CallArg> *args) {
// Prepare the list of method arguments if possible
std::unique_ptr<std::vector<CallArg>> methodArgs;
// FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr, std::vector<CallArg> *args)
// {
// // Prepare the list of method arguments if possible
// std::unique_ptr<std::vector<CallArg>> methodArgs;
if (args) {
// Case: method overloads (DotExpr)
bool addSelf = true;
if (auto dot = cast<DotExpr>(expr)) {
auto cls = getType(dot->expr)->getClass();
auto methods = ctx->findMethod(cls.get(), dot->member, false);
if (!methods.empty() && methods.front()->ast->hasAttribute(Attr::StaticMethod))
addSelf = false;
}
// if (args) {
// methodArgs = std::make_unique<std::vector<CallArg>>();
// for (auto &a : *args)
// methodArgs->push_back(a);
// }
// // else {
// // // Partially deduced type thus far
// // auto typeSoFar = expr->getType() ? getType(expr)->getClass() : nullptr;
// // if (typeSoFar && typeSoFar->getFunc()) {
// // // Case: arguments available from the previous type checking round
// // methodArgs = std::make_unique<std::vector<CallArg>>();
// // if (cast<DotExpr>(expr) &&
// // !cast<DotExpr>(expr)->expr->getType()->is("type")) { // Add `self`
// // auto n = N<NoneExpr>();
// // n->setType(cast<DotExpr>(expr)->expr->getType());
// // methodArgs->push_back({"", n});
// // }
// // for (auto &a : typeSoFar->getFunc()->getArgTypes()) {
// // auto n = N<NoneExpr>();
// // n->setType(a);
// // methodArgs->push_back({"", n});
// // }
// // }
// // }
// Case: arguments explicitly provided (by CallExpr)
if (addSelf && cast<DotExpr>(expr) &&
!cast<DotExpr>(expr)->expr->getType()->is("type")) {
// Add `self` as the first argument
args->insert(args->begin(), {"", cast<DotExpr>(expr)->expr});
}
methodArgs = std::make_unique<std::vector<CallArg>>();
for (auto &a : *args)
methodArgs->push_back(a);
} else {
// Partially deduced type thus far
auto typeSoFar = expr->getType() ? getType(expr)->getClass() : nullptr;
if (typeSoFar && typeSoFar->getFunc()) {
// Case: arguments available from the previous type checking round
methodArgs = std::make_unique<std::vector<CallArg>>();
if (cast<DotExpr>(expr) &&
!cast<DotExpr>(expr)->expr->getType()->is("type")) { // Add `self`
auto n = N<NoneExpr>();
n->setType(cast<DotExpr>(expr)->expr->getType());
methodArgs->push_back({"", n});
}
for (auto &a : typeSoFar->getFunc()->getArgTypes()) {
auto n = N<NoneExpr>();
n->setType(a);
methodArgs->push_back({"", n});
}
}
}
// std::vector<FuncTypePtr> m;
// // Use the provided arguments to select the best method
// if (auto dot = cast<DotExpr>(expr)) {
// // Case: method overloads (DotExpr)
// auto methods =
// ctx->findMethod(getType(dot->expr)->getClass().get(), dot->member, false);
// if (methodArgs)
// m = findMatchingMethods(getType(dot->expr)->getClass(), methods, *methodArgs);
// } else if (auto id = cast<IdExpr>(expr)) {
// // Case: function overloads (IdExpr)
// std::vector<types::FuncTypePtr> methods;
// auto key = id->getValue();
// if (endswith(key, ":dispatch"))
// key = key.substr(0, key.size() - 9);
// for (auto &m : ctx->cache->overloads[key])
// if (!endswith(m, ":dispatch"))
// methods.push_back(ctx->cache->functions[m].type);
// std::reverse(methods.begin(), methods.end());
// m = findMatchingMethods(nullptr, methods, *methodArgs);
// }
bool goDispatch = methodArgs == nullptr;
if (!goDispatch) {
std::vector<FuncTypePtr> m;
// Use the provided arguments to select the best method
if (auto dot = cast<DotExpr>(expr)) {
// Case: method overloads (DotExpr)
auto methods =
ctx->findMethod(getType(dot->expr)->getClass().get(), dot->member, false);
m = findMatchingMethods(getType(dot->expr)->getClass(), methods, *methodArgs);
} else if (auto id = cast<IdExpr>(expr)) {
// Case: function overloads (IdExpr)
std::vector<types::FuncTypePtr> methods;
auto key = id->getValue();
if (endswith(key, ":dispatch"))
key = key.substr(0, key.size() - 9);
for (auto &m : ctx->cache->overloads[key])
if (!endswith(m, ":dispatch"))
methods.push_back(ctx->cache->functions[m].type);
std::reverse(methods.begin(), methods.end());
m = findMatchingMethods(nullptr, methods, *methodArgs);
}
// bool goDispatch = false;
// if (m.size() == 1) {
// return m[0];
// } else if (m.size() > 1) {
// for (auto &a : *methodArgs) {
// if (auto u = a.value->getType()->getUnbound()) {
// goDispatch = true;
// }
// }
// if (!goDispatch)
// return m[0];
// }
if (m.size() == 1) {
return m[0];
} else if (m.size() > 1) {
for (auto &a : *methodArgs) {
if (auto u = a.value->getType()->getUnbound()) {
goDispatch = true;
}
}
if (!goDispatch)
return m[0];
}
}
// if (goDispatch) {
// // If overload is ambiguous, route through a dispatch function
// std::string name;
// if (auto dot = cast<DotExpr>(expr)) {
// auto methods =
// ctx->findMethod(getType(dot->expr)->getClass().get(), dot->member, false);
// seqassert(!methods.empty(), "unknown method");
// name = ctx->cache->functions[methods.back()->ast->name].rootName;
// } else {
// name = cast<IdExpr>(expr)->getValue();
// }
// auto t = getDispatch(name);
// }
if (goDispatch) {
// If overload is ambiguous, route through a dispatch function
std::string name;
if (auto dot = cast<DotExpr>(expr)) {
auto methods =
ctx->findMethod(getType(dot->expr)->getClass().get(), dot->member, false);
seqassert(!methods.empty(), "unknown method");
name = ctx->cache->functions[methods.back()->ast->name].rootName;
} else {
name = cast<IdExpr>(expr)->getValue();
}
return getDispatch(name);
}
// // Print a nice error message
// std::string argsNice;
// if (methodArgs) {
// std::vector<std::string> a;
// for (auto &t : *methodArgs)
// a.emplace_back(fmt::format("{}", t.value->getType()->getStatic()
// ? t.value->getClassType()->name
// : t.value->getType()->prettyString()));
// argsNice = fmt::format("({})", fmt::join(a, ", "));
// }
// Print a nice error message
std::string argsNice;
if (methodArgs) {
std::vector<std::string> a;
for (auto &t : *methodArgs)
a.emplace_back(fmt::format("{}", t.value->getType()->getStatic()
? t.value->getClassType()->name
: t.value->getType()->prettyString()));
argsNice = fmt::format("({})", fmt::join(a, ", "));
}
// if (auto dot = cast<DotExpr>(expr)) {
// // Debug:
// // *args = std::vector<CallArg>(args->begin()+1, args->end());
// getBestOverload(expr, args);
// E(Error::DOT_NO_ATTR_ARGS, expr, getType(dot->expr)->prettyString(), dot->member,
// argsNice);
// } else {
// E(Error::FN_NO_ATTR_ARGS, expr, ctx->cache->rev(cast<IdExpr>(expr)->getValue()),
// argsNice);
// }
if (auto dot = cast<DotExpr>(expr)) {
// Debug:
// *args = std::vector<CallArg>(args->begin()+1, args->end());
// getBestOverload(expr, args);
E(Error::DOT_NO_ATTR_ARGS, expr, getType(dot->expr)->prettyString(), dot->member,
argsNice);
} else {
E(Error::FN_NO_ATTR_ARGS, expr, ctx->cache->rev(cast<IdExpr>(expr)->getValue()),
argsNice);
}
return nullptr;
}
// return nullptr;
// }
} // namespace codon::ast

View File

@ -243,8 +243,8 @@ void TypecheckVisitor::visit(AssignMemberStmt *stmt) {
// Case: setters
auto setters = ctx->findMethod(lhsClass.get(), format(".set_{}", stmt->member));
if (!setters.empty()) {
resultStmt = transform(N<ExprStmt>(
N<CallExpr>(N<IdExpr>(setters[0]->ast->name), stmt->lhs, stmt->rhs)));
resultStmt = transform(N<ExprStmt>(N<CallExpr>(
N<IdExpr>(setters.front()->ast->getName()), stmt->lhs, stmt->rhs)));
return;
}
// Case: class variables

View File

@ -56,66 +56,75 @@ void TypecheckVisitor::visit(EllipsisExpr *expr) {
/// See @c transformCallArgs , @c getCalleeFn , @c callReorderArguments ,
/// @c typecheckCallArgs , @c transformSpecialCall and @c wrapExpr for more details.
void TypecheckVisitor::visit(CallExpr *expr) {
// Special case!
if (expr->getType()->getUnbound() && cast<IdExpr>(expr->expr)) {
auto callExpr = cast<IdExpr>(transform(clean_clone(expr->expr)));
if (callExpr && callExpr->getValue() == "std.collections.namedtuple.0") {
resultExpr = transformNamedTuple(expr);
return;
} else if (callExpr && callExpr->getValue() == "std.functools.partial.0:0") {
resultExpr = transformFunctoolsPartial(expr);
return;
} else if (callExpr && callExpr->getValue() == "tuple" && expr->size() == 1 &&
cast<GeneratorExpr>(expr->begin()->value)) {
resultExpr = transformTupleGenerator(expr);
return;
}
}
// Transform and expand arguments. Return early if it cannot be done yet
ctx->addBlock();
if (expr->expr->getType())
if (auto f = expr->expr->getType()->getFunc())
addFunctionGenerics(f.get());
auto a = transformCallArgs(expr->items);
ctx->popBlock();
if (!a)
return;
auto orig = expr->toString(0);
// Check if this call is partial call
PartialCallData part{
!expr->items.empty() && cast<EllipsisExpr>(expr->items.back().value) &&
cast<EllipsisExpr>(expr->items.back().value)->mode == EllipsisExpr::PARTIAL};
// Transform the callee
if (!part.isPartial) {
// Intercept method calls (e.g. `obj.method`) for faster compilation (because it
// avoids partial calls). This intercept passes the call arguments to
// @c transformDot to select the best overload as well
if (auto dot = cast<DotExpr>(expr->expr)) {
// Pick the best method overload
if (auto edt = transformDot(dot, &expr->items))
expr->expr = edt;
} else if (auto id = cast<IdExpr>(expr->expr)) {
expr->expr = transform(expr->expr);
id = cast<IdExpr>(expr->expr);
// Pick the best function overload
if (endswith(id->getValue(), ":dispatch")) {
if (auto bestMethod = getBestOverload(id, &expr->items)) {
auto t = id->getType();
expr->expr = N<IdExpr>(bestMethod->ast->name);
expr->expr->setType(ctx->instantiate(bestMethod));
}
}
}
}
expr->expr = transform(expr->expr);
expr->expr = transform(expr->getExpr());
if (expr->getExpr()->getType()->getUnbound())
return; // delay
auto [calleeFn, newExpr] = getCalleeFn(expr, part);
if ((resultExpr = newExpr))
return;
if (!calleeFn)
return;
// ctx->addBlock();
// if (expr->getExpr()->getType())
// if (auto f = expr->getExpr()->getType()->getFunc())
// addFunctionGenerics(f.get());
auto a = transformCallArgs(expr->items);
// ctx->popBlock();
if (!a)
return;
// Early dispatch modifier
if (endswith(calleeFn->ast->getName(), ":dispatch")) {
// LOG("-> {}", calleeFn->debugString(2));
std::vector<FuncTypePtr> m;
if (auto id = cast<IdExpr>(expr->getExpr())) {
// Case: function overloads (IdExpr)
std::vector<types::FuncTypePtr> methods;
auto key = id->getValue();
if (endswith(key, ":dispatch"))
key = key.substr(0, key.size() - 9);
for (auto &m : ctx->cache->overloads[key])
if (!endswith(m, ":dispatch"))
methods.push_back(ctx->cache->functions[m].type);
std::reverse(methods.begin(), methods.end());
m = findMatchingMethods(calleeFn->funcParent ? calleeFn->funcParent->getClass()
: nullptr,
methods, expr->items);
}
bool doDispatch = m.size() == 0;
if (m.size() > 1) {
for (auto &a : *expr) {
if (auto u = a.value->getType()->getUnbound())
doDispatch = true;
}
}
if (!doDispatch) {
calleeFn = ctx->instantiate(m.front(), calleeFn->funcParent
? calleeFn->funcParent->getClass()
: nullptr)
->getFunc();
auto e = N<IdExpr>(calleeFn->ast->getName());
e->setType(calleeFn);
if (cast<IdExpr>(expr->getExpr())) {
expr->expr = e;
} else {
expr->expr = N<StmtExpr>(N<ExprStmt>(expr->getExpr()), e);
}
expr->getExpr()->setType(calleeFn);
} else {
LOG("-> {}", expr->toString(0));
}
}
// Handle named and default arguments
if ((resultExpr = callReorderArguments(calleeFn, expr, part)))
return;
@ -163,12 +172,48 @@ void TypecheckVisitor::visit(CallExpr *expr) {
}
call->setAttribute(Attr::ExprPartial);
resultExpr = transform(call);
// LOG("{}: {} --------> {}", getSrcInfo(), orig, resultExpr->toString(0));
} else {
// Case: normal function call
unify(expr->getType(), calleeFn->getRetType());
if (done)
expr->setDone();
}
// unify(expr->getType(), ctx->instantiate(bestMethod, typ));
// A function is deemed virtual if it is marked as such and
// if a base class has a RTTI
// auto cls = ctx->cache->getClass(typ);
// bool isVirtual = in(cls->virtuals, expr->member);
// isVirtual &= cls->rtti;
// isVirtual &= !expr->expr->getType()->is("type");
// if (isVirtual && !bestMethod->ast->hasAttribute(Attr::StaticMethod) &&
// !bestMethod->ast->hasAttribute(Attr::Property)) {
// // Special case: route the call through a vtable
// if (realize(expr->getType())) {
// auto fn = expr->getType()->getFunc();
// auto vid = getRealizationID(typ.get(), fn.get());
// // Function[Tuple[TArg1, TArg2, ...], TRet]
// std::vector<Expr *> ids;
// for (auto &t : fn->getArgTypes())
// ids.push_back(N<IdExpr>(t->realizedName()));
// auto fnType = N<InstantiateExpr>(
// N<IdExpr>("Function"),
// std::vector<Expr *>{N<InstantiateExpr>(N<IdExpr>(TYPE_TUPLE), ids),
// N<IdExpr>(fn->getRetType()->realizedName())});
// // Function[Tuple[TArg1, TArg2, ...],TRet](
// // __internal__.class_get_rtti_vtable(expr)[T[VIRTUAL_ID]]
// // )
// auto e = N<CallExpr>(
// fnType, N<IndexExpr>(N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"),
// "class_get_rtti_vtable"),
// expr->expr),
// N<IntExpr>(vid)));
// return transform(e);
// }
// }
}
/// Transform call arguments. Expand *args and **kwargs to the list of @c CallArg
@ -257,30 +302,30 @@ bool TypecheckVisitor::transformCallArgs(std::vector<CallArg> &args) {
/// (when needed; otherwise nullptr).
std::pair<FuncTypePtr, Expr *> TypecheckVisitor::getCalleeFn(CallExpr *expr,
PartialCallData &part) {
auto callee = expr->expr->getClassType();
auto callee = expr->getExpr()->getClassType();
if (!callee) {
// Case: unknown callee, wait until it becomes known
return {nullptr, nullptr};
}
if (expr->expr->getType()->is("type")) {
auto typ = expr->expr->getClassType();
if (!isId(expr->expr, "type"))
if (expr->getExpr()->getType()->is("type")) {
auto typ = expr->getExpr()->getClassType();
if (!isId(expr->getExpr(), "type"))
typ = typ->generics[0].type->getClass();
if (!typ)
return {nullptr, nullptr};
auto clsName = typ->name;
if (typ->isRecord()) {
// Case: tuple constructor. Transform to: `T.__new__(args)`
return {nullptr,
transform(N<CallExpr>(N<DotExpr>(expr->expr, "__new__"), expr->items))};
return {nullptr, transform(N<CallExpr>(N<DotExpr>(expr->getExpr(), "__new__"),
expr->items))};
}
// Case: reference type constructor. Transform to
// `ctr = T.__new__(); v.__init__(args)`
Expr *var = N<IdExpr>(ctx->cache->getTemporaryVar("ctr"));
auto newInit =
N<AssignStmt>(clone(var), N<CallExpr>(N<DotExpr>(expr->expr, "__new__")));
N<AssignStmt>(clone(var), N<CallExpr>(N<DotExpr>(expr->getExpr(), "__new__")));
auto e = N<StmtExpr>(N<SuiteStmt>(newInit), clone(var));
auto init =
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__init__"), expr->items));
@ -288,20 +333,21 @@ std::pair<FuncTypePtr, Expr *> TypecheckVisitor::getCalleeFn(CallExpr *expr,
return {nullptr, transform(e)};
}
auto calleeFn = callee->getFunc();
if (auto partType = callee->getPartial()) {
auto mask = partType->getPartialMask();
auto func = ctx->instantiate(partType->getPartialFunc()->generalize(0))->getFunc();
// Case: calling partial object `p`. Transform roughly to
// `part = callee; partial_fn(*part.args, args...)`
Expr *var = N<IdExpr>(part.var = ctx->cache->getTemporaryVar("partcall"));
expr->expr = transform(
N<StmtExpr>(N<AssignStmt>(clone(var), expr->expr), N<IdExpr>(func->ast->name)));
if (!partType->isPartialEmpty()) {
// Case: calling partial object `p`. Transform roughly to
// `part = callee; partial_fn(*part.args, args...)`
Expr *var = N<IdExpr>(part.var = ctx->cache->getTemporaryVar("partcall"));
expr->expr = transform(N<StmtExpr>(N<AssignStmt>(clone(var), expr->getExpr()),
N<IdExpr>(func->ast->name)));
}
// Ensure that we got a function
calleeFn = expr->expr->getType()->getFunc();
seqassert(calleeFn, "not a function: {}", expr->expr->getType());
auto calleeFn = expr->getExpr()->getType()->getFunc();
seqassert(calleeFn, "not a function: {}", expr->getExpr()->getType());
// Unify partial generics with types known thus far
auto knownArgTypes = partType->generics[1].type->getClass();
@ -316,12 +362,13 @@ std::pair<FuncTypePtr, Expr *> TypecheckVisitor::getCalleeFn(CallExpr *expr,
}
part.known = mask;
return {calleeFn, nullptr};
} else if (!calleeFn) {
} else if (!callee->getFunc()) {
// Case: callee is not a function. Try __call__ method instead
return {nullptr,
transform(N<CallExpr>(N<DotExpr>(expr->expr, "__call__"), expr->items))};
return {nullptr, transform(N<CallExpr>(N<DotExpr>(expr->getExpr(), "__call__"),
expr->items))};
} else {
return {callee->getFunc(), nullptr};
}
return {calleeFn, nullptr};
}
/// Reorder the call arguments to match the signature order. Ensure that every @c
@ -668,6 +715,22 @@ std::pair<bool, Expr *> TypecheckVisitor::transformSpecialCall(CallExpr *expr) {
} else {
return transformInternalStaticFn(expr);
}
// Special case!
// if (expr->getType()->getUnbound() && cast<IdExpr>(expr->expr)) {
// auto callExpr = cast<IdExpr>(transform(clean_clone(expr->expr)));
// if (callExpr && callExpr->getValue() == "std.collections.namedtuple.0") {
// resultExpr = transformNamedTuple(expr);
// return;
// } else if (callExpr && callExpr->getValue() == "std.functools.partial.0:0") {
// resultExpr = transformFunctoolsPartial(expr);
// return;
// } else if (callExpr && callExpr->getValue() == "tuple" && expr->size() == 1 &&
// cast<GeneratorExpr>(expr->begin()->value)) {
// resultExpr = transformTupleGenerator(expr);
// return;
// }
// }
}
/// Transform `tuple(i for i in tup)` into a GeneratorExpr that will be handled during

View File

@ -30,7 +30,9 @@ TypeContext::TypeContext(Cache *cache, std::string filename)
: Context<TypecheckItem>(std::move(filename)), cache(cache) {
bases.emplace_back();
scope.emplace_back(0);
pushSrcInfo(cache->generateSrcInfo()); // Always have srcInfo() around
auto e = cache->N<NoneExpr>();
e->setSrcInfo(cache->generateSrcInfo());
pushNode(e); // Always have srcInfo() around
}
void TypeContext::add(const std::string &name, const TypeContext::Item &var) {
@ -272,10 +274,6 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
break;
if (idx == cm.size())
cm.push_back(key);
// if (idx)
// LOG("--> {}: realize {}: {} / {}", getSrcInfo(), ft->debugString(2), idx,
// key);
ft->index = idx;
}
}
if (t->getUnion() && !t->getUnion()->isSealed()) {

View File

@ -31,16 +31,13 @@ using namespace types;
/// @return a
TypePtr TypecheckVisitor::unify(const TypePtr &a, const TypePtr &b) {
seqassert(a, "lhs is nullptr");
seqassert(b, "rhs is nullptr");
types::Type::Unification undo;
if (a->unify(b.get(), &undo) >= 0) {
return a;
} else {
undo.undo();
if (!((*a) << b)) {
types::Type::Unification undo;
a->unify(b.get(), &undo);
E(Error::TYPE_UNIFY, getSrcInfo(), a->prettyString(), b->prettyString());
return nullptr;
}
a->unify(b.get(), &undo);
E(Error::TYPE_UNIFY, getSrcInfo(), a->prettyString(), b->prettyString());
return nullptr;
return a;
}
/// Infer all types within a Stmt *. Implements the LTS-DI typechecking.

View File

@ -183,13 +183,12 @@ Expr *TypecheckVisitor::transform(Expr *expr, bool allowTypes) {
if (!expr->getType())
expr->setType(ctx->getUnbound());
auto typ = expr->getType();
if (!expr->isDone()) {
TypecheckVisitor v(ctx, preamble, prependStmts);
v.setSrcInfo(expr->getSrcInfo());
ctx->pushSrcInfo(expr->getSrcInfo());
ctx->pushNode(expr);
expr->accept(v);
ctx->popSrcInfo();
ctx->popNode();
if (v.resultExpr) {
for (auto it = expr->attributes_begin(); it != expr->attributes_end(); ++it) {
const auto *attr = expr->getAttribute(*it);
@ -197,20 +196,23 @@ Expr *TypecheckVisitor::transform(Expr *expr, bool allowTypes) {
v.resultExpr->setAttribute(*it, attr->clone());
}
v.resultExpr->setOrigExpr(expr);
// unify(expr->getType(), v.resultExpr->getType());
expr = v.resultExpr;
if (!expr->getType())
expr->setType(ctx->getUnbound());
}
if (!allowTypes && expr && expr->getType()->is("type"))
E(Error::UNEXPECTED_TYPE, expr, "type");
if (!expr->getType())
expr->setType(ctx->getUnbound());
unify(typ, expr->getType());
// unify(typ, expr->getType());
if (expr->isDone())
ctx->changedNodes++;
}
realize(typ);
if (expr)
if (expr) {
if (auto p = realize(expr->getType()))
unify(expr->getType(), p);
LOG_TYPECHECK("[expr] {}: {}{}", getSrcInfo(), *(expr),
expr->isDone() ? "[done]" : "");
}
return expr;
}
@ -265,9 +267,9 @@ Stmt *TypecheckVisitor::transform(Stmt *stmt) {
if (!stmt->toString(-1).empty())
LOG_TYPECHECK("> [{}] [{}:{}] {}", getSrcInfo(), ctx->getBaseName(),
ctx->getBase()->iteration, stmt->toString(-1));
ctx->pushSrcInfo(stmt->getSrcInfo());
ctx->pushNode(stmt);
stmt->accept(v);
ctx->popSrcInfo();
ctx->popNode();
if (v.resultStmt)
stmt = v.resultStmt;
if (!v.prependStmts->empty()) {

View File

@ -82,10 +82,9 @@ private: // Node typechecking rules
bool checkCapture(const TypeContext::Item &);
void visit(DotExpr *) override;
std::pair<size_t, TypeContext::Item> getImport(const std::vector<std::string> &);
Expr *transformDot(DotExpr *, std::vector<CallArg> * = nullptr);
Expr *getClassMember(DotExpr *, std::vector<CallArg> *);
Expr *getClassMember(DotExpr *);
types::TypePtr findSpecialMember(const std::string &);
types::FuncTypePtr getBestOverload(Expr *, std::vector<CallArg> *);
types::FuncTypePtr getBestOverload(Expr *);
types::FuncTypePtr getDispatch(const std::string &);
/* Collection and comprehension expressions (collections.cpp) */