mirror of https://github.com/exaloop/codon.git
Refactor CallExpr routing
parent
6315dcc3c9
commit
f02f6371fc
|
@ -45,6 +45,13 @@ std::string Expr::wrapType(const std::string &sexpr) const {
|
|||
type && !done ? format(" #:type \"{}\"", type->debugString(2)) : "");
|
||||
return s;
|
||||
}
|
||||
Expr *Expr::operator<<(const types::TypePtr &t) {
|
||||
seqassert(type, "lhs is nullptr");
|
||||
if ((*type) << t) {
|
||||
E(Error::TYPE_UNIFY, getSrcInfo(), type->prettyString(), t->prettyString());
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
Param::Param(std::string name, Expr *type, Expr *defaultValue, int status)
|
||||
: name(std::move(name)), type(type), defaultValue(defaultValue) {
|
||||
|
|
|
@ -57,6 +57,8 @@ struct Expr : public AcceptorExtend<Expr, ASTNode> {
|
|||
static const char NodeId;
|
||||
SERIALIZE(Expr, BASE(ASTNode), /*type,*/ done, origExpr);
|
||||
|
||||
Expr *operator<<(const types::TypePtr &t);
|
||||
|
||||
protected:
|
||||
/// Add a type to S-expression string.
|
||||
std::string wrapType(const std::string &sexpr) const;
|
||||
|
|
|
@ -56,4 +56,16 @@ char Type::isStaticType() {
|
|||
return 0;
|
||||
}
|
||||
|
||||
Type *Type::operator<<(const TypePtr &t) {
|
||||
seqassert(t, "rhs is nullptr");
|
||||
types::Type::Unification undo;
|
||||
if (unify(t.get(), &undo) >= 0) {
|
||||
return this;
|
||||
} else {
|
||||
undo.undo();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // namespace codon::ast::types
|
||||
|
|
|
@ -118,6 +118,8 @@ public:
|
|||
virtual bool is(const std::string &s);
|
||||
char isStaticType();
|
||||
|
||||
Type *operator<<(const std::shared_ptr<Type> &t);
|
||||
|
||||
protected:
|
||||
Cache *cache;
|
||||
explicit Type(const std::shared_ptr<Type> &);
|
||||
|
|
|
@ -43,7 +43,7 @@ private:
|
|||
/// The absolute path of the current module.
|
||||
std::string filename;
|
||||
/// SrcInfo stack used for obtaining source information of the current expression.
|
||||
std::vector<SrcInfo> srcInfos;
|
||||
std::vector<ASTNode *> nodeStack;
|
||||
|
||||
public:
|
||||
explicit Context(std::string filename) : filename(std::move(filename)) {
|
||||
|
@ -120,9 +120,14 @@ protected:
|
|||
|
||||
public:
|
||||
/* SrcInfo helpers */
|
||||
void pushSrcInfo(SrcInfo s) { srcInfos.emplace_back(std::move(s)); }
|
||||
void popSrcInfo() { srcInfos.pop_back(); }
|
||||
SrcInfo getSrcInfo() const { return srcInfos.back(); }
|
||||
void pushNode(ASTNode *n) { nodeStack.emplace_back(n); }
|
||||
void popNode() { nodeStack.pop_back(); }
|
||||
ASTNode *getLastNode() const { return nodeStack.back(); }
|
||||
ASTNode *getParentNode() const {
|
||||
assert(nodeStack.size() > 1);
|
||||
return nodeStack[nodeStack.size() - 2];
|
||||
}
|
||||
SrcInfo getSrcInfo() const { return nodeStack.back()->getSrcInfo(); }
|
||||
};
|
||||
|
||||
} // namespace codon::ast
|
||||
|
|
|
@ -27,7 +27,6 @@ void TypecheckVisitor::visit(IdExpr *expr) {
|
|||
}
|
||||
auto o = in(ctx->cache->overloads, val->canonicalName);
|
||||
if (expr->getType()->getUnbound() && o && o->size() > 1) {
|
||||
// LOG("dispatch: {}", val->canonicalName);
|
||||
val = ctx->forceFind(getDispatch(val->canonicalName)->ast->name);
|
||||
}
|
||||
|
||||
|
@ -72,8 +71,144 @@ void TypecheckVisitor::visit(IdExpr *expr) {
|
|||
/// `a.B.c` -> canonical name of `c` in class `a.B`
|
||||
/// `python.foo` -> internal.python._get_identifier("foo")
|
||||
/// Other cases are handled during the type checking.
|
||||
/// See @c transformDot for details.
|
||||
void TypecheckVisitor::visit(DotExpr *expr) { resultExpr = transformDot(expr); }
|
||||
/// Transform a dot expression. Select the best method overload if possible.
|
||||
/// @example
|
||||
/// `obj.__class__` -> `type(obj)`
|
||||
/// `cls.__name__` -> `"class"` (same for functions)
|
||||
/// `obj.method` -> `cls.method(obj, ...)` or
|
||||
/// `cls.method(obj)` if method has `@property` attribute
|
||||
/// @c getClassMember examples:
|
||||
/// `obj.GENERIC` -> `GENERIC` (IdExpr with generic/static value)
|
||||
/// `optional.member` -> `unwrap(optional).member`
|
||||
/// `pyobj.member` -> `pyobj._getattr("member")`
|
||||
/// @return nullptr if no transformation was made
|
||||
/// See @c getClassMember and @c getBestOverload
|
||||
|
||||
void TypecheckVisitor::visit(DotExpr *expr) {
|
||||
// First flatten the imports:
|
||||
// transform Dot(Dot(a, b), c...) to {a, b, c, ...}
|
||||
|
||||
CallExpr *parentCall = cast<CallExpr>(ctx->getParentNode());
|
||||
|
||||
std::vector<std::string> chain;
|
||||
Expr *root = expr;
|
||||
for (; cast<DotExpr>(root); root = cast<DotExpr>(root)->getExpr())
|
||||
chain.push_back(cast<DotExpr>(root)->getMember());
|
||||
|
||||
Expr *nexpr = expr;
|
||||
if (auto id = cast<IdExpr>(root)) {
|
||||
// Case: a.bar.baz
|
||||
chain.push_back(id->getValue());
|
||||
std::reverse(chain.begin(), chain.end());
|
||||
auto [pos, val] = getImport(chain);
|
||||
if (!val) {
|
||||
seqassert(ctx->getBase()->pyCaptures, "unexpected py capture");
|
||||
ctx->getBase()->pyCaptures->insert(chain[0]);
|
||||
nexpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(chain[0]));
|
||||
} else if (val->getModule() == "std.python") {
|
||||
nexpr = transform(N<CallExpr>(
|
||||
N<DotExpr>(N<DotExpr>(N<IdExpr>("internal"), "python"), "_get_identifier"),
|
||||
N<StringExpr>(chain[pos++])));
|
||||
} else if (val->getModule() == ctx->getModule() && pos == 1) {
|
||||
nexpr = transform(N<IdExpr>(chain[0]), true);
|
||||
} else {
|
||||
nexpr = N<IdExpr>(val->canonicalName);
|
||||
}
|
||||
while (pos < chain.size())
|
||||
nexpr = N<DotExpr>(nexpr, chain[pos++]);
|
||||
}
|
||||
if (!cast<DotExpr>(nexpr)) {
|
||||
resultExpr = transform(nexpr);
|
||||
return;
|
||||
} else {
|
||||
expr->expr = cast<DotExpr>(nexpr)->getExpr();
|
||||
expr->member = cast<DotExpr>(nexpr)->getMember();
|
||||
}
|
||||
|
||||
// Special case: obj.__class__
|
||||
if (expr->getMember() == "__class__") {
|
||||
/// TODO: prevent cls.__class__ and type(cls)
|
||||
resultExpr = transform(N<CallExpr>(N<IdExpr>("type"), expr->getExpr()));
|
||||
return;
|
||||
}
|
||||
expr->expr = transform(expr->getExpr());
|
||||
|
||||
// Special case: fn.__name__
|
||||
// Should go before cls.__name__ to allow printing generic functions
|
||||
if (ctx->getType(expr->getExpr()->getType())->getFunc() &&
|
||||
expr->getMember() == "__name__") {
|
||||
resultExpr = transform(
|
||||
N<StringExpr>(ctx->getType(expr->getExpr()->getType())->prettyString()));
|
||||
return;
|
||||
}
|
||||
// Special case: fn.__llvm_name__ or obj.__llvm_name__
|
||||
if (expr->getMember() == "__llvm_name__") {
|
||||
if (realize(expr->getExpr()->getType()))
|
||||
resultExpr = transform(N<StringExpr>(expr->getExpr()->getType()->realizedName()));
|
||||
return;
|
||||
}
|
||||
// Special case: cls.__name__
|
||||
if (expr->getExpr()->getType()->is("type") && expr->getMember() == "__name__") {
|
||||
if (realize(expr->getExpr()->getType()))
|
||||
resultExpr = transform(
|
||||
N<StringExpr>(ctx->getType(expr->getExpr()->getType())->prettyString()));
|
||||
return;
|
||||
}
|
||||
// Special case: expr.__is_static__
|
||||
if (expr->getMember() == "__is_static__") {
|
||||
if (expr->getExpr()->isDone())
|
||||
resultExpr =
|
||||
transform(N<BoolExpr>(bool(expr->getExpr()->getType()->getStatic())));
|
||||
return;
|
||||
}
|
||||
// Special case: cls.__id__
|
||||
if (expr->getExpr()->getType()->is("type") && expr->getMember() == "__id__") {
|
||||
if (auto c = realize(getType(expr->getExpr())))
|
||||
resultExpr =
|
||||
transform(N<IntExpr>(ctx->cache->getClass(c->getClass())
|
||||
->realizations[c->getClass()->realizedName()]
|
||||
->id));
|
||||
return;
|
||||
}
|
||||
|
||||
// Ensure that the type is known (otherwise wait until it becomes known)
|
||||
auto typ = getType(expr->getExpr()) ? getType(expr->getExpr())->getClass() : nullptr;
|
||||
if (!typ)
|
||||
return;
|
||||
|
||||
// Check if this is a method or member access
|
||||
auto methods = ctx->findMethod(typ.get(), expr->getMember());
|
||||
if (methods.empty()) {
|
||||
resultExpr = getClassMember(expr);
|
||||
} else {
|
||||
auto bestMethod =
|
||||
methods.size() > 1
|
||||
? getDispatch(
|
||||
ctx->cache->functions[methods.front()->ast->getName()].rootName)
|
||||
: methods.front();
|
||||
Expr *e = N<IdExpr>(bestMethod->ast->getName());
|
||||
e->setType(unify(expr->getType(), ctx->instantiate(bestMethod, typ)));;
|
||||
if (expr->getExpr()->getType()->is("type")) {
|
||||
// Static access: `cls.method`
|
||||
} else if (parentCall && !bestMethod->ast->hasAttribute(Attr::StaticMethod) &&
|
||||
!bestMethod->ast->hasAttribute(Attr::Property)) {
|
||||
// Instance access: `obj.method`
|
||||
parentCall->items.insert(parentCall->items.begin(), expr->getExpr());
|
||||
} else {
|
||||
// Instance access: `obj.method`
|
||||
// Transform y.method to a partial call `type(obj).method(args, ...)`
|
||||
std::vector<Expr *> methodArgs;
|
||||
// Do not add self if a method is marked with @staticmethod
|
||||
if (!bestMethod->ast->hasAttribute(Attr::StaticMethod))
|
||||
methodArgs.push_back(expr->getExpr());
|
||||
// If a method is marked with @property, just call it directly
|
||||
if (!bestMethod->ast->hasAttribute(Attr::Property))
|
||||
methodArgs.push_back(N<EllipsisExpr>(EllipsisExpr::PARTIAL));
|
||||
e = N<CallExpr>(e, methodArgs);
|
||||
}
|
||||
resultExpr = transform(e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Access identifiers from outside of the current function/class scope.
|
||||
/// Either use them as-is (globals), capture them if allowed (nonlocals),
|
||||
|
@ -125,32 +260,6 @@ bool TypecheckVisitor::checkCapture(const TypeContext::Item &val) {
|
|||
if (crossCaptureBoundary)
|
||||
E(Error::ID_CANNOT_CAPTURE, getSrcInfo(), ctx->cache->rev(val->canonicalName));
|
||||
|
||||
// Case: a nonlocal variable that has not been marked with `nonlocal` statement
|
||||
// and capturing is enabled
|
||||
// auto captures = ctx->getBase()->captures;
|
||||
// if (captures && !in(*captures, val->canonicalName)) {
|
||||
// // Captures are transformed to function arguments; generate new name for that
|
||||
// // argument
|
||||
// Expr * typ = nullptr;
|
||||
// if (val->isType())
|
||||
// typ = N<IdExpr>("type");
|
||||
// if (auto st = val->isStatic())
|
||||
// typ = N<IndexExpr>(N<IdExpr>("Static"),
|
||||
// N<IdExpr>(st == StaticValue::INT ? "int" : "str"));
|
||||
// auto [newName, _] = (*captures)[val->canonicalName] = {
|
||||
// ctx->generateCanonicalName(val->canonicalName), typ};
|
||||
// ctx->cache->reverseIdentifierLookup[newName] = newName;
|
||||
// // Add newly generated argument to the context
|
||||
// std::shared_ptr<TypecheckItem> newVal = nullptr;
|
||||
// if (val->isType())
|
||||
// newVal = ctx->addType(ctx->cache->rev(val->canonicalName), newName, val->type);
|
||||
// else
|
||||
// newVal = ctx->addVar(ctx->cache->rev(val->canonicalName), newName, val->type);
|
||||
// newVal->baseName = ctx->getBaseName();
|
||||
// newVal->canShadow = false; // todo)) needed here? remove noshadow on fn
|
||||
// boundaries? newVal->scope = ctx->getBase()->scope; return true;
|
||||
// }
|
||||
|
||||
// Case: a nonlocal variable that has not been marked with `nonlocal` statement
|
||||
// and capturing is *not* enabled
|
||||
E(Error::ID_NONLOCAL, getSrcInfo(), ctx->cache->rev(val->canonicalName));
|
||||
|
@ -291,8 +400,10 @@ types::FuncTypePtr TypecheckVisitor::getDispatch(const std::string &fn) {
|
|||
|
||||
auto baseType = getFuncTypeBase(2);
|
||||
auto typ = std::make_shared<FuncType>(baseType, ast, 0);
|
||||
typ->funcParent = ctx->cache->functions[overloads[0]].type->funcParent;
|
||||
typ = std::static_pointer_cast<FuncType>(typ->generalize(ctx->typecheckLevel - 1));
|
||||
ctx->addFunc(name, name, typ);
|
||||
// LOG("-[D]-> {} / {}", typ->debugString(2), typ->funcParent->debugString(2));
|
||||
|
||||
overloads.insert(overloads.begin(), name);
|
||||
ctx->cache->functions[name].ast = ast;
|
||||
|
@ -302,186 +413,6 @@ types::FuncTypePtr TypecheckVisitor::getDispatch(const std::string &fn) {
|
|||
return typ;
|
||||
}
|
||||
|
||||
/// Transform a dot expression. Select the best method overload if possible.
|
||||
/// @param args (optional) list of class method arguments used to select the best
|
||||
/// overload. nullptr if not available.
|
||||
/// @example
|
||||
/// `obj.__class__` -> `type(obj)`
|
||||
/// `cls.__name__` -> `"class"` (same for functions)
|
||||
/// `obj.method` -> `cls.method(obj, ...)` or
|
||||
/// `cls.method(obj)` if method has `@property` attribute
|
||||
/// @c getClassMember examples:
|
||||
/// `obj.GENERIC` -> `GENERIC` (IdExpr with generic/static value)
|
||||
/// `optional.member` -> `unwrap(optional).member`
|
||||
/// `pyobj.member` -> `pyobj._getattr("member")`
|
||||
/// @return nullptr if no transformation was made
|
||||
/// See @c getClassMember and @c getBestOverload
|
||||
Expr *TypecheckVisitor::transformDot(DotExpr *expr, std::vector<CallArg> *args) {
|
||||
// First flatten the imports:
|
||||
// transform Dot(Dot(a, b), c...) to {a, b, c, ...}
|
||||
|
||||
if (!expr->getType())
|
||||
expr->setType(ctx->getUnbound());
|
||||
|
||||
std::vector<std::string> chain;
|
||||
Expr *root = expr;
|
||||
for (; cast<DotExpr>(root); root = cast<DotExpr>(root)->expr)
|
||||
chain.push_back(cast<DotExpr>(root)->member);
|
||||
|
||||
Expr *nexpr = expr;
|
||||
if (auto id = cast<IdExpr>(root)) {
|
||||
// Case: a.bar.baz
|
||||
chain.push_back(id->getValue());
|
||||
std::reverse(chain.begin(), chain.end());
|
||||
auto [pos, val] = getImport(chain);
|
||||
if (!val) {
|
||||
seqassert(ctx->getBase()->pyCaptures, "unexpected py capture");
|
||||
ctx->getBase()->pyCaptures->insert(chain[0]);
|
||||
nexpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(chain[0]));
|
||||
} else if (val->getModule() == "std.python") {
|
||||
nexpr = transform(N<CallExpr>(
|
||||
N<DotExpr>(N<DotExpr>(N<IdExpr>("internal"), "python"), "_get_identifier"),
|
||||
N<StringExpr>(chain[pos++])));
|
||||
} else if (val->getModule() == ctx->getModule() && pos == 1) {
|
||||
nexpr = transform(N<IdExpr>(chain[0]), true);
|
||||
} else {
|
||||
nexpr = N<IdExpr>(val->canonicalName);
|
||||
}
|
||||
while (pos < chain.size())
|
||||
nexpr = N<DotExpr>(nexpr, chain[pos++]);
|
||||
}
|
||||
if (!cast<DotExpr>(nexpr)) {
|
||||
if (args) {
|
||||
nexpr = transform(nexpr);
|
||||
if (auto id = cast<IdExpr>(nexpr))
|
||||
if (endswith(id->getValue(), ":dispatch"))
|
||||
if (auto bestMethod = getBestOverload(id, args)) {
|
||||
auto t = id->getType();
|
||||
nexpr = N<IdExpr>(bestMethod->ast->name);
|
||||
nexpr->setType(ctx->instantiate(bestMethod));
|
||||
}
|
||||
return nexpr;
|
||||
} else {
|
||||
return transform(nexpr);
|
||||
}
|
||||
} else {
|
||||
expr->expr = cast<DotExpr>(nexpr)->expr;
|
||||
expr->member = cast<DotExpr>(nexpr)->member;
|
||||
}
|
||||
|
||||
// Special case: obj.__class__
|
||||
if (expr->member == "__class__") {
|
||||
/// TODO: prevent cls.__class__ and type(cls)
|
||||
return N<CallExpr>(N<IdExpr>("type"), expr->expr);
|
||||
}
|
||||
expr->expr = transform(expr->expr);
|
||||
|
||||
// Special case: fn.__name__
|
||||
// Should go before cls.__name__ to allow printing generic functions
|
||||
if (ctx->getType(expr->expr->getType())->getFunc() && expr->member == "__name__") {
|
||||
return transform(
|
||||
N<StringExpr>(ctx->getType(expr->expr->getType())->prettyString()));
|
||||
}
|
||||
// Special case: fn.__llvm_name__ or obj.__llvm_name__
|
||||
if (expr->member == "__llvm_name__") {
|
||||
if (realize(expr->expr->getType()))
|
||||
return transform(N<StringExpr>(expr->expr->getType()->realizedName()));
|
||||
return nullptr;
|
||||
}
|
||||
// Special case: cls.__name__
|
||||
if (expr->expr->getType()->is("type") && expr->member == "__name__") {
|
||||
if (realize(expr->expr->getType()))
|
||||
return transform(
|
||||
N<StringExpr>(ctx->getType(expr->expr->getType())->prettyString()));
|
||||
return nullptr;
|
||||
}
|
||||
// Special case: expr.__is_static__
|
||||
if (expr->member == "__is_static__") {
|
||||
if (expr->expr->isDone())
|
||||
return transform(N<BoolExpr>(bool(expr->expr->getType()->getStatic())));
|
||||
return nullptr;
|
||||
}
|
||||
// Special case: cls.__id__
|
||||
if (expr->expr->getType()->is("type") && expr->member == "__id__") {
|
||||
if (auto c = realize(getType(expr->expr))) {
|
||||
return transform(N<IntExpr>(ctx->cache->getClass(c->getClass())
|
||||
->realizations[c->getClass()->realizedName()]
|
||||
->id));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Ensure that the type is known (otherwise wait until it becomes known)
|
||||
if (!getType(expr->expr))
|
||||
return nullptr;
|
||||
auto typ = getType(expr->expr)->getClass();
|
||||
if (!typ)
|
||||
return nullptr;
|
||||
|
||||
// Check if this is a method or member access
|
||||
if (ctx->findMethod(typ.get(), expr->member).empty())
|
||||
return getClassMember(expr, args);
|
||||
auto bestMethod = getBestOverload(expr, args);
|
||||
|
||||
if (args) {
|
||||
unify(expr->getType(), ctx->instantiate(bestMethod, typ));
|
||||
|
||||
// A function is deemed virtual if it is marked as such and
|
||||
// if a base class has a RTTI
|
||||
auto cls = ctx->cache->getClass(typ);
|
||||
bool isVirtual = in(cls->virtuals, expr->member);
|
||||
isVirtual &= cls->rtti;
|
||||
isVirtual &= !expr->expr->getType()->is("type");
|
||||
if (isVirtual && !bestMethod->ast->hasAttribute(Attr::StaticMethod) &&
|
||||
!bestMethod->ast->hasAttribute(Attr::Property)) {
|
||||
// Special case: route the call through a vtable
|
||||
if (realize(expr->getType())) {
|
||||
auto fn = expr->getType()->getFunc();
|
||||
auto vid = getRealizationID(typ.get(), fn.get());
|
||||
|
||||
// Function[Tuple[TArg1, TArg2, ...], TRet]
|
||||
std::vector<Expr *> ids;
|
||||
for (auto &t : fn->getArgTypes())
|
||||
ids.push_back(N<IdExpr>(t->realizedName()));
|
||||
auto fnType = N<InstantiateExpr>(
|
||||
N<IdExpr>("Function"),
|
||||
std::vector<Expr *>{N<InstantiateExpr>(N<IdExpr>(TYPE_TUPLE), ids),
|
||||
N<IdExpr>(fn->getRetType()->realizedName())});
|
||||
// Function[Tuple[TArg1, TArg2, ...],TRet](
|
||||
// __internal__.class_get_rtti_vtable(expr)[T[VIRTUAL_ID]]
|
||||
// )
|
||||
auto e = N<CallExpr>(
|
||||
fnType, N<IndexExpr>(N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"),
|
||||
"class_get_rtti_vtable"),
|
||||
expr->expr),
|
||||
N<IntExpr>(vid)));
|
||||
return transform(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if a method is a static or an instance method and transform accordingly
|
||||
if (expr->expr->getType()->is("type") || args) {
|
||||
// Static access: `cls.method`
|
||||
Expr *e = N<IdExpr>(bestMethod->ast->name);
|
||||
e->setType(unify(expr->getType(), ctx->instantiate(bestMethod, typ)));
|
||||
return transform(e); // Realize if needed
|
||||
} else {
|
||||
// Instance access: `obj.method`
|
||||
// Transform y.method to a partial call `type(obj).method(args, ...)`
|
||||
std::vector<Expr *> methodArgs;
|
||||
// Do not add self if a method is marked with @staticmethod
|
||||
if (!bestMethod->ast->hasAttribute(Attr::StaticMethod))
|
||||
methodArgs.push_back(expr->expr);
|
||||
// If a method is marked with @property, just call it directly
|
||||
if (!bestMethod->ast->hasAttribute(Attr::Property))
|
||||
methodArgs.push_back(N<EllipsisExpr>(EllipsisExpr::PARTIAL));
|
||||
auto e = transform(N<CallExpr>(N<IdExpr>(bestMethod->ast->name), methodArgs));
|
||||
unify(expr->getType(), e->getType());
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
/// Select the requested class member.
|
||||
/// @param args (optional) list of class method arguments used to select the best
|
||||
/// overload if the member is optional. nullptr if not available.
|
||||
|
@ -489,7 +420,7 @@ Expr *TypecheckVisitor::transformDot(DotExpr *expr, std::vector<CallArg> *args)
|
|||
/// `obj.GENERIC` -> `GENERIC` (IdExpr with generic/static value)
|
||||
/// `optional.member` -> `unwrap(optional).member`
|
||||
/// `pyobj.member` -> `pyobj._getattr("member")`
|
||||
Expr *TypecheckVisitor::getClassMember(DotExpr *expr, std::vector<CallArg> *args) {
|
||||
Expr *TypecheckVisitor::getClassMember(DotExpr *expr) {
|
||||
auto typ = getType(expr->expr)->getClass();
|
||||
seqassert(typ, "not a class");
|
||||
|
||||
|
@ -547,11 +478,8 @@ Expr *TypecheckVisitor::getClassMember(DotExpr *expr, std::vector<CallArg> *args
|
|||
|
||||
// Case: transform `optional.member` to `unwrap(optional).member`
|
||||
if (typ->is(TYPE_OPTIONAL)) {
|
||||
auto dot = N<DotExpr>(transform(N<CallExpr>(N<IdExpr>(FN_UNWRAP), expr->expr)),
|
||||
expr->member);
|
||||
dot->setType(ctx->getUnbound()); // as dot is not transformed
|
||||
if (auto d = transformDot(dot, args))
|
||||
return d;
|
||||
auto dot = transform(N<DotExpr>(
|
||||
transform(N<CallExpr>(N<IdExpr>(FN_UNWRAP), expr->expr)), expr->member));
|
||||
return dot;
|
||||
}
|
||||
|
||||
|
@ -594,121 +522,107 @@ TypePtr TypecheckVisitor::findSpecialMember(const std::string &member) {
|
|||
/// @param methods List of available methods.
|
||||
/// @param args (optional) list of class method arguments used to select the best
|
||||
/// overload if the member is optional. nullptr if not available.
|
||||
FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr, std::vector<CallArg> *args) {
|
||||
// Prepare the list of method arguments if possible
|
||||
std::unique_ptr<std::vector<CallArg>> methodArgs;
|
||||
// FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr, std::vector<CallArg> *args)
|
||||
// {
|
||||
// // Prepare the list of method arguments if possible
|
||||
// std::unique_ptr<std::vector<CallArg>> methodArgs;
|
||||
|
||||
if (args) {
|
||||
// Case: method overloads (DotExpr)
|
||||
bool addSelf = true;
|
||||
if (auto dot = cast<DotExpr>(expr)) {
|
||||
auto cls = getType(dot->expr)->getClass();
|
||||
auto methods = ctx->findMethod(cls.get(), dot->member, false);
|
||||
if (!methods.empty() && methods.front()->ast->hasAttribute(Attr::StaticMethod))
|
||||
addSelf = false;
|
||||
}
|
||||
// if (args) {
|
||||
// methodArgs = std::make_unique<std::vector<CallArg>>();
|
||||
// for (auto &a : *args)
|
||||
// methodArgs->push_back(a);
|
||||
// }
|
||||
// // else {
|
||||
// // // Partially deduced type thus far
|
||||
// // auto typeSoFar = expr->getType() ? getType(expr)->getClass() : nullptr;
|
||||
// // if (typeSoFar && typeSoFar->getFunc()) {
|
||||
// // // Case: arguments available from the previous type checking round
|
||||
// // methodArgs = std::make_unique<std::vector<CallArg>>();
|
||||
// // if (cast<DotExpr>(expr) &&
|
||||
// // !cast<DotExpr>(expr)->expr->getType()->is("type")) { // Add `self`
|
||||
// // auto n = N<NoneExpr>();
|
||||
// // n->setType(cast<DotExpr>(expr)->expr->getType());
|
||||
// // methodArgs->push_back({"", n});
|
||||
// // }
|
||||
// // for (auto &a : typeSoFar->getFunc()->getArgTypes()) {
|
||||
// // auto n = N<NoneExpr>();
|
||||
// // n->setType(a);
|
||||
// // methodArgs->push_back({"", n});
|
||||
// // }
|
||||
// // }
|
||||
// // }
|
||||
|
||||
// Case: arguments explicitly provided (by CallExpr)
|
||||
if (addSelf && cast<DotExpr>(expr) &&
|
||||
!cast<DotExpr>(expr)->expr->getType()->is("type")) {
|
||||
// Add `self` as the first argument
|
||||
args->insert(args->begin(), {"", cast<DotExpr>(expr)->expr});
|
||||
}
|
||||
methodArgs = std::make_unique<std::vector<CallArg>>();
|
||||
for (auto &a : *args)
|
||||
methodArgs->push_back(a);
|
||||
} else {
|
||||
// Partially deduced type thus far
|
||||
auto typeSoFar = expr->getType() ? getType(expr)->getClass() : nullptr;
|
||||
if (typeSoFar && typeSoFar->getFunc()) {
|
||||
// Case: arguments available from the previous type checking round
|
||||
methodArgs = std::make_unique<std::vector<CallArg>>();
|
||||
if (cast<DotExpr>(expr) &&
|
||||
!cast<DotExpr>(expr)->expr->getType()->is("type")) { // Add `self`
|
||||
auto n = N<NoneExpr>();
|
||||
n->setType(cast<DotExpr>(expr)->expr->getType());
|
||||
methodArgs->push_back({"", n});
|
||||
}
|
||||
for (auto &a : typeSoFar->getFunc()->getArgTypes()) {
|
||||
auto n = N<NoneExpr>();
|
||||
n->setType(a);
|
||||
methodArgs->push_back({"", n});
|
||||
}
|
||||
}
|
||||
}
|
||||
// std::vector<FuncTypePtr> m;
|
||||
// // Use the provided arguments to select the best method
|
||||
// if (auto dot = cast<DotExpr>(expr)) {
|
||||
// // Case: method overloads (DotExpr)
|
||||
// auto methods =
|
||||
// ctx->findMethod(getType(dot->expr)->getClass().get(), dot->member, false);
|
||||
// if (methodArgs)
|
||||
// m = findMatchingMethods(getType(dot->expr)->getClass(), methods, *methodArgs);
|
||||
// } else if (auto id = cast<IdExpr>(expr)) {
|
||||
// // Case: function overloads (IdExpr)
|
||||
// std::vector<types::FuncTypePtr> methods;
|
||||
// auto key = id->getValue();
|
||||
// if (endswith(key, ":dispatch"))
|
||||
// key = key.substr(0, key.size() - 9);
|
||||
// for (auto &m : ctx->cache->overloads[key])
|
||||
// if (!endswith(m, ":dispatch"))
|
||||
// methods.push_back(ctx->cache->functions[m].type);
|
||||
// std::reverse(methods.begin(), methods.end());
|
||||
// m = findMatchingMethods(nullptr, methods, *methodArgs);
|
||||
// }
|
||||
|
||||
bool goDispatch = methodArgs == nullptr;
|
||||
if (!goDispatch) {
|
||||
std::vector<FuncTypePtr> m;
|
||||
// Use the provided arguments to select the best method
|
||||
if (auto dot = cast<DotExpr>(expr)) {
|
||||
// Case: method overloads (DotExpr)
|
||||
auto methods =
|
||||
ctx->findMethod(getType(dot->expr)->getClass().get(), dot->member, false);
|
||||
m = findMatchingMethods(getType(dot->expr)->getClass(), methods, *methodArgs);
|
||||
} else if (auto id = cast<IdExpr>(expr)) {
|
||||
// Case: function overloads (IdExpr)
|
||||
std::vector<types::FuncTypePtr> methods;
|
||||
auto key = id->getValue();
|
||||
if (endswith(key, ":dispatch"))
|
||||
key = key.substr(0, key.size() - 9);
|
||||
for (auto &m : ctx->cache->overloads[key])
|
||||
if (!endswith(m, ":dispatch"))
|
||||
methods.push_back(ctx->cache->functions[m].type);
|
||||
std::reverse(methods.begin(), methods.end());
|
||||
m = findMatchingMethods(nullptr, methods, *methodArgs);
|
||||
}
|
||||
// bool goDispatch = false;
|
||||
// if (m.size() == 1) {
|
||||
// return m[0];
|
||||
// } else if (m.size() > 1) {
|
||||
// for (auto &a : *methodArgs) {
|
||||
// if (auto u = a.value->getType()->getUnbound()) {
|
||||
// goDispatch = true;
|
||||
// }
|
||||
// }
|
||||
// if (!goDispatch)
|
||||
// return m[0];
|
||||
// }
|
||||
|
||||
if (m.size() == 1) {
|
||||
return m[0];
|
||||
} else if (m.size() > 1) {
|
||||
for (auto &a : *methodArgs) {
|
||||
if (auto u = a.value->getType()->getUnbound()) {
|
||||
goDispatch = true;
|
||||
}
|
||||
}
|
||||
if (!goDispatch)
|
||||
return m[0];
|
||||
}
|
||||
}
|
||||
// if (goDispatch) {
|
||||
// // If overload is ambiguous, route through a dispatch function
|
||||
// std::string name;
|
||||
// if (auto dot = cast<DotExpr>(expr)) {
|
||||
// auto methods =
|
||||
// ctx->findMethod(getType(dot->expr)->getClass().get(), dot->member, false);
|
||||
// seqassert(!methods.empty(), "unknown method");
|
||||
// name = ctx->cache->functions[methods.back()->ast->name].rootName;
|
||||
// } else {
|
||||
// name = cast<IdExpr>(expr)->getValue();
|
||||
// }
|
||||
// auto t = getDispatch(name);
|
||||
// }
|
||||
|
||||
if (goDispatch) {
|
||||
// If overload is ambiguous, route through a dispatch function
|
||||
std::string name;
|
||||
if (auto dot = cast<DotExpr>(expr)) {
|
||||
auto methods =
|
||||
ctx->findMethod(getType(dot->expr)->getClass().get(), dot->member, false);
|
||||
seqassert(!methods.empty(), "unknown method");
|
||||
name = ctx->cache->functions[methods.back()->ast->name].rootName;
|
||||
} else {
|
||||
name = cast<IdExpr>(expr)->getValue();
|
||||
}
|
||||
return getDispatch(name);
|
||||
}
|
||||
// // Print a nice error message
|
||||
// std::string argsNice;
|
||||
// if (methodArgs) {
|
||||
// std::vector<std::string> a;
|
||||
// for (auto &t : *methodArgs)
|
||||
// a.emplace_back(fmt::format("{}", t.value->getType()->getStatic()
|
||||
// ? t.value->getClassType()->name
|
||||
// : t.value->getType()->prettyString()));
|
||||
// argsNice = fmt::format("({})", fmt::join(a, ", "));
|
||||
// }
|
||||
|
||||
// Print a nice error message
|
||||
std::string argsNice;
|
||||
if (methodArgs) {
|
||||
std::vector<std::string> a;
|
||||
for (auto &t : *methodArgs)
|
||||
a.emplace_back(fmt::format("{}", t.value->getType()->getStatic()
|
||||
? t.value->getClassType()->name
|
||||
: t.value->getType()->prettyString()));
|
||||
argsNice = fmt::format("({})", fmt::join(a, ", "));
|
||||
}
|
||||
// if (auto dot = cast<DotExpr>(expr)) {
|
||||
// // Debug:
|
||||
// // *args = std::vector<CallArg>(args->begin()+1, args->end());
|
||||
// getBestOverload(expr, args);
|
||||
// E(Error::DOT_NO_ATTR_ARGS, expr, getType(dot->expr)->prettyString(), dot->member,
|
||||
// argsNice);
|
||||
// } else {
|
||||
// E(Error::FN_NO_ATTR_ARGS, expr, ctx->cache->rev(cast<IdExpr>(expr)->getValue()),
|
||||
// argsNice);
|
||||
// }
|
||||
|
||||
if (auto dot = cast<DotExpr>(expr)) {
|
||||
// Debug:
|
||||
// *args = std::vector<CallArg>(args->begin()+1, args->end());
|
||||
// getBestOverload(expr, args);
|
||||
E(Error::DOT_NO_ATTR_ARGS, expr, getType(dot->expr)->prettyString(), dot->member,
|
||||
argsNice);
|
||||
} else {
|
||||
E(Error::FN_NO_ATTR_ARGS, expr, ctx->cache->rev(cast<IdExpr>(expr)->getValue()),
|
||||
argsNice);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
// return nullptr;
|
||||
// }
|
||||
|
||||
} // namespace codon::ast
|
||||
|
|
|
@ -243,8 +243,8 @@ void TypecheckVisitor::visit(AssignMemberStmt *stmt) {
|
|||
// Case: setters
|
||||
auto setters = ctx->findMethod(lhsClass.get(), format(".set_{}", stmt->member));
|
||||
if (!setters.empty()) {
|
||||
resultStmt = transform(N<ExprStmt>(
|
||||
N<CallExpr>(N<IdExpr>(setters[0]->ast->name), stmt->lhs, stmt->rhs)));
|
||||
resultStmt = transform(N<ExprStmt>(N<CallExpr>(
|
||||
N<IdExpr>(setters.front()->ast->getName()), stmt->lhs, stmt->rhs)));
|
||||
return;
|
||||
}
|
||||
// Case: class variables
|
||||
|
|
|
@ -56,66 +56,75 @@ void TypecheckVisitor::visit(EllipsisExpr *expr) {
|
|||
/// See @c transformCallArgs , @c getCalleeFn , @c callReorderArguments ,
|
||||
/// @c typecheckCallArgs , @c transformSpecialCall and @c wrapExpr for more details.
|
||||
void TypecheckVisitor::visit(CallExpr *expr) {
|
||||
// Special case!
|
||||
if (expr->getType()->getUnbound() && cast<IdExpr>(expr->expr)) {
|
||||
auto callExpr = cast<IdExpr>(transform(clean_clone(expr->expr)));
|
||||
if (callExpr && callExpr->getValue() == "std.collections.namedtuple.0") {
|
||||
resultExpr = transformNamedTuple(expr);
|
||||
return;
|
||||
} else if (callExpr && callExpr->getValue() == "std.functools.partial.0:0") {
|
||||
resultExpr = transformFunctoolsPartial(expr);
|
||||
return;
|
||||
} else if (callExpr && callExpr->getValue() == "tuple" && expr->size() == 1 &&
|
||||
cast<GeneratorExpr>(expr->begin()->value)) {
|
||||
resultExpr = transformTupleGenerator(expr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Transform and expand arguments. Return early if it cannot be done yet
|
||||
ctx->addBlock();
|
||||
if (expr->expr->getType())
|
||||
if (auto f = expr->expr->getType()->getFunc())
|
||||
addFunctionGenerics(f.get());
|
||||
auto a = transformCallArgs(expr->items);
|
||||
|
||||
ctx->popBlock();
|
||||
if (!a)
|
||||
return;
|
||||
auto orig = expr->toString(0);
|
||||
|
||||
// Check if this call is partial call
|
||||
PartialCallData part{
|
||||
!expr->items.empty() && cast<EllipsisExpr>(expr->items.back().value) &&
|
||||
cast<EllipsisExpr>(expr->items.back().value)->mode == EllipsisExpr::PARTIAL};
|
||||
// Transform the callee
|
||||
if (!part.isPartial) {
|
||||
// Intercept method calls (e.g. `obj.method`) for faster compilation (because it
|
||||
// avoids partial calls). This intercept passes the call arguments to
|
||||
// @c transformDot to select the best overload as well
|
||||
if (auto dot = cast<DotExpr>(expr->expr)) {
|
||||
// Pick the best method overload
|
||||
if (auto edt = transformDot(dot, &expr->items))
|
||||
expr->expr = edt;
|
||||
} else if (auto id = cast<IdExpr>(expr->expr)) {
|
||||
expr->expr = transform(expr->expr);
|
||||
id = cast<IdExpr>(expr->expr);
|
||||
// Pick the best function overload
|
||||
if (endswith(id->getValue(), ":dispatch")) {
|
||||
if (auto bestMethod = getBestOverload(id, &expr->items)) {
|
||||
auto t = id->getType();
|
||||
expr->expr = N<IdExpr>(bestMethod->ast->name);
|
||||
expr->expr->setType(ctx->instantiate(bestMethod));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
expr->expr = transform(expr->expr);
|
||||
|
||||
expr->expr = transform(expr->getExpr());
|
||||
if (expr->getExpr()->getType()->getUnbound())
|
||||
return; // delay
|
||||
|
||||
auto [calleeFn, newExpr] = getCalleeFn(expr, part);
|
||||
if ((resultExpr = newExpr))
|
||||
return;
|
||||
if (!calleeFn)
|
||||
return;
|
||||
|
||||
// ctx->addBlock();
|
||||
// if (expr->getExpr()->getType())
|
||||
// if (auto f = expr->getExpr()->getType()->getFunc())
|
||||
// addFunctionGenerics(f.get());
|
||||
auto a = transformCallArgs(expr->items);
|
||||
// ctx->popBlock();
|
||||
if (!a)
|
||||
return;
|
||||
|
||||
// Early dispatch modifier
|
||||
if (endswith(calleeFn->ast->getName(), ":dispatch")) {
|
||||
// LOG("-> {}", calleeFn->debugString(2));
|
||||
std::vector<FuncTypePtr> m;
|
||||
if (auto id = cast<IdExpr>(expr->getExpr())) {
|
||||
// Case: function overloads (IdExpr)
|
||||
std::vector<types::FuncTypePtr> methods;
|
||||
auto key = id->getValue();
|
||||
if (endswith(key, ":dispatch"))
|
||||
key = key.substr(0, key.size() - 9);
|
||||
for (auto &m : ctx->cache->overloads[key])
|
||||
if (!endswith(m, ":dispatch"))
|
||||
methods.push_back(ctx->cache->functions[m].type);
|
||||
std::reverse(methods.begin(), methods.end());
|
||||
m = findMatchingMethods(calleeFn->funcParent ? calleeFn->funcParent->getClass()
|
||||
: nullptr,
|
||||
methods, expr->items);
|
||||
}
|
||||
bool doDispatch = m.size() == 0;
|
||||
if (m.size() > 1) {
|
||||
for (auto &a : *expr) {
|
||||
if (auto u = a.value->getType()->getUnbound())
|
||||
doDispatch = true;
|
||||
}
|
||||
}
|
||||
if (!doDispatch) {
|
||||
calleeFn = ctx->instantiate(m.front(), calleeFn->funcParent
|
||||
? calleeFn->funcParent->getClass()
|
||||
: nullptr)
|
||||
->getFunc();
|
||||
auto e = N<IdExpr>(calleeFn->ast->getName());
|
||||
e->setType(calleeFn);
|
||||
if (cast<IdExpr>(expr->getExpr())) {
|
||||
expr->expr = e;
|
||||
} else {
|
||||
expr->expr = N<StmtExpr>(N<ExprStmt>(expr->getExpr()), e);
|
||||
}
|
||||
expr->getExpr()->setType(calleeFn);
|
||||
} else {
|
||||
LOG("-> {}", expr->toString(0));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle named and default arguments
|
||||
if ((resultExpr = callReorderArguments(calleeFn, expr, part)))
|
||||
return;
|
||||
|
@ -163,12 +172,48 @@ void TypecheckVisitor::visit(CallExpr *expr) {
|
|||
}
|
||||
call->setAttribute(Attr::ExprPartial);
|
||||
resultExpr = transform(call);
|
||||
|
||||
// LOG("{}: {} --------> {}", getSrcInfo(), orig, resultExpr->toString(0));
|
||||
} else {
|
||||
// Case: normal function call
|
||||
unify(expr->getType(), calleeFn->getRetType());
|
||||
if (done)
|
||||
expr->setDone();
|
||||
}
|
||||
|
||||
// unify(expr->getType(), ctx->instantiate(bestMethod, typ));
|
||||
// A function is deemed virtual if it is marked as such and
|
||||
// if a base class has a RTTI
|
||||
// auto cls = ctx->cache->getClass(typ);
|
||||
// bool isVirtual = in(cls->virtuals, expr->member);
|
||||
// isVirtual &= cls->rtti;
|
||||
// isVirtual &= !expr->expr->getType()->is("type");
|
||||
// if (isVirtual && !bestMethod->ast->hasAttribute(Attr::StaticMethod) &&
|
||||
// !bestMethod->ast->hasAttribute(Attr::Property)) {
|
||||
// // Special case: route the call through a vtable
|
||||
// if (realize(expr->getType())) {
|
||||
// auto fn = expr->getType()->getFunc();
|
||||
// auto vid = getRealizationID(typ.get(), fn.get());
|
||||
|
||||
// // Function[Tuple[TArg1, TArg2, ...], TRet]
|
||||
// std::vector<Expr *> ids;
|
||||
// for (auto &t : fn->getArgTypes())
|
||||
// ids.push_back(N<IdExpr>(t->realizedName()));
|
||||
// auto fnType = N<InstantiateExpr>(
|
||||
// N<IdExpr>("Function"),
|
||||
// std::vector<Expr *>{N<InstantiateExpr>(N<IdExpr>(TYPE_TUPLE), ids),
|
||||
// N<IdExpr>(fn->getRetType()->realizedName())});
|
||||
// // Function[Tuple[TArg1, TArg2, ...],TRet](
|
||||
// // __internal__.class_get_rtti_vtable(expr)[T[VIRTUAL_ID]]
|
||||
// // )
|
||||
// auto e = N<CallExpr>(
|
||||
// fnType, N<IndexExpr>(N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"),
|
||||
// "class_get_rtti_vtable"),
|
||||
// expr->expr),
|
||||
// N<IntExpr>(vid)));
|
||||
// return transform(e);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
/// Transform call arguments. Expand *args and **kwargs to the list of @c CallArg
|
||||
|
@ -257,30 +302,30 @@ bool TypecheckVisitor::transformCallArgs(std::vector<CallArg> &args) {
|
|||
/// (when needed; otherwise nullptr).
|
||||
std::pair<FuncTypePtr, Expr *> TypecheckVisitor::getCalleeFn(CallExpr *expr,
|
||||
PartialCallData &part) {
|
||||
auto callee = expr->expr->getClassType();
|
||||
auto callee = expr->getExpr()->getClassType();
|
||||
if (!callee) {
|
||||
// Case: unknown callee, wait until it becomes known
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
if (expr->expr->getType()->is("type")) {
|
||||
auto typ = expr->expr->getClassType();
|
||||
if (!isId(expr->expr, "type"))
|
||||
if (expr->getExpr()->getType()->is("type")) {
|
||||
auto typ = expr->getExpr()->getClassType();
|
||||
if (!isId(expr->getExpr(), "type"))
|
||||
typ = typ->generics[0].type->getClass();
|
||||
if (!typ)
|
||||
return {nullptr, nullptr};
|
||||
auto clsName = typ->name;
|
||||
if (typ->isRecord()) {
|
||||
// Case: tuple constructor. Transform to: `T.__new__(args)`
|
||||
return {nullptr,
|
||||
transform(N<CallExpr>(N<DotExpr>(expr->expr, "__new__"), expr->items))};
|
||||
return {nullptr, transform(N<CallExpr>(N<DotExpr>(expr->getExpr(), "__new__"),
|
||||
expr->items))};
|
||||
}
|
||||
|
||||
// Case: reference type constructor. Transform to
|
||||
// `ctr = T.__new__(); v.__init__(args)`
|
||||
Expr *var = N<IdExpr>(ctx->cache->getTemporaryVar("ctr"));
|
||||
auto newInit =
|
||||
N<AssignStmt>(clone(var), N<CallExpr>(N<DotExpr>(expr->expr, "__new__")));
|
||||
N<AssignStmt>(clone(var), N<CallExpr>(N<DotExpr>(expr->getExpr(), "__new__")));
|
||||
auto e = N<StmtExpr>(N<SuiteStmt>(newInit), clone(var));
|
||||
auto init =
|
||||
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__init__"), expr->items));
|
||||
|
@ -288,20 +333,21 @@ std::pair<FuncTypePtr, Expr *> TypecheckVisitor::getCalleeFn(CallExpr *expr,
|
|||
return {nullptr, transform(e)};
|
||||
}
|
||||
|
||||
auto calleeFn = callee->getFunc();
|
||||
if (auto partType = callee->getPartial()) {
|
||||
auto mask = partType->getPartialMask();
|
||||
auto func = ctx->instantiate(partType->getPartialFunc()->generalize(0))->getFunc();
|
||||
|
||||
// Case: calling partial object `p`. Transform roughly to
|
||||
// `part = callee; partial_fn(*part.args, args...)`
|
||||
Expr *var = N<IdExpr>(part.var = ctx->cache->getTemporaryVar("partcall"));
|
||||
expr->expr = transform(
|
||||
N<StmtExpr>(N<AssignStmt>(clone(var), expr->expr), N<IdExpr>(func->ast->name)));
|
||||
if (!partType->isPartialEmpty()) {
|
||||
// Case: calling partial object `p`. Transform roughly to
|
||||
// `part = callee; partial_fn(*part.args, args...)`
|
||||
Expr *var = N<IdExpr>(part.var = ctx->cache->getTemporaryVar("partcall"));
|
||||
expr->expr = transform(N<StmtExpr>(N<AssignStmt>(clone(var), expr->getExpr()),
|
||||
N<IdExpr>(func->ast->name)));
|
||||
}
|
||||
|
||||
// Ensure that we got a function
|
||||
calleeFn = expr->expr->getType()->getFunc();
|
||||
seqassert(calleeFn, "not a function: {}", expr->expr->getType());
|
||||
auto calleeFn = expr->getExpr()->getType()->getFunc();
|
||||
seqassert(calleeFn, "not a function: {}", expr->getExpr()->getType());
|
||||
|
||||
// Unify partial generics with types known thus far
|
||||
auto knownArgTypes = partType->generics[1].type->getClass();
|
||||
|
@ -316,12 +362,13 @@ std::pair<FuncTypePtr, Expr *> TypecheckVisitor::getCalleeFn(CallExpr *expr,
|
|||
}
|
||||
part.known = mask;
|
||||
return {calleeFn, nullptr};
|
||||
} else if (!calleeFn) {
|
||||
} else if (!callee->getFunc()) {
|
||||
// Case: callee is not a function. Try __call__ method instead
|
||||
return {nullptr,
|
||||
transform(N<CallExpr>(N<DotExpr>(expr->expr, "__call__"), expr->items))};
|
||||
return {nullptr, transform(N<CallExpr>(N<DotExpr>(expr->getExpr(), "__call__"),
|
||||
expr->items))};
|
||||
} else {
|
||||
return {callee->getFunc(), nullptr};
|
||||
}
|
||||
return {calleeFn, nullptr};
|
||||
}
|
||||
|
||||
/// Reorder the call arguments to match the signature order. Ensure that every @c
|
||||
|
@ -668,6 +715,22 @@ std::pair<bool, Expr *> TypecheckVisitor::transformSpecialCall(CallExpr *expr) {
|
|||
} else {
|
||||
return transformInternalStaticFn(expr);
|
||||
}
|
||||
|
||||
// Special case!
|
||||
// if (expr->getType()->getUnbound() && cast<IdExpr>(expr->expr)) {
|
||||
// auto callExpr = cast<IdExpr>(transform(clean_clone(expr->expr)));
|
||||
// if (callExpr && callExpr->getValue() == "std.collections.namedtuple.0") {
|
||||
// resultExpr = transformNamedTuple(expr);
|
||||
// return;
|
||||
// } else if (callExpr && callExpr->getValue() == "std.functools.partial.0:0") {
|
||||
// resultExpr = transformFunctoolsPartial(expr);
|
||||
// return;
|
||||
// } else if (callExpr && callExpr->getValue() == "tuple" && expr->size() == 1 &&
|
||||
// cast<GeneratorExpr>(expr->begin()->value)) {
|
||||
// resultExpr = transformTupleGenerator(expr);
|
||||
// return;
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
/// Transform `tuple(i for i in tup)` into a GeneratorExpr that will be handled during
|
||||
|
|
|
@ -30,7 +30,9 @@ TypeContext::TypeContext(Cache *cache, std::string filename)
|
|||
: Context<TypecheckItem>(std::move(filename)), cache(cache) {
|
||||
bases.emplace_back();
|
||||
scope.emplace_back(0);
|
||||
pushSrcInfo(cache->generateSrcInfo()); // Always have srcInfo() around
|
||||
auto e = cache->N<NoneExpr>();
|
||||
e->setSrcInfo(cache->generateSrcInfo());
|
||||
pushNode(e); // Always have srcInfo() around
|
||||
}
|
||||
|
||||
void TypeContext::add(const std::string &name, const TypeContext::Item &var) {
|
||||
|
@ -272,10 +274,6 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
|
|||
break;
|
||||
if (idx == cm.size())
|
||||
cm.push_back(key);
|
||||
// if (idx)
|
||||
// LOG("--> {}: realize {}: {} / {}", getSrcInfo(), ft->debugString(2), idx,
|
||||
// key);
|
||||
ft->index = idx;
|
||||
}
|
||||
}
|
||||
if (t->getUnion() && !t->getUnion()->isSealed()) {
|
||||
|
|
|
@ -31,16 +31,13 @@ using namespace types;
|
|||
/// @return a
|
||||
TypePtr TypecheckVisitor::unify(const TypePtr &a, const TypePtr &b) {
|
||||
seqassert(a, "lhs is nullptr");
|
||||
seqassert(b, "rhs is nullptr");
|
||||
types::Type::Unification undo;
|
||||
if (a->unify(b.get(), &undo) >= 0) {
|
||||
return a;
|
||||
} else {
|
||||
undo.undo();
|
||||
if (!((*a) << b)) {
|
||||
types::Type::Unification undo;
|
||||
a->unify(b.get(), &undo);
|
||||
E(Error::TYPE_UNIFY, getSrcInfo(), a->prettyString(), b->prettyString());
|
||||
return nullptr;
|
||||
}
|
||||
a->unify(b.get(), &undo);
|
||||
E(Error::TYPE_UNIFY, getSrcInfo(), a->prettyString(), b->prettyString());
|
||||
return nullptr;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Infer all types within a Stmt *. Implements the LTS-DI typechecking.
|
||||
|
|
|
@ -183,13 +183,12 @@ Expr *TypecheckVisitor::transform(Expr *expr, bool allowTypes) {
|
|||
if (!expr->getType())
|
||||
expr->setType(ctx->getUnbound());
|
||||
|
||||
auto typ = expr->getType();
|
||||
if (!expr->isDone()) {
|
||||
TypecheckVisitor v(ctx, preamble, prependStmts);
|
||||
v.setSrcInfo(expr->getSrcInfo());
|
||||
ctx->pushSrcInfo(expr->getSrcInfo());
|
||||
ctx->pushNode(expr);
|
||||
expr->accept(v);
|
||||
ctx->popSrcInfo();
|
||||
ctx->popNode();
|
||||
if (v.resultExpr) {
|
||||
for (auto it = expr->attributes_begin(); it != expr->attributes_end(); ++it) {
|
||||
const auto *attr = expr->getAttribute(*it);
|
||||
|
@ -197,20 +196,23 @@ Expr *TypecheckVisitor::transform(Expr *expr, bool allowTypes) {
|
|||
v.resultExpr->setAttribute(*it, attr->clone());
|
||||
}
|
||||
v.resultExpr->setOrigExpr(expr);
|
||||
// unify(expr->getType(), v.resultExpr->getType());
|
||||
expr = v.resultExpr;
|
||||
if (!expr->getType())
|
||||
expr->setType(ctx->getUnbound());
|
||||
}
|
||||
if (!allowTypes && expr && expr->getType()->is("type"))
|
||||
E(Error::UNEXPECTED_TYPE, expr, "type");
|
||||
if (!expr->getType())
|
||||
expr->setType(ctx->getUnbound());
|
||||
unify(typ, expr->getType());
|
||||
// unify(typ, expr->getType());
|
||||
if (expr->isDone())
|
||||
ctx->changedNodes++;
|
||||
}
|
||||
realize(typ);
|
||||
if (expr)
|
||||
if (expr) {
|
||||
if (auto p = realize(expr->getType()))
|
||||
unify(expr->getType(), p);
|
||||
LOG_TYPECHECK("[expr] {}: {}{}", getSrcInfo(), *(expr),
|
||||
expr->isDone() ? "[done]" : "");
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
||||
|
@ -265,9 +267,9 @@ Stmt *TypecheckVisitor::transform(Stmt *stmt) {
|
|||
if (!stmt->toString(-1).empty())
|
||||
LOG_TYPECHECK("> [{}] [{}:{}] {}", getSrcInfo(), ctx->getBaseName(),
|
||||
ctx->getBase()->iteration, stmt->toString(-1));
|
||||
ctx->pushSrcInfo(stmt->getSrcInfo());
|
||||
ctx->pushNode(stmt);
|
||||
stmt->accept(v);
|
||||
ctx->popSrcInfo();
|
||||
ctx->popNode();
|
||||
if (v.resultStmt)
|
||||
stmt = v.resultStmt;
|
||||
if (!v.prependStmts->empty()) {
|
||||
|
|
|
@ -82,10 +82,9 @@ private: // Node typechecking rules
|
|||
bool checkCapture(const TypeContext::Item &);
|
||||
void visit(DotExpr *) override;
|
||||
std::pair<size_t, TypeContext::Item> getImport(const std::vector<std::string> &);
|
||||
Expr *transformDot(DotExpr *, std::vector<CallArg> * = nullptr);
|
||||
Expr *getClassMember(DotExpr *, std::vector<CallArg> *);
|
||||
Expr *getClassMember(DotExpr *);
|
||||
types::TypePtr findSpecialMember(const std::string &);
|
||||
types::FuncTypePtr getBestOverload(Expr *, std::vector<CallArg> *);
|
||||
types::FuncTypePtr getBestOverload(Expr *);
|
||||
types::FuncTypePtr getDispatch(const std::string &);
|
||||
|
||||
/* Collection and comprehension expressions (collections.cpp) */
|
||||
|
|
Loading…
Reference in New Issue