1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Fix polymorphism

This commit is contained in:
Ibrahim Numanagić 2024-03-18 16:15:36 -07:00
parent f4fe8ec18f
commit e737536b38
23 changed files with 204 additions and 123 deletions

View File

@ -57,6 +57,10 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) {
}
TranslateVisitor(cache->codegenCtx).transform(stmts);
for (auto &[_, f]: cache->functions)
TranslateVisitor(cache->codegenCtx).transform(f.ast);
cache->populatePythonModule();
return main;
}
@ -174,7 +178,7 @@ void TranslateVisitor::visit(StringExpr *expr) {
void TranslateVisitor::visit(IdExpr *expr) {
auto val = ctx->find(expr->value);
seqassert(val, "cannot find '{}'", expr->value);
if (expr->value == "__vtable_size__") {
if (expr->value == "__vtable_size__.0") {
// LOG("[] __vtable_size__={}", ctx->cache->classRealizationCnt + 2);
result = make<ir::IntConst>(expr, ctx->cache->classRealizationCnt + 2,
getType(expr->getType()));
@ -438,7 +442,6 @@ void TranslateVisitor::visit(AssignStmt *stmt) {
auto isGlobal = in(ctx->cache->globals, var);
ir::Var *v = nullptr;
if (!stmt->lhs->type->isInstantiated() || (stmt->lhs->type->is("type"))) {
// LOG("{} {}", getSrcInfo(), stmt->toString(0));
return; // type aliases/fn aliases etc
@ -697,9 +700,8 @@ void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt
} else {
seqassert(ss[i]->getExpr()->expr->getType(), "invalid LLVM type argument: {}",
ss[i]->getExpr()->toString());
literals.emplace_back(getType(
ctx->cache->typeCtx->getType(
ss[i]->getExpr()->expr->getType())));
literals.emplace_back(
getType(ctx->cache->typeCtx->getType(ss[i]->getExpr()->expr->getType())));
}
}
bool isDeclare = true;

View File

@ -357,10 +357,11 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
}
// Special case: cls.__id__
if (expr->expr->type->is("type") && expr->member == "__id__") {
if (auto c = realize(expr->expr->type))
if (auto c = realize(getType(expr->expr))) {
return transform(N<IntExpr>(ctx->cache->classes[c->getClass()->name]
.realizations[c->getClass()->realizedName()]
->id));
}
return nullptr;
}
@ -618,7 +619,9 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr,
// If overload is ambiguous, route through a dispatch function
std::string name;
if (auto dot = expr->getDot()) {
name = ctx->cache->getMethod(getType(dot->expr)->getClass(), dot->member);
auto methods = ctx->findMethod(getType(dot->expr)->getClass()->name, dot->member, false);
seqassert(!methods.empty(), "unknown method");
name = ctx->cache->functions[methods.back()->ast->name].rootName;
} else {
name = expr->getId()->value;
}

View File

@ -209,7 +209,7 @@ StmtPtr TypecheckVisitor::transformUpdate(AssignStmt *stmt) {
void TypecheckVisitor::visit(AssignMemberStmt *stmt) {
transform(stmt->lhs);
if (auto lhsClass = stmt->lhs->getType()->getClass()) {
if (auto lhsClass = getType(stmt->lhs)->getClass()) {
auto member = ctx->findMember(lhsClass->name, stmt->member);
if (!member && stmt->lhs->type->is("type")) {

View File

@ -471,9 +471,8 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
expr->args.pop_back();
if (!part.args)
part.args = transform(N<TupleExpr>()); // use ()
if (!part.kwArgs) {
if (!part.kwArgs)
part.kwArgs = transform(N<CallExpr>(N<IdExpr>("NamedTuple"))); // use NamedTuple()
}
}
// Unify function type generics with the provided generics
@ -805,7 +804,8 @@ ExprPtr TypecheckVisitor::transformSuper() {
self->type = typ;
auto typExpr = N<IdExpr>(superTyp->name);
typExpr->setType(superTyp);
typExpr->setType(ctx->instantiateGeneric(ctx->getType("type"), {superTyp}));
// LOG("-> {:c} : {:c} {:c}", typ, vCands[1], typExpr->type);
return transform(N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "class_super"),
self, typExpr, N<IntExpr>(1)));
}
@ -819,7 +819,6 @@ ExprPtr TypecheckVisitor::transformSuper() {
members.push_back(N<DotExpr>(N<IdExpr>(funcTyp->ast->args[0].name), field.name));
ExprPtr e =
transform(N<CallExpr>(N<IdExpr>(generateTuple(members.size())), members));
auto ft = getClassFieldTypes(superTyp);
for (size_t i = 0; i < ft.size(); i++)
unify(

View File

@ -122,7 +122,8 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
}
if (auto st = getStaticGeneric(a.type.get())) {
if (st > 3) transform(a.type); // error check
if (st > 3)
transform(a.type); // error check
generic->isStatic = st;
auto val = ctx->addVar(genName, varName, generic);
val->generic = true;
@ -197,8 +198,9 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
// : ctx->generateCanonicalName(a.name);
args.emplace_back(varName, transformType(clean_clone(a.type)),
transform(clone(a.defaultValue), true));
ctx->cache->classes[canonicalName].fields.emplace_back(Cache::Class::ClassField{
varName, types::TypePtr(nullptr), canonicalName});
ctx->cache->classes[canonicalName].fields.emplace_back(
Cache::Class::ClassField{varName, types::TypePtr(nullptr),
canonicalName});
}
}
}
@ -248,7 +250,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
for (auto &b : staticBaseASTs)
ctx->cache->classes[canonicalName].staticParentClasses.emplace_back(b->name);
ctx->cache->classes[canonicalName].ast->validate();
ctx->cache->classes[canonicalName].module = ctx->getModule();
ctx->cache->classes[canonicalName].module = ctx->moduleName.path;
// Codegen default magic methods
// __new__ must be the first
@ -260,11 +262,28 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
for (auto &base : staticBaseASTs) {
for (auto &mm : ctx->cache->classes[base->name].methods)
for (auto &mf : ctx->cache->overloads[mm.second]) {
auto f = ctx->cache->functions[mf].origAst;
const auto &fp = ctx->cache->functions[mf];
auto f = fp.origAst;
if (f && !f->attributes.has("autogenerated")) {
ctx->addBlock();
addClassGenerics(base);
fnStmts.push_back(transform(clean_clone(f)));
// since functions can come from other modules
// make sure to transform them in their respective module
// however makle sure to add/pop generics :/
auto cf = clean_clone(f);
if (!ctx->isStdlibLoading && fp.module != ctx->moduleName.path) {
auto ictx = ctx->cache->imports[fp.module].ctx;
TypeContext::BaseGuard br(ictx.get(), canonicalName);
ictx->getBase()->type = typ;
ictx->addBlock();
auto tv = TypecheckVisitor(ictx);
tv.addClassGenerics(typ, true);
cf = std::dynamic_pointer_cast<FunctionStmt>(tv.transform(cf));
ictx->popBlock();
} else {
cf = std::dynamic_pointer_cast<FunctionStmt>(transform(cf));
}
fnStmts.push_back(cf);
ctx->popBlock();
}
}
@ -328,13 +347,12 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
// Debug information
// LOG("[class] {} -> {:c} / {}", canonicalName, typ,
// ctx->cache->classes[canonicalName].fields.size());
// if (auto r = typ->getRecord())
// for (auto &tx: r->args)
// LOG(" ... {:c}", tx);
// for (auto &m : ctx->cache->classes[canonicalName].fields)
// LOG(" - member: {}: {:D}", m.name, m.type);
// LOG(" - member: {}: {:c}", m.name, m.type);
// for (auto &m : ctx->cache->classes[canonicalName].methods)
// LOG(" - method: {}: {}", m.first, m.second);
// for (auto &m : ctx->cache->classes[canonicalName].mro)
// LOG(" - mro: {:c}", m);
// LOG("");
// ctx->dump();
} catch (const exc::ParserException &) {
@ -395,7 +413,10 @@ TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
name = clsTyp->name;
asts.push_back(clsTyp);
Cache::Class *cachedCls = in(ctx->cache->classes, name);
mro.push_back(cachedCls->mro);
auto rootMro = cachedCls->mro;
for (auto &t : rootMro)
t = ctx->instantiate(t, clsTyp)->getClass();
mro.push_back(rootMro);
// Sanity checks
if (attr.has(Attr::Tuple) && typeAst)
@ -438,9 +459,7 @@ TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
transform(clean_clone(a.defaultValue)));
ctx->cache->classes[canonicalName].fields.emplace_back(Cache::Class::ClassField{
name, getType(args.back().type),
ctx->cache->classes[ast->name].fields[ai].baseClass
}
);
ctx->cache->classes[ast->name].fields[ai].baseClass});
ai++;
}
}
@ -455,7 +474,8 @@ TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
if (ctx->cache->classes[canonicalName].mro.empty()) {
E(Error::CLASS_BAD_MRO, getSrcInfo());
} else if (ctx->cache->classes[canonicalName].mro.size() > 1) {
// LOG("[mro] {} -> {}", canonicalName, ctx->cache->classes[canonicalName].mro);
// for (auto &t: ctx->cache->classes[canonicalName].mro)
// LOG("[mro] {} -> {:c}", canonicalName, t);
}
}
return asts;
@ -753,11 +773,19 @@ int TypecheckVisitor::generateKwId(const std::vector<std::string> &names) {
}
}
void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp) {
void TypecheckVisitor::addClassGenerics(const types::ClassTypePtr &clsTyp,
bool instantiate) {
auto addGen = [&](auto g) {
auto t = g.type;
if (instantiate)
if (auto l = g.type->getLink())
if (l->kind == LinkType::Generic) {
auto lx = std::make_shared<LinkType>(*l);
lx->kind = LinkType::Unbound;
t = lx;
}
if (t->getClass() && !t->getStatic() && !t->is("type"))
t = ctx->instantiateGeneric(ctx->getType("type"), {t});
t = ctx->instantiateGeneric(ctx->getType("type"), {t});
ctx->addVar(ctx->cache->rev(g.name), g.name, t)->generic = true;
};
for (auto &g : clsTyp->hiddenGenerics)

View File

@ -138,6 +138,10 @@ std::string TypeContext::getModule() const {
return base;
}
std::string TypeContext::getModulePath() const {
return moduleName.path;
}
void TypeContext::dump() { dump(0); }
bool TypeContext::isCanonicalName(const std::string &name) const {

View File

@ -205,6 +205,8 @@ public:
std::string getBaseName() const;
/// Return the current module.
std::string getModule() const;
/// Return the current module path.
std::string getModulePath() const;
/// Pretty-print the current context state.
void dump() override;

View File

@ -49,6 +49,8 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) {
stmt->expr = partializeFunction(stmt->expr->type->getFunc());
}
if (!ctx->getBase()->returnType->isStaticType() && stmt->expr->type->getStatic())
stmt->expr->type = stmt->expr->type->getStatic()->getNonStaticType();
unify(ctx->getBase()->returnType, stmt->expr->type);
} else {
// Just set the expr for the translation stage. However, do not unify the return
@ -397,7 +399,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
// Make function AST and cache it for later realization
auto f = N<FunctionStmt>(canonicalName, ret, args, suite, stmt->attributes);
ctx->cache->functions[canonicalName].module = ctx->getModule();
ctx->cache->functions[canonicalName].module = ctx->moduleName.path;
ctx->cache->functions[canonicalName].ast = f;
ctx->cache->functions[canonicalName].origAst = stmt_clone;
ctx->cache->functions[canonicalName].isToplevel =

View File

@ -77,7 +77,7 @@ void TypecheckVisitor::visit(ImportStmt *stmt) {
transform(N<AssignStmt>(
N<IdExpr>(name),
N<CallExpr>(N<IdExpr>("Import"), N<StringExpr>(file->path),
N<StringExpr>(file->path), N<StringExpr>(file->module)))));
N<StringExpr>(file->module), N<StringExpr>(file->path)))));
} else if (stmt->what->isId("*")) {
// Case: from foo import *
seqassert(stmt->as.empty(), "renamed star-import");
@ -202,7 +202,7 @@ StmtPtr TypecheckVisitor::transformCVarImport(const std::string &name, const Exp
auto canonical = ctx->generateCanonicalName(name);
auto typ = transformType(clone(type));
auto val = ctx->addVar(altName.empty() ? name : altName, canonical,
std::make_shared<types::LinkType>(typ->type->getClass()));
std::make_shared<types::LinkType>(getType(typ)->getClass()));
auto s = N<AssignStmt>(N<IdExpr>(canonical), nullptr, typ);
s->lhs->setAttr(ExprAttr::ExternVar);
s->lhs->setType(val->type);

View File

@ -394,8 +394,9 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
// Use NoneType as the return type when the return type is not specified and
// function has no return statement
if (!ast->ret && type->getRetType()->getUnbound())
if (!ast->ret && type->getRetType()->getUnbound()) {
unify(type->getRetType(), ctx->getType("NoneType"));
}
// LOG("-> {} {}", key, ret->toString(2));
}
// Realize the return type
@ -583,8 +584,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
for (auto &[_, real] : cls.realizations) {
auto &vtable = real->vtables[baseCls];
auto ct =
ctx->instantiate(ctx->forceFind(clsName)->type, cp->getClass())->getClass();
auto ct = ctx->instantiate(ctx->getType(clsName), cp->getClass())->getClass();
std::vector<types::TypePtr> args = fp->getArgTypes();
args[0] = ct;
auto m = findBestMethod(ct, fnName, args);
@ -605,7 +605,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
// Thunk name: _thunk.<BASE>.<FN>.<ARGS>
auto thunkName =
format("_thunk.{}.{}.{}", baseCls, m->ast->name, fmt::join(ns, "."));
if (in(ctx->cache->functions, thunkName))
if (in(ctx->cache->functions, thunkName+":0"))
continue;
// Thunk contents:
@ -614,27 +614,24 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
// __internal__.class_base_to_derived(self, <BASE>, <DERIVED>),
// <ARGS...>)
std::vector<Param> fnArgs;
fnArgs.emplace_back(fp->ast->args[0].name, N<IdExpr>(cp->realizedName()),
nullptr);
fnArgs.emplace_back("self", N<IdExpr>(cp->realizedName()), nullptr);
for (size_t i = 1; i < args.size(); i++)
fnArgs.emplace_back(fp->ast->args[i].name, N<IdExpr>(args[i]->realizedName()),
nullptr);
fnArgs.emplace_back(ctx->cache->rev(fp->ast->args[i].name),
N<IdExpr>(args[i]->realizedName()), nullptr);
std::vector<ExprPtr> callArgs;
callArgs.emplace_back(
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "class_base_to_derived"),
N<IdExpr>(fp->ast->args[0].name), N<IdExpr>(cp->realizedName()),
N<IdExpr>("self"), N<IdExpr>(cp->realizedName()),
N<IdExpr>(real->type->realizedName())));
for (size_t i = 1; i < args.size(); i++)
callArgs.emplace_back(N<IdExpr>(fp->ast->args[i].name));
callArgs.emplace_back(N<IdExpr>(ctx->cache->rev(fp->ast->args[i].name)));
auto thunkAst = N<FunctionStmt>(
thunkName, nullptr, fnArgs,
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(N<IdExpr>(m->ast->name), callArgs))),
Attr({"std.internal.attributes.inline", Attr::ForceRealize}));
auto &thunkFn = ctx->cache->functions[thunkAst->name];
thunkFn.ast = clone(thunkAst);
Attr({"std.internal.attributes.inline"}));
thunkAst = std::dynamic_pointer_cast<FunctionStmt>(transform(thunkAst));
transform(thunkAst);
prependStmts->push_back(thunkAst);
auto &thunkFn = ctx->cache->functions[thunkAst->name];
auto ti = ctx->instantiate(thunkFn.type)->getFunc();
auto tm = realizeFunc(ti.get(), true);
seqassert(tm, "bad thunk {}", thunkFn.type);
@ -651,8 +648,11 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
auto realizedName = t->ClassType::realizedName();
if (!in(ctx->cache->classes[t->name].realizations, realizedName))
realize(t->getClass());
if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir)
if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir) {
if (ctx->cache->classes[t->name].rtti)
ir::cast<ir::types::RefType>(l)->setPolymorphic();
return l;
}
auto forceFindIRType = [&](const TypePtr &tt) {
auto t = tt->getClass();

View File

@ -289,7 +289,7 @@ TypecheckVisitor::transformStaticLoopCall(
if (vars.size() != 1)
error("expected one item");
for (auto &a : args) {
stmt->rhs = a.value;
stmt->rhs = transform(clean_clone(a.value));
if (auto st = stmt->rhs->type->getStatic()) {
stmt->type = N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>(st->name));
} else {

View File

@ -543,7 +543,12 @@ void ScopingVisitor::visit(GlobalStmt *stmt) {
}
void ScopingVisitor::visit(FunctionStmt *stmt) {
visitName(stmt->name, true, stmt->shared_from_this(), stmt->getSrcInfo());
bool isOverload = false;
for (auto &d: stmt->decorators)
if (d->isId("overload"))
isOverload = true;
if (!isOverload)
visitName(stmt->name, true, stmt->shared_from_this(), stmt->getSrcInfo());
auto c = std::make_shared<ScopingVisitor::Context>();
c->cache = ctx->cache;

View File

@ -56,25 +56,29 @@ void TypecheckVisitor::visit(UnaryExpr *expr) {
/// @c transformBinaryInplaceMagic for details.
/// Also evaluate static expressions. See @c evaluateStaticBinary for details.
void TypecheckVisitor::visit(BinaryExpr *expr) {
// Transform lexpr and rexpr. Ignore Nones for now
if (!(startswith(expr->op, "is") && expr->lexpr->getNone()))
transform(expr->lexpr, true);
if (!(startswith(expr->op, "is") && expr->rexpr->getNone()))
transform(expr->rexpr, true);
transform(expr->lexpr, true);
transform(expr->rexpr, true);
static std::unordered_map<int, std::unordered_set<std::string>> staticOps = {
{1,
{"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//", "%", "&",
"|", "^"}},
{2, {"==", "!=", "+"}},
{3,
{"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}};
if (expr->lexpr->type->isStaticType() &&
expr->lexpr->type->isStaticType() == expr->rexpr->type->isStaticType() &&
in(staticOps[expr->lexpr->type->isStaticType()], expr->op)) {
// Handle static expressions
resultExpr = evaluateStaticBinary(expr);
} else if (auto e = transformBinarySimple(expr)) {
{3, {"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}};
if (expr->lexpr->type->isStaticType() && expr->rexpr->type->isStaticType()) {
auto l = expr->lexpr->type->isStaticType();
auto r = expr->rexpr->type->isStaticType();
bool isStatic = l == r && in(staticOps[l], expr->op);
if (!isStatic && ((l == 1 && r == 3) || (r == 1 && l == 3)) &&
in(staticOps[1], expr->op))
isStatic = true;
if (isStatic) {
resultExpr = evaluateStaticBinary(expr);
return;
}
}
if (auto e = transformBinarySimple(expr)) {
// Case: simple binary expressions
resultExpr = e;
} else if (expr->lexpr->getType()->getUnbound() ||
@ -264,7 +268,8 @@ void TypecheckVisitor::visit(PipeExpr *expr) {
void TypecheckVisitor::visit(IndexExpr *expr) {
if (expr->expr->isId("Static")) {
// Special case: static types. Ensure that static is supported
if (!expr->index->isId("int") && !expr->index->isId("str") && !expr->index->isId("bool"))
if (!expr->index->isId("int") && !expr->index->isId("str") &&
!expr->index->isId("bool"))
E(Error::BAD_STATIC_TYPE, expr->index);
auto typ = ctx->getUnbound();
typ->isStatic = getStaticGeneric(expr);
@ -364,19 +369,16 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
unify(expr->type, typ);
} else {
for (size_t i = 0; i < expr->typeParams.size(); i++) {
// transform(expr->typeParams[i]);
transformType(expr->typeParams[i]);
auto t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(),
getType(expr->typeParams[i]));
// if (expr->typeParams[i]->type->isStaticType() &&
// generics[i].type->isStaticType()) {
// t = ctx->instantiate(expr->typeParams[i]->type);
// } else {
// if (expr->typeParams[i]->getNone()) // `None` -> `NoneType`
// transformType(expr->typeParams[i]);
// if (!expr->typeParams[i]->type->is("type"))
// E(Error::EXPECTED_TYPE, expr->typeParams[i], "type");
// }
if (expr->typeParams[i]->type->isStaticType() !=
generics[i].type->isStaticType()) {
if (expr->typeParams[i]->getNone()) // `None` -> `NoneType`
transformType(expr->typeParams[i]);
if (!expr->typeParams[i]->type->is("type"))
E(Error::EXPECTED_TYPE, expr->typeParams[i], "type");
}
if (isUnion)
typ->getUnion()->addType(t);
else
@ -454,7 +456,7 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
value = !bool(value);
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
if (expr->op == "!")
return transform(N<IntExpr>(bool(value)));
return transform(N<BoolExpr>(value));
else
return transform(N<IntExpr>(value));
} else {
@ -469,9 +471,10 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
/// Division and modulus implementations.
std::pair<int64_t, int64_t> divMod(const std::shared_ptr<TypeContext> &ctx, int64_t a,
int64_t b) {
if (!b)
if (!b) {
E(Error::STATIC_DIV_ZERO, ctx->getSrcInfo());
if (ctx->cache->pythonCompat) {
return {0, 0};
} else if (ctx->cache->pythonCompat) {
// Use Python implementation.
int64_t d = a / b;
int64_t m = a - d * b;
@ -511,7 +514,7 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
expr->rexpr->type->getStrStatic()->value;
bool value = expr->op == "==" ? eq : !eq;
LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value);
return transform(N<IntExpr>(value));
return transform(N<BoolExpr>(value));
} else {
// Cannot be evaluated yet: just set the type
expr->type->getUnbound()->isStatic = 1;
@ -522,8 +525,12 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
// Case: static integers
if (expr->lexpr->type->getStatic() && expr->rexpr->type->getStatic()) {
int64_t lvalue = expr->lexpr->type->getIntStatic() ? expr->lexpr->type->getIntStatic()->value : expr->lexpr->type->getBoolStatic()->value;
int64_t rvalue = expr->rexpr->type->getIntStatic() ? expr->rexpr->type->getIntStatic()->value : expr->rexpr->type->getBoolStatic()->value;
int64_t lvalue = expr->lexpr->type->getIntStatic()
? expr->lexpr->type->getIntStatic()->value
: expr->lexpr->type->getBoolStatic()->value;
int64_t rvalue = expr->rexpr->type->getIntStatic()
? expr->rexpr->type->getIntStatic()->value
: expr->rexpr->type->getBoolStatic()->value;
if (expr->op == "<")
lvalue = lvalue < rvalue;
else if (expr->op == "<=")
@ -596,7 +603,7 @@ ExprPtr TypecheckVisitor::transformBinarySimple(BinaryExpr *expr) {
return transform(N<CallExpr>(N<DotExpr>(expr->rexpr, "__contains__"), expr->lexpr));
} else if (expr->op == "is") {
if (expr->lexpr->getNone() && expr->rexpr->getNone())
return transform(N<IntExpr>(1));
return transform(N<BoolExpr>(true));
else if (expr->lexpr->getNone())
return transform(N<BinaryExpr>(expr->rexpr, "is", expr->lexpr));
} else if (expr->op == "is not") {
@ -613,17 +620,17 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
// Case: `is None` expressions
if (expr->rexpr->getNone()) {
if (expr->lexpr->getType()->is("NoneType"))
return transform(N<IntExpr>(1));
return transform(N<BoolExpr>(true));
if (!expr->lexpr->getType()->is(TYPE_OPTIONAL)) {
// lhs is not optional: `return False`
return transform(N<IntExpr>(0));
return transform(N<BoolExpr>(false));
} else {
// Special case: Optional[Optional[... Optional[NoneType]]...] == NoneType
auto g = expr->lexpr->getType()->getClass();
for (; g->generics[0].type->is("Optional"); g = g->generics[0].type->getClass())
;
if (g->generics[0].type->is("NoneType"))
return transform(N<IntExpr>(1));
return transform(N<BoolExpr>(true));
// lhs is optional: `return lhs.__has__().__invert__()`
return transform(N<CallExpr>(
@ -640,7 +647,7 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
return nullptr;
}
if (expr->lexpr->type->is("type") && expr->rexpr->type->is("type"))
return transform(N<IntExpr>(lc->realizedName() == rc->realizedName()));
return transform(N<BoolExpr>(lc->realizedName() == rc->realizedName()));
if (!lc->getClass()->isRecord() && !rc->getClass()->isRecord()) {
// Both reference types: `return lhs.__raw__() == rhs.__raw__()`
return transform(
@ -659,7 +666,7 @@ ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) {
}
if (lc->realizedName() != rc->realizedName()) {
// tuple names do not match: `return False`
return transform(N<IntExpr>(0));
return transform(N<BoolExpr>(false));
}
// Same tuple types: `return lhs == rhs`
return transform(N<BinaryExpr>(expr->lexpr, "==", expr->rexpr));

View File

@ -579,6 +579,20 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
} else {
expr = p;
}
} else if (expectedClass && expectedClass->name == "Function" && exprClass &&
exprClass->getPartial() &&
exprClass->generics[2].type->getClass()->generics.size() == 1 &&
exprClass->generics[2]
.type->getClass()
->generics[0]
.type->getClass()
->generics.empty() &&
exprClass->generics[3]
.type->getClass()
->generics[0]
.type->getClass()
->generics.empty()) {
expr = transform(N<IdExpr>(exprClass->getPartialFunc()->ast->name));
} else if (allowUnwrap && exprClass && expr->type->getUnion() && expectedClass &&
!expectedClass->getUnion()) {
// Extract union types via __internal__.get_union
@ -696,7 +710,8 @@ types::TypePtr TypecheckVisitor::getType(const ExprPtr &e) {
return t;
}
std::vector<types::TypePtr> TypecheckVisitor::getClassFieldTypes(const types::ClassTypePtr &cls) {
std::vector<types::TypePtr>
TypecheckVisitor::getClassFieldTypes(const types::ClassTypePtr &cls) {
std::vector<types::TypePtr> result;
ctx->addBlock();
addClassGenerics(cls);

View File

@ -247,7 +247,7 @@ private: // Node typechecking rules
bool);
std::string generateTuple(size_t);
int generateKwId(const std::vector<std::string> & = {});
void addClassGenerics(const types::ClassTypePtr &);
void addClassGenerics(const types::ClassTypePtr &, bool instantiate = false);
/* The rest (typecheck.cpp) */
void visit(SuiteStmt *) override;

View File

@ -15,7 +15,11 @@ from internal.types.float import *
from internal.types.byte import *
from internal.types.generator import *
from internal.types.optional import *
import internal.c_stubs as _C
from internal.format import *
from internal.internal import *
from internal.types.slice import *
from internal.types.range import *
from internal.types.complex import *
@ -28,26 +32,21 @@ from internal.types.collections.set import *
from internal.types.collections.dict import *
from internal.types.collections.tuple import *
# Extended core library
import internal.c_stubs as _C
from internal.format import *
from internal.builtin import *
from internal.builtin import _jit_display
from internal.str import *
from internal.sort import sorted
# # from openmp import Ident as __OMPIdent, for_par
# # from gpu import _gpu_loop_outline_template
# from internal.file import File, gzFile, open, gzopen
# from pickle import pickle, unpickle
# from internal.dlopen import dlsym as _dlsym
# import internal.python
# from internal.python import PyError
from openmp import Ident as __OMPIdent, for_par
from gpu import _gpu_loop_outline_template
from internal.file import File, gzFile, open, gzopen
from pickle import pickle, unpickle
from internal.dlopen import dlsym as _dlsym
import internal.python
from internal.python import PyError
# # if __py_numerics__:
# # import internal.pynumerics
# # if __py_extension__:
# # internal.python.ensure_initialized()
if __py_numerics__:
import internal.pynumerics
if __py_extension__:
internal.python.ensure_initialized()

View File

@ -174,7 +174,7 @@ class Import:
P: Static[str]
@llvm
def __new__(P: Static[str], name: str, path: str) -> Import[P]:
def __new__(path: str, name: str, P: Static[str]) -> Import[P]:
%0 = insertvalue { {=str}, {=str} } undef, {=str} %path, 0
%1 = insertvalue { {=str}, {=str} } %0, {=str} %name, 1
ret { {=str}, {=str} } %1

View File

@ -68,6 +68,10 @@ class __internal__:
__vtables__ = Ptr[Ptr[cobj]](alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj])))
__internal__.class_populate_vtables()
# def print(a):
# from C import seq_print(str)
# seq_print(a.__repr__())
def class_populate_vtables() -> None:
"""
Populate content of vtables. Compiler generated.
@ -91,7 +95,8 @@ class __internal__:
def class_set_rtti_vtable(id: int, sz: int, T: type):
if not __has_rtti__(T):
compile_error("class is not polymorphic")
__vtables__[id] = Ptr[cobj](sz + 1)
p = alloc_atomic_uncollectable((sz + 1) * sizeof(cobj))
__vtables__[id] = Ptr[cobj](p)
__internal__.class_set_typeinfo(__vtables__[id], id)
def class_set_rtti_vtable_fn(id: int, fid: int, f: cobj, T: type):

View File

@ -192,8 +192,9 @@ Jar = Ptr[byte]
@extend
class NoneType:
@llvm
def __new__() -> NoneType:
return ()
ret {} {}
def __eq__(self, other: NoneType):
return True

View File

@ -466,3 +466,15 @@ def test_mandelbrot():
return (MAX, N, pixels, scale(N, -2, 0.4))
k(pixels, grid=(N*N)//1024, block=1024)
test_mandelbrot() #: 0 1024 (10, 2, [0, 0], 0.4)
#%% id_shadow_overload_call,barebones
def foo():
def bar():
return -1
def xo():
return bar()
@overload # w/o this this fails because xo cannot capture bar
def bar(a):
return a
bar(1)
foo()

View File

@ -92,12 +92,6 @@ class F[T: Static[float]]:
pass
#! expected 'int' or 'str' (only integers and strings can be static)
#%% class_err_10,barebones
def foo[T]():
class A:
x: T
#! name 'T' cannot be captured
#%% class_err_11,barebones
def foo(x):
class A:

View File

@ -103,7 +103,7 @@ for i in range(10):
#%% for_error,barebones
for i in 1:
pass
#! 'int' object has no attribute '__iter__'
#! '1' object has no attribute '__iter__'
#%% for_void,barebones
def foo(): yield

View File

@ -3,7 +3,7 @@
a, b = False, 1
print not a, not b, ~b, +b, -b, -(+(-b)) #: True False -2 1 -1 1
#%% binary,barebones
#%% binary_simple,barebones
x, y = 1, 0
c = [1, 2, 3]
@ -356,16 +356,19 @@ print Foo[int, 3, 4](), Foo[int, 5, 4]()
#%% static_int,barebones
def foo(n: Static[int]):
print n
@overload
def foo(n: Static[bool]):
print n
a: Static[int] = 5
foo(a < 1) #: 0
foo(a <= 1) #: 0
foo(a > 1) #: 1
foo(a >= 1) #: 1
foo(a == 1) #: 0
foo(a != 1) #: 1
foo(a and 1) #: 1
foo(a or 1) #: 1
foo(a < 1) #: False
foo(a <= 1) #: False
foo(a > 1) #: True
foo(a >= 1) #: True
foo(a == 1) #: False
foo(a != 1) #: True
foo(a and 1) #: True
foo(a or 1) #: True
foo(a + 1) #: 6
foo(a - 1) #: 4
foo(a * 1) #: 5