diff --git a/codon/cir/attribute.cpp b/codon/cir/attribute.cpp index 2012044a..ffaf0677 100644 --- a/codon/cir/attribute.cpp +++ b/codon/cir/attribute.cpp @@ -12,6 +12,8 @@ namespace ir { const std::string StringValueAttribute::AttributeName = "svAttribute"; +const std::string IntValueAttribute::AttributeName = "i64Attribute"; + const std::string StringListAttribute::AttributeName = "slAttribute"; std::ostream &StringListAttribute::doFormat(std::ostream &os) const { diff --git a/codon/cir/attribute.h b/codon/cir/attribute.h index 441dcf44..3282b754 100644 --- a/codon/cir/attribute.h +++ b/codon/cir/attribute.h @@ -329,6 +329,23 @@ private: std::ostream &doFormat(std::ostream &os) const override; }; +struct IntValueAttribute : public Attribute { + static const std::string AttributeName; + + int64_t value; + + IntValueAttribute() = default; + /// Constructs a IntValueAttribute. + explicit IntValueAttribute(int64_t value) : value(value) {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + +private: + std::ostream &doFormat(std::ostream &os) const override { return os << value; } +}; + } // namespace ir std::map> diff --git a/codon/cir/base.h b/codon/cir/base.h index 6f6dd2c4..d3d9599b 100644 --- a/codon/cir/base.h +++ b/codon/cir/base.h @@ -161,6 +161,10 @@ public: AttributeType *getAttribute(const std::string &key) { return static_cast(getAttribute(key)); } + template + const AttributeType *getAttribute(const std::string &key) const { + return static_cast(getAttribute(key)); + } void eraseAttribute(const std::string &key) { attributes.erase(key); } void cloneAttributesFrom(Node *n) { attributes = codon::clone(n->attributes); } diff --git a/codon/cir/instr.cpp b/codon/cir/instr.cpp index 0a88bbad..b3b1a814 100644 --- a/codon/cir/instr.cpp +++ b/codon/cir/instr.cpp @@ -28,10 +28,7 @@ types::Type *Instr::doGetType() const { return getModule()->getNoneType(); } const char AssignInstr::NodeId = 0; AssignInstr::AssignInstr(Var *lhs, Value *rhs, std::string name) - : AcceptorExtend(std::move(name)), lhs(lhs), rhs(rhs) { - if (!lhs->getType()) - LOG("->"); -} + : AcceptorExtend(std::move(name)), lhs(lhs), rhs(rhs) {} int AssignInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (rhs->getId() == id) { diff --git a/codon/compiler/error.cpp b/codon/compiler/error.cpp index 5a2900e2..8ad18ff5 100644 --- a/codon/compiler/error.cpp +++ b/codon/compiler/error.cpp @@ -16,6 +16,14 @@ SrcInfo::SrcInfo() : SrcInfo("", 0, 0, 0) {} bool SrcInfo::operator==(const SrcInfo &src) const { return id == src.id; } +bool SrcInfo::operator<(const SrcInfo &src) const { + return std::tie(file, line, col) < std::tie(src.file, src.line, src.col); +} + +bool SrcInfo::operator<=(const SrcInfo &src) const { + return std::tie(file, line, col) <= std::tie(src.file, src.line, src.col); +} + namespace error { char ParserErrorInfo::ID = 0; diff --git a/codon/parser/ast/attr.cpp b/codon/parser/ast/attr.cpp index d5ff4518..07bbcb9d 100644 --- a/codon/parser/ast/attr.cpp +++ b/codon/parser/ast/attr.cpp @@ -57,5 +57,6 @@ const std::string Attr::ExprOrderedCall = "exprOrderedCall"; const std::string Attr::ExprExternVar = "exprExternVar"; const std::string Attr::ExprDominatedUndefCheck = "exprDominatedUndefCheck"; const std::string Attr::ExprDominatedUsed = "exprDominatedUsed"; +const std::string Attr::ExprTime = "exprTime"; } // namespace codon::ast diff --git a/codon/parser/ast/attr.h b/codon/parser/ast/attr.h index 22ad7cc3..bfd6b44d 100644 --- a/codon/parser/ast/attr.h +++ b/codon/parser/ast/attr.h @@ -62,5 +62,6 @@ struct Attr { const static std::string ExprExternVar; const static std::string ExprDominatedUndefCheck; const static std::string ExprDominatedUsed; + const static std::string ExprTime; }; } // namespace codon::ast diff --git a/codon/parser/ast/error.h b/codon/parser/ast/error.h index a32c2df9..cd715211 100644 --- a/codon/parser/ast/error.h +++ b/codon/parser/ast/error.h @@ -27,6 +27,8 @@ struct SrcInfo { SrcInfo(); SrcInfo(std::string file, int line, int col, int len); bool operator==(const SrcInfo &src) const; + bool operator<(const SrcInfo &src) const; + bool operator<=(const SrcInfo &src) const; }; class ErrorMessage { diff --git a/codon/parser/ast/node.h b/codon/parser/ast/node.h index aa88eba1..a02279ca 100644 --- a/codon/parser/ast/node.h +++ b/codon/parser/ast/node.h @@ -55,6 +55,9 @@ struct ASTNode : public ir::Node { void setAttribute(const std::string &key, const std::string &value) { attributes[key] = std::make_unique(value); } + void setAttribute(const std::string &key, int64_t value) { + attributes[key] = std::make_unique(value); + } void setAttribute(const std::string &key) { attributes[key] = std::make_unique(); } diff --git a/codon/parser/ast/stmt.cpp b/codon/parser/ast/stmt.cpp index 962d5605..6924c299 100644 --- a/codon/parser/ast/stmt.cpp +++ b/codon/parser/ast/stmt.cpp @@ -27,6 +27,11 @@ Stmt::Stmt(const Stmt &expr, bool clean) : AcceptorExtend(expr) { if (clean) done = false; } +std::string Stmt::wrapStmt(const std::string &s) const { + // if (auto a = ir::Node::getAttribute(Attr::ExprTime)) + // return format("{}%%{}", s, a->value); + return s; +} SuiteStmt::SuiteStmt(std::vector stmts) : AcceptorExtend(), Items(std::move(stmts)) {} @@ -44,7 +49,8 @@ std::string SuiteStmt::toString(int indent) const { is.insert(findStar(is), "*"); s += (i ? pad : "") + is; } - return format("({}suite{})", (isDone() ? "*" : ""), (s.empty() ? s : " " + pad + s)); + return wrapStmt( + format("({}suite{})", (isDone() ? "*" : ""), (s.empty() ? s : " " + pad + s))); } void SuiteStmt::flatten() { std::vector ns; @@ -71,17 +77,17 @@ SuiteStmt *SuiteStmt::wrap(Stmt *s) { } BreakStmt::BreakStmt(const BreakStmt &stmt, bool clean) : AcceptorExtend(stmt, clean) {} -std::string BreakStmt::toString(int indent) const { return "(break)"; } +std::string BreakStmt::toString(int indent) const { return wrapStmt("(break)"); } ContinueStmt::ContinueStmt(const ContinueStmt &stmt, bool clean) : AcceptorExtend(stmt, clean) {} -std::string ContinueStmt::toString(int indent) const { return "(continue)"; } +std::string ContinueStmt::toString(int indent) const { return wrapStmt("(continue)"); } ExprStmt::ExprStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} ExprStmt::ExprStmt(const ExprStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string ExprStmt::toString(int indent) const { - return format("(expr {})", expr->toString(indent)); + return wrapStmt(format("(expr {})", expr->toString(indent))); } AssignStmt::AssignStmt(Expr *lhs, Expr *rhs, Expr *type, UpdateMode update) @@ -91,16 +97,16 @@ AssignStmt::AssignStmt(const AssignStmt &stmt, bool clean) rhs(ast::clone(stmt.rhs, clean)), type(ast::clone(stmt.type, clean)), update(stmt.update) {} std::string AssignStmt::toString(int indent) const { - return format("({} {}{}{})", update != Assign ? "update" : "assign", - lhs->toString(indent), rhs ? " " + rhs->toString(indent) : "", - type ? format(" #:type {}", type->toString(indent)) : ""); + return wrapStmt(format("({} {}{}{})", update != Assign ? "update" : "assign", + lhs->toString(indent), rhs ? " " + rhs->toString(indent) : "", + type ? format(" #:type {}", type->toString(indent)) : "")); } DelStmt::DelStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} DelStmt::DelStmt(const DelStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string DelStmt::toString(int indent) const { - return format("(del {})", expr->toString(indent)); + return wrapStmt(format("(del {})", expr->toString(indent))); } PrintStmt::PrintStmt(std::vector items, bool noNewline) @@ -109,21 +115,21 @@ PrintStmt::PrintStmt(const PrintStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), noNewline(stmt.noNewline) {} std::string PrintStmt::toString(int indent) const { - return format("(print {}{})", noNewline ? "#:inline " : "", combine(items)); + return wrapStmt(format("(print {}{})", noNewline ? "#:inline " : "", combine(items))); } ReturnStmt::ReturnStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} ReturnStmt::ReturnStmt(const ReturnStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string ReturnStmt::toString(int indent) const { - return expr ? format("(return {})", expr->toString(indent)) : "(return)"; + return wrapStmt(expr ? format("(return {})", expr->toString(indent)) : "(return)"); } YieldStmt::YieldStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} YieldStmt::YieldStmt(const YieldStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string YieldStmt::toString(int indent) const { - return expr ? format("(yield {})", expr->toString(indent)) : "(yield)"; + return wrapStmt(expr ? format("(yield {})", expr->toString(indent)) : "(yield)"); } AssertStmt::AssertStmt(Expr *expr, Expr *message) @@ -132,8 +138,8 @@ AssertStmt::AssertStmt(const AssertStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)), message(ast::clone(stmt.message, clean)) {} std::string AssertStmt::toString(int indent) const { - return format("(assert {}{})", expr->toString(indent), - message ? message->toString(indent) : ""); + return wrapStmt(format("(assert {}{})", expr->toString(indent), + message ? message->toString(indent) : "")); } WhileStmt::WhileStmt(Expr *cond, Stmt *suite, Stmt *elseSuite) @@ -145,15 +151,17 @@ WhileStmt::WhileStmt(const WhileStmt &stmt, bool clean) elseSuite(ast::clone(stmt.elseSuite, clean)) {} std::string WhileStmt::toString(int indent) const { if (indent == -1) - return format("(while {})", cond->toString(indent)); + return wrapStmt(format("(while {})", cond->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; - if (elseSuite && elseSuite->firstInBlock()) - return format("(while-else {}{}{}{}{})", cond->toString(indent), pad, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, - elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); - else - return format("(while {}{}{})", cond->toString(indent), pad, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); + if (elseSuite && elseSuite->firstInBlock()) { + return wrapStmt( + format("(while-else {}{}{}{}{})", cond->toString(indent), pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); + } else { + return wrapStmt(format("(while {}{}{})", cond->toString(indent), pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); + } } ForStmt::ForStmt(Expr *var, Expr *iter, Stmt *suite, Stmt *elseSuite, Expr *decorator, @@ -171,7 +179,7 @@ ForStmt::ForStmt(const ForStmt &stmt, bool clean) std::string ForStmt::toString(int indent) const { auto vs = var->toString(indent); if (indent == -1) - return format("(for {} {})", vs, iter->toString(indent)); + return wrapStmt(format("(for {} {})", vs, iter->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string attr; @@ -179,13 +187,15 @@ std::string ForStmt::toString(int indent) const { attr += " " + decorator->toString(indent); if (!attr.empty()) attr = " #:attr" + attr; - if (elseSuite && elseSuite->firstInBlock()) - return format("(for-else {} {}{}{}{}{}{})", vs, iter->toString(indent), attr, pad, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, - elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); - else - return format("(for {} {}{}{}{})", vs, iter->toString(indent), attr, pad, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); + if (elseSuite && elseSuite->firstInBlock()) { + return wrapStmt( + format("(for-else {} {}{}{}{}{}{})", vs, iter->toString(indent), attr, pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); + } else { + return wrapStmt(format("(for {} {}{}{}{})", vs, iter->toString(indent), attr, pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); + } } IfStmt::IfStmt(Expr *cond, Stmt *ifSuite, Stmt *elseSuite) @@ -197,13 +207,13 @@ IfStmt::IfStmt(const IfStmt &stmt, bool clean) elseSuite(ast::clone(stmt.elseSuite, clean)) {} std::string IfStmt::toString(int indent) const { if (indent == -1) - return format("(if {})", cond->toString(indent)); + return wrapStmt(format("(if {})", cond->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; - return format("(if {}{}{}{})", cond->toString(indent), pad, - ifSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), - elseSuite - ? pad + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) - : ""); + return wrapStmt(format( + "(if {}{}{}{})", cond->toString(indent), pad, + ifSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), + elseSuite ? pad + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) + : "")); } MatchCase::MatchCase(Expr *pattern, Expr *guard, Stmt *suite) @@ -220,7 +230,7 @@ MatchStmt::MatchStmt(const MatchStmt &stmt, bool clean) expr(ast::clone(stmt.expr, clean)) {} std::string MatchStmt::toString(int indent) const { if (indent == -1) - return format("(match {})", expr->toString(indent)); + return wrapStmt(format("(match {})", expr->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string padExtra = indent > 0 ? std::string(INDENT_SIZE, ' ') : ""; std::vector s; @@ -229,14 +239,13 @@ std::string MatchStmt::toString(int indent) const { c.guard ? " #:guard " + c.guard->toString(indent) : "", pad + padExtra, c.suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1 * 2))); - return format("(match {}{}{})", expr->toString(indent), pad, join(s, pad)); + return wrapStmt(format("(match {}{}{})", expr->toString(indent), pad, join(s, pad))); } ImportStmt::ImportStmt(Expr *from, Expr *what, std::vector args, Expr *ret, std::string as, size_t dots, bool isFunction) : AcceptorExtend(), from(from), what(what), as(std::move(as)), dots(dots), - args(std::move(args)), ret(ret), isFunction(isFunction) { -} + args(std::move(args)), ret(ret), isFunction(isFunction) {} ImportStmt::ImportStmt(const ImportStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), from(ast::clone(stmt.from, clean)), what(ast::clone(stmt.what, clean)), as(stmt.as), dots(stmt.dots), @@ -246,12 +255,12 @@ std::string ImportStmt::toString(int indent) const { std::vector va; for (auto &a : args) va.push_back(a.toString(indent)); - return format("(import {}{}{}{}{}{})", from ? from->toString(indent) : "", - as.empty() ? "" : format(" #:as '{}", as), - what ? format(" #:what {}", what->toString(indent)) : "", - dots ? format(" #:dots {}", dots) : "", - va.empty() ? "" : format(" #:args ({})", join(va)), - ret ? format(" #:ret {}", ret->toString(indent)) : ""); + return wrapStmt(format("(import {}{}{}{}{}{})", from ? from->toString(indent) : "", + as.empty() ? "" : format(" #:as '{}", as), + what ? format(" #:what {}", what->toString(indent)) : "", + dots ? format(" #:dots {}", dots) : "", + va.empty() ? "" : format(" #:args ({})", join(va)), + ret ? format(" #:ret {}", ret->toString(indent)) : "")); } ExceptStmt::ExceptStmt(const std::string &var, Expr *exc, Stmt *suite) @@ -262,9 +271,10 @@ ExceptStmt::ExceptStmt(const ExceptStmt &stmt, bool clean) std::string ExceptStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string padExtra = indent > 0 ? std::string(INDENT_SIZE, ' ') : ""; - return format("(catch {}{}{}{})", !var.empty() ? format("#:var '{}", var) : "", - exc ? format(" #:exc {}", exc->toString(indent)) : "", pad + padExtra, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1 * 2)); + return wrapStmt( + format("(catch {}{}{}{})", !var.empty() ? format("#:var '{}", var) : "", + exc ? format(" #:exc {}", exc->toString(indent)) : "", pad + padExtra, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1 * 2))); } TryStmt::TryStmt(Stmt *suite, std::vector excepts, Stmt *finally) @@ -275,17 +285,17 @@ TryStmt::TryStmt(const TryStmt &stmt, bool clean) suite(ast::clone(stmt.suite, clean)), finally(ast::clone(stmt.finally, clean)) {} std::string TryStmt::toString(int indent) const { if (indent == -1) - return format("(try)"); + return wrapStmt(format("(try)")); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::vector s; for (auto &i : items) s.push_back(i->toString(indent)); - return format( + return wrapStmt(format( "(try{}{}{}{}{})", pad, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, join(s, pad), finally ? format("{}{}", pad, finally->toString(indent >= 0 ? indent + INDENT_SIZE : -1)) - : ""); + : "")); } ThrowStmt::ThrowStmt(Expr *expr, Expr *from, bool transformed) @@ -294,8 +304,8 @@ ThrowStmt::ThrowStmt(const ThrowStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)), from(ast::clone(stmt.from, clean)), transformed(stmt.transformed) {} std::string ThrowStmt::toString(int indent) const { - return format("(throw{}{})", expr ? " " + expr->toString(indent) : "", - from ? format(" :from {}", from->toString(indent)) : ""); + return wrapStmt(format("(throw{}{})", expr ? " " + expr->toString(indent) : "", + from ? format(" :from {}", from->toString(indent)) : "")); } GlobalStmt::GlobalStmt(std::string var, bool nonLocal) @@ -303,14 +313,13 @@ GlobalStmt::GlobalStmt(std::string var, bool nonLocal) GlobalStmt::GlobalStmt(const GlobalStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), var(stmt.var), nonLocal(stmt.nonLocal) {} std::string GlobalStmt::toString(int indent) const { - return format("({} '{})", nonLocal ? "nonlocal" : "global", var); + return wrapStmt(format("({} '{})", nonLocal ? "nonlocal" : "global", var)); } FunctionStmt::FunctionStmt(std::string name, Expr *ret, std::vector args, Stmt *suite, std::vector decorators) : AcceptorExtend(), Items(std::move(args)), name(std::move(name)), ret(ret), - suite(SuiteStmt::wrap(suite)), decorators(std::move(decorators)) { -} + suite(SuiteStmt::wrap(suite)), decorators(std::move(decorators)) {} FunctionStmt::FunctionStmt(const FunctionStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), name(stmt.name), ret(ast::clone(stmt.ret, clean)), @@ -326,13 +335,13 @@ std::string FunctionStmt::toString(int indent) const { if (a) dec.push_back(format("(dec {})", a->toString(indent))); if (indent == -1) - return format("(fn '{} ({}){})", name, join(as, " "), - ret ? " #:ret " + ret->toString(indent) : ""); - return format("(fn '{} ({}){}{}{}{})", name, join(as, " "), - ret ? " #:ret " + ret->toString(indent) : "", - dec.empty() ? "" : format(" (dec {})", join(dec, " ")), pad, - suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) - : "(suite)"); + return wrapStmt(format("(fn '{} ({}){})", name, join(as, " "), + ret ? " #:ret " + ret->toString(indent) : "")); + return wrapStmt(format( + "(fn '{} ({}){}{}{}{})", name, join(as, " "), + ret ? " #:ret " + ret->toString(indent) : "", + dec.empty() ? "" : format(" (dec {})", join(dec, " ")), pad, + suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "(suite)")); } std::string FunctionStmt::signature() const { std::vector s; @@ -452,13 +461,13 @@ std::string ClassStmt::toString(int indent) const { for (auto &a : decorators) attr.push_back(format("(dec {})", a->toString(indent))); if (indent == -1) - return format("(class '{} ({}))", name, as); - return format("(class '{}{}{}{}{}{})", name, - bases.empty() ? "" : format(" (bases {})", join(bases, " ")), - attr.empty() ? "" : format(" (attr {})", join(attr, " ")), - as.empty() ? as : pad + as, pad, - suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) - : "(suite)"); + return wrapStmt(format("(class '{} ({}))", name, as)); + return wrapStmt(format( + "(class '{}{}{}{}{}{})", name, + bases.empty() ? "" : format(" (bases {})", join(bases, " ")), + attr.empty() ? "" : format(" (attr {})", join(attr, " ")), + as.empty() ? as : pad + as, pad, + suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "(suite)")); } bool ClassStmt::isRecord() const { return hasAttribute(Attr::Tuple); } bool ClassStmt::isClassVar(const Param &p) { @@ -484,7 +493,7 @@ YieldFromStmt::YieldFromStmt(Expr *expr) : AcceptorExtend(), expr(std::move(expr YieldFromStmt::YieldFromStmt(const YieldFromStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} std::string YieldFromStmt::toString(int indent) const { - return format("(yield-from {})", expr->toString(indent)); + return wrapStmt(format("(yield-from {})", expr->toString(indent))); } WithStmt::WithStmt(std::vector items, std::vector vars, @@ -517,9 +526,9 @@ std::string WithStmt::toString(int indent) const { : items[i]->toString(indent)); } if (indent == -1) - return format("(with ({}))", join(as, " ")); - return format("(with ({}){}{})", join(as, " "), pad, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); + return wrapStmt(format("(with ({}))", join(as, " "))); + return wrapStmt(format("(with ({}){}{})", join(as, " "), pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); } CustomStmt::CustomStmt(std::string keyword, Expr *expr, Stmt *suite) @@ -530,9 +539,10 @@ CustomStmt::CustomStmt(const CustomStmt &stmt, bool clean) expr(ast::clone(stmt.expr, clean)), suite(ast::clone(stmt.suite, clean)) {} std::string CustomStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; - return format("(custom-{} {}{}{})", keyword, - expr ? format(" #:expr {}", expr->toString(indent)) : "", pad, - suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : ""); + return wrapStmt( + format("(custom-{} {}{}{})", keyword, + expr ? format(" #:expr {}", expr->toString(indent)) : "", pad, + suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "")); } AssignMemberStmt::AssignMemberStmt(Expr *lhs, std::string member, Expr *rhs) @@ -541,8 +551,8 @@ AssignMemberStmt::AssignMemberStmt(const AssignMemberStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), lhs(ast::clone(stmt.lhs, clean)), member(stmt.member), rhs(ast::clone(stmt.rhs, clean)) {} std::string AssignMemberStmt::toString(int indent) const { - return format("(assign-member {} {} {})", lhs->toString(indent), member, - rhs->toString(indent)); + return wrapStmt(format("(assign-member {} {} {})", lhs->toString(indent), member, + rhs->toString(indent))); } CommentStmt::CommentStmt(std::string comment) @@ -550,7 +560,7 @@ CommentStmt::CommentStmt(std::string comment) CommentStmt::CommentStmt(const CommentStmt &stmt, bool clean) : AcceptorExtend(stmt, clean), comment(stmt.comment) {} std::string CommentStmt::toString(int indent) const { - return format("(comment \"{}\")", comment); + return wrapStmt(format("(comment \"{}\")", comment)); } const char Stmt::NodeId = 0; diff --git a/codon/parser/ast/stmt.h b/codon/parser/ast/stmt.h index e919abda..ca63405e 100644 --- a/codon/parser/ast/stmt.h +++ b/codon/parser/ast/stmt.h @@ -51,6 +51,8 @@ struct Stmt : public AcceptorExtend { static const char NodeId; SERIALIZE(Stmt, BASE(ASTNode), done); + virtual std::string wrapStmt(const std::string &) const; + private: /// Flag that indicates if all types in a statement are inferred (i.e. if a /// type-checking procedure was successful). diff --git a/codon/parser/visitors/scoping/scoping.cpp b/codon/parser/visitors/scoping/scoping.cpp index e6fd2602..e138ba57 100644 --- a/codon/parser/visitors/scoping/scoping.cpp +++ b/codon/parser/visitors/scoping/scoping.cpp @@ -18,7 +18,7 @@ } #define STOP_ERROR(...) \ do { \ - addError(__VA_ARGS__); \ + addError(__VA_ARGS__); \ return; \ } while (0) @@ -66,6 +66,7 @@ bool ScopingVisitor::transform(Stmt *stmt) { errors.append(v.errors); if (!canContinue()) return false; + stmt->setAttribute(Attr::ExprTime, ++ctx->time); } return true; } @@ -847,8 +848,6 @@ ScopingVisitor::findDominatingBinding(const std::string &name, bool allowShadow) lastGood = i; } } - // if (commonScope != ctx->scope.size()) - // LOG("==> {}: {} / {} vs {}", getSrcInfo(), name, ctx->getScope(), commonScope); seqassert(lastGood != it->end(), "corrupted scoping ({})", name); if (!allowShadow) { // go to the end lastGood = it->end(); diff --git a/codon/parser/visitors/scoping/scoping.h b/codon/parser/visitors/scoping/scoping.h index ab0f5323..4aec2770 100644 --- a/codon/parser/visitors/scoping/scoping.h +++ b/codon/parser/visitors/scoping/scoping.h @@ -93,6 +93,9 @@ class ScopingVisitor : public CallbackASTVisitor { std::vector> renames = {{}}; bool tempScope = false; + + // Time to track positions of assignments and references to them. + int64_t time = 0; }; std::shared_ptr ctx = nullptr; diff --git a/codon/parser/visitors/typecheck/access.cpp b/codon/parser/visitors/typecheck/access.cpp index e238ee7a..5edac5c7 100644 --- a/codon/parser/visitors/typecheck/access.cpp +++ b/codon/parser/visitors/typecheck/access.cpp @@ -24,7 +24,7 @@ using namespace types; /// If the identifier of a generic is fully qualified, use its qualified name /// (e.g., replace `Ptr` with `Ptr[byte]`). void TypecheckVisitor::visit(IdExpr *expr) { - auto val = ctx->find(expr->getValue()); + auto val = ctx->find(expr->getValue(), getTime()); if (!val) { E(Error::ID_NOT_FOUND, expr, expr->getValue()); } @@ -61,7 +61,7 @@ void TypecheckVisitor::visit(IdExpr *expr) { if (expr->hasAttribute(Attr::ExprDominatedUndefCheck)) { auto controlVar = fmt::format("{}{}", getUnmangledName(val->canonicalName), VAR_USED_SUFFIX); - if (ctx->find(controlVar)) { + if (ctx->find(controlVar, getTime())) { auto checkStmt = N(N( N(N("__internal__"), "undef"), N(controlVar), N(getUnmangledName(val->canonicalName)))); @@ -295,7 +295,7 @@ TypecheckVisitor::getImport(const std::vector &chain) { TypeContext::Item val = nullptr, importVal = nullptr; for (auto i = chain.size(); i-- > 0;) { auto name = join(chain, "/", 0, i + 1); - val = ctx->find(name); + val = ctx->find(name, getTime()); if (val && val->type->is("Import") && startswith(val->getName(), "%_import_")) { importName = getStrLiteral(val->type.get()); importEnd = i + 1; diff --git a/codon/parser/visitors/typecheck/assign.cpp b/codon/parser/visitors/typecheck/assign.cpp index d3455a70..14401cd3 100644 --- a/codon/parser/visitors/typecheck/assign.cpp +++ b/codon/parser/visitors/typecheck/assign.cpp @@ -202,7 +202,7 @@ Stmt *TypecheckVisitor::transformAssignment(AssignStmt *stmt, bool mustExist) { if (!e) E(Error::ASSIGN_INVALID, stmt->getLhs()); - auto val = ctx->find(e->getValue()); + auto val = ctx->find(e->getValue(), getTime()); // Make sure that existing values that cannot be shadowed are only updated // mustExist |= val && !ctx->isOuter(val); if (mustExist) { @@ -237,7 +237,7 @@ Stmt *TypecheckVisitor::transformAssignment(AssignStmt *stmt, bool mustExist) { // static check) assign->getLhs()->getType()->getLink()->defaultType = getStdLibType("NoneType")->shared_from_this(); - ctx->getBase()->pendingDefaults.insert( + ctx->getBase()->pendingDefaults[1].insert( assign->getLhs()->getType()->shared_from_this()); } if (stmt->getTypeExpr()) { @@ -248,6 +248,7 @@ Stmt *TypecheckVisitor::transformAssignment(AssignStmt *stmt, bool mustExist) { val = std::make_shared(canonical, ctx->getBaseName(), ctx->getModule(), assign->getLhs()->getType()->shared_from_this(), ctx->getScope()); + val->time = getTime(); val->setSrcInfo(getSrcInfo()); ctx->add(e->getValue(), val); ctx->addAlwaysVisible(val); diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index 21c76fc3..e6d482ee 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -150,8 +150,10 @@ void TypecheckVisitor::visit(CallExpr *expr) { ? t.getExpr()->getClassType()->name : t.getExpr()->getType()->prettyString())); auto argsNice = fmt::format("({})", fmt::join(a, ", ")); - E(Error::FN_NO_ATTR_ARGS, expr, getUnmangledName(calleeFn->getFuncName()), - argsNice); + auto name = getUnmangledName(calleeFn->getFuncName()); + if (calleeFn->getParentType() && calleeFn->getParentType()->getClass()) + name = format("{}.{}", calleeFn->getParentType()->getClass()->niceName, name); + E(Error::FN_NO_ATTR_ARGS, expr, name, argsNice); } } @@ -740,11 +742,11 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) { return {true, transformPtr(expr)}; } else if (val == "__array__.__new__:0") { return {true, transformArray(expr)}; - } else if (val == "isinstance") { + } else if (val == "isinstance") { // static return {true, transformIsInstance(expr)}; - } else if (val == "staticlen") { + } else if (val == "staticlen") { // static return {true, transformStaticLen(expr)}; - } else if (val == "hasattr") { + } else if (val == "hasattr") { // static return {true, transformHasAttr(expr)}; } else if (val == "getattr") { return {true, transformGetAttr(expr)}; @@ -758,21 +760,21 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) { return {true, transformRealizedFn(expr)}; } else if (val == "std.internal.static.static_print.0") { return {false, transformStaticPrintFn(expr)}; - } else if (val == "__has_rtti__") { + } else if (val == "__has_rtti__") { // static return {true, transformHasRttiFn(expr)}; } else if (val == "std.collections.namedtuple.0") { return {true, transformNamedTuple(expr)}; } else if (val == "std.functools.partial.0:0") { return {true, transformFunctoolsPartial(expr)}; - } else if (val == "std.internal.static.fn_can_call.0") { + } else if (val == "std.internal.static.fn_can_call.0") { // static return {true, transformStaticFnCanCall(expr)}; - } else if (val == "std.internal.static.fn_arg_has_type.0") { + } else if (val == "std.internal.static.fn_arg_has_type.0") { // static return {true, transformStaticFnArgHasType(expr)}; } else if (val == "std.internal.static.fn_arg_get_type.0") { return {true, transformStaticFnArgGetType(expr)}; } else if (val == "std.internal.static.fn_args.0") { return {true, transformStaticFnArgs(expr)}; - } else if (val == "std.internal.static.fn_has_default.0") { + } else if (val == "std.internal.static.fn_has_default.0") { // static return {true, transformStaticFnHasDefault(expr)}; } else if (val == "std.internal.static.fn_get_default.0") { return {true, transformStaticFnGetDefault(expr)}; diff --git a/codon/parser/visitors/typecheck/class.cpp b/codon/parser/visitors/typecheck/class.cpp index f899574e..b4741c03 100644 --- a/codon/parser/visitors/typecheck/class.cpp +++ b/codon/parser/visitors/typecheck/class.cpp @@ -63,7 +63,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { // Find the canonical name and AST of the class that is to be extended if (!ctx->isGlobal() || ctx->isConditional()) E(Error::EXPECTED_TOPLEVEL, getSrcInfo(), "class extension"); - auto val = ctx->find(name); + auto val = ctx->find(name, getTime()); if (!val || !val->isType()) E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), name); typ = val->getName() == TYPE_TYPE ? val->getType()->getClass() diff --git a/codon/parser/visitors/typecheck/ctx.cpp b/codon/parser/visitors/typecheck/ctx.cpp index 81d9101c..d63fb07f 100644 --- a/codon/parser/visitors/typecheck/ctx.cpp +++ b/codon/parser/visitors/typecheck/ctx.cpp @@ -121,6 +121,29 @@ TypeContext::Item TypeContext::find(const std::string &name) const { return t; } +TypeContext::Item TypeContext::find(const std::string &name, int64_t time) const { + auto it = map.find(name); + if (it != map.end()) { + for (auto &i : it->second) { + if (i->getBaseName() != getBaseName() || !time || i->getTime() <= time) + return i; + } + } + + // Item is not found in the current module. Time to look in the standard library! + // Note: the standard library items cannot be dominated. + TypeContext::Item t = nullptr; + auto stdlib = cache->imports[STDLIB_IMPORT].ctx; + if (stdlib.get() != this) + t = stdlib->Context::find(name); + + // Maybe we are looking for a canonical identifier? + if (!t && cache->typeCtx.get() != this) + t = cache->typeCtx->Context::find(name); + + return t; +} + TypeContext::Item TypeContext::forceFind(const std::string &name) const { auto f = find(name); seqassert(f, "cannot find '{}'", name); diff --git a/codon/parser/visitors/typecheck/ctx.h b/codon/parser/visitors/typecheck/ctx.h index 578b1908..945ecebf 100644 --- a/codon/parser/visitors/typecheck/ctx.h +++ b/codon/parser/visitors/typecheck/ctx.h @@ -34,6 +34,10 @@ struct TypecheckItem : public SrcObject { /// Full base scope information std::vector scope = {0}; + /// Specifies at which time the name was added to the context. + /// Used to prevent using later definitions early (can happen in + /// advanced type checking iterations). + int64_t time = 0; /// Set if an identifier is a class or a function generic bool generic = false; @@ -57,6 +61,8 @@ struct TypecheckItem : public SrcObject { types::Type *getType() const { return type.get(); } std::string getName() const { return canonicalName; } + + int64_t getTime() const { return time; } }; /** Context class that tracks identifiers during the typechecking. **/ @@ -136,7 +142,7 @@ struct TypeContext : public Context { }; std::vector loops; - std::set pendingDefaults; + std::map> pendingDefaults; public: Loop *getLoop() { return loops.empty() ? nullptr : &(loops.back()); } @@ -181,6 +187,9 @@ struct TypeContext : public Context { /// Stack of static loop control variables (used to emulate goto statements). std::vector staticLoops = {}; + /// Current statement time. + int64_t time; + public: explicit TypeContext(Cache *cache, std::string filename = ""); @@ -198,6 +207,9 @@ public: /// Get an item from the context. If the item does not exist, nullptr is returned. Item find(const std::string &name) const override; + /// Get an item from the context before given srcInfo. If the item does not exist, + /// nullptr is returned. + Item find(const std::string &name, int64_t time) const; /// Get an item that exists in the context. If the item does not exist, assertion is /// raised. Item forceFind(const std::string &name) const; diff --git a/codon/parser/visitors/typecheck/function.cpp b/codon/parser/visitors/typecheck/function.cpp index e41652f9..3c30c6df 100644 --- a/codon/parser/visitors/typecheck/function.cpp +++ b/codon/parser/visitors/typecheck/function.cpp @@ -31,6 +31,7 @@ void TypecheckVisitor::visit(LambdaExpr *expr) { N(N(expr->getExpr()))); if (auto err = ScopingVisitor::apply(ctx->cache, N(f))) throw exc::ParserException(std::move(err)); + f->setAttribute(Attr::ExprTime, getTime()); // to handle captures properly f = transform(f); if (auto a = expr->getAttribute(Attr::Bindings)) f->setAttribute(Attr::Bindings, a->clone()); @@ -168,7 +169,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { rootName = *n; } else if (stmt->hasAttribute(Attr::Overload)) { // Case 2: function overload - if (auto c = ctx->find(stmt->getName())) { + if (auto c = ctx->find(stmt->getName(), getTime())) { if (c->isFunc() && c->getModule() == ctx->getModule() && c->getBaseName() == ctx->getBaseName()) { rootName = c->canonicalName; @@ -196,7 +197,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) { std::map captures; if (auto b = stmt->getAttribute(Attr::Bindings)) for (auto &[c, t] : b->captures) { - if (auto v = ctx->find(c)) { + if (auto v = ctx->find(c, getTime())) { if (t != BindingsAttribute::CaptureType::Global && !v->isGlobal()) { bool parentClassGeneric = ctx->bases.back().isType() && ctx->bases.back().name == v->getBaseName(); @@ -581,7 +582,7 @@ std::pair TypecheckVisitor::getDecorator(Expr *e) { auto dt = transform(clone(e)); auto id = cast(cast(dt) ? cast(dt)->getExpr() : dt); if (id) { - auto ci = ctx->find(id->getValue()); + auto ci = ctx->find(id->getValue(), getTime()); if (ci && ci->isFunc()) { auto fn = ci->getName(); auto f = getFunction(fn); diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index 65ac150e..bedea193 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -109,27 +109,35 @@ Stmt *TypecheckVisitor::inferTypes(Stmt *result, bool isToplevel) { bool anotherRound = false; // Special case: return type might have default as well (e.g., Union) if (auto t = ctx->getBase()->returnType) { - ctx->getBase()->pendingDefaults.insert(t); + ctx->getBase()->pendingDefaults[0].insert(t); } - for (auto &unbound : ctx->getBase()->pendingDefaults) { - if (auto tu = unbound->getUnion()) { - // Seal all dynamic unions after the iteration is over - if (!tu->isSealed()) { - tu->seal(); - anotherRound = true; - } - } else if (auto u = unbound->getLink()) { - types::Type::Unification undo; - if (u->defaultType && - u->unify(extractClassType(u->defaultType.get()), &undo) >= 0) { - anotherRound = true; + // First unify "explicit" generics (whose default type is explicit), + // then "implicit" ones (whose default type is compiler generated, + // e.g. compiler-generated variable placeholders with default NoneType) + for (auto &[level, unbounds] : ctx->getBase()->pendingDefaults) { + if (!unbounds.empty()) { + for (const auto &unbound : unbounds) { + if (auto tu = unbound->getUnion()) { + // Seal all dynamic unions after the iteration is over + if (!tu->isSealed()) { + tu->seal(); + anotherRound = true; + } + } else if (auto u = unbound->getLink()) { + types::Type::Unification undo; + if (u->defaultType && + u->unify(extractClassType(u->defaultType.get()), &undo) >= 0) { + anotherRound = true; + } + } } + unbounds.clear(); + if (anotherRound) + break; } } - ctx->getBase()->pendingDefaults.clear(); if (anotherRound) continue; - // Nothing helps. Return nullptr. return nullptr; } @@ -356,11 +364,21 @@ types::Type *TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) { E(Error::FN_GLOBAL_NOT_FOUND, getSrcInfo(), "global", c); } } - // Add self reference! TODO: maybe remove later when doing contexts? + // Add self [recursive] reference! TODO: maybe remove later when doing contexts? auto pc = ast->getAttribute(Attr::ParentClass); - if (!pc || pc->value.empty()) - ctx->addFunc(getUnmangledName(ast->getName()), ast->getName(), - ctx->forceFind(ast->getName())->type); + if (!pc || pc->value.empty()) { + // Check if we already exist? + bool exists = false; + auto val = ctx->find(getUnmangledName(ast->getName())); + if (val && val->getType()->getFunc()) { + auto fn = getFunction(val->getType()); + exists = fn->rootName == getFunction(type)->rootName; + } + if (!exists) { + ctx->addFunc(getUnmangledName(ast->getName()), ast->getName(), + ctx->forceFind(ast->getName())->type); + } + } for (size_t i = 0, j = 0; hasAst && i < ast->size(); i++) { if ((*ast)[i].isValue()) { auto [_, varName] = (*ast)[i].getNameWithStars(); @@ -407,7 +425,8 @@ types::Type *TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) { // TODO: generalize this further. for (size_t w = ctx->bases.size(); w-- > 0;) if (ctx->bases[w].suite) - LOG("[error=> {}] {}", ctx->bases[w].type->debugString(2), + LOG("[error=> {}] {}", + ctx->bases[w].type ? ctx->bases[w].type->debugString(2) : "-", ctx->bases[w].suite->toString(2)); } if (!isImport) { diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index 4a1fd0b9..6bb8167d 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -59,6 +59,50 @@ void TypecheckVisitor::visit(UnaryExpr *expr) { /// Also evaluate static expressions. See @c evaluateStaticBinary for details. void TypecheckVisitor::visit(BinaryExpr *expr) { expr->lexpr = transform(expr->getLhs(), true); + + // Static short-circuit + if (expr->getLhs()->getType()->isStaticType() && expr->op == "&&") { + if (auto tb = expr->getLhs()->getType()->getBoolStatic()) { + if (!tb->value) { + resultExpr = transform(N(false)); + return; + } + } else if (auto ts = expr->getLhs()->getType()->getStrStatic()) { + if (ts->value.empty()) { + resultExpr = transform(N(false)); + return; + } + } else if (auto ti = expr->getLhs()->getType()->getIntStatic()) { + if (!ti->value) { + resultExpr = transform(N(false)); + return; + } + } else { + expr->getType()->getUnbound()->isStatic = 3; + return; + } + } else if (expr->getLhs()->getType()->isStaticType() && expr->op == "||") { + if (auto tb = expr->getLhs()->getType()->getBoolStatic()) { + if (tb->value) { + resultExpr = transform(N(true)); + return; + } + } else if (auto ts = expr->getLhs()->getType()->getStrStatic()) { + if (!ts->value.empty()) { + resultExpr = transform(N(true)); + return; + } + } else if (auto ti = expr->getLhs()->getType()->getIntStatic()) { + if (ti->value) { + resultExpr = transform(N(true)); + return; + } + } else { + expr->getType()->getUnbound()->isStatic = 3; + return; + } + } + expr->rexpr = transform(expr->getRhs(), true); static std::unordered_map> staticOps = { diff --git a/codon/parser/visitors/typecheck/special.cpp b/codon/parser/visitors/typecheck/special.cpp index 48c6ee5c..51db48d8 100644 --- a/codon/parser/visitors/typecheck/special.cpp +++ b/codon/parser/visitors/typecheck/special.cpp @@ -483,7 +483,7 @@ Expr *TypecheckVisitor::transformSuper() { /// the argument is a variable binding. Expr *TypecheckVisitor::transformPtr(CallExpr *expr) { auto id = cast(expr->begin()->getExpr()); - auto val = id ? ctx->find(id->getValue()) : nullptr; + auto val = id ? ctx->find(id->getValue(), getTime()) : nullptr; if (!val || !val->isVar()) E(Error::CALL_PTR_VAR, expr->begin()->getExpr()); @@ -511,6 +511,9 @@ Expr *TypecheckVisitor::transformArray(CallExpr *expr) { /// `isinstance(obj, ByVal)` is True if `type(obj)` is a tuple type /// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type Expr *TypecheckVisitor::transformIsInstance(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + expr->begin()->value = transform(expr->begin()->getExpr()); auto typ = expr->begin()->getExpr()->getClassType(); if (!typ || !typ->canRealize()) @@ -582,6 +585,9 @@ Expr *TypecheckVisitor::transformIsInstance(CallExpr *expr) { /// Transform staticlen method to a static integer expression. This method supports only /// static strings and tuple types. Expr *TypecheckVisitor::transformStaticLen(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 1; + expr->begin()->value = transform(expr->begin()->getExpr()); auto typ = extractType(expr->begin()->getExpr()); @@ -605,6 +611,9 @@ Expr *TypecheckVisitor::transformStaticLen(CallExpr *expr) { /// This method also supports additional argument types that are used to check /// for a matching overload (not available in Python). Expr *TypecheckVisitor::transformHasAttr(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + auto typ = extractClassType((*expr)[0].getExpr()); if (!typ) return nullptr; @@ -763,6 +772,9 @@ Expr *TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) { /// Transform __has_rtti__ to a static boolean that indicates RTTI status of a type. Expr *TypecheckVisitor::transformHasRttiFn(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + auto t = extractFuncGeneric(expr->getExpr()->getType())->getClass(); if (!t) return nullptr; @@ -771,6 +783,9 @@ Expr *TypecheckVisitor::transformHasRttiFn(CallExpr *expr) { // Transform internal.static calls Expr *TypecheckVisitor::transformStaticFnCanCall(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + auto typ = extractClassType((*expr)[0].getExpr()); if (!typ) return nullptr; @@ -800,6 +815,9 @@ Expr *TypecheckVisitor::transformStaticFnCanCall(CallExpr *expr) { } Expr *TypecheckVisitor::transformStaticFnArgHasType(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + auto fn = extractFunction(expr->begin()->getExpr()->getType()); if (!fn) E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", @@ -838,6 +856,9 @@ Expr *TypecheckVisitor::transformStaticFnArgs(CallExpr *expr) { } Expr *TypecheckVisitor::transformStaticFnHasDefault(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + auto fn = extractFunction(expr->begin()->getExpr()->getType()); if (!fn) E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index d4de876e..fae08e3a 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -68,7 +68,7 @@ Stmt *TypecheckVisitor::apply( auto n = tv.inferTypes(suite, true); if (!n) { LOG("[error=>] {}", suite->toString(2)); - E(Error::CUSTOM, suite, "cannot typecheck the program"); + E(Error::CUSTOM, suite->getSrcInfo(), "cannot typecheck the program"); } suite = tv.N(); @@ -275,7 +275,14 @@ Stmt *TypecheckVisitor::transform(Stmt *stmt) { LOG_TYPECHECK("> [{}] [{}:{}] {}", getSrcInfo(), ctx->getBaseName(), ctx->getBase()->iteration, stmt->toString(-1)); ctx->pushNode(stmt); + + int64_t time = 0; + if (auto a = stmt->getAttribute(Attr::ExprTime)) + time = a->value; + auto oldTime = ctx->time; + ctx->time = time; stmt->accept(v); + ctx->time = oldTime; ctx->popNode(); if (v.resultStmt) stmt = v.resultStmt; @@ -546,10 +553,9 @@ bool TypecheckVisitor::wrapExpr(Expr **expr, Type *expectedType, FuncType *calle auto [canWrap, newArgTyp, fn] = canWrapExpr((*expr)->getType(), expectedType, callee, allowUnwrap, cast(*expr)); // TODO: get rid of this line one day! - if ((*expr)->getType()->getStatic() && + if ((*expr)->getType()->isStaticType() && (!expectedType || !expectedType->isStaticType())) - (*expr)->setType( - (*expr)->getType()->getStatic()->getNonStaticType()->shared_from_this()); + (*expr)->setType(getUnderlyingStaticType((*expr)->getType())->shared_from_this()); if (canWrap && fn) *expr = transform(fn(*expr)); return canWrap; @@ -583,10 +589,14 @@ TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *call std::unordered_set hints = {"Generator", "float", TYPE_OPTIONAL, "pyobj"}; - if (exprType->getStatic() && (!expectedType || !expectedType->isStaticType())) { - exprType = exprType->getStatic()->getNonStaticType(); - exprClass = exprType->getClass(); + if (!expectedType || !expectedType->isStaticType()) { + if (auto c = exprType->isStaticType()) { + exprType = getUnderlyingStaticType(exprType); + exprClass = exprType->getClass(); + type = exprType->shared_from_this(); + } } + if (!exprClass && expectedClass && in(hints, expectedClass->name)) { return {false, nullptr, nullptr}; // argument type not yet known. } @@ -630,8 +640,10 @@ TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *call return N(N("pyobj"), N(N(expr, "__to_py__"))); }; - } else if (allowUnwrap && expectedClass && exprClass && exprClass->is("pyobj") && - !exprClass->is(expectedClass->name)) { // unwrap pyobj + } + + else if (allowUnwrap && expectedClass && exprClass && exprClass->is("pyobj") && + !exprClass->is(expectedClass->name)) { // unwrap pyobj if (findMethod(expectedClass, "__from_py__").empty()) return {false, nullptr, nullptr}; type = instantiateType(expectedClass); @@ -692,7 +704,9 @@ TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *call } else { return {false, nullptr, nullptr}; } - } else if (exprClass && expectedClass && expectedClass->getUnion()) { + } + + else if (exprClass && expectedClass && expectedClass->getUnion()) { // Make union types via __internal__.new_union if (!expectedClass->getUnion()->isSealed()) { if (!expectedClass->getUnion()->addType(exprClass)) @@ -1072,6 +1086,22 @@ bool TypecheckVisitor::isImportFn(const std::string &s) { return startswith(s, "%_import_"); } +int64_t TypecheckVisitor::getTime() { return ctx->time; } + +types::Type *TypecheckVisitor::getUnderlyingStaticType(types::Type *t) { + if (t->getStatic()) { + return t->getStatic()->getNonStaticType(); + } else if (auto c = t->isStaticType()) { + if (c == 1) + return getStdLibType("int"); + if (c == 2) + return getStdLibType("str"); + if (c == 3) + return getStdLibType("bool"); + } + return t; +} + std::shared_ptr TypecheckVisitor::instantiateUnbound(const SrcInfo &srcInfo, int level) const { auto typ = std::make_shared( @@ -1112,7 +1142,7 @@ types::TypePtr TypecheckVisitor::instantiateType(const SrcInfo &srcInfo, if (auto l = i.second->getLink()) { i.second->setSrcInfo(srcInfo); if (l->defaultType) { - ctx->getBase()->pendingDefaults.insert(i.second); + ctx->getBase()->pendingDefaults[0].insert(i.second); } } } @@ -1164,7 +1194,7 @@ types::TypePtr TypecheckVisitor::instantiateType(const SrcInfo &srcInfo, } if (t->getUnion() && !t->getUnion()->isSealed()) { t->setSrcInfo(srcInfo); - ctx->getBase()->pendingDefaults.insert(t); + ctx->getBase()->pendingDefaults[0].insert(t); } return t; } diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index 90c5c677..32c8b966 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -346,6 +346,8 @@ public: std::string getClassMethod(types::Type *typ, const std::string &member); std::string getTemporaryVar(const std::string &s); bool isImportFn(const std::string &s); + int64_t getTime(); + types::Type *getUnderlyingStaticType(types::Type *t); int64_t getIntLiteral(types::Type *t, size_t pos = 0); bool getBoolLiteral(types::Type *t, size_t pos = 0); diff --git a/test/parser/typecheck/test_access.codon b/test/parser/typecheck/test_access.codon index 4282c423..f1f0c1e2 100644 --- a/test/parser/typecheck/test_access.codon +++ b/test/parser/typecheck/test_access.codon @@ -431,6 +431,21 @@ fox(1, 2) fox(1, 2, 3) #: fox 1: 1 2 3 +# Test whether recursive self references override overloads (they shouldn't) + +def arange(start: int, stop: int, step: int): + return (start, stop, step) + +@overload +def arange(stop: int): + return arange(0, stop, 1) + +print(arange(0, 1, 2)) +#: (0, 1, 2) +print(arange(12)) +#: (0, 12, 1) + + #%% fn_shadow,barebones def foo(x): return 1, x diff --git a/test/parser/typecheck/test_op.codon b/test/parser/typecheck/test_op.codon index ba114611..e5df113d 100644 --- a/test/parser/typecheck/test_op.codon +++ b/test/parser/typecheck/test_op.codon @@ -421,6 +421,13 @@ foo2(s[10:50]) #: kl True foo2(s[1:30:3]) #: behk True foo2(s[::-1]) #: lkjihgfedcba True +#%% static_short_circuit,barebones +x = 3.14 +if isinstance(x, List) and x.T is float: + print('is list') +else: + print('not list') #: not list + #%% partial_star_pipe_args,barebones iter(['A', 'C']) |> print #: A diff --git a/test/parser/typecheck/test_typecheck.codon b/test/parser/typecheck/test_typecheck.codon index 8360d2f3..c5a9d412 100644 --- a/test/parser/typecheck/test_typecheck.codon +++ b/test/parser/typecheck/test_typecheck.codon @@ -40,3 +40,29 @@ def foo(x: List[1]): pass #! expected type expression a = 5; b = 3 print a, b #: 5 3 +#%% delayed_instantiation_correct_context,barebones +# Test timing of the statements; ensure that delayed blocks still +# use correct names. +def foo(): + l = [] + + s = 1 # CH1 + if isinstance(l, List[int]): # delay typechecking this block + print(s) #: 1 + # if this is done badly, this print will print 's' + # or result in assertion error + print(s) #: 1 + + s = 's' # CH2 + print(s) #: s + + # instantiate l so that the block above + # is typechecked in the next iteration + l.append(1) +foo() + +# check that this does not mess up comprehensions +# (where variable names are used BEFORE their declaration) +slice_prefixes = [(start, end) + for start, end in [(1, 2), (3, 4)]] +print(slice_prefixes) #: [(1, 2), (3, 4)]