mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
Fix polymorphism
This commit is contained in:
parent
f4fe8ec18f
commit
e737536b38
@ -57,6 +57,10 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) {
|
||||
}
|
||||
|
||||
TranslateVisitor(cache->codegenCtx).transform(stmts);
|
||||
|
||||
for (auto &[_, f]: cache->functions)
|
||||
TranslateVisitor(cache->codegenCtx).transform(f.ast);
|
||||
|
||||
cache->populatePythonModule();
|
||||
return main;
|
||||
}
|
||||
@ -174,7 +178,7 @@ void TranslateVisitor::visit(StringExpr *expr) {
|
||||
void TranslateVisitor::visit(IdExpr *expr) {
|
||||
auto val = ctx->find(expr->value);
|
||||
seqassert(val, "cannot find '{}'", expr->value);
|
||||
if (expr->value == "__vtable_size__") {
|
||||
if (expr->value == "__vtable_size__.0") {
|
||||
// LOG("[] __vtable_size__={}", ctx->cache->classRealizationCnt + 2);
|
||||
result = make<ir::IntConst>(expr, ctx->cache->classRealizationCnt + 2,
|
||||
getType(expr->getType()));
|
||||
@ -438,7 +442,6 @@ void TranslateVisitor::visit(AssignStmt *stmt) {
|
||||
auto isGlobal = in(ctx->cache->globals, var);
|
||||
ir::Var *v = nullptr;
|
||||
|
||||
|
||||
if (!stmt->lhs->type->isInstantiated() || (stmt->lhs->type->is("type"))) {
|
||||
// LOG("{} {}", getSrcInfo(), stmt->toString(0));
|
||||
return; // type aliases/fn aliases etc
|
||||
@ -697,9 +700,8 @@ void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt
|
||||
} else {
|
||||
seqassert(ss[i]->getExpr()->expr->getType(), "invalid LLVM type argument: {}",
|
||||
ss[i]->getExpr()->toString());
|
||||
literals.emplace_back(getType(
|
||||
ctx->cache->typeCtx->getType(
|
||||
ss[i]->getExpr()->expr->getType())));
|
||||
literals.emplace_back(
|
||||
getType(ctx->cache->typeCtx->getType(ss[i]->getExpr()->expr->getType())));
|
||||
}
|
||||
}
|
||||
bool isDeclare = true;
|
||||
|
@ -357,10 +357,11 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
|
||||
}
|
||||
// Special case: cls.__id__
|
||||
if (expr->expr->type->is("type") && expr->member == "__id__") {
|
||||
if (auto c = realize(expr->expr->type))
|
||||
if (auto c = realize(getType(expr->expr))) {
|
||||
return transform(N<IntExpr>(ctx->cache->classes[c->getClass()->name]
|
||||
.realizations[c->getClass()->realizedName()]
|
||||
->id));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -618,7 +619,9 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr,
|
||||
// If overload is ambiguous, route through a dispatch function
|
||||
std::string name;
|
||||
if (auto dot = expr->getDot()) {
|
||||
name = ctx->cache->getMethod(getType(dot->expr)->getClass(), dot->member);
|
||||
auto methods = ctx->findMethod(getType(dot->expr)->getClass()->name, dot->member, false);
|
||||
seqassert(!methods.empty(), "unknown method");
|
||||
name = ctx->cache->functions[methods.back()->ast->name].rootName;
|
||||
} else {
|
||||
name = expr->getId()->value;
|
||||
}
|
||||
|
@ -209,7 +209,7 @@ StmtPtr TypecheckVisitor::transformUpdate(AssignStmt *stmt) {
|
||||
void TypecheckVisitor::visit(AssignMemberStmt *stmt) {
|
||||
transform(stmt->lhs);
|
||||
|
||||
if (auto lhsClass = stmt->lhs->getType()->getClass()) {
|
||||
if (auto lhsClass = getType(stmt->lhs)->getClass()) {
|
||||
auto member = ctx->findMember(lhsClass->name, stmt->member);
|
||||
|
||||
if (!member && stmt->lhs->type->is("type")) {
|
||||
|
@ -471,9 +471,8 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
|
||||
expr->args.pop_back();
|
||||
if (!part.args)
|
||||
part.args = transform(N<TupleExpr>()); // use ()
|
||||
if (!part.kwArgs) {
|
||||
if (!part.kwArgs)
|
||||
part.kwArgs = transform(N<CallExpr>(N<IdExpr>("NamedTuple"))); // use NamedTuple()
|
||||
}
|
||||
}
|
||||
|
||||
// Unify function type generics with the provided generics
|
||||
@ -805,7 +804,8 @@ ExprPtr TypecheckVisitor::transformSuper() {
|
||||
self->type = typ;
|
||||
|
||||
auto typExpr = N<IdExpr>(superTyp->name);
|
||||
typExpr->setType(superTyp);
|
||||
typExpr->setType(ctx->instantiateGeneric(ctx->getType("type"), {superTyp}));
|
||||
// LOG("-> {:c} : {:c} {:c}", typ, vCands[1], typExpr->type);
|
||||
return transform(N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "class_super"),
|
||||
self, typExpr, N<IntExpr>(1)));
|
||||
}
|
||||
@ -819,7 +819,6 @@ ExprPtr TypecheckVisitor::transformSuper() {
|
||||
members.push_back(N<DotExpr>(N<IdExpr>(funcTyp->ast->args[0].name), field.name));
|
||||
ExprPtr e =
|
||||
transform(N<CallExpr>(N<IdExpr>(generateTuple(members.size())), members));
|
||||
|
||||
auto ft = getClassFieldTypes(superTyp);
|
||||
for (size_t i = 0; i < ft.size(); i++)
|
||||
unify(
|
||||
|
@ -122,7 +122,8 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
|
||||
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
|
||||
}
|
||||
if (auto st = getStaticGeneric(a.type.get())) {
|
||||
if (st > 3) transform(a.type); // error check
|
||||
if (st > 3)
|
||||
transform(a.type); // error check
|
||||
generic->isStatic = st;
|
||||
auto val = ctx->addVar(genName, varName, generic);
|
||||
val->generic = true;
|
||||
@ -197,8 +198,9 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
|
||||
// : ctx->generateCanonicalName(a.name);
|
||||
args.emplace_back(varName, transformType(clean_clone(a.type)),
|
||||
transform(clone(a.defaultValue), true));
|
||||
ctx->cache->classes[canonicalName].fields.emplace_back(Cache::Class::ClassField{
|
||||
varName, types::TypePtr(nullptr), canonicalName});
|
||||
ctx->cache->classes[canonicalName].fields.emplace_back(
|
||||
Cache::Class::ClassField{varName, types::TypePtr(nullptr),
|
||||
canonicalName});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -248,7 +250,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
|
||||
for (auto &b : staticBaseASTs)
|
||||
ctx->cache->classes[canonicalName].staticParentClasses.emplace_back(b->name);
|
||||
ctx->cache->classes[canonicalName].ast->validate();
|
||||
ctx->cache->classes[canonicalName].module = ctx->getModule();
|
||||
ctx->cache->classes[canonicalName].module = ctx->moduleName.path;
|
||||
|
||||
// Codegen default magic methods
|
||||
// __new__ must be the first
|
||||
@ -260,11 +262,28 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
|
||||
for (auto &base : staticBaseASTs) {
|
||||
for (auto &mm : ctx->cache->classes[base->name].methods)
|
||||
for (auto &mf : ctx->cache->overloads[mm.second]) {
|
||||
auto f = ctx->cache->functions[mf].origAst;
|
||||
const auto &fp = ctx->cache->functions[mf];
|
||||
auto f = fp.origAst;
|
||||
if (f && !f->attributes.has("autogenerated")) {
|
||||
ctx->addBlock();
|
||||
addClassGenerics(base);
|
||||
fnStmts.push_back(transform(clean_clone(f)));
|
||||
// since functions can come from other modules
|
||||
// make sure to transform them in their respective module
|
||||
// however makle sure to add/pop generics :/
|
||||
auto cf = clean_clone(f);
|
||||
if (!ctx->isStdlibLoading && fp.module != ctx->moduleName.path) {
|
||||
auto ictx = ctx->cache->imports[fp.module].ctx;
|
||||
TypeContext::BaseGuard br(ictx.get(), canonicalName);
|
||||
ictx->getBase()->type = typ;
|
||||
ictx->addBlock();
|
||||
auto tv = TypecheckVisitor(ictx);
|
||||
tv.addClassGenerics(typ, true);
|
||||
cf = std::dynamic_pointer_cast<FunctionStmt>(tv.transform(cf));
|
||||
ictx->popBlock();
|
||||
} else {
|
||||
cf = std::dynamic_pointer_cast<FunctionStmt>(transform(cf));
|
||||
}
|
||||
fnStmts.push_back(cf);
|
||||
ctx->popBlock();
|
||||
}
|
||||
}
|
||||
@ -328,13 +347,12 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
|
||||
// Debug information
|
||||
// LOG("[class] {} -> {:c} / {}", canonicalName, typ,
|
||||
// ctx->cache->classes[canonicalName].fields.size());
|
||||
// if (auto r = typ->getRecord())
|
||||
// for (auto &tx: r->args)
|
||||
// LOG(" ... {:c}", tx);
|
||||
// for (auto &m : ctx->cache->classes[canonicalName].fields)
|
||||
// LOG(" - member: {}: {:D}", m.name, m.type);
|
||||
// LOG(" - member: {}: {:c}", m.name, m.type);
|
||||
// for (auto &m : ctx->cache->classes[canonicalName].methods)
|
||||
// LOG(" - method: {}: {}", m.first, m.second);
|
||||
// for (auto &m : ctx->cache->classes[canonicalName].mro)
|
||||
// LOG(" - mro: {:c}", m);
|
||||
// LOG("");
|
||||
// ctx->dump();
|
||||
} catch (const exc::ParserException &) {
|
||||
@ -395,7 +413,10 @@ TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
|
||||
name = clsTyp->name;
|
||||
asts.push_back(clsTyp);
|
||||
Cache::Class *cachedCls = in(ctx->cache->classes, name);
|
||||
mro.push_back(cachedCls->mro);
|
||||
auto rootMro = cachedCls->mro;
|
||||
for (auto &t : rootMro)
|
||||
t = ctx->instantiate(t, clsTyp)->getClass();
|
||||
mro.push_back(rootMro);
|
||||
|
||||
// Sanity checks
|
||||
if (attr.has(Attr::Tuple) && typeAst)
|
||||
@ -438,9 +459,7 @@ TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
|
||||
transform(clean_clone(a.defaultValue)));
|
||||
ctx->cache->classes[canonicalName].fields.emplace_back(Cache::Class::ClassField{
|
||||
name, getType(args.back().type),
|
||||
ctx->cache->classes[ast->name].fields[ai].baseClass
|
||||
}
|
||||
);
|
||||
ctx->cache->classes[ast->name].fields[ai].baseClass});
|
||||
ai++;
|
||||
}
|
||||
}
|
||||
@ -455,7 +474,8 @@ TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
|
||||
if (ctx->cache->classes[canonicalName].mro.empty()) {
|
||||
E(Error::CLASS_BAD_MRO, getSrcInfo());
|
||||
} else if (ctx->cache->classes[canonicalName].mro.size() > 1) {
|
||||
// LOG("[mro] {} -> {}", canonicalName, ctx->cache->classes[canonicalName].mro);
|
||||
// for (auto &t: ctx->cache->classes[canonicalName].mro)
|
||||
// LOG("[mro] {} -> {:c}", canonicalName, t);
|
||||
}
|
||||
}
|
||||
return asts;
|
||||
@ -753,11 +773,19 @@ int TypecheckVisitor::generateKwId(const std::vector<std::string> &names) {
|
||||
}
|
||||
}
|
||||
|
||||
void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp) {
|
||||
void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp,
|
||||
bool instantiate) {
|
||||
auto addGen = [&](auto g) {
|
||||
auto t = g.type;
|
||||
if (instantiate)
|
||||
if (auto l = g.type->getLink())
|
||||
if (l->kind == LinkType::Generic) {
|
||||
auto lx = std::make_shared<LinkType>(*l);
|
||||
lx->kind = LinkType::Unbound;
|
||||
t = lx;
|
||||
}
|
||||
if (t->getClass() && !t->getStatic() && !t->is("type"))
|
||||
t = ctx->instantiateGeneric(ctx->getType("type"), {t});
|
||||
t = ctx->instantiateGeneric(ctx->getType("type"), {t});
|
||||
ctx->addVar(ctx->cache->rev(g.name), g.name, t)->generic = true;
|
||||
};
|
||||
for (auto &g : clsTyp->hiddenGenerics)
|
||||
|
@ -138,6 +138,10 @@ std::string TypeContext::getModule() const {
|
||||
return base;
|
||||
}
|
||||
|
||||
std::string TypeContext::getModulePath() const {
|
||||
return moduleName.path;
|
||||
}
|
||||
|
||||
void TypeContext::dump() { dump(0); }
|
||||
|
||||
bool TypeContext::isCanonicalName(const std::string &name) const {
|
||||
|
@ -205,6 +205,8 @@ public:
|
||||
std::string getBaseName() const;
|
||||
/// Return the current module.
|
||||
std::string getModule() const;
|
||||
/// Return the current module path.
|
||||
std::string getModulePath() const;
|
||||
/// Pretty-print the current context state.
|
||||
void dump() override;
|
||||
|
||||
|
@ -49,6 +49,8 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) {
|
||||
stmt->expr = partializeFunction(stmt->expr->type->getFunc());
|
||||
}
|
||||
|
||||
if (!ctx->getBase()->returnType->isStaticType() && stmt->expr->type->getStatic())
|
||||
stmt->expr->type = stmt->expr->type->getStatic()->getNonStaticType();
|
||||
unify(ctx->getBase()->returnType, stmt->expr->type);
|
||||
} else {
|
||||
// Just set the expr for the translation stage. However, do not unify the return
|
||||
@ -397,7 +399,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
|
||||
|
||||
// Make function AST and cache it for later realization
|
||||
auto f = N<FunctionStmt>(canonicalName, ret, args, suite, stmt->attributes);
|
||||
ctx->cache->functions[canonicalName].module = ctx->getModule();
|
||||
ctx->cache->functions[canonicalName].module = ctx->moduleName.path;
|
||||
ctx->cache->functions[canonicalName].ast = f;
|
||||
ctx->cache->functions[canonicalName].origAst = stmt_clone;
|
||||
ctx->cache->functions[canonicalName].isToplevel =
|
||||
|
@ -77,7 +77,7 @@ void TypecheckVisitor::visit(ImportStmt *stmt) {
|
||||
transform(N<AssignStmt>(
|
||||
N<IdExpr>(name),
|
||||
N<CallExpr>(N<IdExpr>("Import"), N<StringExpr>(file->path),
|
||||
N<StringExpr>(file->path), N<StringExpr>(file->module)))));
|
||||
N<StringExpr>(file->module), N<StringExpr>(file->path)))));
|
||||
} else if (stmt->what->isId("*")) {
|
||||
// Case: from foo import *
|
||||
seqassert(stmt->as.empty(), "renamed star-import");
|
||||
@ -202,7 +202,7 @@ StmtPtr TypecheckVisitor::transformCVarImport(const std::string &name, const Exp
|
||||
auto canonical = ctx->generateCanonicalName(name);
|
||||
auto typ = transformType(clone(type));
|
||||
auto val = ctx->addVar(altName.empty() ? name : altName, canonical,
|
||||
std::make_shared<types::LinkType>(typ->type->getClass()));
|
||||
std::make_shared<types::LinkType>(getType(typ)->getClass()));
|
||||
auto s = N<AssignStmt>(N<IdExpr>(canonical), nullptr, typ);
|
||||
s->lhs->setAttr(ExprAttr::ExternVar);
|
||||
s->lhs->setType(val->type);
|
||||
|
@ -394,8 +394,9 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
|
||||
|
||||
// Use NoneType as the return type when the return type is not specified and
|
||||
// function has no return statement
|
||||
if (!ast->ret && type->getRetType()->getUnbound())
|
||||
if (!ast->ret && type->getRetType()->getUnbound()) {
|
||||
unify(type->getRetType(), ctx->getType("NoneType"));
|
||||
}
|
||||
// LOG("-> {} {}", key, ret->toString(2));
|
||||
}
|
||||
// Realize the return type
|
||||
@ -583,8 +584,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
|
||||
for (auto &[_, real] : cls.realizations) {
|
||||
auto &vtable = real->vtables[baseCls];
|
||||
|
||||
auto ct =
|
||||
ctx->instantiate(ctx->forceFind(clsName)->type, cp->getClass())->getClass();
|
||||
auto ct = ctx->instantiate(ctx->getType(clsName), cp->getClass())->getClass();
|
||||
std::vector<types::TypePtr> args = fp->getArgTypes();
|
||||
args[0] = ct;
|
||||
auto m = findBestMethod(ct, fnName, args);
|
||||
@ -605,7 +605,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
|
||||
// Thunk name: _thunk.<BASE>.<FN>.<ARGS>
|
||||
auto thunkName =
|
||||
format("_thunk.{}.{}.{}", baseCls, m->ast->name, fmt::join(ns, "."));
|
||||
if (in(ctx->cache->functions, thunkName))
|
||||
if (in(ctx->cache->functions, thunkName+":0"))
|
||||
continue;
|
||||
|
||||
// Thunk contents:
|
||||
@ -614,27 +614,24 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
|
||||
// __internal__.class_base_to_derived(self, <BASE>, <DERIVED>),
|
||||
// <ARGS...>)
|
||||
std::vector<Param> fnArgs;
|
||||
fnArgs.emplace_back(fp->ast->args[0].name, N<IdExpr>(cp->realizedName()),
|
||||
nullptr);
|
||||
fnArgs.emplace_back("self", N<IdExpr>(cp->realizedName()), nullptr);
|
||||
for (size_t i = 1; i < args.size(); i++)
|
||||
fnArgs.emplace_back(fp->ast->args[i].name, N<IdExpr>(args[i]->realizedName()),
|
||||
nullptr);
|
||||
fnArgs.emplace_back(ctx->cache->rev(fp->ast->args[i].name),
|
||||
N<IdExpr>(args[i]->realizedName()), nullptr);
|
||||
std::vector<ExprPtr> callArgs;
|
||||
callArgs.emplace_back(
|
||||
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "class_base_to_derived"),
|
||||
N<IdExpr>(fp->ast->args[0].name), N<IdExpr>(cp->realizedName()),
|
||||
N<IdExpr>("self"), N<IdExpr>(cp->realizedName()),
|
||||
N<IdExpr>(real->type->realizedName())));
|
||||
for (size_t i = 1; i < args.size(); i++)
|
||||
callArgs.emplace_back(N<IdExpr>(fp->ast->args[i].name));
|
||||
callArgs.emplace_back(N<IdExpr>(ctx->cache->rev(fp->ast->args[i].name)));
|
||||
auto thunkAst = N<FunctionStmt>(
|
||||
thunkName, nullptr, fnArgs,
|
||||
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(N<IdExpr>(m->ast->name), callArgs))),
|
||||
Attr({"std.internal.attributes.inline", Attr::ForceRealize}));
|
||||
auto &thunkFn = ctx->cache->functions[thunkAst->name];
|
||||
thunkFn.ast = clone(thunkAst);
|
||||
Attr({"std.internal.attributes.inline"}));
|
||||
thunkAst = std::dynamic_pointer_cast<FunctionStmt>(transform(thunkAst));
|
||||
|
||||
transform(thunkAst);
|
||||
prependStmts->push_back(thunkAst);
|
||||
auto &thunkFn = ctx->cache->functions[thunkAst->name];
|
||||
auto ti = ctx->instantiate(thunkFn.type)->getFunc();
|
||||
auto tm = realizeFunc(ti.get(), true);
|
||||
seqassert(tm, "bad thunk {}", thunkFn.type);
|
||||
@ -651,8 +648,11 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
|
||||
auto realizedName = t->ClassType::realizedName();
|
||||
if (!in(ctx->cache->classes[t->name].realizations, realizedName))
|
||||
realize(t->getClass());
|
||||
if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir)
|
||||
if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir) {
|
||||
if (ctx->cache->classes[t->name].rtti)
|
||||
ir::cast<ir::types::RefType>(l)->setPolymorphic();
|
||||
return l;
|
||||
}
|
||||
|
||||
auto forceFindIRType = [&](const TypePtr &tt) {
|
||||
auto t = tt->getClass();
|
||||
|
@ -289,7 +289,7 @@ TypecheckVisitor::transformStaticLoopCall(
|
||||
if (vars.size() != 1)
|
||||
error("expected one item");
|
||||
for (auto &a : args) {
|
||||
stmt->rhs = a.value;
|
||||
stmt->rhs = transform(clean_clone(a.value));
|
||||
if (auto st = stmt->rhs->type->getStatic()) {
|
||||
stmt->type = N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>(st->name));
|
||||
} else {
|
||||
|
@ -543,7 +543,12 @@ void ScopingVisitor::visit(GlobalStmt *stmt) {
|
||||
}
|
||||
|
||||
void ScopingVisitor::visit(FunctionStmt *stmt) {
|
||||
visitName(stmt->name, true, stmt->shared_from_this(), stmt->getSrcInfo());
|
||||
bool isOverload = false;
|
||||
for (auto &d: stmt->decorators)
|
||||
if (d->isId("overload"))
|
||||
isOverload = true;
|
||||
if (!isOverload)
|
||||
visitName(stmt->name, true, stmt->shared_from_this(), stmt->getSrcInfo());
|
||||
|
||||
auto c = std::make_shared<ScopingVisitor::Context>();
|
||||
c->cache = ctx->cache;
|
||||
|
@ -56,25 +56,29 @@ void TypecheckVisitor::visit(UnaryExpr *expr) {
|
||||
/// @c transformBinaryInplaceMagic for details.
|
||||
/// Also evaluate static expressions. See @c evaluateStaticBinary for details.
|
||||
void TypecheckVisitor::visit(BinaryExpr *expr) {
|
||||
// Transform lexpr and rexpr. Ignore Nones for now
|
||||
if (!(startswith(expr->op, "is") && expr->lexpr->getNone()))
|
||||
transform(expr->lexpr, true);
|
||||
if (!(startswith(expr->op, "is") && expr->rexpr->getNone()))
|
||||
transform(expr->rexpr, true);
|
||||
transform(expr->lexpr, true);
|
||||
transform(expr->rexpr, true);
|
||||
|
||||
static std::unordered_map<int, std::unordered_set<std::string>> staticOps = {
|
||||
{1,
|
||||
{"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//", "%", "&",
|
||||
"|", "^"}},
|
||||
{2, {"==", "!=", "+"}},
|
||||
{3,
|
||||
{"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}};
|
||||
if (expr->lexpr->type->isStaticType() &&
|
||||
expr->lexpr->type->isStaticType() == expr->rexpr->type->isStaticType() &&
|
||||
in(staticOps[expr->lexpr->type->isStaticType()], expr->op)) {
|
||||
// Handle static expressions
|
||||
resultExpr = evaluateStaticBinary(expr);
|
||||
} else if (auto e = transformBinarySimple(expr)) {
|
||||
{3, {"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}};
|
||||
if (expr->lexpr->type->isStaticType() && expr->rexpr->type->isStaticType()) {
|
||||
auto l = expr->lexpr->type->isStaticType();
|
||||
auto r = expr->rexpr->type->isStaticType();
|
||||
bool isStatic = l == r && in(staticOps[l], expr->op);
|
||||
if (!isStatic && ((l == 1 && r == 3) || (r == 1 && l == 3)) &&
|
||||
in(staticOps[1], expr->op))
|
||||
isStatic = true;
|
||||
if (isStatic) {
|
||||
resultExpr = evaluateStaticBinary(expr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto e = transformBinarySimple(expr)) {
|
||||
// Case: simple binary expressions
|
||||
resultExpr = e;
|
||||
} else if (expr->lexpr->getType()->getUnbound() ||
|
||||
@ -264,7 +268,8 @@ void TypecheckVisitor::visit(PipeExpr *expr) {
|
||||
void TypecheckVisitor::visit(IndexExpr *expr) {
|
||||
if (expr->expr->isId("Static")) {
|
||||
// Special case: static types. Ensure that static is supported
|
||||
if (!expr->index->isId("int") && !expr->index->isId("str") && !expr->index->isId("bool"))
|
||||
if (!expr->index->isId("int") && !expr->index->isId("str") &&
|
||||
!expr->index->isId("bool"))
|
||||
E(Error::BAD_STATIC_TYPE, expr->index);
|
||||
auto typ = ctx->getUnbound();
|
||||
typ->isStatic = getStaticGeneric(expr);
|
||||
@ -364,19 +369,16 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
|
||||
unify(expr->type, typ);
|
||||
} else {
|
||||
for (size_t i = 0; i < expr->typeParams.size(); i++) {
|
||||
// transform(expr->typeParams[i]);
|
||||
transformType(expr->typeParams[i]);
|
||||
auto t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(),
|
||||
getType(expr->typeParams[i]));
|
||||
// if (expr->typeParams[i]->type->isStaticType() &&
|
||||
// generics[i].type->isStaticType()) {
|
||||
// t = ctx->instantiate(expr->typeParams[i]->type);
|
||||
// } else {
|
||||
// if (expr->typeParams[i]->getNone()) // `None` -> `NoneType`
|
||||
// transformType(expr->typeParams[i]);
|
||||
// if (!expr->typeParams[i]->type->is("type"))
|
||||
// E(Error::EXPECTED_TYPE, expr->typeParams[i], "type");
|
||||
// }
|
||||
if (expr->typeParams[i]->type->isStaticType() !=
|
||||
generics[i].type->isStaticType()) {
|
||||
if (expr->typeParams[i]->getNone()) // `None` -> `NoneType`
|
||||
transformType(expr->typeParams[i]);
|
||||
if (!expr->typeParams[i]->type->is("type"))
|
||||
E(Error::EXPECTED_TYPE, expr->typeParams[i], "type");
|
||||
}
|
||||
if (isUnion)
|
||||
typ->getUnion()->addType(t);
|
||||
else
|
||||
@ -454,7 +456,7 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
|
||||
value = !bool(value);
|
||||
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
|
||||
if (expr->op == "!")
|
||||
return transform(N<IntExpr>(bool(value)));
|
||||
return transform(N<BoolExpr>(value));
|
||||
else
|
||||
return transform(N<IntExpr>(value));
|
||||
} else {
|
||||
@ -469,9 +471,10 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
|
||||
/// Division and modulus implementations.
|
||||
std::pair<int64_t, int64_t> divMod(const std::shared_ptr<TypeContext> &ctx, int64_t a,
|
||||
int64_t b) {
|
||||
if (!b)
|
||||
if (!b) {
|
||||
E(Error::STATIC_DIV_ZERO, ctx->getSrcInfo());
|
||||
if (ctx->cache->pythonCompat) {
|
||||
return {0, 0};
|
||||
} else if (ctx->cache->pythonCompat) {
|
||||
// Use Python implementation.
|
||||
int64_t d = a / b;
|
||||
int64_t m = a - d * b;
|
||||
@ -511,7 +514,7 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
|
||||
expr->rexpr->type->getStrStatic()->value;
|
||||
bool value = expr->op == "==" ? eq : !eq;
|
||||
LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value);
|
||||
return transform(N<IntExpr>(value));
|
||||
return transform(N<BoolExpr>(value));
|
||||
} else {
|
||||
// Cannot be evaluated yet: just set the type
|
||||
expr->type->getUnbound()->isStatic = 1;
|
||||
@ -522,8 +525,12 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
|
||||
|
||||
// Case: static integers
|
||||
if (expr->lexpr->type->getStatic() && expr->rexpr->type->getStatic()) {
|
||||
int64_t lvalue = expr->lexpr->type->getIntStatic() ? expr->lexpr->type->getIntStatic()->value : expr->lexpr->type->getBoolStatic()->value;
|
||||
int64_t rvalue = expr->rexpr->type->getIntStatic() ? expr->rexpr->type->getIntStatic()->value : expr->rexpr->type->getBoolStatic()->value;
|
||||
int64_t lvalue = expr->lexpr->type->getIntStatic()
|
||||
? expr->lexpr->type->getIntStatic()->value
|
||||
: expr->lexpr->type->getBoolStatic()->value;
|
||||
int64_t rvalue = expr->rexpr->type->getIntStatic()
|
||||
? expr->rexpr->type->getIntStatic()->value
|
||||
: expr->rexpr->type->getBoolStatic()->value;
|
||||
if (expr->op == "<")
|
||||
lvalue = lvalue < rvalue;
|
||||
else if (expr->op == "<=")
|
||||
@ -596,7 +603,7 @@ ExprPtr TypecheckVisitor::transformBinarySimple(BinaryExpr *expr) {
|
||||
return transform(N<CallExpr>(N<DotExpr>(expr->rexpr, "__contains__"), expr->lexpr));
|
||||
} else if (expr->op == "is") {
|
||||
if (expr->lexpr->getNone() && expr->rexpr->getNone())
|
||||
return transform(N<IntExpr>(1));
|
||||
return transform(N<BoolExpr>(true));
|
||||
else if (expr->lexpr->getNone())
|
||||
return transform(N<BinaryExpr>(expr->rexpr, "is", expr->lexpr));
|
||||
} else if (expr->op == "is not") {
|
||||
@ -613,17 +620,17 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
|
||||
// Case: `is None` expressions
|
||||
if (expr->rexpr->getNone()) {
|
||||
if (expr->lexpr->getType()->is("NoneType"))
|
||||
return transform(N<IntExpr>(1));
|
||||
return transform(N<BoolExpr>(true));
|
||||
if (!expr->lexpr->getType()->is(TYPE_OPTIONAL)) {
|
||||
// lhs is not optional: `return False`
|
||||
return transform(N<IntExpr>(0));
|
||||
return transform(N<BoolExpr>(false));
|
||||
} else {
|
||||
// Special case: Optional[Optional[... Optional[NoneType]]...] == NoneType
|
||||
auto g = expr->lexpr->getType()->getClass();
|
||||
for (; g->generics[0].type->is("Optional"); g = g->generics[0].type->getClass())
|
||||
;
|
||||
if (g->generics[0].type->is("NoneType"))
|
||||
return transform(N<IntExpr>(1));
|
||||
return transform(N<BoolExpr>(true));
|
||||
|
||||
// lhs is optional: `return lhs.__has__().__invert__()`
|
||||
return transform(N<CallExpr>(
|
||||
@ -640,7 +647,7 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (expr->lexpr->type->is("type") && expr->rexpr->type->is("type"))
|
||||
return transform(N<IntExpr>(lc->realizedName() == rc->realizedName()));
|
||||
return transform(N<BoolExpr>(lc->realizedName() == rc->realizedName()));
|
||||
if (!lc->getClass()->isRecord() && !rc->getClass()->isRecord()) {
|
||||
// Both reference types: `return lhs.__raw__() == rhs.__raw__()`
|
||||
return transform(
|
||||
@ -659,7 +666,7 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
|
||||
}
|
||||
if (lc->realizedName() != rc->realizedName()) {
|
||||
// tuple names do not match: `return False`
|
||||
return transform(N<IntExpr>(0));
|
||||
return transform(N<BoolExpr>(false));
|
||||
}
|
||||
// Same tuple types: `return lhs == rhs`
|
||||
return transform(N<BinaryExpr>(expr->lexpr, "==", expr->rexpr));
|
||||
|
@ -579,6 +579,20 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
|
||||
} else {
|
||||
expr = p;
|
||||
}
|
||||
} else if (expectedClass && expectedClass->name == "Function" && exprClass &&
|
||||
exprClass->getPartial() &&
|
||||
exprClass->generics[2].type->getClass()->generics.size() == 1 &&
|
||||
exprClass->generics[2]
|
||||
.type->getClass()
|
||||
->generics[0]
|
||||
.type->getClass()
|
||||
->generics.empty() &&
|
||||
exprClass->generics[3]
|
||||
.type->getClass()
|
||||
->generics[0]
|
||||
.type->getClass()
|
||||
->generics.empty()) {
|
||||
expr = transform(N<IdExpr>(exprClass->getPartialFunc()->ast->name));
|
||||
} else if (allowUnwrap && exprClass && expr->type->getUnion() && expectedClass &&
|
||||
!expectedClass->getUnion()) {
|
||||
// Extract union types via __internal__.get_union
|
||||
@ -696,7 +710,8 @@ types::TypePtr TypecheckVisitor::getType(const ExprPtr &e) {
|
||||
return t;
|
||||
}
|
||||
|
||||
std::vector<types::TypePtr> TypecheckVisitor::getClassFieldTypes(const types::ClassTypePtr &cls) {
|
||||
std::vector<types::TypePtr>
|
||||
TypecheckVisitor::getClassFieldTypes(const types::ClassTypePtr &cls) {
|
||||
std::vector<types::TypePtr> result;
|
||||
ctx->addBlock();
|
||||
addClassGenerics(cls);
|
||||
|
@ -247,7 +247,7 @@ private: // Node typechecking rules
|
||||
bool);
|
||||
std::string generateTuple(size_t);
|
||||
int generateKwId(const std::vector<std::string> & = {});
|
||||
void addClassGenerics(const types::ClassTypePtr &);
|
||||
void addClassGenerics(const types::ClassTypePtr &, bool instantiate = false);
|
||||
|
||||
/* The rest (typecheck.cpp) */
|
||||
void visit(SuiteStmt *) override;
|
||||
|
@ -15,7 +15,11 @@ from internal.types.float import *
|
||||
from internal.types.byte import *
|
||||
from internal.types.generator import *
|
||||
from internal.types.optional import *
|
||||
|
||||
import internal.c_stubs as _C
|
||||
from internal.format import *
|
||||
from internal.internal import *
|
||||
|
||||
from internal.types.slice import *
|
||||
from internal.types.range import *
|
||||
from internal.types.complex import *
|
||||
@ -28,26 +32,21 @@ from internal.types.collections.set import *
|
||||
from internal.types.collections.dict import *
|
||||
from internal.types.collections.tuple import *
|
||||
|
||||
# Extended core library
|
||||
|
||||
import internal.c_stubs as _C
|
||||
from internal.format import *
|
||||
from internal.builtin import *
|
||||
|
||||
from internal.builtin import _jit_display
|
||||
from internal.str import *
|
||||
|
||||
from internal.sort import sorted
|
||||
|
||||
# # from openmp import Ident as __OMPIdent, for_par
|
||||
# # from gpu import _gpu_loop_outline_template
|
||||
# from internal.file import File, gzFile, open, gzopen
|
||||
# from pickle import pickle, unpickle
|
||||
# from internal.dlopen import dlsym as _dlsym
|
||||
# import internal.python
|
||||
# from internal.python import PyError
|
||||
from openmp import Ident as __OMPIdent, for_par
|
||||
from gpu import _gpu_loop_outline_template
|
||||
from internal.file import File, gzFile, open, gzopen
|
||||
from pickle import pickle, unpickle
|
||||
from internal.dlopen import dlsym as _dlsym
|
||||
import internal.python
|
||||
from internal.python import PyError
|
||||
|
||||
# # if __py_numerics__:
|
||||
# # import internal.pynumerics
|
||||
# # if __py_extension__:
|
||||
# # internal.python.ensure_initialized()
|
||||
if __py_numerics__:
|
||||
import internal.pynumerics
|
||||
if __py_extension__:
|
||||
internal.python.ensure_initialized()
|
||||
|
@ -174,7 +174,7 @@ class Import:
|
||||
P: Static[str]
|
||||
|
||||
@llvm
|
||||
def __new__(P: Static[str], name: str, path: str) -> Import[P]:
|
||||
def __new__(path: str, name: str, P: Static[str]) -> Import[P]:
|
||||
%0 = insertvalue { {=str}, {=str} } undef, {=str} %path, 0
|
||||
%1 = insertvalue { {=str}, {=str} } %0, {=str} %name, 1
|
||||
ret { {=str}, {=str} } %1
|
||||
|
@ -68,6 +68,10 @@ class __internal__:
|
||||
__vtables__ = Ptr[Ptr[cobj]](alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj])))
|
||||
__internal__.class_populate_vtables()
|
||||
|
||||
# def print(a):
|
||||
# from C import seq_print(str)
|
||||
# seq_print(a.__repr__())
|
||||
|
||||
def class_populate_vtables() -> None:
|
||||
"""
|
||||
Populate content of vtables. Compiler generated.
|
||||
@ -91,7 +95,8 @@ class __internal__:
|
||||
def class_set_rtti_vtable(id: int, sz: int, T: type):
|
||||
if not __has_rtti__(T):
|
||||
compile_error("class is not polymorphic")
|
||||
__vtables__[id] = Ptr[cobj](sz + 1)
|
||||
p = alloc_atomic_uncollectable((sz + 1) * sizeof(cobj))
|
||||
__vtables__[id] = Ptr[cobj](p)
|
||||
__internal__.class_set_typeinfo(__vtables__[id], id)
|
||||
|
||||
def class_set_rtti_vtable_fn(id: int, fid: int, f: cobj, T: type):
|
||||
|
@ -192,8 +192,9 @@ Jar = Ptr[byte]
|
||||
|
||||
@extend
|
||||
class NoneType:
|
||||
@llvm
|
||||
def __new__() -> NoneType:
|
||||
return ()
|
||||
ret {} {}
|
||||
|
||||
def __eq__(self, other: NoneType):
|
||||
return True
|
||||
|
@ -466,3 +466,15 @@ def test_mandelbrot():
|
||||
return (MAX, N, pixels, scale(N, -2, 0.4))
|
||||
k(pixels, grid=(N*N)//1024, block=1024)
|
||||
test_mandelbrot() #: 0 1024 (10, 2, [0, 0], 0.4)
|
||||
|
||||
#%% id_shadow_overload_call,barebones
|
||||
def foo():
|
||||
def bar():
|
||||
return -1
|
||||
def xo():
|
||||
return bar()
|
||||
@overload # w/o this this fails because xo cannot capture bar
|
||||
def bar(a):
|
||||
return a
|
||||
bar(1)
|
||||
foo()
|
||||
|
@ -92,12 +92,6 @@ class F[T: Static[float]]:
|
||||
pass
|
||||
#! expected 'int' or 'str' (only integers and strings can be static)
|
||||
|
||||
#%% class_err_10,barebones
|
||||
def foo[T]():
|
||||
class A:
|
||||
x: T
|
||||
#! name 'T' cannot be captured
|
||||
|
||||
#%% class_err_11,barebones
|
||||
def foo(x):
|
||||
class A:
|
||||
|
@ -103,7 +103,7 @@ for i in range(10):
|
||||
#%% for_error,barebones
|
||||
for i in 1:
|
||||
pass
|
||||
#! 'int' object has no attribute '__iter__'
|
||||
#! '1' object has no attribute '__iter__'
|
||||
|
||||
#%% for_void,barebones
|
||||
def foo(): yield
|
||||
|
@ -3,7 +3,7 @@
|
||||
a, b = False, 1
|
||||
print not a, not b, ~b, +b, -b, -(+(-b)) #: True False -2 1 -1 1
|
||||
|
||||
#%% binary,barebones
|
||||
#%% binary_simple,barebones
|
||||
x, y = 1, 0
|
||||
c = [1, 2, 3]
|
||||
|
||||
@ -356,16 +356,19 @@ print Foo[int, 3, 4](), Foo[int, 5, 4]()
|
||||
#%% static_int,barebones
|
||||
def foo(n: Static[int]):
|
||||
print n
|
||||
@overload
|
||||
def foo(n: Static[bool]):
|
||||
print n
|
||||
|
||||
a: Static[int] = 5
|
||||
foo(a < 1) #: 0
|
||||
foo(a <= 1) #: 0
|
||||
foo(a > 1) #: 1
|
||||
foo(a >= 1) #: 1
|
||||
foo(a == 1) #: 0
|
||||
foo(a != 1) #: 1
|
||||
foo(a and 1) #: 1
|
||||
foo(a or 1) #: 1
|
||||
foo(a < 1) #: False
|
||||
foo(a <= 1) #: False
|
||||
foo(a > 1) #: True
|
||||
foo(a >= 1) #: True
|
||||
foo(a == 1) #: False
|
||||
foo(a != 1) #: True
|
||||
foo(a and 1) #: True
|
||||
foo(a or 1) #: True
|
||||
foo(a + 1) #: 6
|
||||
foo(a - 1) #: 4
|
||||
foo(a * 1) #: 5
|
||||
|
Loading…
x
Reference in New Issue
Block a user