Merge simplify & typecheck [wip]

typecheck-v2
Ibrahim Numanagić 2023-06-25 00:17:52 +02:00
parent bd6be10834
commit 50f0c3803a
37 changed files with 1147 additions and 1032 deletions

View File

@ -4,4 +4,6 @@ WarningsAsErrors: false
HeaderFilterRegex: '(build/.+)|(codon/util/.+)' HeaderFilterRegex: '(build/.+)|(codon/util/.+)'
AnalyzeTemporaryDtors: false AnalyzeTemporaryDtors: false
FormatStyle: llvm FormatStyle: llvm
... CheckOptions:
- key: cppcoreguidelines-macro-usage.CheckCapsOnly
value: '1'

View File

@ -11,7 +11,14 @@
#include "codon/parser/visitors/visitor.h" #include "codon/parser/visitors/visitor.h"
#define ACCEPT_IMPL(T, X) \ #define ACCEPT_IMPL(T, X) \
ExprPtr T::clone() const { return std::make_shared<T>(*this); } \ ExprPtr T::clone() const { \
auto e = std::make_shared<T>(*this); \
e->type = nullptr; \
e->done = false; \
e->attributes = 0; \
return e; \
} \
ExprPtr T::full_clone() const { return std::make_shared<T>(*this); } \
void T::accept(X &visitor) { visitor.visit(this); } void T::accept(X &visitor) { visitor.visit(this); }
using fmt::format; using fmt::format;
@ -126,6 +133,7 @@ IntExpr::IntExpr(const std::string &value, std::string suffix)
std::make_unique<int64_t>(std::stoull(this->value.substr(2), nullptr, 2)); std::make_unique<int64_t>(std::stoull(this->value.substr(2), nullptr, 2));
else else
intValue = std::make_unique<int64_t>(std::stoull(this->value, nullptr, 0)); intValue = std::make_unique<int64_t>(std::stoull(this->value, nullptr, 0));
staticValue = StaticValue(*intValue);
} catch (std::out_of_range &) { } catch (std::out_of_range &) {
intValue = nullptr; intValue = nullptr;
} }

View File

@ -15,6 +15,7 @@ namespace codon::ast {
#define ACCEPT(X) \ #define ACCEPT(X) \
ExprPtr clone() const override; \ ExprPtr clone() const override; \
ExprPtr full_clone() const override; \
void accept(X &visitor) override void accept(X &visitor) override
// Forward declarations // Forward declarations
@ -95,6 +96,8 @@ public:
void validate() const; void validate() const;
/// Deep copy a node. /// Deep copy a node.
virtual std::shared_ptr<Expr> clone() const = 0; virtual std::shared_ptr<Expr> clone() const = 0;
/// Deep copy a node; preserve types/attributes!
virtual std::shared_ptr<Expr> full_clone() const = 0;
/// Accept an AST visitor. /// Accept an AST visitor.
virtual void accept(ASTVisitor &visitor) = 0; virtual void accept(ASTVisitor &visitor) = 0;

View File

@ -30,7 +30,7 @@ struct FunctionStmt;
* A Seq AST statement. * A Seq AST statement.
* Each AST statement is intended to be instantiated as a shared_ptr. * Each AST statement is intended to be instantiated as a shared_ptr.
*/ */
struct Stmt : public codon::SrcObject { struct Stmt : public codon::SrcObject, public std::enable_shared_from_this<Stmt> {
using base_type = Stmt; using base_type = Stmt;
/// Flag that indicates if all types in a statement are inferred (i.e. if a /// Flag that indicates if all types in a statement are inferred (i.e. if a

View File

@ -36,8 +36,12 @@ int LinkType::unify(Type *typ, Unification *undo) {
return -1; return -1;
} else { } else {
// Case 3: Unbound unification // Case 3: Unbound unification
if (isStaticType() != typ->isStaticType()) if (isStaticType() != typ->isStaticType()) {
return -1; if (!isStaticType())
isStatic = typ->isStaticType();
else
return -1;
}
if (auto ts = typ->getStatic()) { if (auto ts = typ->getStatic()) {
if (ts->expr->getId()) if (ts->expr->getId())
return unify(ts->generics[0].type.get(), undo); return unify(ts->generics[0].type.get(), undo);
@ -154,11 +158,12 @@ bool LinkType::isInstantiated() const { return kind == Link && type->isInstantia
std::string LinkType::debugString(char mode) const { std::string LinkType::debugString(char mode) const {
if (kind == Unbound || kind == Generic) { if (kind == Unbound || kind == Generic) {
if (mode == 2) { if (mode == 2) {
return fmt::format("{}{}{}", kind == Unbound ? '?' : '#', id, return fmt::format("{}{}{}{}", genericName.empty() ? "" : genericName + ":",
kind == Unbound ? '?' : '#', id,
trait ? ":" + trait->debugString(mode) : ""); trait ? ":" + trait->debugString(mode) : "");
} } else if (trait) {
if (trait)
return trait->debugString(mode); return trait->debugString(mode);
}
return (genericName.empty() ? (mode ? "?" : "<unknown type>") : genericName); return (genericName.empty() ? (mode ? "?" : "<unknown type>") : genericName);
} }
return type->debugString(mode); return type->debugString(mode);

View File

@ -137,12 +137,11 @@ std::string StaticType::realizedName() const {
StaticValue StaticType::evaluate() const { StaticValue StaticType::evaluate() const {
if (expr->staticValue.evaluated) if (expr->staticValue.evaluated)
return expr->staticValue; return expr->staticValue;
cache->typeCtx->addBlock(); auto ctx = std::make_shared<TypeContext>(cache);
for (auto &g : generics) for (auto &g : generics)
cache->typeCtx->addType(g.name, g.name, getSrcInfo(), g.type); ctx->addType(g.name, g.name, g.type);
auto en = TypecheckVisitor(cache->typeCtx).transform(expr->clone()); auto en = TypecheckVisitor(ctx).transform(expr->clone());
seqassert(en->isStatic() && en->staticValue.evaluated, "{} cannot be evaluated", en); seqassert(en->isStatic() && en->staticValue.evaluated, "{} cannot be evaluated", en);
cache->typeCtx->popBlock();
return en->staticValue; return en->staticValue;
} }
@ -157,8 +156,7 @@ void StaticType::parseExpr(const ExprPtr &e, std::unordered_set<std::string> &se
: genTyp->getStatic()->generics.empty() : genTyp->getStatic()->generics.empty()
? 0 ? 0
: genTyp->getStatic()->generics[0].id; : genTyp->getStatic()->generics[0].id;
generics.emplace_back(ei->value, generics.emplace_back(ei->value, cache->reverseIdentifierLookup[ei->value],
cache->typeCtx->cache->reverseIdentifierLookup[ei->value],
genTyp, id); genTyp, id);
seen.insert(ei->value); seen.insert(ei->value);
} }

View File

@ -190,7 +190,7 @@ TypePtr TypeTrait::instantiate(int atLevel, int *unboundCount,
} }
std::string TypeTrait::debugString(char mode) const { std::string TypeTrait::debugString(char mode) const {
return fmt::format("Trait[{}]", type->debugString(mode)); return fmt::format("Trait[{}]", type->getClass() ? type->getClass()->name : "-");
} }
} // namespace codon::ast::types } // namespace codon::ast::types

View File

@ -16,7 +16,9 @@
namespace codon::ast { namespace codon::ast {
Cache::Cache(std::string argv0) : argv0(std::move(argv0)) {} Cache::Cache(std::string argv0) : argv0(std::move(argv0)) {
typeCtx = std::make_shared<TypeContext>(this, ".root");
}
std::string Cache::getTemporaryVar(const std::string &prefix, char sigil) { std::string Cache::getTemporaryVar(const std::string &prefix, char sigil) {
return fmt::format("{}{}_{}", sigil ? fmt::format("{}_", sigil) : "", prefix, return fmt::format("{}{}_{}", sigil ? fmt::format("{}_", sigil) : "", prefix,
@ -59,17 +61,17 @@ std::string Cache::getContent(const SrcInfo &info) {
types::ClassTypePtr Cache::findClass(const std::string &name) const { types::ClassTypePtr Cache::findClass(const std::string &name) const {
auto f = typeCtx->find(name); auto f = typeCtx->find(name);
if (f && f->kind == TypecheckItem::Type) if (f && f->isType())
return f->type->getClass(); return f->type->getClass();
return nullptr; return nullptr;
} }
types::FuncTypePtr Cache::findFunction(const std::string &name) const { types::FuncTypePtr Cache::findFunction(const std::string &name) const {
auto f = typeCtx->find(name); auto f = typeCtx->find(name);
if (f && f->type && f->kind == TypecheckItem::Func) if (f && f->type && f->isFunc())
return f->type->getFunc(); return f->type->getFunc();
f = typeCtx->find(name + ":0"); f = typeCtx->find(name + ":0");
if (f && f->type && f->kind == TypecheckItem::Func) if (f && f->type && f->isFunc())
return f->type->getFunc(); return f->type->getFunc();
return nullptr; return nullptr;
} }
@ -79,6 +81,7 @@ types::FuncTypePtr Cache::findMethod(types::ClassType *typ, const std::string &m
auto e = std::make_shared<IdExpr>(typ->name); auto e = std::make_shared<IdExpr>(typ->name);
e->type = typ->getClass(); e->type = typ->getClass();
seqassertn(e->type, "not a class"); seqassertn(e->type, "not a class");
auto f = TypecheckVisitor(typeCtx).findBestMethod(e->type->getClass(), member, args); auto f = TypecheckVisitor(typeCtx).findBestMethod(e->type->getClass(), member, args);
return f; return f;
} }

View File

@ -64,7 +64,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
int varCount = 0; int varCount = 0;
/// Holds module import data. /// Holds module import data.
struct Import { struct Module {
/// Relative module name (e.g., `foo.bar`)
std::string name;
/// Absolute filename of an import. /// Absolute filename of an import.
std::string filename; std::string filename;
/// Import typechecking context. /// Import typechecking context.
@ -73,8 +75,6 @@ struct Cache : public std::enable_shared_from_this<Cache> {
std::string importVar; std::string importVar;
/// File content (line:col indexable) /// File content (line:col indexable)
std::vector<std::string> content; std::vector<std::string> content;
/// Relative module name (e.g., `foo.bar`)
std::string moduleName;
}; };
/// Absolute path of seqc executable (if available). /// Absolute path of seqc executable (if available).
@ -85,8 +85,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
ir::Module *module = nullptr; ir::Module *module = nullptr;
/// Table of imported files that maps an absolute filename to a Import structure. /// Table of imported files that maps an absolute filename to a Import structure.
/// By convention, the key of the Codon's standard library is "". /// By convention, the key of the Codon's standard library is ":stdlib:",
std::unordered_map<std::string, Import> imports; /// and the main module is "".
std::unordered_map<std::string, Module> imports;
/// Set of unique (canonical) global identifiers for marking such variables as global /// Set of unique (canonical) global identifiers for marking such variables as global
/// in code-generation step and in JIT. /// in code-generation step and in JIT.
@ -94,10 +95,13 @@ struct Cache : public std::enable_shared_from_this<Cache> {
/// Stores class data for each class (type) in the source code. /// Stores class data for each class (type) in the source code.
struct Class { struct Class {
/// Module information
std::string module;
/// Generic (unrealized) class template AST. /// Generic (unrealized) class template AST.
std::shared_ptr<ClassStmt> ast; std::shared_ptr<ClassStmt> ast = nullptr;
/// Non-simplified AST. Used for base class instantiation. /// Non-simplified AST. Used for base class instantiation.
std::shared_ptr<ClassStmt> originalAst; std::shared_ptr<ClassStmt> originalAst = nullptr;
/// Class method lookup table. Each non-canonical name points /// Class method lookup table. Each non-canonical name points
/// to a root function name of a corresponding method. /// to a root function name of a corresponding method.
@ -155,10 +159,7 @@ struct Cache : public std::enable_shared_from_this<Cache> {
/// List of statically inherited classes. /// List of statically inherited classes.
std::vector<std::string> staticParentClasses; std::vector<std::string> staticParentClasses;
/// Module information bool hasRTTI() const { return rtti; }
std::string module;
Class() : ast(nullptr), originalAst(nullptr), rtti(false) {}
}; };
/// Class lookup table that maps a canonical class identifier to the corresponding /// Class lookup table that maps a canonical class identifier to the corresponding
/// Class instance. /// Class instance.
@ -166,10 +167,12 @@ struct Cache : public std::enable_shared_from_this<Cache> {
size_t classRealizationCnt = 0; size_t classRealizationCnt = 0;
struct Function { struct Function {
/// Module information
std::string module;
/// Generic (unrealized) function template AST. /// Generic (unrealized) function template AST.
std::shared_ptr<FunctionStmt> ast; std::shared_ptr<FunctionStmt> ast = nullptr;
/// Non-simplified AST. /// Non-simplified AST.
std::shared_ptr<FunctionStmt> origAst; std::shared_ptr<FunctionStmt> origAst = nullptr;
/// A function realization. /// A function realization.
struct FunctionRealization { struct FunctionRealization {
@ -186,15 +189,10 @@ struct Cache : public std::enable_shared_from_this<Cache> {
std::unordered_map<std::string, std::shared_ptr<FunctionRealization>> realizations; std::unordered_map<std::string, std::shared_ptr<FunctionRealization>> realizations;
/// Unrealized function type. /// Unrealized function type.
types::FuncTypePtr type; types::FuncTypePtr type = nullptr;
/// Module information std::string rootName;
std::string rootName = "";
bool isToplevel = false; bool isToplevel = false;
Function()
: ast(nullptr), origAst(nullptr), type(nullptr), rootName(""),
isToplevel(false) {}
}; };
/// Function lookup table that maps a canonical function identifier to the /// Function lookup table that maps a canonical function identifier to the
/// corresponding Function instance. /// corresponding Function instance.
@ -229,7 +227,6 @@ struct Cache : public std::enable_shared_from_this<Cache> {
bool isJit = false; bool isJit = false;
int jitCell = 0; int jitCell = 0;
std::unordered_map<std::string, std::pair<std::string, bool>> replacements;
std::unordered_map<std::string, int> generatedTuples; std::unordered_map<std::string, int> generatedTuples;
std::vector<exc::ParserException> errors; std::vector<exc::ParserException> errors;

View File

@ -104,6 +104,12 @@ const V *in(const std::unordered_map<K, V> &m, const U &item) {
auto f = m.find(item); auto f = m.find(item);
return f != m.end() ? &(f->second) : nullptr; return f != m.end() ? &(f->second) : nullptr;
} }
/// @return True if an item is found in an unordered_map m.
template <typename K, typename V, typename U>
V *in(std::unordered_map<K, V> &m, const U &item) {
auto f = m.find(item);
return f != m.end() ? &(f->second) : nullptr;
}
/// @return vector c transformed by the function f. /// @return vector c transformed by the function f.
template <typename T, typename F> auto vmap(const std::vector<T> &c, F &&f) { template <typename T, typename F> auto vmap(const std::vector<T> &c, F &&f) {
std::vector<typename std::result_of<F(const T &)>::type> ret; std::vector<typename std::result_of<F(const T &)>::type> ret;

View File

@ -266,7 +266,7 @@ SimplifyVisitor::getImport(const std::vector<std::string> &chain) {
} }
if (itemName.empty()) if (itemName.empty())
E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd], E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd],
ctx->cache->imports[importName].moduleName); ctx->cache->imports[importName].name);
importEnd = itemEnd; importEnd = itemEnd;
} }
return {importEnd, val}; return {importEnd, val};

View File

@ -25,14 +25,20 @@ void TypecheckVisitor::visit(IdExpr *expr) {
if (isTuple(expr->value)) if (isTuple(expr->value))
generateTuple(std::stoi(expr->value.substr(sizeof(TYPE_TUPLE) - 1))); generateTuple(std::stoi(expr->value.substr(sizeof(TYPE_TUPLE) - 1)));
auto val = ctx->findDominatingBinding(expr->value, this); auto val = findDominatingBinding(expr->value, ctx.get());
if (!val && ctx->getBase()->pyCaptures) { if (!val && ctx->getBase()->pyCaptures) {
ctx->getBase()->pyCaptures->insert(expr->value); ctx->getBase()->pyCaptures->insert(expr->value);
resultExpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(expr->value)); resultExpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(expr->value));
return; return;
} else if (!val) { } else if (!val) {
E(Error::ID_NOT_FOUND, expr, expr->value); if (in(ctx->cache->overloads, expr->value))
val = ctx->forceFind(getDispatch(expr->value)->ast->name);
if (!val) {
ctx->dump();
// LOG("=================================================================");
// ctx->cache->typeCtx->dump();
E(Error::ID_NOT_FOUND, expr, expr->value);
}
} }
// If we are accessing an outside variable, capture it or raise an error // If we are accessing an outside variable, capture it or raise an error
@ -54,7 +60,6 @@ void TypecheckVisitor::visit(IdExpr *expr) {
// Replace the variable with its canonical name // Replace the variable with its canonical name
expr->value = val->canonicalName; expr->value = val->canonicalName;
val->references.push_back(expr->shared_from_this());
// Mark global as "seen" to prevent later creation of local variables // Mark global as "seen" to prevent later creation of local variables
// with the same name. Example: // with the same name. Example:
@ -99,8 +104,6 @@ void TypecheckVisitor::visit(IdExpr *expr) {
} }
} }
// todo)) handle overloads [each overloaded fn is basically a new FnOverload object]
// Set up type // Set up type
unify(expr->type, ctx->instantiate(val->type)); unify(expr->type, ctx->instantiate(val->type));
if (val->type->isStaticType()) { if (val->type->isStaticType()) {
@ -111,10 +114,11 @@ void TypecheckVisitor::visit(IdExpr *expr) {
expr->toString()); expr->toString());
if (s && s->expr->staticValue.evaluated) { if (s && s->expr->staticValue.evaluated) {
// Replace the identifier with static expression // Replace the identifier with static expression
if (s->expr->staticValue.type == StaticValue::STRING) if (s->expr->staticValue.type == StaticValue::STRING) {
resultExpr = transform(N<StringExpr>(s->expr->staticValue.getString())); resultExpr = transform(N<StringExpr>(s->expr->staticValue.getString()));
else } else {
resultExpr = transform(N<IntExpr>(s->expr->staticValue.getInt())); resultExpr = transform(N<IntExpr>(s->expr->staticValue.getInt()));
}
} }
return; return;
} }
@ -135,46 +139,101 @@ void TypecheckVisitor::visit(IdExpr *expr) {
/// `python.foo` -> internal.python._get_identifier("foo") /// `python.foo` -> internal.python._get_identifier("foo")
/// Other cases are handled during the type checking. /// Other cases are handled during the type checking.
/// See @c transformDot for details. /// See @c transformDot for details.
void TypecheckVisitor::visit(DotExpr *expr) { void TypecheckVisitor::visit(DotExpr *expr) { resultExpr = transformDot(expr); }
if (!expr->type) {
// First flatten the imports:
// transform Dot(Dot(a, b), c...) to {a, b, c, ...}
std::vector<std::string> chain;
Expr *root = expr;
for (; root->getDot(); root = root->getDot()->expr.get())
chain.push_back(root->getDot()->member);
if (auto id = root->getId()) { /// Get an item from the context. Perform domination analysis for accessing items
// Case: a.bar.baz /// defined in the conditional blocks (i.e., Python scoping).
chain.push_back(id->value); TypeContext::Item TypecheckVisitor::findDominatingBinding(const std::string &name,
std::reverse(chain.begin(), chain.end()); TypeContext *ctx) {
auto [pos, val] = getImport(chain); auto it = ctx->find_all(name);
if (!it) {
if (!val) { return ctx->find(name);
seqassert(ctx->getBase()->pyCaptures, "unexpected py capture"); } else if (ctx->isCanonicalName(name)) {
ctx->getBase()->pyCaptures->insert(chain[0]); return *(it->begin());
resultExpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(chain[0]));
} else if (val->getModule() == "std.python") {
resultExpr = transform(N<CallExpr>(
N<DotExpr>(N<DotExpr>(N<IdExpr>("internal"), "python"), "_get_identifier"),
N<StringExpr>(chain[pos++])));
} else if (val->getModule() == ctx->getModule() && pos == 1) {
resultExpr = transform(N<IdExpr>(chain[0]), true);
} else {
resultExpr = N<IdExpr>(val->canonicalName);
if (val->isType() && pos == chain.size())
resultExpr->markType();
}
while (pos < chain.size())
resultExpr = N<DotExpr>(resultExpr, chain[pos++]);
resultExpr = transformDot(resultExpr->getDot());
} else {
transform(expr->expr, true);
resultExpr = transformDot(expr);
}
} else {
resultExpr = transformDot(expr);
} }
seqassert(!it->empty(), "corrupted TypecheckContext ({})", name);
// The item is found. Let's see is it accessible now.
std::string canonicalName;
auto lastGood = it->begin();
bool isOutside = (*lastGood)->getBaseName() != ctx->getBaseName();
int prefix = int(ctx->scope.blocks.size());
// Iterate through all bindings with the given name and find the closest binding that
// dominates the current scope.
for (auto i = it->begin(); i != it->end(); i++) {
// Find the longest block prefix between the binding and the current scope.
int p = std::min(prefix, int((*i)->scope.size()));
while (p >= 0 && (*i)->scope[p - 1] != ctx->scope.blocks[p - 1])
p--;
// We reached the toplevel. Break.
if (p < 0)
break;
// We went outside the function scope. Break.
if (!isOutside && (*i)->getBaseName() != ctx->getBaseName())
break;
prefix = p;
lastGood = i;
// The binding completely dominates the current scope. Break.
if ((*i)->scope.size() <= ctx->scope.blocks.size() &&
(*i)->scope.back() == ctx->scope.blocks[(*i)->scope.size() - 1])
break;
}
seqassert(lastGood != it->end(), "corrupted scoping ({})", name);
if (lastGood != it->begin() && !(*lastGood)->isVar())
E(Error::CLASS_INVALID_BIND, getSrcInfo(), name);
bool hasUsed = false;
types::TypePtr type = nullptr;
if ((*lastGood)->scope.size() == prefix) {
// The current scope is dominated by a binding. Use that binding.
canonicalName = (*lastGood)->canonicalName;
type = (*lastGood)->type;
} else {
// The current scope is potentially reachable by multiple bindings that are
// not dominated by a common binding. Create such binding in the scope that
// dominates (covers) all of them.
canonicalName = ctx->generateCanonicalName(name);
auto item = std::make_shared<TypecheckItem>(
canonicalName, (*lastGood)->baseName, (*lastGood)->moduleName,
ctx->getUnbound(getSrcInfo()),
std::vector<int>(ctx->scope.blocks.begin(),
ctx->scope.blocks.begin() + prefix));
item->accessChecked = {(*lastGood)->scope};
type = item->type;
lastGood = it->insert(++lastGood, item);
// Make sure to prepend a binding declaration: `var` and `var__used__ = False`
// to the dominating scope.
ctx->scope.stmts[ctx->scope.blocks[prefix - 1]].push_back(
N<SuiteStmt>(N<AssignStmt>(N<IdExpr>(canonicalName), nullptr, nullptr),
N<AssignStmt>(N<IdExpr>(fmt::format("{}.__used__", canonicalName)),
N<BoolExpr>(false), nullptr)));
// Reached the toplevel? Register the binding as global.
if (prefix == 1) {
ctx->cache->addGlobal(canonicalName);
ctx->cache->addGlobal(fmt::format("{}.__used__", canonicalName));
}
hasUsed = true;
}
// Remove all bindings after the dominant binding.
for (auto i = it->begin(); i != it->end(); i++) {
if (i == lastGood)
break;
if (!(*i)->canDominate())
continue;
// These bindings (and their canonical identifiers) will be replaced by the
// dominating binding during the type checking pass.
ctx->getBase()->replacements[(*i)->canonicalName] = {canonicalName, hasUsed};
ctx->getBase()->replacements[format("{}.__used__", (*i)->canonicalName)] = {
format("{}.__used__", canonicalName), false};
seqassert((*i)->canonicalName != canonicalName, "invalid replacement at {}: {}",
getSrcInfo(), canonicalName);
ctx->removeFromTopStack(name);
}
it->erase(it->begin(), lastGood);
return it->front();
} }
/// Access identifiers from outside of the current function/class scope. /// Access identifiers from outside of the current function/class scope.
@ -219,7 +278,7 @@ bool TypecheckVisitor::checkCapture(const TypeContext::Item &val) {
// Case: a global variable that has not been marked with `global` statement // Case: a global variable that has not been marked with `global` statement
if (val->isVar() && val->getBaseName().empty() && val->scope.size() == 1) { if (val->isVar() && val->getBaseName().empty() && val->scope.size() == 1) {
val->noShadow = true; val->canShadow = false;
if (!val->isStatic()) if (!val->isStatic())
ctx->cache->addGlobal(val->canonicalName); ctx->cache->addGlobal(val->canonicalName);
return false; return false;
@ -247,11 +306,11 @@ bool TypecheckVisitor::checkCapture(const TypeContext::Item &val) {
// Add newly generated argument to the context // Add newly generated argument to the context
std::shared_ptr<TypecheckItem> newVal = nullptr; std::shared_ptr<TypecheckItem> newVal = nullptr;
if (val->isType()) if (val->isType())
newVal = ctx->addType(ctx->cache->rev(val->canonicalName), newName, getSrcInfo()); newVal = ctx->addType(ctx->cache->rev(val->canonicalName), newName, val->type);
else else
newVal = ctx->addVar(ctx->cache->rev(val->canonicalName), newName, getSrcInfo()); newVal = ctx->addVar(ctx->cache->rev(val->canonicalName), newName, val->type);
newVal->baseName = ctx->getBaseName(); newVal->baseName = ctx->getBaseName();
newVal->noShadow = true; // todo)) needed here? remove noshadow on fn boundaries? newVal->canShadow = false; // todo)) needed here? remove noshadow on fn boundaries?
newVal->scope = ctx->getBase()->scope; newVal->scope = ctx->getBase()->scope;
return true; return true;
} }
@ -272,9 +331,11 @@ TypecheckVisitor::getImport(const std::vector<std::string> &chain) {
// (e.g., `a.b.c.d` -> `a.b.c` if there is `import a.b.c`) // (e.g., `a.b.c.d` -> `a.b.c` if there is `import a.b.c`)
TypeContext::Item val = nullptr; TypeContext::Item val = nullptr;
for (auto i = chain.size(); i-- > 0;) { for (auto i = chain.size(); i-- > 0;) {
val = ctx->find(join(chain, "/", 0, i + 1)); auto name = join(chain, "/", 0, i + 1);
if (val && val->isImport()) { val = ctx->find(name);
importName = val->importPath, importEnd = i + 1; if (val && val->type->is("Import") && name != "Import") {
importName = getClassStaticStr(val->type->getClass());
importEnd = i + 1;
break; break;
} }
} }
@ -289,12 +350,16 @@ TypecheckVisitor::getImport(const std::vector<std::string> &chain) {
if (fctx->getModule() == "std.python" && importEnd < chain.size()) { if (fctx->getModule() == "std.python" && importEnd < chain.size()) {
// Special case: importing from Python. // Special case: importing from Python.
// Fake TypecheckItem that indicates std.python access // Fake TypecheckItem that indicates std.python access
val = std::make_shared<TypecheckItem>(TypecheckItem::Var, "", "", val = std::make_shared<TypecheckItem>("", "", fctx->getModule(),
fctx->getModule(), std::vector<int>{}); fctx->getUnbound());
return {importEnd, val}; return {importEnd, val};
} else { } else {
val = fctx->find(join(chain, ".", importEnd, i + 1)); val = fctx->find(join(chain, ".", importEnd, i + 1));
if (val && (importName.empty() || val->isType() || !val->isConditional())) { bool isOverload = val && val->isFunc() &&
in(ctx->cache->overloads, val->canonicalName) &&
ctx->cache->overloads[val->canonicalName].size() > 1;
if (val && !isOverload &&
(importName.empty() || val->isType() || !val->isConditional())) {
itemName = val->canonicalName, itemEnd = i + 1; itemName = val->canonicalName, itemEnd = i + 1;
break; break;
} }
@ -305,9 +370,10 @@ TypecheckVisitor::getImport(const std::vector<std::string> &chain) {
return {1, nullptr}; return {1, nullptr};
E(Error::IMPORT_NO_MODULE, getSrcInfo(), chain[importEnd]); E(Error::IMPORT_NO_MODULE, getSrcInfo(), chain[importEnd]);
} }
if (itemName.empty()) if (itemName.empty()) {
E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd], E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd],
ctx->cache->imports[importName].moduleName); ctx->cache->imports[importName].name);
}
importEnd = itemEnd; importEnd = itemEnd;
} }
return {importEnd, val}; return {importEnd, val};
@ -351,7 +417,7 @@ types::FuncTypePtr TypecheckVisitor::getDispatch(const std::string &fn) {
auto baseType = getFuncTypeBase(2); auto baseType = getFuncTypeBase(2);
auto typ = std::make_shared<FuncType>(baseType, ast.get()); auto typ = std::make_shared<FuncType>(baseType, ast.get());
typ = std::static_pointer_cast<FuncType>(typ->generalize(ctx->typecheckLevel - 1)); typ = std::static_pointer_cast<FuncType>(typ->generalize(ctx->typecheckLevel - 1));
ctx->addFunc(name, name, getSrcInfo(), typ); ctx->addFunc(name, name, typ);
overloads.insert(overloads.begin(), name); overloads.insert(overloads.begin(), name);
ctx->cache->functions[name].ast = ast; ctx->cache->functions[name].ast = ast;
@ -376,12 +442,49 @@ types::FuncTypePtr TypecheckVisitor::getDispatch(const std::string &fn) {
/// See @c getClassMember and @c getBestOverload /// See @c getClassMember and @c getBestOverload
ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
std::vector<CallExpr::Arg> *args) { std::vector<CallExpr::Arg> *args) {
// First flatten the imports:
// transform Dot(Dot(a, b), c...) to {a, b, c, ...}
std::vector<std::string> chain;
Expr *root = expr;
for (; root->getDot(); root = root->getDot()->expr.get())
chain.push_back(root->getDot()->member);
ExprPtr nexpr = expr->shared_from_this();
if (auto id = root->getId()) {
// Case: a.bar.baz
chain.push_back(id->value);
std::reverse(chain.begin(), chain.end());
auto [pos, val] = getImport(chain);
if (!val) {
seqassert(ctx->getBase()->pyCaptures, "unexpected py capture");
ctx->getBase()->pyCaptures->insert(chain[0]);
nexpr = N<IndexExpr>(N<IdExpr>("__pyenv__"), N<StringExpr>(chain[0]));
} else if (val->getModule() == "std.python") {
nexpr = transform(N<CallExpr>(
N<DotExpr>(N<DotExpr>(N<IdExpr>("internal"), "python"), "_get_identifier"),
N<StringExpr>(chain[pos++])));
} else if (val->getModule() == ctx->getModule() && pos == 1) {
nexpr = transform(N<IdExpr>(chain[0]), true);
} else {
nexpr = N<IdExpr>(val->canonicalName);
if (val->isType() && pos == chain.size())
nexpr->markType();
}
while (pos < chain.size())
nexpr = N<DotExpr>(nexpr, chain[pos++]);
}
if (!nexpr->getDot()) {
return transform(nexpr);
} else {
expr->expr = nexpr->getDot()->expr;
expr->member = nexpr->getDot()->member;
}
// Special case: obj.__class__ // Special case: obj.__class__
if (expr->member == "__class__") { if (expr->member == "__class__") {
/// TODO: prevent cls.__class__ and type(cls) /// TODO: prevent cls.__class__ and type(cls)
return transformType(NT<CallExpr>(NT<IdExpr>("type"), expr->expr)); return transformType(NT<CallExpr>(NT<IdExpr>("type"), expr->expr));
} }
transform(expr->expr); transform(expr->expr);
// Special case: fn.__name__ // Special case: fn.__name__
@ -455,7 +558,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
// ) // )
auto e = N<CallExpr>( auto e = N<CallExpr>(
fnType, fnType,
N<IndexExpr>(N<CallExpr>(N<IdExpr>("__internal__.class_get_rtti_vtable:0"), N<IndexExpr>(N<CallExpr>(N<IdExpr>("__internal__.class_get_rtti_vtable"),
expr->expr), expr->expr),
N<IntExpr>(vid))); N<IntExpr>(vid)));
return transform(e); return transform(e);
@ -569,7 +672,7 @@ ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr,
// Case: transform `union.m` to `__internal__.get_union_method(union, "m", ...)` // Case: transform `union.m` to `__internal__.get_union_method(union, "m", ...)`
if (typ->getUnion()) { if (typ->getUnion()) {
return transform(N<CallExpr>( return transform(N<CallExpr>(
N<IdExpr>("__internal__.get_union_method:0"), N<IdExpr>("__internal__.get_union_method"),
std::vector<CallExpr::Arg>{{"union", expr->expr}, std::vector<CallExpr::Arg>{{"union", expr->expr},
{"method", N<StringExpr>(expr->member)}, {"method", N<StringExpr>(expr->member)},
{"", N<EllipsisExpr>(EllipsisExpr::PARTIAL)}})); {"", N<EllipsisExpr>(EllipsisExpr::PARTIAL)}}));

View File

@ -48,125 +48,18 @@ void TypecheckVisitor::visit(AssignExpr *expr) {
/// See @c transformAssignment and @c unpackAssignments for more details. /// See @c transformAssignment and @c unpackAssignments for more details.
/// See @c wrapExpr for more examples. /// See @c wrapExpr for more examples.
void TypecheckVisitor::visit(AssignStmt *stmt) { void TypecheckVisitor::visit(AssignStmt *stmt) {
std::vector<StmtPtr> stmts;
if (stmt->rhs && stmt->rhs->getBinary() && stmt->rhs->getBinary()->inPlace) { if (stmt->rhs && stmt->rhs->getBinary() && stmt->rhs->getBinary()->inPlace) {
// Update case: a += b // Update case: a += b
seqassert(!stmt->type, "invalid AssignStmt {}", stmt->toString()); seqassert(!stmt->type, "invalid AssignStmt {}", stmt->toString());
resultStmt = transform(transformAssignment(stmt->lhs, stmt->rhs, nullptr, true)); resultStmt = transformAssignment(stmt->lhs, stmt->rhs, nullptr, true);
} else if (stmt->type) { } else if (!stmt->type && !stmt->lhs->getId()) {
// Type case: `a: T = b, c` (no unpacking)
resultStmt = transform(transformAssignment(stmt->lhs, stmt->rhs, stmt->type));
} else if (!stmt->lhs->getId()) {
// Normal case // Normal case
std::vector<StmtPtr> stmts;
unpackAssignments(stmt->lhs, stmt->rhs, stmts); unpackAssignments(stmt->lhs, stmt->rhs, stmts);
resultStmt = transform(N<SuiteStmt>(stmts)); resultStmt = transform(N<SuiteStmt>(stmts));
} else { } else {
auto assign = transformAssignment(stmt->lhs, stmt->rhs, stmt->type); // Type case: `a: T = b, c` (no unpacking); all other (invalid) cases
resultStmt = transformAssignment(stmt->lhs, stmt->rhs, stmt->type);
// Update statements are handled by @c visitUpdate
if (stmt->isUpdate()) {
transformUpdate(stmt);
return;
}
seqassert(stmt->lhs->getId(), "invalid AssignStmt {}", stmt->lhs);
std::string lhs = stmt->lhs->getId()->value;
// Special case: this assignment has been dominated and is not a true assignment but
// an update of the dominating binding.
if (auto changed = in(ctx->cache->replacements, lhs)) {
while (auto s = in(ctx->cache->replacements, lhs))
lhs = changed->first, changed = s;
if (stmt->rhs && changed->second) {
// Mark the dominating binding as used: `var.__used__ = True`
auto u = N<AssignStmt>(N<IdExpr>(fmt::format("{}.__used__", lhs)),
N<BoolExpr>(true));
u->setUpdate();
prependStmts->push_back(transform(u));
} else if (changed->second && !stmt->rhs) {
// This assignment was a declaration only. Just mark the dominating binding as
// used: `var.__used__ = True`
stmt->lhs = N<IdExpr>(fmt::format("{}.__used__", lhs));
stmt->rhs = N<BoolExpr>(true);
}
seqassert(stmt->rhs, "bad domination statement: '{}'", stmt->toString());
// Change this to the update and follow the update logic
stmt->setUpdate();
transformUpdate(stmt);
return;
}
transform(stmt->rhs);
transformType(stmt->type);
if (!stmt->rhs) {
// Forward declarations (e.g., dominating bindings, C imports etc.).
// The type is unknown and will be deduced later
unify(stmt->lhs->type, ctx->getUnbound(stmt->lhs->getSrcInfo()));
if (stmt->type) {
unify(stmt->lhs->type,
ctx->instantiate(stmt->type->getSrcInfo(), stmt->type->getType()));
}
ctx->addVar(lhs, lhs, getSrcInfo(), stmt->lhs->type);
if (realize(stmt->lhs->type))
stmt->setDone();
} else if (stmt->type && stmt->type->getType()->isStaticType()) {
// Static assignments (e.g., `x: Static[int] = 5`)
if (!stmt->rhs->isStatic())
E(Error::EXPECTED_STATIC, stmt->rhs);
seqassert(stmt->rhs->staticValue.evaluated, "static not evaluated");
unify(stmt->lhs->type,
unify(stmt->type->type, Type::makeStatic(ctx->cache, stmt->rhs)));
auto val = ctx->addVar(lhs, lhs, getSrcInfo(), stmt->lhs->type);
if (in(ctx->cache->globals, lhs)) {
// Make globals always visible!
ctx->addToplevel(lhs, val);
}
if (realize(stmt->lhs->type))
stmt->setDone();
} else {
// Normal assignments
unify(stmt->lhs->type, ctx->getUnbound());
if (stmt->type) {
unify(stmt->lhs->type,
ctx->instantiate(stmt->type->getSrcInfo(), stmt->type->getType()));
}
// Check if we can wrap the expression (e.g., `a: float = 3` -> `a = float(3)`)
if (wrapExpr(stmt->rhs, stmt->lhs->getType()))
unify(stmt->lhs->type, stmt->rhs->type);
auto type = stmt->lhs->getType();
auto kind = TypecheckItem::Var;
if (stmt->rhs->isType())
kind = TypecheckItem::Type;
else if (type->getFunc())
kind = TypecheckItem::Func;
// Generalize non-variable types. That way we can support cases like:
// `a = foo(x, ...); a(1); a('s')`
auto val = std::make_shared<TypecheckItem>(kind, ctx->getBaseName(), lhs,
ctx->getModule(), ctx->scope.blocks);
val->setSrcInfo(getSrcInfo());
val->type =
kind != TypecheckItem::Var ? type->generalize(ctx->typecheckLevel - 1) : type;
if (in(ctx->cache->globals, lhs)) {
// Make globals always visible!
ctx->addToplevel(lhs, val);
if (kind != TypecheckItem::Var)
ctx->cache->globals.erase(lhs);
} else if (startswith(ctx->getRealizationBase()->name, "._import_") &&
kind == TypecheckItem::Type) {
// Make import toplevel type aliases (e.g., `a = Ptr[byte]`) visible
ctx->addToplevel(lhs, val);
} else {
ctx->add(lhs, val);
}
if (stmt->lhs->getId() && kind != TypecheckItem::Var) {
// Special case: type/function renames
stmt->rhs->type = nullptr;
stmt->setDone();
} else if (stmt->rhs->isDone() && realize(stmt->lhs->type)) {
stmt->setDone();
}
}
} }
} }
@ -218,11 +111,12 @@ StmtPtr TypecheckVisitor::transformAssignment(ExprPtr lhs, ExprPtr rhs, ExprPtr
transform(dot->expr, true); transform(dot->expr, true);
// If we are deducing class members, check if we can deduce a member from this // If we are deducing class members, check if we can deduce a member from this
// assignment // assignment
auto deduced = ctx->getClassBase() ? ctx->getClassBase()->deducedMembers : nullptr; // todo)) deduction!
if (deduced && dot->expr->isId(ctx->getBase()->selfName) && // auto deduced = ctx->getClassBase() ? ctx->getClassBase()->deducedMembers :
!in(*deduced, dot->member)) // nullptr; if (deduced && dot->expr->isId(ctx->getBase()->selfName) &&
deduced->push_back(dot->member); // !in(*deduced, dot->member))
return N<AssignMemberStmt>(dot->expr, dot->member, transform(rhs)); // deduced->push_back(dot->member);
return transform(N<AssignMemberStmt>(dot->expr, dot->member, transform(rhs)));
} }
// Case: a (: t) = b // Case: a (: t) = b
@ -243,15 +137,16 @@ StmtPtr TypecheckVisitor::transformAssignment(ExprPtr lhs, ExprPtr rhs, ExprPtr
auto val = ctx->find(e->value); auto val = ctx->find(e->value);
// Make sure that existing values that cannot be shadowed (e.g. imported globals) are // Make sure that existing values that cannot be shadowed (e.g. imported globals) are
// only updated // only updated
mustExist |= val && val->noShadow && !ctx->isOuter(val); mustExist |= val && !val->canShadow && !ctx->isOuter(val);
if (mustExist) { if (mustExist) {
val = ctx->findDominatingBinding(e->value, this); val = findDominatingBinding(e->value, ctx.get());
if (val && val->isVar() && !ctx->isOuter(val)) { if (val && val->isVar() && !ctx->isOuter(val)) {
auto s = N<AssignStmt>(transform(lhs, false), transform(rhs)); auto s = N<AssignStmt>(lhs, rhs);
if (ctx->getBase()->attributes && ctx->getBase()->attributes->has(Attr::Atomic)) if (ctx->getBase()->attributes && ctx->getBase()->attributes->has(Attr::Atomic))
s->setAtomicUpdate(); s->setAtomicUpdate();
else else
s->setUpdate(); s->setUpdate();
transformUpdate(s.get());
return s; return s;
} else { } else {
E(Error::ASSIGN_LOCAL_REFERENCE, e, e->value); E(Error::ASSIGN_LOCAL_REFERENCE, e, e->value);
@ -264,23 +159,55 @@ StmtPtr TypecheckVisitor::transformAssignment(ExprPtr lhs, ExprPtr rhs, ExprPtr
// Generate new canonical variable name for this assignment and add it to the context // Generate new canonical variable name for this assignment and add it to the context
auto canonical = ctx->generateCanonicalName(e->value); auto canonical = ctx->generateCanonicalName(e->value);
auto assign = N<AssignStmt>(N<IdExpr>(canonical), rhs, type); auto assign = N<AssignStmt>(N<IdExpr>(canonical), rhs, type);
val = nullptr; unify(assign->lhs->type, ctx->getUnbound(assign->lhs->getSrcInfo()));
if (rhs && rhs->isType()) { if (assign->type) {
val = ctx->addType(e->value, canonical, lhs->getSrcInfo()); unify(assign->lhs->type,
} else { ctx->instantiate(assign->type->getSrcInfo(), assign->type->getType()));
val = ctx->addVar(e->value, canonical, lhs->getSrcInfo());
if (auto st = getStaticGeneric(type.get()))
val->staticType = st;
if (ctx->avoidDomination)
val->avoidDomination = true;
} }
val = std::make_shared<TypecheckItem>(canonical, ctx->getBaseName(), ctx->getModule(),
assign->lhs->type, ctx->scope.blocks);
val->setSrcInfo(getSrcInfo());
if (auto st = getStaticGeneric(assign->type.get()))
val->staticType = st;
if (ctx->avoidDomination)
val->avoidDomination = true;
ctx->Context<TypecheckItem>::add(e->value, val);
ctx->addAlwaysVisible(val);
LOG("added ass/{}: {}", val->isVar() ? "v" : (val->isFunc() ? "f" : "t"),
val->canonicalName);
if (assign->rhs && assign->type && assign->type->getType()->isStaticType()) {
// Static assignments (e.g., `x: Static[int] = 5`)
if (!assign->rhs->isStatic())
E(Error::EXPECTED_STATIC, assign->rhs);
seqassert(assign->rhs->staticValue.evaluated, "static not evaluated");
unify(assign->lhs->type,
unify(assign->type->type, Type::makeStatic(ctx->cache, assign->rhs)));
} else if (assign->rhs) {
// Check if we can wrap the expression (e.g., `a: float = 3` -> `a = float(3)`)
if (wrapExpr(assign->rhs, assign->lhs->getType()))
unify(assign->lhs->type, assign->rhs->type);
if (rhs->isType())
val->type = val->type->getClass();
auto type = assign->lhs->getType();
// Generalize non-variable types. That way we can support cases like:
// `a = foo(x, ...); a(1); a('s')`
if (!val->isVar())
val->type = val->type->generalize(ctx->typecheckLevel - 1);
// todo)) if (in(ctx->cache->globals, lhs)) {
}
if ((!assign->rhs || assign->rhs->isDone()) && realize(assign->lhs->type)) {
assign->setDone();
}
// Clean up seen tags if shadowing a name // Clean up seen tags if shadowing a name
ctx->getBase()->seenGlobalIdentifiers.erase(e->value); ctx->getBase()->seenGlobalIdentifiers.erase(e->value);
// Register all toplevel variables as global in JIT mode // Register all toplevel variables as global in JIT mode
bool isGlobal = (ctx->cache->isJit && val->isGlobal() && !val->isGeneric()) || bool isGlobal = (ctx->cache->isJit && val->isGlobal() && !val->isGeneric()) ||
(canonical == VAR_ARGV); (canonical == VAR_ARGV);
if (isGlobal && !val->isGeneric()) if (isGlobal && val->isVar())
ctx->cache->addGlobal(canonical); ctx->cache->addGlobal(canonical);
return assign; return assign;

View File

@ -19,7 +19,7 @@ void TypecheckVisitor::visit(NoneExpr *expr) {
if (realize(expr->type)) { if (realize(expr->type)) {
// Realize the appropriate `Optional.__new__` for the translation stage // Realize the appropriate `Optional.__new__` for the translation stage
auto cls = expr->type->getClass(); auto cls = expr->type->getClass();
auto f = ctx->forceFind(TYPE_OPTIONAL ".__new__:0")->type; auto f = ctx->forceFind(TYPE_OPTIONAL ".__new__")->type;
auto t = realize(ctx->instantiate(f, cls)->getFunc()); auto t = realize(ctx->instantiate(f, cls)->getFunc());
expr->setDone(); expr->setDone();
} }

View File

@ -337,7 +337,7 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
} }
ExprPtr e = N<TupleExpr>(extra); ExprPtr e = N<TupleExpr>(extra);
e->setAttr(ExprAttr::StarArgument); e->setAttr(ExprAttr::StarArgument);
if (!expr->expr->isId("hasattr:0")) if (!expr->expr->isId("hasattr"))
e = transform(e); e = transform(e);
if (partial) { if (partial) {
part.args = e; part.args = e;
@ -570,7 +570,8 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
if (!expr->expr->getId()) if (!expr->expr->getId())
return {false, nullptr}; return {false, nullptr};
auto val = expr->expr->getId()->value; auto val = expr->expr->getId()->value;
if (val == "tuple") { if (val == "tuple" && expr->args.size() == 1 &&
CAST(expr->args.front().value, GeneratorExpr)) {
return {true, transformTupleGenerator(expr)}; return {true, transformTupleGenerator(expr)};
} else if (val == "std.collections.namedtuple") { } else if (val == "std.collections.namedtuple") {
return {true, transformNamedTuple(expr)}; return {true, transformNamedTuple(expr)};
@ -578,11 +579,11 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
return {true, transformFunctoolsPartial(expr)}; return {true, transformFunctoolsPartial(expr)};
} else if (val == "superf") { } else if (val == "superf") {
return {true, transformSuperF(expr)}; return {true, transformSuperF(expr)};
} else if (val == "super:0") { } else if (val == "super") {
return {true, transformSuper()}; return {true, transformSuper()};
} else if (val == "__ptr__") { } else if (val == "__ptr__") {
return {true, transformPtr(expr)}; return {true, transformPtr(expr)};
} else if (val == "__array__.__new__:0") { } else if (val == "__array__.__new__") {
return {true, transformArray(expr)}; return {true, transformArray(expr)};
} else if (val == "isinstance") { } else if (val == "isinstance") {
return {true, transformIsInstance(expr)}; return {true, transformIsInstance(expr)};
@ -594,7 +595,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
return {true, transformGetAttr(expr)}; return {true, transformGetAttr(expr)};
} else if (val == "setattr") { } else if (val == "setattr") {
return {true, transformSetAttr(expr)}; return {true, transformSetAttr(expr)};
} else if (val == "type.__new__:0") { } else if (val == "type.__new__") {
return {true, transformTypeFn(expr)}; return {true, transformTypeFn(expr)};
} else if (val == "compile_error") { } else if (val == "compile_error") {
return {true, transformCompileError(expr)}; return {true, transformCompileError(expr)};
@ -605,6 +606,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
} else if (val == "std.internal.static.static_print") { } else if (val == "std.internal.static.static_print") {
return {false, transformStaticPrintFn(expr)}; return {false, transformStaticPrintFn(expr)};
} else if (val == "__has_rtti__") { } else if (val == "__has_rtti__") {
LOG("- rtti has {}", getSrcInfo());
return {true, transformHasRttiFn(expr)}; return {true, transformHasRttiFn(expr)};
} else { } else {
return transformInternalStaticFn(expr); return transformInternalStaticFn(expr);
@ -627,12 +629,12 @@ ExprPtr TypecheckVisitor::transformTupleGenerator(CallExpr *expr) {
ctx->enterConditionalBlock(); ctx->enterConditionalBlock();
ctx->getBase()->loops.push_back({"", ctx->scope.blocks, {}}); ctx->getBase()->loops.push_back({"", ctx->scope.blocks, {}});
if (auto i = var->getId()) { if (auto i = var->getId()) {
ctx->addVar(i->value, ctx->generateCanonicalName(i->value), var->getSrcInfo()); ctx->addVar(i->value, ctx->generateCanonicalName(i->value), ctx->getUnbound());
var = transform(var); var = transform(var);
ex = transform(ex); ex = transform(ex);
} else { } else {
std::string varName = ctx->cache->getTemporaryVar("for"); std::string varName = ctx->cache->getTemporaryVar("for");
ctx->addVar(varName, varName, var->getSrcInfo()); ctx->addVar(varName, varName, ctx->getUnbound());
var = N<IdExpr>(varName); var = N<IdExpr>(varName);
auto head = transform(N<AssignStmt>(clone(g->loops[0].vars), clone(var))); auto head = transform(N<AssignStmt>(clone(g->loops[0].vars), clone(var)));
ex = N<StmtExpr>(head, transform(ex)); ex = N<StmtExpr>(head, transform(ex));
@ -640,7 +642,7 @@ ExprPtr TypecheckVisitor::transformTupleGenerator(CallExpr *expr) {
ctx->leaveConditionalBlock(); ctx->leaveConditionalBlock();
// Dominate loop variables // Dominate loop variables
for (auto &var : ctx->getBase()->getLoop()->seenVars) for (auto &var : ctx->getBase()->getLoop()->seenVars)
ctx->findDominatingBinding(var, this); findDominatingBinding(var, ctx.get());
ctx->getBase()->loops.pop_back(); ctx->getBase()->loops.pop_back();
return N<GeneratorExpr>( return N<GeneratorExpr>(
GeneratorExpr::TupleGenerator, ex, GeneratorExpr::TupleGenerator, ex,
@ -705,7 +707,7 @@ ExprPtr TypecheckVisitor::transformFunctoolsPartial(CallExpr *expr) {
/// cls.foo()``` /// cls.foo()```
/// prints "foo 1" followed by "foo 2" /// prints "foo 1" followed by "foo 2"
ExprPtr TypecheckVisitor::transformSuperF(CallExpr *expr) { ExprPtr TypecheckVisitor::transformSuperF(CallExpr *expr) {
auto func = ctx->getRealizationBase()->type->getFunc(); auto func = ctx->getBase()->type->getFunc();
// Find list of matching superf methods // Find list of matching superf methods
std::vector<types::FuncTypePtr> supers; std::vector<types::FuncTypePtr> supers;
@ -740,9 +742,9 @@ ExprPtr TypecheckVisitor::transformSuperF(CallExpr *expr) {
/// to the first inherited type. /// to the first inherited type.
/// TODO: only an empty super() is currently supported. /// TODO: only an empty super() is currently supported.
ExprPtr TypecheckVisitor::transformSuper() { ExprPtr TypecheckVisitor::transformSuper() {
if (!ctx->getRealizationBase()->type) if (!ctx->getBase()->type)
E(Error::CALL_SUPER_PARENT, getSrcInfo()); E(Error::CALL_SUPER_PARENT, getSrcInfo());
auto funcTyp = ctx->getRealizationBase()->type->getFunc(); auto funcTyp = ctx->getBase()->type->getFunc();
if (!funcTyp || !funcTyp->ast->hasAttr(Attr::Method)) if (!funcTyp || !funcTyp->ast->hasAttr(Attr::Method))
E(Error::CALL_SUPER_PARENT, getSrcInfo()); E(Error::CALL_SUPER_PARENT, getSrcInfo());
if (funcTyp->getArgTypes().empty()) if (funcTyp->getArgTypes().empty())
@ -791,7 +793,7 @@ ExprPtr TypecheckVisitor::transformSuper() {
ExprPtr TypecheckVisitor::transformPtr(CallExpr *expr) { ExprPtr TypecheckVisitor::transformPtr(CallExpr *expr) {
auto id = expr->args[0].value->getId(); auto id = expr->args[0].value->getId();
auto val = id ? ctx->find(id->value) : nullptr; auto val = id ? ctx->find(id->value) : nullptr;
if (!val || val->kind != TypecheckItem::Var) if (!val || !val->isVar())
E(Error::CALL_PTR_VAR, expr->args[0]); E(Error::CALL_PTR_VAR, expr->args[0]);
transform(expr->args[0].value); transform(expr->args[0].value);
@ -859,12 +861,12 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
if (tag == -1) if (tag == -1)
return transform(N<BoolExpr>(false)); return transform(N<BoolExpr>(false));
return transform(N<BinaryExpr>( return transform(N<BinaryExpr>(
N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"), expr->args[0].value), N<CallExpr>(N<IdExpr>("__internal__.union_get_tag"), expr->args[0].value),
"==", N<IntExpr>(tag))); "==", N<IntExpr>(tag)));
} else if (typExpr->type->is("pyobj") && !typExpr->isType()) { } else if (typExpr->type->is("pyobj") && !typExpr->isType()) {
if (typ->is("pyobj")) { if (typ->is("pyobj")) {
expr->staticValue.type = StaticValue::NOT_STATIC; expr->staticValue.type = StaticValue::NOT_STATIC;
return transform(N<CallExpr>(N<IdExpr>("std.internal.python._isinstance:0"), return transform(N<CallExpr>(N<IdExpr>("std.internal.python._isinstance"),
expr->args[0].value, expr->args[1].value)); expr->args[0].value, expr->args[1].value));
} else { } else {
return transform(N<BoolExpr>(false)); return transform(N<BoolExpr>(false));
@ -923,7 +925,7 @@ ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) {
->evaluate() ->evaluate()
.getString(); .getString();
std::vector<std::pair<std::string, TypePtr>> args{{"", typ}}; std::vector<std::pair<std::string, TypePtr>> args{{"", typ}};
if (expr->expr->isId("hasattr:0")) { if (expr->expr->isId("hasattr")) {
// Case: the first hasattr overload allows passing argument types via *args // Case: the first hasattr overload allows passing argument types via *args
auto tup = expr->args[1].value->getTuple(); auto tup = expr->args[1].value->getTuple();
seqassert(tup, "not a tuple"); seqassert(tup, "not a tuple");
@ -933,7 +935,6 @@ ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) {
return nullptr; return nullptr;
args.emplace_back("", a->getType()); args.emplace_back("", a->getType());
} }
auto kwtup = expr->args[2].value->origExpr->getCall();
seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(), seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(),
"expected call: {}", expr->args[2].value->origExpr); "expected call: {}", expr->args[2].value->origExpr);
auto kw = expr->args[2].value->origExpr->getCall(); auto kw = expr->args[2].value->origExpr->getCall();
@ -1075,7 +1076,7 @@ ExprPtr TypecheckVisitor::transformHasRttiFn(CallExpr *expr) {
return nullptr; return nullptr;
auto c = in(ctx->cache->classes, t->name); auto c = in(ctx->cache->classes, t->name);
seqassert(c, "bad class {}", t->name); seqassert(c, "bad class {}", t->name);
return transform(N<BoolExpr>(const_cast<Cache::Class *>(c)->rtti)); return transform(N<BoolExpr>(c->hasRTTI()));
} }
// Transform internal.static calls // Transform internal.static calls
@ -1269,33 +1270,37 @@ std::vector<ClassTypePtr> TypecheckVisitor::getSuperTypes(const ClassTypePtr &cl
/// Find all generics on which a function depends on and add them to the current /// Find all generics on which a function depends on and add them to the current
/// context. /// context.
void TypecheckVisitor::addFunctionGenerics(const FuncType *t) { void TypecheckVisitor::addFunctionGenerics(const FuncType *t) {
auto addT = [&](const std::string &name, const types::TypePtr &type) {
TypeContext::Item v = nullptr;
if (auto c = type->getClass()) {
v = ctx->addType(ctx->cache->rev(name), name, c);
} else {
v = ctx->addType(ctx->cache->rev(name), name, type);
v->generic = true;
}
// LOG(" <=> {} :: {} ({}) / {}", type->debugString(2), ctx->cache->rev(name), name,
// v->isType());
ctx->add(name, v);
};
for (auto parent = t->funcParent; parent;) { for (auto parent = t->funcParent; parent;) {
if (auto f = parent->getFunc()) { if (auto f = parent->getFunc()) {
// Add parent function generics // Add parent function generics
for (auto &g : f->funcGenerics) { for (auto &g : f->funcGenerics)
// LOG(" -> {} := {}", g.name, g.type->debugString(true)); addT(g.name, g.type);
ctx->addType(g.name, g.name, getSrcInfo(), g.type);
}
parent = f->funcParent; parent = f->funcParent;
} else { } else {
// Add parent class generics // Add parent class generics
seqassert(parent->getClass(), "not a class: {}", parent); seqassert(parent->getClass(), "not a class: {}", parent);
for (auto &g : parent->getClass()->generics) { for (auto &g : parent->getClass()->generics)
// LOG(" => {} := {}", g.name, g.type->debugString(true)); addT(g.name, g.type);
ctx->addType(g.name, g.name, getSrcInfo(), g.type); for (auto &g : parent->getClass()->hiddenGenerics)
} addT(g.name, g.type);
for (auto &g : parent->getClass()->hiddenGenerics) {
// LOG(" :> {} := {}", g.name, g.type->debugString(true));
ctx->addType(g.name, g.name, getSrcInfo(), g.type);
}
break; break;
} }
} }
// Add function generics // Add function generics
for (auto &g : t->funcGenerics) { for (auto &g : t->funcGenerics)
// LOG(" >> {} := {}", g.name, g.type->debugString(true)); addT(g.name, g.type);
ctx->addType(g.name, g.name, getSrcInfo(), g.type);
}
} }
/// Generate a partial type `Partial.N<mask>` for a given function. /// Generate a partial type `Partial.N<mask>` for a given function.

View File

@ -25,13 +25,14 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
std::vector<Param> &argsToParse = stmt->args; std::vector<Param> &argsToParse = stmt->args;
// classItem will be added later when the scope is different // classItem will be added later when the scope is different
auto classItem = std::make_shared<TypecheckItem>(TypecheckItem::Type, "", "", auto classItem = std::make_shared<TypecheckItem>("", "", ctx->getModule(), nullptr,
ctx->getModule(), ctx->scope.blocks); ctx->scope.blocks);
classItem->setSrcInfo(stmt->getSrcInfo()); classItem->setSrcInfo(stmt->getSrcInfo());
types::ClassTypePtr typ = nullptr; types::ClassTypePtr typ = nullptr;
if (!stmt->attributes.has(Attr::Extend)) { if (!stmt->attributes.has(Attr::Extend)) {
classItem->canonicalName = canonicalName = classItem->canonicalName = canonicalName =
ctx->generateCanonicalName(name, !stmt->attributes.has(Attr::Internal)); ctx->generateCanonicalName(name, !stmt->attributes.has(Attr::Internal),
/* noSuffix*/ stmt->attributes.has(Attr::Internal));
typ = Type::makeType(ctx->cache, canonicalName, name, stmt->isRecord())->getClass(); typ = Type::makeType(ctx->cache, canonicalName, name, stmt->isRecord())->getClass();
if (stmt->isRecord() && stmt->hasAttr("__notuple__")) if (stmt->isRecord() && stmt->hasAttr("__notuple__"))
@ -50,6 +51,9 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
if (!stmt->attributes.has(Attr::Tuple)) { if (!stmt->attributes.has(Attr::Tuple)) {
ctx->add(name, classItem); ctx->add(name, classItem);
ctx->addAlwaysVisible(classItem); ctx->addAlwaysVisible(classItem);
// LOG("added typ/{}: {}",
// classItem->isVar() ? "v" : (classItem->isFunc() ? "f" : "t"),
// classItem->canonicalName);
} }
} else { } else {
// Find the canonical name and AST of the class that is to be extended // Find the canonical name and AST of the class that is to be extended
@ -59,6 +63,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
if (!val || !val->isType()) if (!val || !val->isType())
E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), name); E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), name);
canonicalName = val->canonicalName; canonicalName = val->canonicalName;
typ = val->type->getClass();
const auto &astIter = ctx->cache->classes.find(canonicalName); const auto &astIter = ctx->cache->classes.find(canonicalName);
if (astIter == ctx->cache->classes.end()) { if (astIter == ctx->cache->classes.end()) {
E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), name); E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), name);
@ -74,36 +79,38 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
try { try {
// Add the class base // Add the class base
TypeContext::BaseGuard br(ctx.get(), canonicalName); TypeContext::BaseGuard br(ctx.get(), canonicalName);
ctx->getBase()->type = typ;
// Parse and add class generics // Parse and add class generics
std::vector<Param> args; std::vector<Param> args;
std::pair<StmtPtr, FunctionStmt *> autoDeducedInit{nullptr, nullptr}; std::pair<StmtPtr, FunctionStmt *> autoDeducedInit{nullptr, nullptr};
if (stmt->attributes.has("deduce") && args.empty()) { if (stmt->attributes.has("deduce") && args.empty()) {
// todo)) do this
// Auto-detect generics and fields // Auto-detect generics and fields
autoDeducedInit = autoDeduceMembers(stmt, args); // autoDeducedInit = autoDeduceMembers(stmt, args);
} else if (stmt->attributes.has(Attr::Extend)) {
for (auto &a : argsToParse) {
if (a.status != Param::Generic)
continue;
auto val = ctx->forceFind(a.name);
auto generic = ctx->instantiate(val->type);
generic->getUnbound()->id = val->type->getLink()->id;
ctx->addType(ctx->cache->rev(val->canonicalName), val->canonicalName, generic)
->generic = true;
}
} else { } else {
// Add all generics before parent classes, fields and methods // Add all generics before parent classes, fields and methods
for (auto &a : argsToParse) { for (auto &a : argsToParse) {
if (a.status != Param::Generic) if (a.status != Param::Generic)
continue; continue;
std::string genName, varName;
if (stmt->attributes.has(Attr::Extend))
varName = a.name, genName = ctx->cache->rev(a.name);
else
varName = ctx->generateCanonicalName(a.name), genName = a.name;
auto varName = ctx->generateCanonicalName(a.name), genName = a.name;
auto generic = ctx->getUnbound(); auto generic = ctx->getUnbound();
auto typId = generic->id; auto typId = generic->id;
generic->getLink()->genericName = ctx->cache->rev(a.name); generic->getLink()->genericName = genName;
if (a.defaultValue) { if (a.defaultValue) {
auto defType = transformType(clone(a.defaultValue)); auto defType = transformType(clone(a.defaultValue));
if (a.status == Param::Generic) { generic->defaultType = defType->type;
generic->defaultType = defType->type;
} else {
// Hidden generics can be outright replaced (e.g., `T=int`).
// Unify them immediately.
unify(defType->type, generic);
}
} }
if (auto ti = CAST(a.type, InstantiateExpr)) { if (auto ti = CAST(a.type, InstantiateExpr)) {
// Parse TraitVar // Parse TraitVar
@ -114,17 +121,16 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
else else
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l); generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
} }
if (auto st = getStaticGeneric(a.type.get())) { if (auto st = getStaticGeneric(a.type.get())) {
generic->isStatic = true; generic->isStatic = st;
auto val = ctx->addVar(genName, varName, a.type->getSrcInfo(), generic); auto val = ctx->addVar(genName, varName, generic);
val->generic = true; val->generic = true;
val->staticType = st; val->staticType = st;
} else { } else {
ctx->addType(genName, varName, a.type->getSrcInfo(), generic)->generic = true; ctx->addType(genName, varName, generic)->generic = true;
} }
ClassType::Generic g{a.name, ctx->cache->rev(a.name), ClassType::Generic g{varName, genName, generic->generalize(ctx->typecheckLevel),
generic->generalize(ctx->typecheckLevel), typId}; typId};
if (a.status == Param::Generic) { if (a.status == Param::Generic) {
typ->generics.push_back(g); typ->generics.push_back(g);
} else { } else {
@ -132,9 +138,6 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
} }
args.emplace_back(varName, transformType(clone(a.type), false), args.emplace_back(varName, transformType(clone(a.type), false),
transformType(clone(a.defaultValue), false), a.status); transformType(clone(a.defaultValue), false), a.status);
if (!stmt->attributes.has(Attr::Extend) && a.status == Param::Normal)
ctx->cache->classes[canonicalName].fields.push_back(
Cache::Class::ClassField{varName, nullptr, canonicalName});
} }
} }
@ -157,12 +160,12 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
std::vector<ClassStmt *> staticBaseASTs, baseASTs; std::vector<ClassStmt *> staticBaseASTs, baseASTs;
if (!stmt->attributes.has(Attr::Extend)) { if (!stmt->attributes.has(Attr::Extend)) {
staticBaseASTs = parseBaseClasses(stmt->staticBaseClasses, args, stmt->attributes, staticBaseASTs = parseBaseClasses(stmt->staticBaseClasses, args, stmt->attributes,
canonicalName); canonicalName, nullptr, typ);
if (ctx->cache->isJit && !stmt->baseClasses.empty()) if (ctx->cache->isJit && !stmt->baseClasses.empty())
E(Error::CUSTOM, stmt->baseClasses[0], E(Error::CUSTOM, stmt->baseClasses[0],
"inheritance is not yet supported in JIT mode"); "inheritance is not yet supported in JIT mode");
parseBaseClasses(stmt->baseClasses, args, stmt->attributes, canonicalName, parseBaseClasses(stmt->baseClasses, args, stmt->attributes, canonicalName,
transformedTypeAst); transformedTypeAst, typ);
} }
// A ClassStmt will be separated into class variable assignments, method-free // A ClassStmt will be separated into class variable assignments, method-free
@ -172,23 +175,25 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
// Collect class fields // Collect class fields
for (auto &a : argsToParse) { for (auto &a : argsToParse) {
if (a.status == Param::Normal) { if (a.status == Param::Normal) {
if (!ClassStmt::isClassVar(a)) { if (ClassStmt::isClassVar(a)) {
args.emplace_back(a.name, transformType(clone(a.type), false),
transform(clone(a.defaultValue), true));
if (!stmt->attributes.has(Attr::Extend)) {
ctx->cache->classes[canonicalName].fields.push_back(
Cache::Class::ClassField{a.name, nullptr, canonicalName});
}
} else if (!stmt->attributes.has(Attr::Extend)) {
// Handle class variables. Transform them later to allow self-references // Handle class variables. Transform them later to allow self-references
auto name = format("{}.{}", canonicalName, a.name); auto name = format("{}.{}", canonicalName, a.name);
prependStmts->push_back(N<AssignStmt>(N<IdExpr>(name), nullptr, nullptr)); // prependStmts->push_back(N<AssignStmt>(N<IdExpr>(name), nullptr, nullptr));
ctx->cache->addGlobal(name); // ctx->cache->addGlobal(name);
auto assign = N<AssignStmt>(N<IdExpr>(name), a.defaultValue, auto assign = N<AssignStmt>(N<IdExpr>(name), a.defaultValue,
a.type ? a.type->getIndex()->index : nullptr); a.type ? a.type->getIndex()->index : nullptr);
assign->setUpdate();
varStmts.push_back(assign); varStmts.push_back(assign);
ctx->cache->classes[canonicalName].classVars[a.name] = name; ctx->cache->classes[canonicalName].classVars[a.name] = name;
} else if (!stmt->attributes.has(Attr::Extend)) {
std::string varName = a.name;
// stmt->attributes.has(Attr::Extend)
// ? a.name
// : ctx->generateCanonicalName(a.name);
args.emplace_back(varName, transformType(clone(a.type), false),
transform(clone(a.defaultValue), true));
LOG(" -> {}", varName);
ctx->cache->classes[canonicalName].fields.push_back(Cache::Class::ClassField{
varName, args.back().type->getType(), canonicalName});
} }
} }
} }
@ -196,48 +201,28 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
// ASTs for member arguments to be used for populating magic methods // ASTs for member arguments to be used for populating magic methods
std::vector<Param> memberArgs; std::vector<Param> memberArgs;
for (auto &a : args) { for (auto &a : args) {
if (a.status == Param::Normal) if (a.status == Param::Normal) {
memberArgs.push_back(a.clone()); memberArgs.push_back(a.clone());
}
} }
// Handle class members // Handle class members
ctx->typecheckLevel++; // to avoid unifying generics early if (!stmt->attributes.has(Attr::Extend)) {
auto &fields = ctx->cache->classes[stmt->name].fields; ctx->typecheckLevel++; // to avoid unifying generics early
for (auto ai = 0, aj = 0; ai < stmt->args.size(); ai++) auto &fields = ctx->cache->classes[canonicalName].fields;
if (stmt->args[ai].status == Param::Normal) { for (auto ai = 0, aj = 0; ai < stmt->args.size(); ai++)
fields[aj].type = transformType(stmt->args[ai].type) if (stmt->args[ai].status == Param::Normal &&
->getType() !ClassStmt::isClassVar(stmt->args[ai])) {
->generalize(ctx->typecheckLevel - 1); fields[aj].type = transformType(stmt->args[ai].type)
fields[aj].type->setSrcInfo(stmt->args[ai].type->getSrcInfo()); ->getType()
if (stmt->isRecord()) ->generalize(ctx->typecheckLevel - 1);
typ->getRecord()->args.push_back(fields[aj].type); fields[aj].type->setSrcInfo(stmt->args[ai].type->getSrcInfo());
aj++; if (stmt->isRecord())
} typ->getRecord()->args.push_back(fields[aj].type);
ctx->typecheckLevel--; aj++;
// Handle MRO
for (auto &m : ctx->cache->classes[stmt->name].mro) {
m = transformType(m);
}
// Generalize generics and remove them from the context
for (const auto &g : args)
if (g.status != Param::Normal) {
auto generic = ctx->forceFind(g.name)->type;
if (g.status == Param::Generic) {
// Generalize generics. Hidden generics are linked to the class generics so
// ignore them
seqassert(generic && generic->getLink() &&
generic->getLink()->kind != types::LinkType::Link,
"generic has been unified");
generic->getLink()->kind = LinkType::Generic;
} }
ctx->remove(g.name); ctx->typecheckLevel--;
} }
// Debug information
LOG_REALIZE("[class] {} -> {}", stmt->name, typ);
for (auto &m : ctx->cache->classes[stmt->name].fields)
LOG_REALIZE(" - member: {}: {}", m.name, m.type);
// Parse class members (arguments) and methods // Parse class members (arguments) and methods
if (!stmt->attributes.has(Attr::Extend)) { if (!stmt->attributes.has(Attr::Extend)) {
@ -246,15 +231,23 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
// Ensure that class binding does not shadow anything. // Ensure that class binding does not shadow anything.
// Class bindings cannot be dominated either // Class bindings cannot be dominated either
auto v = ctx->find(name); auto v = ctx->find(name);
if (v && v->noShadow) if (v && !v->canShadow)
E(Error::CLASS_INVALID_BIND, stmt, name); E(Error::CLASS_INVALID_BIND, stmt, name);
ctx->add(name, classItem); ctx->add(name, classItem);
ctx->addAlwaysVisible(classItem); ctx->addAlwaysVisible(classItem);
// LOG("added typ/{}: {}",
// classItem->isVar() ? "v" : (classItem->isFunc() ? "f" : "t"),
// classItem->canonicalName);
} }
// Create a cached AST. // Create a cached AST.
stmt->attributes.module = stmt->attributes.module = ctx->moduleName.status == ImportFile::STDLIB
format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::", ? STDLIB_IMPORT
ctx->moduleName.module); : ctx->moduleName.path;
;
// format(
// "{}{}",
// ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::",
// ctx->moduleName.module);
ctx->cache->classes[canonicalName].ast = ctx->cache->classes[canonicalName].ast =
N<ClassStmt>(canonicalName, args, N<SuiteStmt>(), stmt->attributes); N<ClassStmt>(canonicalName, args, N<SuiteStmt>(), stmt->attributes);
ctx->cache->classes[canonicalName].ast->baseClasses = stmt->baseClasses; ctx->cache->classes[canonicalName].ast->baseClasses = stmt->baseClasses;
@ -263,6 +256,11 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
ctx->cache->classes[canonicalName].ast->validate(); ctx->cache->classes[canonicalName].ast->validate();
ctx->cache->classes[canonicalName].module = ctx->getModule(); ctx->cache->classes[canonicalName].module = ctx->getModule();
// Handle MRO
for (auto &m : ctx->cache->classes[canonicalName].mro) {
m = transformType(m);
}
// Codegen default magic methods // Codegen default magic methods
for (auto &m : stmt->attributes.magics) { for (auto &m : stmt->attributes.magics) {
fnStmts.push_back(transform( fnStmts.push_back(transform(
@ -335,6 +333,28 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
} }
} }
} }
// Generalize generics and remove them from the context
for (const auto &g : args)
if (g.status != Param::Normal) {
auto generic = ctx->forceFind(g.name)->type;
if (g.status == Param::Generic) {
// Generalize generics. Hidden generics are linked to the class generics so
// ignore them
seqassert(generic && generic->getLink() &&
generic->getLink()->kind != types::LinkType::Link,
"generic has been unified");
generic->getLink()->kind = LinkType::Generic;
}
ctx->remove(g.name);
}
// Debug information
LOG("[class] {} -> {:D} / {}", canonicalName, typ,
ctx->cache->classes[canonicalName].fields.size());
for (auto &m : ctx->cache->classes[canonicalName].fields)
LOG(" - member: {}: {:D}", m.name, m.type);
for (auto &m : ctx->cache->classes[canonicalName].methods)
LOG(" - method: {}: {}", m.first, m.second);
} catch (const exc::ParserException &) { } catch (const exc::ParserException &) {
if (!stmt->attributes.has(Attr::Tuple)) if (!stmt->attributes.has(Attr::Tuple))
ctx->remove(name); ctx->remove(name);
@ -348,16 +368,18 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
if (!stmt->attributes.has(Attr::Extend)) { if (!stmt->attributes.has(Attr::Extend)) {
auto c = ctx->cache->classes[canonicalName].ast; auto c = ctx->cache->classes[canonicalName].ast;
seqassert(c, "not a class AST for {}", canonicalName); seqassert(c, "not a class AST for {}", canonicalName);
c->setDone();
clsStmts.push_back(c); clsStmts.push_back(c);
} }
clsStmts.insert(clsStmts.end(), fnStmts.begin(), fnStmts.end()); clsStmts.insert(clsStmts.end(), fnStmts.begin(), fnStmts.end());
for (auto &a : varStmts) { for (auto &a : varStmts) {
// Transform class variables here to allow self-references // Transform class variables here to allow self-references
if (auto assign = a->getAssign()) { transform(a);
transform(assign->rhs); // if (auto assign = a->getAssign()) {
transformType(assign->type); // transform(assign->rhs);
} // transformType(assign->type);
// }
clsStmts.push_back(a); clsStmts.push_back(a);
} }
resultStmt = N<SuiteStmt>(clsStmts); resultStmt = N<SuiteStmt>(clsStmts);
@ -368,9 +390,11 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
/// @param args Class fields that are to be updated with base classes' fields. /// @param args Class fields that are to be updated with base classes' fields.
/// @param typeAst Transformed AST for base class type (e.g., `A[T]`). /// @param typeAst Transformed AST for base class type (e.g., `A[T]`).
/// Only set when dealing with dynamic polymorphism. /// Only set when dealing with dynamic polymorphism.
std::vector<ClassStmt *> TypecheckVisitor::parseBaseClasses( std::vector<ClassStmt *>
std::vector<ExprPtr> &baseClasses, std::vector<Param> &args, const Attr &attr, TypecheckVisitor::parseBaseClasses(std::vector<ExprPtr> &baseClasses,
const std::string &canonicalName, const ExprPtr &typeAst) { std::vector<Param> &args, const Attr &attr,
const std::string &canonicalName,
const ExprPtr &typeAst, types::ClassTypePtr &typ) {
std::vector<ClassStmt *> asts; std::vector<ClassStmt *> asts;
// MAJOR TODO: fix MRO it to work with generic classes (maybe replacements? IDK...) // MAJOR TODO: fix MRO it to work with generic classes (maybe replacements? IDK...)
@ -392,7 +416,7 @@ std::vector<ClassStmt *> TypecheckVisitor::parseBaseClasses(
} }
} }
auto cachedCls = const_cast<Cache::Class *>(in(ctx->cache->classes, name)); Cache::Class *cachedCls = in(ctx->cache->classes, name);
if (!cachedCls) if (!cachedCls)
E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), ctx->cache->rev(name)); E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), ctx->cache->rev(name));
asts.push_back(cachedCls->ast.get()); asts.push_back(cachedCls->ast.get());
@ -418,6 +442,9 @@ std::vector<ClassStmt *> TypecheckVisitor::parseBaseClasses(
nGenerics += a.status == Param::Generic; nGenerics += a.status == Param::Generic;
int si = 0; int si = 0;
for (auto &a : asts.back()->args) { for (auto &a : asts.back()->args) {
if (a.status == Param::Normal)
continue;
if (a.status == Param::Generic) { if (a.status == Param::Generic) {
if (si == subs.size()) if (si == subs.size())
E(Error::GENERICS_MISMATCH, cls, ctx->cache->rev(asts.back()->name), E(Error::GENERICS_MISMATCH, cls, ctx->cache->rev(asts.back()->name),
@ -427,14 +454,40 @@ std::vector<ClassStmt *> TypecheckVisitor::parseBaseClasses(
} else if (a.status == Param::HiddenGeneric) { } else if (a.status == Param::HiddenGeneric) {
args.emplace_back(a); args.emplace_back(a);
} }
if (a.status != Param::Normal) {
if (auto st = getStaticGeneric(a.type.get())) { auto generic = ctx->getUnbound();
auto val = ctx->addVar(a.name, a.name, a.type->getSrcInfo()); auto typId = generic->id;
val->generic = true; generic->getLink()->genericName = ctx->cache->rev(a.name);
val->staticType = st; if (args.back().defaultValue) {
} else { auto defType = transformType(clone(args.back().defaultValue));
ctx->addType(a.name, a.name, a.type->getSrcInfo())->generic = true; // Hidden generics can be outright replaced (e.g., `T=int`).
} // Unify them immediately.
unify(defType->type, generic);
}
if (auto ti = CAST(a.type, InstantiateExpr)) {
// Parse TraitVar
seqassert(ti->typeExpr->isId(TYPE_TYPEVAR), "not a TypeVar instantiation");
auto l = transformType(ti->typeParams[0])->type;
if (l->getLink() && l->getLink()->trait)
generic->getLink()->trait = l->getLink()->trait;
else
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
}
if (auto st = getStaticGeneric(a.type.get())) {
generic->isStatic = st;
auto val = ctx->addVar(a.name, a.name, generic);
val->generic = true;
val->staticType = st;
} else {
ctx->addType(a.name, a.name, generic)->generic = true;
}
ClassType::Generic g{a.name, a.name, generic->generalize(ctx->typecheckLevel),
typId};
if (a.status == Param::Generic) {
typ->generics.push_back(g);
} else {
typ->hiddenGenerics.push_back(g);
} }
} }
if (si != subs.size()) if (si != subs.size())
@ -455,9 +508,10 @@ std::vector<ClassStmt *> TypecheckVisitor::parseBaseClasses(
seqassert(ctx->cache->classes[ast->name].fields[ai].name == a.name, seqassert(ctx->cache->classes[ast->name].fields[ai].name == a.name,
"bad class fields: {} vs {}", "bad class fields: {} vs {}",
ctx->cache->classes[ast->name].fields[ai].name, a.name); ctx->cache->classes[ast->name].fields[ai].name, a.name);
args.emplace_back(name, a.type, a.defaultValue); args.emplace_back(name, transformType(a.type), transform(a.defaultValue));
ctx->cache->classes[canonicalName].fields.push_back(Cache::Class::ClassField{ ctx->cache->classes[canonicalName].fields.push_back(Cache::Class::ClassField{
name, nullptr, ctx->cache->classes[ast->name].fields[ai].baseClass}); name, args.back().type->getType(),
ctx->cache->classes[ast->name].fields[ai].baseClass});
ai++; ai++;
} }
} }
@ -495,27 +549,29 @@ TypecheckVisitor::autoDeduceMembers(ClassStmt *stmt, std::vector<Param> &args) {
for (const auto &sp : getClassMethods(stmt->suite)) for (const auto &sp : getClassMethods(stmt->suite))
if (sp && sp->getFunction()) { if (sp && sp->getFunction()) {
auto f = sp->getFunction(); auto f = sp->getFunction();
if (f->name == "__init__" && !f->args.empty() && f->args[0].name == "self") { // todo)) do this
// Set up deducedMembers that will be populated during AssignStmt evaluation // if (f->name == "__init__" && !f->args.empty() && f->args[0].name == "self") {
ctx->getBase()->deducedMembers = std::make_shared<std::vector<std::string>>(); // // Set up deducedMembers that will be populated during AssignStmt evaluation
auto transformed = transform(sp); // ctx->getBase()->deducedMembers =
transformed->getFunction()->attributes.set(Attr::RealizeWithoutSelf); // std::make_shared<std::vector<std::string>>(); auto transformed =
ctx->cache->functions[transformed->getFunction()->name].ast->attributes.set( // transform(sp);
Attr::RealizeWithoutSelf); // transformed->getFunction()->attributes.set(Attr::RealizeWithoutSelf);
int i = 0; // ctx->cache->functions[transformed->getFunction()->name].ast->attributes.set(
// Once done, add arguments // Attr::RealizeWithoutSelf);
for (auto &m : *(ctx->getBase()->deducedMembers)) { // int i = 0;
auto varName = ctx->generateCanonicalName(format("T{}", ++i)); // // Once done, add arguments
auto memberName = ctx->cache->rev(varName); // for (auto &m : *(ctx->getBase()->deducedMembers)) {
ctx->addType(memberName, varName, stmt->getSrcInfo())->generic = true; // auto varName = ctx->generateCanonicalName(format("T{}", ++i));
args.emplace_back(varName, N<IdExpr>("type"), nullptr, Param::Generic); // auto memberName = ctx->cache->rev(varName);
args.emplace_back(m, N<IdExpr>(varName)); // ctx->addType(memberName, varName, stmt->getSrcInfo())->generic = true;
ctx->cache->classes[stmt->name].fields.push_back( // args.emplace_back(varName, N<IdExpr>("type"), nullptr, Param::Generic);
Cache::Class::ClassField{m, nullptr, stmt->name}); // args.emplace_back(m, N<IdExpr>(varName));
} // ctx->cache->classes[canonicalName].fields.push_back(
ctx->getBase()->deducedMembers = nullptr; // Cache::Class::ClassField{m, nullptr, canonicalName});
return {transformed, f}; // }
} // ctx->getBase()->deducedMembers = nullptr;
// return {transformed, f};
// }
} }
return {nullptr, nullptr}; return {nullptr, nullptr};
} }
@ -594,8 +650,9 @@ StmtPtr TypecheckVisitor::codegenMagic(const std::string &op, const ExprPtr &typ
attr.set("autogenerated"); attr.set("autogenerated");
std::vector<Param> args; std::vector<Param> args;
args.reserve(allArgs.size());
for (auto &a : allArgs) for (auto &a : allArgs)
args.push_back(a); args.push_back(a.clone());
if (op == "new") { if (op == "new") {
ret = typExpr->clone(); ret = typExpr->clone();
@ -621,7 +678,7 @@ StmtPtr TypecheckVisitor::codegenMagic(const std::string &op, const ExprPtr &typ
a.defaultValue ? clone(a.defaultValue) a.defaultValue ? clone(a.defaultValue)
: N<CallExpr>(clone(a.type))); : N<CallExpr>(clone(a.type)));
} }
} else if (op == "raw") { } else if (op == "raw" || op == "dict") {
// Classes: def __raw__(self: T) // Classes: def __raw__(self: T)
fargs.emplace_back("self", typExpr->clone()); fargs.emplace_back("self", typExpr->clone());
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self")))); stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self"))));
@ -651,7 +708,7 @@ StmtPtr TypecheckVisitor::codegenMagic(const std::string &op, const ExprPtr &typ
fargs.emplace_back("obj", typExpr->clone()); fargs.emplace_back("obj", typExpr->clone());
ret = I("bool"); ret = I("bool");
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self"), I("obj")))); stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self"), I("obj"))));
} else if (op == "hash") { } else if (op == "hash" || op == "len") {
// def __hash__(self: T) -> int // def __hash__(self: T) -> int
fargs.emplace_back("self", typExpr->clone()); fargs.emplace_back("self", typExpr->clone());
ret = I("int"); ret = I("int");
@ -661,26 +718,16 @@ StmtPtr TypecheckVisitor::codegenMagic(const std::string &op, const ExprPtr &typ
fargs.emplace_back("self", typExpr->clone()); fargs.emplace_back("self", typExpr->clone());
fargs.emplace_back("dest", N<IndexExpr>(I("Ptr"), I("byte"))); fargs.emplace_back("dest", N<IndexExpr>(I("Ptr"), I("byte")));
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self"), I("dest")))); stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self"), I("dest"))));
} else if (op == "unpickle") { } else if (op == "unpickle" || op == "from_py") {
// def __unpickle__(src: Ptr[byte]) -> T // def __unpickle__(src: Ptr[byte]) -> T
fargs.emplace_back("src", N<IndexExpr>(I("Ptr"), I("byte"))); fargs.emplace_back("src", N<IndexExpr>(I("Ptr"), I("byte")));
ret = typExpr->clone(); ret = typExpr->clone();
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("src"), typExpr->clone()))); stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("src"), typExpr->clone())));
} else if (op == "len") {
// def __len__(self: T) -> int
fargs.emplace_back("self", typExpr->clone());
ret = I("int");
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self"))));
} else if (op == "to_py") { } else if (op == "to_py") {
// def __to_py__(self: T) -> Ptr[byte] // def __to_py__(self: T) -> Ptr[byte]
fargs.emplace_back("self", typExpr->clone()); fargs.emplace_back("self", typExpr->clone());
ret = N<IndexExpr>(I("Ptr"), I("byte")); ret = N<IndexExpr>(I("Ptr"), I("byte"));
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self")))); stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self"))));
} else if (op == "from_py") {
// def __from_py__(src: Ptr[byte]) -> T
fargs.emplace_back("src", N<IndexExpr>(I("Ptr"), I("byte")));
ret = typExpr->clone();
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("src"), typExpr->clone())));
} else if (op == "to_gpu") { } else if (op == "to_gpu") {
// def __to_gpu__(self: T, cache) -> T // def __to_gpu__(self: T, cache) -> T
fargs.emplace_back("self", typExpr->clone()); fargs.emplace_back("self", typExpr->clone());
@ -702,10 +749,6 @@ StmtPtr TypecheckVisitor::codegenMagic(const std::string &op, const ExprPtr &typ
fargs.emplace_back("self", typExpr->clone()); fargs.emplace_back("self", typExpr->clone());
ret = I("str"); ret = I("str");
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self")))); stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self"))));
} else if (op == "dict") {
// def __dict__(self: T)
fargs.emplace_back("self", typExpr->clone());
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), I("self"))));
} else if (op == "add") { } else if (op == "add") {
// def __add__(self, obj) // def __add__(self, obj)
fargs.emplace_back("self", typExpr->clone()); fargs.emplace_back("self", typExpr->clone());

View File

@ -17,84 +17,84 @@ using namespace codon::error;
namespace codon::ast { namespace codon::ast {
TypecheckItem::TypecheckItem(TypecheckItem::Kind kind, std::string baseName, TypecheckItem::TypecheckItem(std::string canonicalName, std::string baseName,
std::string canonicalName, std::string moduleName, std::string moduleName, types::TypePtr type,
std::vector<int> scope, std::string importPath, std::vector<int> scope)
types::TypePtr type) : canonicalName(std::move(canonicalName)), baseName(std::move(baseName)),
: kind(kind), baseName(std::move(baseName)), moduleName(std::move(moduleName)), type(std::move(type)),
canonicalName(std::move(canonicalName)), moduleName(std::move(moduleName)), scope(std::move(scope)) {}
scope(std::move(scope)), importPath(std::move(importPath)),
type(std::move(type)) {}
TypeContext::TypeContext(Cache *cache, std::string filename) TypeContext::TypeContext(Cache *cache, std::string filename)
: Context<TypecheckItem>(std::move(filename)), cache(cache) { : Context<TypecheckItem>(std::move(filename)), cache(cache) {
bases.emplace_back(""); bases.emplace_back();
scope.blocks.emplace_back(scope.counter = 0); scope.blocks.emplace_back(scope.counter = 0);
realizationBases.emplace_back();
pushSrcInfo(cache->generateSrcInfo()); // Always have srcInfo() around pushSrcInfo(cache->generateSrcInfo()); // Always have srcInfo() around
} }
TypeContext::Base::Base(std::string name, Attr *attributes)
: name(std::move(name)), attributes(attributes) {}
void TypeContext::add(const std::string &name, const TypeContext::Item &var) { void TypeContext::add(const std::string &name, const TypeContext::Item &var) {
auto v = find(name); auto v = find(name);
if (v && v->noShadow) if (v && !v->canShadow)
E(Error::ID_INVALID_BIND, getSrcInfo(), name); E(Error::ID_INVALID_BIND, getSrcInfo(), name);
Context<TypecheckItem>::add(name, var); Context<TypecheckItem>::add(name, var);
} }
TypeContext::Item TypeContext::addVar(const std::string &name, TypeContext::Item TypeContext::addVar(const std::string &name,
const std::string &canonicalName, const std::string &canonicalName,
const SrcInfo &srcInfo, const types::TypePtr &type,
const types::TypePtr &type) { const SrcInfo &srcInfo) {
seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name);
auto t = std::make_shared<TypecheckItem>(TypecheckItem::Var, getBaseName(), seqassert(type->getLink(), "bad var");
canonicalName, getModule(), scope.blocks); auto t = std::make_shared<TypecheckItem>(canonicalName, getBaseName(), getModule(),
type, scope.blocks);
t->setSrcInfo(srcInfo); t->setSrcInfo(srcInfo);
t->type = type;
Context<TypecheckItem>::add(name, t); Context<TypecheckItem>::add(name, t);
Context<TypecheckItem>::add(canonicalName, t); addAlwaysVisible(t);
// LOG("added var/{}: {}", t->isVar() ? "v" : (t->isFunc() ? "f" : "t"),
// canonicalName);
return t; return t;
} }
TypeContext::Item TypeContext::addType(const std::string &name, TypeContext::Item TypeContext::addType(const std::string &name,
const std::string &canonicalName, const std::string &canonicalName,
const SrcInfo &srcInfo, const types::TypePtr &type,
const types::TypePtr &type) { const SrcInfo &srcInfo) {
seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name);
auto t = std::make_shared<TypecheckItem>(TypecheckItem::Type, getBaseName(), // seqassert(type->getClass(), "bad type");
canonicalName, getModule(), scope.blocks); auto t = std::make_shared<TypecheckItem>(canonicalName, getBaseName(), getModule(),
type, scope.blocks);
t->setSrcInfo(srcInfo); t->setSrcInfo(srcInfo);
t->type = type;
Context<TypecheckItem>::add(name, t); Context<TypecheckItem>::add(name, t);
Context<TypecheckItem>::add(canonicalName, t); addAlwaysVisible(t);
// LOG("added typ/{}: {}", t->isVar() ? "v" : (t->isFunc() ? "f" : "t"),
// canonicalName);
return t; return t;
} }
TypeContext::Item TypeContext::addFunc(const std::string &name, TypeContext::Item TypeContext::addFunc(const std::string &name,
const std::string &canonicalName, const std::string &canonicalName,
const SrcInfo &srcInfo, const types::TypePtr &type,
const types::TypePtr &type) { const SrcInfo &srcInfo) {
seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name);
auto t = std::make_shared<TypecheckItem>(TypecheckItem::Func, getBaseName(), seqassert(type->getFunc(), "bad func");
canonicalName, getModule(), scope.blocks); auto t = std::make_shared<TypecheckItem>(canonicalName, getBaseName(), getModule(),
type, scope.blocks);
t->setSrcInfo(srcInfo); t->setSrcInfo(srcInfo);
t->type = type;
Context<TypecheckItem>::add(name, t); Context<TypecheckItem>::add(name, t);
Context<TypecheckItem>::add(canonicalName, t); addAlwaysVisible(t);
// LOG("added fun/{}: {}", t->isVar() ? "v" : (t->isFunc() ? "f" : "t"),
// canonicalName);
return t; return t;
} }
TypeContext::Item TypeContext::addAlwaysVisible(const TypeContext::Item &item) { TypeContext::Item TypeContext::addAlwaysVisible(const TypeContext::Item &item) {
auto i = std::make_shared<TypecheckItem>(item->kind, item->baseName, if (!cache->typeCtx->Context<TypecheckItem>::find(item->canonicalName)) {
item->canonicalName, item->moduleName, cache->typeCtx->add(item->canonicalName, item);
std::vector<int>{0}, item->importPath);
auto stdlib = cache->imports[STDLIB_IMPORT].ctx; // Realizations etc.
if (!stdlib->find(i->canonicalName)) { if (!in(cache->reverseIdentifierLookup, item->canonicalName))
stdlib->add(i->canonicalName, i); cache->reverseIdentifierLookup[item->canonicalName] = item->canonicalName;
} }
return i; return item;
} }
TypeContext::Item TypeContext::find(const std::string &name) const { TypeContext::Item TypeContext::find(const std::string &name) const {
@ -106,7 +106,12 @@ TypeContext::Item TypeContext::find(const std::string &name) const {
// Note: the standard library items cannot be dominated. // Note: the standard library items cannot be dominated.
auto stdlib = cache->imports[STDLIB_IMPORT].ctx; auto stdlib = cache->imports[STDLIB_IMPORT].ctx;
if (stdlib.get() != this) if (stdlib.get() != this)
t = stdlib->find(name); t = stdlib->Context<TypecheckItem>::find(name);
// Maybe we are looking for a canonical identifier?
if (!t && cache->typeCtx.get() != this)
t = cache->typeCtx->Context<TypecheckItem>::find(name);
return t; return t;
} }
@ -116,158 +121,6 @@ TypeContext::Item TypeContext::forceFind(const std::string &name) const {
return f; return f;
} }
TypeContext::Item TypeContext::findDominatingBinding(const std::string &name,
TypecheckVisitor *tv) {
auto it = map.find(name);
if (it == map.end()) {
return find(name);
} else if (isCanonicalName(name)) {
return *(it->second.begin());
}
seqassert(!it->second.empty(), "corrupted TypecheckContext ({})", name);
// The item is found. Let's see is it accessible now.
std::string canonicalName;
auto lastGood = it->second.begin();
bool isOutside = (*lastGood)->getBaseName() != getBaseName();
int prefix = int(scope.blocks.size());
// Iterate through all bindings with the given name and find the closest binding that
// dominates the current scope.
for (auto i = it->second.begin(); i != it->second.end(); i++) {
// Find the longest block prefix between the binding and the current scope.
int p = std::min(prefix, int((*i)->scope.size()));
while (p >= 0 && (*i)->scope[p - 1] != scope.blocks[p - 1])
p--;
// We reached the toplevel. Break.
if (p < 0)
break;
// We went outside the function scope. Break.
if (!isOutside && (*i)->getBaseName() != getBaseName())
break;
prefix = p;
lastGood = i;
// The binding completely dominates the current scope. Break.
if ((*i)->scope.size() <= scope.blocks.size() &&
(*i)->scope.back() == scope.blocks[(*i)->scope.size() - 1])
break;
}
seqassert(lastGood != it->second.end(), "corrupted scoping ({})", name);
if (lastGood != it->second.begin() && !(*lastGood)->isVar())
E(Error::CLASS_INVALID_BIND, getSrcInfo(), name);
bool hasUsed = false;
types::TypePtr type = nullptr;
if ((*lastGood)->scope.size() == prefix) {
// The current scope is dominated by a binding. Use that binding.
canonicalName = (*lastGood)->canonicalName;
type = (*lastGood)->type;
} else {
// The current scope is potentially reachable by multiple bindings that are
// not dominated by a common binding. Create such binding in the scope that
// dominates (covers) all of them.
canonicalName = generateCanonicalName(name);
auto item = std::make_shared<TypecheckItem>(
(*lastGood)->kind, (*lastGood)->baseName, canonicalName,
(*lastGood)->moduleName,
std::vector<int>(scope.blocks.begin(), scope.blocks.begin() + prefix),
(*lastGood)->importPath);
item->accessChecked = {(*lastGood)->scope};
type = item->type = getUnbound(getSrcInfo());
lastGood = it->second.insert(++lastGood, item);
// Make sure to prepend a binding declaration: `var` and `var__used__ = False`
// to the dominating scope.
getBase()->preamble.push_back(tv->N<AssignStmt>(
tv->transform(tv->N<IdExpr>(canonicalName)), nullptr, nullptr));
getBase()->preamble.push_back(tv->N<AssignStmt>(
tv->transform(tv->N<IdExpr>(fmt::format("{}.__used__", canonicalName))),
tv->transform(tv->N<BoolExpr>(false)), nullptr));
// Reached the toplevel? Register the binding as global.
if (prefix == 1) {
cache->addGlobal(canonicalName);
cache->addGlobal(fmt::format("{}.__used__", canonicalName));
}
hasUsed = true;
}
// Remove all bindings after the dominant binding.
for (auto i = it->second.begin(); i != it->second.end(); i++) {
if (i == lastGood)
break;
if (!(*i)->canDominate())
continue;
// These bindings (and their canonical identifiers) will be replaced by the
// dominating binding during the type checking pass.
seqassert((*i)->canonicalName != canonicalName, "invalid replacement at {}: {}",
getSrcInfo(), canonicalName);
for (auto &ref : (*i)->references) {
ref->getId()->value = canonicalName;
tv->unify(type, ref->type);
}
auto update = tv->N<AssignStmt>(tv->N<IdExpr>(format("{}.__used__", canonicalName)),
tv->N<BoolExpr>(true));
update->setUpdate();
if (auto a = (*i)->root->getAssign()) {
a->lhs->getId()->value = canonicalName;
tv->unify(type, a->lhs->getType());
if (hasUsed) {
if (a->preamble) {
a->preamble->getAssign()->lhs->getId()->value = update->lhs->getId()->value;
} else {
a->preamble = tv->transform(update);
}
}
} else if (auto ts = dynamic_cast<TryStmt *>((*i)->root)) {
for (auto &c : ts->catches)
if (c.var == (*i)->canonicalName) {
c.var = canonicalName;
c.exc->setAttr(ExprAttr::Dominated);
tv->unify(type, c.exc->getType());
if (hasUsed) {
seqassert(c.suite->getSuite(), "not a Suite");
if (c.suite->getSuite() && !c.suite->getSuite()->stmts.empty() &&
c.suite->getSuite()->stmts[0]->getAssign() &&
c.suite->getSuite()->stmts[0]->getAssign()->lhs->isId(
format("{}.__used__", (*i)->canonicalName))) {
c.suite->getSuite()->stmts[0]->getAssign()->lhs->getId()->value =
update->lhs->getId()->value;
} else {
c.suite->getSuite()->stmts.insert(c.suite->getSuite()->stmts.begin(),
tv->transform(update));
}
}
}
} else if (auto fs = dynamic_cast<ForStmt *>((*i)->root)) {
fs->var->getId()->value = canonicalName;
fs->var->setAttr(ExprAttr::Dominated);
tv->unify(type, fs->var->getType());
if (hasUsed) {
seqassert(fs->suite->getSuite(), "not a Suite");
if (fs->suite->getSuite() && !fs->suite->getSuite()->stmts.empty() &&
fs->suite->getSuite()->stmts[0]->getAssign() &&
fs->suite->getSuite()->stmts[0]->getAssign()->lhs->isId(
format("{}.__used__", (*i)->canonicalName))) {
fs->suite->getSuite()->stmts[0]->getAssign()->lhs->getId()->value =
update->lhs->getId()->value;
} else {
fs->suite->getSuite()->stmts.insert(fs->suite->getSuite()->stmts.begin(),
tv->transform(update));
}
}
} else {
seqassert(false, "bad identifier root: '{}'", canonicalName);
}
auto it = std::find(stack.front().begin(), stack.front().end(), name);
if (it != stack.front().end())
stack.front().erase(it);
}
it->second.erase(it->second.begin(), lastGood);
return it->second.front();
}
/// Getters and setters /// Getters and setters
std::string TypeContext::getBaseName() const { return bases.back().name; } std::string TypeContext::getBaseName() const { return bases.back().name; }
@ -287,20 +140,23 @@ bool TypeContext::isCanonicalName(const std::string &name) const {
} }
std::string TypeContext::generateCanonicalName(const std::string &name, std::string TypeContext::generateCanonicalName(const std::string &name,
bool includeBase, bool zeroId) const { bool includeBase, bool noSuffix) const {
std::string newName = name; std::string newName = name;
bool alreadyGenerated = name.find('.') != std::string::npos; bool alreadyGenerated = name.find('.') != std::string::npos;
if (includeBase && !alreadyGenerated) { if (includeBase && !alreadyGenerated) {
std::string base = getBaseName(); std::string base = getBaseName();
if (base.empty()) if (base.empty())
base = getModule(); base = getModule();
if (base == "std.internal.core") if (base == "std.internal.core") {
noSuffix = true;
base = ""; base = "";
}
newName = (base.empty() ? "" : (base + ".")) + newName; newName = (base.empty() ? "" : (base + ".")) + newName;
} }
auto num = cache->identifierCount[newName]++; auto num = cache->identifierCount[newName]++;
newName = format("{}.{}", newName, num); if (!noSuffix && !alreadyGenerated)
if (name != newName && !zeroId) newName = format("{}.{}", newName, num);
if (name != newName)
cache->identifierCount[newName]++; cache->identifierCount[newName]++;
cache->reverseIdentifierLookup[newName] = name; cache->reverseIdentifierLookup[newName] = name;
return newName; return newName;
@ -308,7 +164,106 @@ std::string TypeContext::generateCanonicalName(const std::string &name,
void TypeContext::enterConditionalBlock() { scope.blocks.push_back(++scope.counter); } void TypeContext::enterConditionalBlock() { scope.blocks.push_back(++scope.counter); }
void TypeContext::leaveConditionalBlock(std::vector<StmtPtr> *stmts) { ExprPtr NameVisitor::transform(const std::shared_ptr<Expr> &expr) {
NameVisitor v(tv);
if (expr)
expr->accept(v);
return v.resultExpr ? v.resultExpr : expr;
}
ExprPtr NameVisitor::transform(std::shared_ptr<Expr> &expr) {
NameVisitor v(tv);
if (expr)
expr->accept(v);
if (v.resultExpr)
expr = v.resultExpr;
return expr;
}
StmtPtr NameVisitor::transform(const std::shared_ptr<Stmt> &stmt) {
NameVisitor v(tv);
if (stmt)
stmt->accept(v);
return v.resultStmt ? v.resultStmt : stmt;
}
StmtPtr NameVisitor::transform(std::shared_ptr<Stmt> &stmt) {
NameVisitor v(tv);
if (stmt)
stmt->accept(v);
if (v.resultExpr)
stmt = v.resultStmt;
return stmt;
}
void NameVisitor::visit(IdExpr *expr) {
while (auto s = in(tv->getCtx()->getBase()->replacements, expr->value)) {
expr->value = s->first;
tv->unify(expr->type, tv->getCtx()->forceFind(s->first)->type);
}
}
void NameVisitor::visit(AssignStmt *stmt) {
seqassert(stmt->lhs->getId(), "invalid AssignStmt {}", stmt->lhs);
std::string lhs = stmt->lhs->getId()->value;
if (auto changed = in(tv->getCtx()->getBase()->replacements, lhs)) {
while (auto s = in(tv->getCtx()->getBase()->replacements, lhs))
lhs = changed->first, changed = s;
if (stmt->rhs && changed->second) {
// Mark the dominating binding as used: `var.__used__ = True`
auto u =
N<AssignStmt>(N<IdExpr>(fmt::format("{}.__used__", lhs)), N<BoolExpr>(true));
u->setUpdate();
stmt->setUpdate();
resultStmt = N<SuiteStmt>(u, stmt->shared_from_this());
} else if (changed->second && !stmt->rhs) {
// This assignment was a declaration only.
// Just mark the dominating binding as used: `var.__used__ = True`
stmt->lhs = N<IdExpr>(fmt::format("{}.__used__", lhs));
stmt->rhs = N<BoolExpr>(true);
stmt->setUpdate();
}
seqassert(stmt->rhs, "bad domination statement: '{}'", stmt->toString());
}
}
void NameVisitor::visit(TryStmt *stmt) {
for (auto &c : stmt->catches) {
if (!c.var.empty()) {
// Handle dominated except bindings
auto changed = in(tv->getCtx()->getBase()->replacements, c.var);
while (auto s = in(tv->getCtx()->getBase()->replacements, c.var))
c.var = s->first, changed = s;
if (changed && changed->second) {
auto update =
N<AssignStmt>(N<IdExpr>(format("{}.__used__", c.var)), N<BoolExpr>(true));
update->setUpdate();
c.suite = N<SuiteStmt>(update, c.suite);
}
if (changed)
c.exc->setAttr(ExprAttr::Dominated);
}
}
}
void NameVisitor::visit(ForStmt *stmt) {
auto var = stmt->var->getId();
seqassert(var, "corrupt for variable: {}", stmt->var);
auto changed = in(tv->getCtx()->getBase()->replacements, var->value);
while (auto s = in(tv->getCtx()->getBase()->replacements, var->value))
var->value = s->first, changed = s;
if (changed && changed->second) {
auto u =
N<AssignStmt>(N<IdExpr>(format("{}.__used__", var->value)), N<BoolExpr>(true));
u->setUpdate();
stmt->suite = N<SuiteStmt>(u, stmt->suite);
}
if (changed)
var->setAttr(ExprAttr::Dominated);
}
void TypeContext::leaveConditionalBlock(std::vector<StmtPtr> *stmts,
TypecheckVisitor *tv) {
if (stmts && in(scope.stmts, scope.blocks.back())) {
stmts->insert(stmts->begin(), scope.stmts[scope.blocks.back()].begin(),
scope.stmts[scope.blocks.back()].end());
NameVisitor nv(tv);
for (auto &s : *stmts)
nv.transform(s);
}
scope.blocks.pop_back(); scope.blocks.pop_back();
} }
@ -334,17 +289,13 @@ TypeContext::Base *TypeContext::getClassBase() {
return nullptr; return nullptr;
} }
TypeContext::RealizationBase *TypeContext::getRealizationBase() { size_t TypeContext::getRealizationDepth() const { return bases.size(); }
return &(realizationBases.back());
}
size_t TypeContext::getRealizationDepth() const { return realizationBases.size(); }
std::string TypeContext::getRealizationStackName() const { std::string TypeContext::getRealizationStackName() const {
if (realizationBases.empty()) if (bases.empty())
return ""; return "";
std::vector<std::string> s; std::vector<std::string> s;
for (auto &b : realizationBases) for (auto &b : bases)
if (b.type) if (b.type)
s.push_back(b.type->realizedName()); s.push_back(b.type->realizedName());
return join(s, ":"); return join(s, ":");
@ -577,21 +528,31 @@ int TypeContext::reorderNamedArgs(types::FuncType *func,
void TypeContext::dump(int pad) { void TypeContext::dump(int pad) {
auto ordered = auto ordered =
std::map<std::string, decltype(map)::mapped_type>(map.begin(), map.end()); std::map<std::string, decltype(map)::mapped_type>(map.begin(), map.end());
LOG("base: {}", getRealizationStackName()); LOG("current module: {} ({})", moduleName.module, moduleName.path);
LOG("current base: {} / {}", getRealizationStackName(), getBase()->name);
for (auto &i : ordered) { for (auto &i : ordered) {
std::string s; std::string s;
auto t = i.second.front(); auto t = i.second.front();
LOG("{}{:.<25} {}", std::string(size_t(pad) * 2, ' '), i.first, t->type); LOG("{}{:.<25}", std::string(size_t(pad) * 2, ' '), i.first);
LOG(" ... kind: {}", t->isType() * 100 + t->isFunc() * 10 + t->isVar());
LOG(" ... canonical: {}", t->canonicalName);
LOG(" ... base: {}", t->baseName);
LOG(" ... module: {}", t->moduleName);
LOG(" ... type: {}", t->type ? t->type->debugString(2) : "<null>");
LOG(" ... scope: {}", t->scope);
LOG(" ... access: {}", t->accessChecked);
LOG(" ... shdw/dom: {} / {}", t->canShadow, t->avoidDomination);
LOG(" ... gnrc/sttc: {} / {}", t->generic, int(t->staticType));
} }
} }
std::string TypeContext::debugInfo() { std::string TypeContext::debugInfo() {
return fmt::format("[{}:i{}@{}]", getRealizationBase()->name, return fmt::format("[{}:i{}@{}]", getBase()->name, getBase()->iteration,
getRealizationBase()->iteration, getSrcInfo()); getSrcInfo());
} }
std::shared_ptr<std::pair<std::vector<types::TypePtr>, std::vector<types::TypePtr>>> std::shared_ptr<std::pair<std::vector<types::TypePtr>, std::vector<types::TypePtr>>>
TypeContext::getFunctionArgs(types::TypePtr t) { TypeContext::getFunctionArgs(const types::TypePtr &t) {
if (!t->getFunc()) if (!t->getFunc())
return nullptr; return nullptr;
auto fn = t->getFunc(); auto fn = t->getFunc();
@ -604,7 +565,7 @@ TypeContext::getFunctionArgs(types::TypePtr t) {
return ret; return ret;
} }
std::shared_ptr<std::string> TypeContext::getStaticString(types::TypePtr t) { std::shared_ptr<std::string> TypeContext::getStaticString(const types::TypePtr &t) {
if (auto s = t->getStatic()) { if (auto s = t->getStatic()) {
auto r = s->evaluate(); auto r = s->evaluate();
if (r.type == StaticValue::STRING) if (r.type == StaticValue::STRING)
@ -613,7 +574,7 @@ std::shared_ptr<std::string> TypeContext::getStaticString(types::TypePtr t) {
return nullptr; return nullptr;
} }
std::shared_ptr<int64_t> TypeContext::getStaticInt(types::TypePtr t) { std::shared_ptr<int64_t> TypeContext::getStaticInt(const types::TypePtr &t) {
if (auto s = t->getStatic()) { if (auto s = t->getStatic()) {
auto r = s->evaluate(); auto r = s->evaluate();
if (r.type == StaticValue::INT) if (r.type == StaticValue::INT)
@ -622,7 +583,7 @@ std::shared_ptr<int64_t> TypeContext::getStaticInt(types::TypePtr t) {
return nullptr; return nullptr;
} }
types::FuncTypePtr TypeContext::extractFunction(types::TypePtr t) { types::FuncTypePtr TypeContext::extractFunction(const types::TypePtr &t) {
if (auto f = t->getFunc()) if (auto f = t->getFunc())
return f; return f;
if (auto p = t->getPartial()) if (auto p = t->getPartial())

View File

@ -23,49 +23,42 @@ class TypecheckVisitor;
* Can be either a function, a class (type), or a variable. * Can be either a function, a class (type), or a variable.
*/ */
struct TypecheckItem : public SrcObject { struct TypecheckItem : public SrcObject {
/// Identifier kind
enum Kind { Func, Type, Var } kind;
/// Base name (e.g., foo.bar.baz)
std::string baseName;
/// Unique identifier (canonical name) /// Unique identifier (canonical name)
std::string canonicalName; std::string canonicalName;
/// Base name (e.g., foo.bar.baz)
std::string baseName;
/// Full module name /// Full module name
std::string moduleName; std::string moduleName;
/// Type
types::TypePtr type = nullptr;
/// Full base scope information /// Full base scope information
std::vector<int> scope; std::vector<int> scope = {0};
/// Non-empty string if a variable is import variable
std::string importPath;
/// List of scopes where the identifier is accessible /// List of scopes where the identifier is accessible
/// without __used__ check /// without __used__ check
std::vector<std::vector<int>> accessChecked; std::vector<std::vector<int>> accessChecked;
/// Set if an identifier cannot be shadowed /// Set if an identifier cannot be shadowed
/// (e.g., global-marked variables) /// (e.g., global-marked variables)
bool noShadow = false; bool canShadow = true;
/// Set if an identifier is a class or a function generic
bool generic = false;
/// Set if an identifier is a static variable.
char staticType = 0;
/// Set if an identifier should not be dominated /// Set if an identifier should not be dominated
/// (e.g., a loop variable in a comprehension). /// (e.g., a loop variable in a comprehension).
bool avoidDomination = false; bool avoidDomination = false;
std::list<ExprPtr> references; /// Set if an identifier is a class or a function generic
Stmt *root = nullptr; bool generic = false;
/// Set if an identifier is a static variable.
char staticType = 0;
/// Type TypecheckItem(std::string, std::string, std::string, types::TypePtr,
types::TypePtr type = nullptr; std::vector<int> = {0});
TypecheckItem(Kind, std::string, std::string, std::string, std::vector<int> = {},
std::string = "", types::TypePtr = nullptr);
/* Convenience getters */ /* Convenience getters */
std::string getBaseName() const { return baseName; } std::string getBaseName() const { return baseName; }
std::string getModule() const { return moduleName; } std::string getModule() const { return moduleName; }
bool isVar() const { return kind == Var; } bool isVar() const { return type->getLink() != nullptr && !generic; }
bool isFunc() const { return kind == Func; } bool isFunc() const { return type->getFunc() != nullptr; }
bool isType() const { return kind == Type; } bool isType() const { return !isFunc() && !isVar(); }
bool isImport() const { return !importPath.empty(); }
bool isGlobal() const { return scope.size() == 1 && baseName.empty(); } bool isGlobal() const { return scope.size() == 1 && baseName.empty(); }
/// True if an identifier is within a conditional block /// True if an identifier is within a conditional block
/// (i.e., a block that might not be executed during the runtime) /// (i.e., a block that might not be executed during the runtime)
@ -100,15 +93,25 @@ struct TypeContext : public Context<TypecheckItem> {
struct Base { struct Base {
/// Canonical name of a function or a class that owns this base. /// Canonical name of a function or a class that owns this base.
std::string name; std::string name;
/// Function type
types::TypePtr type = nullptr;
/// The return type of currently realized function
types::TypePtr returnType = nullptr;
/// Typechecking iteration
int iteration = 0;
/// Tracks function attributes (e.g. if it has @atomic or @test attributes). /// Tracks function attributes (e.g. if it has @atomic or @test attributes).
/// Only set for functions. /// Only set for functions.
Attr *attributes; Attr *attributes = nullptr;
/// Set if the base is class base and if class is marked with @deduce.
/// Stores the list of class fields in the order of traversal. struct {
std::shared_ptr<std::vector<std::string>> deducedMembers = nullptr; /// Set if the base is class base and if class is marked with @deduce.
/// Canonical name of `self` parameter that is used to deduce class fields /// Stores the list of class fields in the order of traversal.
/// (e.g., self in self.foo). std::shared_ptr<std::vector<std::string>> deducedMembers = nullptr;
std::string selfName; /// Canonical name of `self` parameter that is used to deduce class fields
/// (e.g., self in self.foo).
std::string selfName;
} deduce;
/// Map of captured identifiers (i.e., identifiers not defined in a function). /// Map of captured identifiers (i.e., identifiers not defined in a function).
/// Captured (canonical) identifiers are mapped to the new canonical names /// Captured (canonical) identifiers are mapped to the new canonical names
/// (representing the canonical function argument names that are appended to the /// (representing the canonical function argument names that are appended to the
@ -123,8 +126,6 @@ struct TypeContext : public Context<TypecheckItem> {
/// Scope that defines the base. /// Scope that defines the base.
std::vector<int> scope; std::vector<int> scope;
std::vector<StmtPtr> preamble;
/// Set of seen global identifiers used to prevent later creation of local variables /// Set of seen global identifiers used to prevent later creation of local variables
/// with the same name. /// with the same name.
std::unordered_map<std::string, ExprPtr> seenGlobalIdentifiers; std::unordered_map<std::string, ExprPtr> seenGlobalIdentifiers;
@ -142,8 +143,9 @@ struct TypeContext : public Context<TypecheckItem> {
}; };
std::vector<Loop> loops; std::vector<Loop> loops;
std::unordered_map<std::string, std::pair<std::string, bool>> replacements;
public: public:
explicit Base(std::string name, Attr *attributes = nullptr);
Loop *getLoop() { return loops.empty() ? nullptr : &(loops.back()); } Loop *getLoop() { return loops.empty() ? nullptr : &(loops.back()); }
bool isType() const { return attributes == nullptr; } bool isType() const { return attributes == nullptr; }
}; };
@ -153,7 +155,8 @@ struct TypeContext : public Context<TypecheckItem> {
struct BaseGuard { struct BaseGuard {
TypeContext *holder; TypeContext *holder;
BaseGuard(TypeContext *holder, const std::string &name) : holder(holder) { BaseGuard(TypeContext *holder, const std::string &name) : holder(holder) {
holder->bases.emplace_back(Base(name)); holder->bases.emplace_back();
holder->bases.back().name = name;
holder->bases.back().scope = holder->scope.blocks; holder->bases.back().scope = holder->scope.blocks;
holder->addBlock(); holder->addBlock();
} }
@ -163,10 +166,10 @@ struct TypeContext : public Context<TypecheckItem> {
} }
}; };
/// Set if the standard library is currently being loaded.
bool isStdlibLoading = false;
/// Current module. The default module is named `__main__`. /// Current module. The default module is named `__main__`.
ImportFile moduleName = {ImportFile::PACKAGE, "", ""}; ImportFile moduleName = {ImportFile::PACKAGE, "", ""};
/// Set if the standard library is currently being loaded.
bool isStdlibLoading = false;
/// Tracks if we are in a dependent part of a short-circuiting expression (e.g. b in a /// Tracks if we are in a dependent part of a short-circuiting expression (e.g. b in a
/// and b) to disallow assignment expressions there. /// and b) to disallow assignment expressions there.
bool isConditionalExpr = false; bool isConditionalExpr = false;
@ -176,21 +179,6 @@ struct TypeContext : public Context<TypecheckItem> {
/// Set if all assignments should not be dominated later on. /// Set if all assignments should not be dominated later on.
bool avoidDomination = false; bool avoidDomination = false;
/// A realization base definition. Each function realization defines a new base scope.
/// Used to properly realize enclosed functions and to prevent mess with mutually
/// recursive enclosed functions.
struct RealizationBase {
/// Function name
std::string name;
/// Function type
types::TypePtr type = nullptr;
/// The return type of currently realized function
types::TypePtr returnType = nullptr;
/// Typechecking iteration
int iteration = 0;
};
std::vector<RealizationBase> realizationBases;
/// The current type-checking level (for type instantiation and generalization). /// The current type-checking level (for type instantiation and generalization).
int typecheckLevel = 0; int typecheckLevel = 0;
std::set<types::TypePtr> pendingDefaults; std::set<types::TypePtr> pendingDefaults;
@ -215,13 +203,11 @@ public:
void add(const std::string &name, const Item &var) override; void add(const std::string &name, const Item &var) override;
/// Convenience method for adding an object to the context. /// Convenience method for adding an object to the context.
Item addVar(const std::string &name, const std::string &canonicalName, Item addVar(const std::string &name, const std::string &canonicalName,
const SrcInfo &srcInfo = SrcInfo(), const types::TypePtr &type = nullptr); const types::TypePtr &type, const SrcInfo &srcInfo = SrcInfo());
Item addType(const std::string &name, const std::string &canonicalName, Item addType(const std::string &name, const std::string &canonicalName,
const SrcInfo &srcInfo = SrcInfo(), const types::TypePtr &type, const SrcInfo &srcInfo = SrcInfo());
const types::TypePtr &type = nullptr);
Item addFunc(const std::string &name, const std::string &canonicalName, Item addFunc(const std::string &name, const std::string &canonicalName,
const SrcInfo &srcInfo = SrcInfo(), const types::TypePtr &type, const SrcInfo &srcInfo = SrcInfo());
const types::TypePtr &type = nullptr);
/// Add the item to the standard library module, thus ensuring its visibility from all /// Add the item to the standard library module, thus ensuring its visibility from all
/// modules. /// modules.
Item addAlwaysVisible(const Item &item); Item addAlwaysVisible(const Item &item);
@ -231,9 +217,6 @@ public:
/// Get an item that exists in the context. If the item does not exist, assertion is /// Get an item that exists in the context. If the item does not exist, assertion is
/// raised. /// raised.
Item forceFind(const std::string &name) const; Item forceFind(const std::string &name) const;
/// Get an item from the context. Perform domination analysis for accessing items
/// defined in the conditional blocks (i.e., Python scoping).
Item findDominatingBinding(const std::string &name, TypecheckVisitor *);
/// Return a canonical name of the current base. /// Return a canonical name of the current base.
/// An empty string represents the toplevel base. /// An empty string represents the toplevel base.
@ -247,11 +230,12 @@ public:
void enterConditionalBlock(); void enterConditionalBlock();
/// Leave a conditional block. Populate stmts (if set) with the declarations of /// Leave a conditional block. Populate stmts (if set) with the declarations of
/// newly added identifiers that dominate the children blocks. /// newly added identifiers that dominate the children blocks.
void leaveConditionalBlock(std::vector<StmtPtr> *stmts = nullptr); void leaveConditionalBlock(std::vector<StmtPtr> *stmts = nullptr,
TypecheckVisitor *t = nullptr);
/// Generate a unique identifier (name) for a given string. /// Generate a unique identifier (name) for a given string.
std::string generateCanonicalName(const std::string &name, bool includeBase = false, std::string generateCanonicalName(const std::string &name, bool includeBase = false,
bool zeroId = false) const; bool noSuffix = false) const;
/// True if we are at the toplevel. /// True if we are at the toplevel.
bool isGlobal() const; bool isGlobal() const;
/// True if we are within a conditional block. /// True if we are within a conditional block.
@ -277,8 +261,6 @@ public:
public: public:
/// Get the current realization depth (i.e., the number of nested realizations). /// Get the current realization depth (i.e., the number of nested realizations).
size_t getRealizationDepth() const; size_t getRealizationDepth() const;
/// Get the current base.
RealizationBase *getRealizationBase();
/// Get the name of the current realization stack (e.g., `fn1:fn2:...`). /// Get the name of the current realization stack (e.g., `fn1:fn2:...`).
std::string getRealizationStackName() const; std::string getRealizationStackName() const;
@ -343,10 +325,10 @@ private:
public: public:
std::shared_ptr<std::pair<std::vector<types::TypePtr>, std::vector<types::TypePtr>>> std::shared_ptr<std::pair<std::vector<types::TypePtr>, std::vector<types::TypePtr>>>
getFunctionArgs(types::TypePtr t); getFunctionArgs(const types::TypePtr &);
std::shared_ptr<std::string> getStaticString(types::TypePtr t); std::shared_ptr<std::string> getStaticString(const types::TypePtr &);
std::shared_ptr<int64_t> getStaticInt(types::TypePtr t); std::shared_ptr<int64_t> getStaticInt(const types::TypePtr &);
types::FuncTypePtr extractFunction(types::TypePtr t); types::FuncTypePtr extractFunction(const types::TypePtr &);
}; };
} // namespace codon::ast } // namespace codon::ast

View File

@ -64,7 +64,7 @@ void TypecheckVisitor::visit(TryStmt *stmt) {
ctx->enterConditionalBlock(); ctx->enterConditionalBlock();
if (!c.var.empty()) { if (!c.var.empty()) {
c.var = ctx->generateCanonicalName(c.var); c.var = ctx->generateCanonicalName(c.var);
ctx->addVar(ctx->cache->rev(c.var), c.var, c.suite->getSrcInfo()); ctx->addVar(ctx->cache->rev(c.var), c.var, ctx->getUnbound());
} }
transform(c.exc); transform(c.exc);
if (c.exc && c.exc->type->is("pyobj")) { if (c.exc && c.exc->type->is("pyobj")) {
@ -92,8 +92,7 @@ void TypecheckVisitor::visit(TryStmt *stmt) {
transformType(c.exc); transformType(c.exc);
if (!c.var.empty()) { if (!c.var.empty()) {
// Handle dominated except bindings // Handle dominated except bindings
auto val = ctx->addVar(c.var, c.var, getSrcInfo(), c.exc->getType()); auto val = ctx->addVar(c.var, c.var, c.exc->getType());
val->root = stmt;
unify(val->type, c.exc->getType()); unify(val->type, c.exc->getType());
} }
ctx->blockLevel++; ctx->blockLevel++;
@ -110,7 +109,7 @@ void TypecheckVisitor::visit(TryStmt *stmt) {
pyCatchStmt->suite->getSuite()->stmts.push_back(N<ThrowStmt>(nullptr)); pyCatchStmt->suite->getSuite()->stmts.push_back(N<ThrowStmt>(nullptr));
TryStmt::Catch c{pyVar, transformType(exc), pyCatchStmt}; TryStmt::Catch c{pyVar, transformType(exc), pyCatchStmt};
auto val = ctx->addVar(pyVar, pyVar, getSrcInfo(), c.exc->getType()); auto val = ctx->addVar(pyVar, pyVar, c.exc->getType());
unify(val->type, c.exc->getType()); unify(val->type, c.exc->getType());
ctx->blockLevel++; ctx->blockLevel++;
transform(c.suite); transform(c.suite);
@ -142,12 +141,11 @@ void TypecheckVisitor::visit(ThrowStmt *stmt) {
transform(stmt->expr); transform(stmt->expr);
if (!(stmt->expr->getCall() && if (!(stmt->expr->getCall() &&
stmt->expr->getCall()->expr->isId("__internal__.set_header:0"))) { stmt->expr->getCall()->expr->isId("__internal__.set_header"))) {
stmt->expr = transform(N<CallExpr>( stmt->expr = transform(N<CallExpr>(
N<DotExpr>(N<IdExpr>("__internal__"), "set_header"), stmt->expr, N<DotExpr>(N<IdExpr>("__internal__"), "set_header"), stmt->expr,
N<StringExpr>(ctx->getRealizationBase()->name), N<StringExpr>(ctx->getBase()->name), N<StringExpr>(stmt->getSrcInfo().file),
N<StringExpr>(stmt->getSrcInfo().file), N<IntExpr>(stmt->getSrcInfo().line), N<IntExpr>(stmt->getSrcInfo().line), N<IntExpr>(stmt->getSrcInfo().col)));
N<IntExpr>(stmt->getSrcInfo().col)));
} }
if (stmt->expr->isDone()) if (stmt->expr->isDone())
stmt->setDone(); stmt->setDone();

View File

@ -22,7 +22,7 @@ void TypecheckVisitor::visit(YieldExpr *expr) {
E(Error::FN_OUTSIDE_ERROR, expr, "yield"); E(Error::FN_OUTSIDE_ERROR, expr, "yield");
unify(expr->type, ctx->getUnbound()); unify(expr->type, ctx->getUnbound());
unify(ctx->getRealizationBase()->returnType, unify(ctx->getBase()->returnType,
ctx->instantiateGeneric(ctx->forceFind("Generator")->type, {expr->type})); ctx->instantiateGeneric(ctx->forceFind("Generator")->type, {expr->type}));
if (realize(expr->type)) if (realize(expr->type))
expr->setDone(); expr->setDone();
@ -48,19 +48,19 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) {
if (transform(stmt->expr)) { if (transform(stmt->expr)) {
// Wrap expression to match the return type // Wrap expression to match the return type
if (!ctx->getRealizationBase()->returnType->getUnbound()) if (!ctx->getBase()->returnType->getUnbound())
if (!wrapExpr(stmt->expr, ctx->getRealizationBase()->returnType)) { if (!wrapExpr(stmt->expr, ctx->getBase()->returnType)) {
return; return;
} }
// Special case: partialize functions if we are returning them // Special case: partialize functions if we are returning them
if (stmt->expr->getType()->getFunc() && if (stmt->expr->getType()->getFunc() &&
!(ctx->getRealizationBase()->returnType->getClass() && !(ctx->getBase()->returnType->getClass() &&
ctx->getRealizationBase()->returnType->is("Function"))) { ctx->getBase()->returnType->is("Function"))) {
stmt->expr = partializeFunction(stmt->expr->type->getFunc()); stmt->expr = partializeFunction(stmt->expr->type->getFunc());
} }
unify(ctx->getRealizationBase()->returnType, stmt->expr->type); unify(ctx->getBase()->returnType, stmt->expr->type);
} else { } else {
// Just set the expr for the translation stage. However, do not unify the return // Just set the expr for the translation stage. However, do not unify the return
// type! This might be a `return` in a generator. // type! This might be a `return` in a generator.
@ -82,7 +82,7 @@ void TypecheckVisitor::visit(YieldStmt *stmt) {
E(Error::FN_OUTSIDE_ERROR, stmt, "yield"); E(Error::FN_OUTSIDE_ERROR, stmt, "yield");
stmt->expr = transform(stmt->expr ? stmt->expr : N<CallExpr>(N<IdExpr>("NoneType"))); stmt->expr = transform(stmt->expr ? stmt->expr : N<CallExpr>(N<IdExpr>("NoneType")));
unify(ctx->getRealizationBase()->returnType, unify(ctx->getBase()->returnType,
ctx->instantiateGeneric(ctx->forceFind("Generator")->type, {stmt->expr->type})); ctx->instantiateGeneric(ctx->forceFind("Generator")->type, {stmt->expr->type}));
if (stmt->expr->isDone()) if (stmt->expr->isDone())
@ -104,7 +104,7 @@ void TypecheckVisitor::visit(GlobalStmt *stmt) {
E(Error::FN_OUTSIDE_ERROR, stmt, stmt->nonLocal ? "nonlocal" : "global"); E(Error::FN_OUTSIDE_ERROR, stmt, stmt->nonLocal ? "nonlocal" : "global");
// Dominate the binding // Dominate the binding
auto val = ctx->findDominatingBinding(stmt->var, this); auto val = findDominatingBinding(stmt->var, ctx.get());
if (!val || !val->isVar()) if (!val || !val->isVar())
E(Error::ID_NOT_FOUND, stmt, stmt->var); E(Error::ID_NOT_FOUND, stmt, stmt->var);
if (val->getBaseName() == ctx->getBaseName()) if (val->getBaseName() == ctx->getBaseName())
@ -121,10 +121,10 @@ void TypecheckVisitor::visit(GlobalStmt *stmt) {
// Register as global if needed // Register as global if needed
ctx->cache->addGlobal(val->canonicalName); ctx->cache->addGlobal(val->canonicalName);
val = ctx->addVar(stmt->var, val->canonicalName, stmt->getSrcInfo()); val = ctx->addVar(stmt->var, val->canonicalName, val->type);
val->baseName = ctx->getBaseName(); val->baseName = ctx->getBaseName();
// Globals/nonlocals cannot be shadowed in children scopes (as in Python) // Globals/nonlocals cannot be shadowed in children scopes (as in Python)
val->noShadow = true; val->canShadow = false;
// Erase the statement // Erase the statement
resultStmt = N<SuiteStmt>(); resultStmt = N<SuiteStmt>();
} }
@ -139,9 +139,6 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
return; return;
} }
// Function should be constructed only once
stmt->setDone();
// Parse attributes // Parse attributes
for (auto i = stmt->decorators.size(); i-- > 0;) { for (auto i = stmt->decorators.size(); i-- > 0;) {
auto [isAttr, attrName] = getDecorator(stmt->decorators[i]); auto [isAttr, attrName] = getDecorator(stmt->decorators[i]);
@ -168,24 +165,28 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
if (auto c = ctx->find(stmt->name)) { if (auto c = ctx->find(stmt->name)) {
if (c->isFunc() && c->getModule() == ctx->getModule() && if (c->isFunc() && c->getModule() == ctx->getModule() &&
c->getBaseName() == ctx->getBaseName()) c->getBaseName() == ctx->getBaseName())
rootName = c->canonicalName; rootName = ctx->cache->functions[c->canonicalName].rootName;
} }
} }
if (rootName.empty()) if (rootName.empty())
rootName = ctx->generateCanonicalName(stmt->name, true); rootName = ctx->generateCanonicalName(stmt->name, true, isClassMember);
// Append overload number to the name // Append overload number to the name
auto canonicalName = auto canonicalName = rootName;
format("{}:{}", rootName, ctx->cache->overloads[rootName].size()); if (!ctx->cache->overloads[rootName].empty())
canonicalName += format(":{}", ctx->cache->overloads[rootName].size());
ctx->cache->reverseIdentifierLookup[canonicalName] = stmt->name; ctx->cache->reverseIdentifierLookup[canonicalName] = stmt->name;
// Ensure that function binding does not shadow anything. if (isClassMember) {
// Function bindings cannot be dominated either // Set the enclosing class name
if (!isClassMember) { stmt->attributes.parentClass = ctx->getBase()->name;
auto funcVal = ctx->find(stmt->name); // Add the method to the class' method list
if (funcVal && funcVal->noShadow) ctx->cache->classes[ctx->getBase()->name].methods[stmt->name] = rootName;
} else {
// Ensure that function binding does not shadow anything.
// Function bindings cannot be dominated either
auto funcVal = ctx->find(stmt->name);
if (funcVal && !funcVal->canShadow)
E(Error::CLASS_INVALID_BIND, stmt, stmt->name); E(Error::CLASS_INVALID_BIND, stmt, stmt->name);
funcVal = ctx->addFunc(stmt->name, rootName, stmt->getSrcInfo());
ctx->addAlwaysVisible(funcVal);
} }
std::vector<Param> args; std::vector<Param> args;
@ -208,7 +209,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
// Mark as method if the first argument is self // Mark as method if the first argument is self
if (isClassMember && stmt->attributes.has(Attr::HasSelf) && a.name == "self") { if (isClassMember && stmt->attributes.has(Attr::HasSelf) && a.name == "self") {
ctx->getBase()->selfName = name; // ctx->getBase()->selfName = name;
stmt->attributes.set(Attr::Method); stmt->attributes.set(Attr::Method);
} }
@ -237,46 +238,52 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
// Generic and static types // Generic and static types
auto generic = ctx->getUnbound(); auto generic = ctx->getUnbound();
auto typId = generic->getLink()->id; auto typId = generic->getLink()->id;
generic->genericName = ctx->cache->rev(a.name); generic->genericName = varName;
if (a.defaultValue) {
auto defType = transformType(clone(a.defaultValue));
generic->defaultType = defType->type;
}
if (auto st = getStaticGeneric(a.type.get())) { if (auto st = getStaticGeneric(a.type.get())) {
auto val = ctx->addVar(varName, name, stmt->getSrcInfo(), generic); auto val = ctx->addVar(varName, name, generic);
val->generic = true; val->generic = true;
val->staticType = st; val->staticType = st;
generic->isStatic = true; generic->isStatic = st;
if (a.defaultValue) {
auto defType = transform(clone(a.defaultValue));
generic->defaultType = defType->type;
}
} else { } else {
auto val = ctx->addType(varName, name, stmt->getSrcInfo(), generic); auto val = ctx->addType(varName, name, generic);
val->generic = true; val->generic = true;
if (a.defaultValue) {
auto defType = transformType(clone(a.defaultValue));
generic->defaultType = defType->type;
}
} }
explicits.emplace_back(a.name, ctx->cache->rev(a.name), explicits.emplace_back(name, varName, generic->generalize(ctx->typecheckLevel),
generic->generalize(ctx->typecheckLevel), typId); typId);
} }
} }
// Prepare list of all generic types // Prepare list of all generic types
std::vector<TypePtr> generics;
ClassTypePtr parentClass = nullptr; ClassTypePtr parentClass = nullptr;
if (isClassMember && stmt->attributes.has(Attr::Method)) { if (isClassMember && stmt->attributes.has(Attr::Method)) {
// Get class generics (e.g., T for `class Cls[T]: def foo:`) // Get class generics (e.g., T for `class Cls[T]: def foo:`)
auto parentClassAST = ctx->cache->classes[stmt->attributes.parentClass].ast.get(); // auto parentClassAST =
// ctx->cache->classes[stmt->attributes.parentClass].ast.get();
parentClass = ctx->forceFind(stmt->attributes.parentClass)->type->getClass(); parentClass = ctx->forceFind(stmt->attributes.parentClass)->type->getClass();
parentClass = parentClass->instantiate(ctx->typecheckLevel - 1, nullptr, nullptr) parentClass = parentClass->instantiate(ctx->typecheckLevel - 1, nullptr, nullptr)
->getClass(); ->getClass();
seqassert(parentClass, "parent class not set"); // seqassert(parentClass, "parent class not set");
for (int i = 0, j = 0, k = 0; i < parentClassAST->args.size(); i++) { // for (int i = 0, j = 0, k = 0; i < parentClassAST->args.size(); i++) {
if (parentClassAST->args[i].status != Param::Normal) { // if (parentClassAST->args[i].status != Param::Normal) {
generics.push_back(parentClassAST->args[i].status == Param::Generic // generics.push_back(parentClassAST->args[i].status == Param::Generic
? parentClass->generics[j++].type // ? parentClass->generics[j++].type
: parentClass->hiddenGenerics[k++].type); // : parentClass->hiddenGenerics[k++].type);
ctx->addType(parentClassAST->args[i].name, parentClassAST->args[i].name, // ctx->addType(parentClassAST->args[i].name, parentClassAST->args[i].name,
getSrcInfo(), generics.back()); // generics.back())
} // ->generic = true;
} // }
// }
} }
// Add function generics // Add function generics
std::vector<TypePtr> generics;
for (const auto &i : explicits) for (const auto &i : explicits)
generics.push_back(ctx->find(i.name)->type); generics.push_back(ctx->find(i.name)->type);
@ -284,11 +291,6 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
// Base type: `Function[[args,...], ret]` // Base type: `Function[[args,...], ret]`
baseType = getFuncTypeBase(stmt->args.size() - explicits.size()); baseType = getFuncTypeBase(stmt->args.size() - explicits.size());
ctx->typecheckLevel++; ctx->typecheckLevel++;
if (stmt->ret) {
unify(baseType->generics[1].type, transformType(stmt->ret)->getType());
} else {
generics.push_back(unify(baseType->generics[1].type, ctx->getUnbound()));
}
// Parse arguments to the context. Needs to be done after adding generics // Parse arguments to the context. Needs to be done after adding generics
// to support cases like `foo(a: T, T: type)` // to support cases like `foo(a: T, T: type)`
@ -306,36 +308,42 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
std::string canName = stmt->args[ai].name; std::string canName = stmt->args[ai].name;
trimStars(canName); trimStars(canName);
if (!stmt->args[ai].type) { if (!stmt->args[ai].type) {
if (parentClass && ai == 0 && ctx->cache->rev(stmt->args[ai].name) == "self") { if (parentClass && ai == 0 && stmt->args[ai].name == "self") {
// Special case: self in methods // Special case: self in methods
unify(argType->args[aj], parentClass); unify(argType->args[aj], parentClass);
} else { } else {
unify(argType->args[aj], ctx->getUnbound()); unify(argType->args[aj], ctx->getUnbound());
generics.push_back(argType->args[aj]);
} }
generics.push_back(argType->args[aj++]);
} else if (startswith(stmt->args[ai].name, "*")) { } else if (startswith(stmt->args[ai].name, "*")) {
// Special case: `*args: type` and `**kwargs: type`. Do not add this type to the // Special case: `*args: type` and `**kwargs: type`. Do not add this type to the
// signature (as the real type is `Tuple[type, ...]`); it will be used during // signature (as the real type is `Tuple[type, ...]`); it will be used during
// call typechecking // call typechecking
unify(argType->args[aj], ctx->getUnbound()); unify(argType->args[aj], ctx->getUnbound());
generics.push_back(argType->args[aj++]); generics.push_back(argType->args[aj]);
} else { } else {
unify(argType->args[aj], transformType(stmt->args[ai].type)->getType()); unify(argType->args[aj], transformType(stmt->args[ai].type)->getType());
generics.push_back(argType->args[aj++]); // generics.push_back(argType->args[aj++]);
} }
ctx->addVar(ctx->cache->rev(canName), canName, stmt->getSrcInfo(), aj++;
argType->args[aj]); // ctx->addVar(ctx->cache->rev(canName), canName, argType->args[aj]);
} }
ctx->typecheckLevel--; ctx->typecheckLevel--;
// Parse the return type // Parse the return type
ret = transformType(stmt->ret, false); ret = transformType(stmt->ret, false);
if (ret) {
unify(baseType->generics[1].type, ret->getType());
} else {
generics.push_back(unify(baseType->generics[1].type, ctx->getUnbound()));
}
// Generalize generics and remove them from the context // Generalize generics and remove them from the context
for (const auto &g : generics) { for (const auto &g : generics) {
for (auto &u : g->getUnbounds()) for (auto &u : g->getUnbounds())
if (u->getUnbound()) if (u->getUnbound()) {
u->getUnbound()->kind = LinkType::Generic; u->getUnbound()->kind = LinkType::Generic;
}
} }
// Parse function body // Parse function body
@ -349,45 +357,56 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
ctx->getBase()->captures = &captures; ctx->getBase()->captures = &captures;
if (stmt->attributes.has("std.internal.attributes.pycapture")) if (stmt->attributes.has("std.internal.attributes.pycapture"))
ctx->getBase()->pyCaptures = &pyCaptures; ctx->getBase()->pyCaptures = &pyCaptures;
suite = clone(stmt->suite);
// suite = SimplifyVisitor(ctx, // suite = SimplifyVisitor(ctx,
// preamble).transformConditionalScope(stmt->suite); // preamble).transformConditionalScope(stmt->suite);
} }
} }
} }
stmt->attributes.module = stmt->attributes.module = ctx->moduleName.path;
format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::", // format(
ctx->moduleName.module); // "{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" :
// "::", ctx->moduleName.module);
ctx->cache->overloads[rootName].push_back(canonicalName); ctx->cache->overloads[rootName].push_back(canonicalName);
// Make function AST and cache it for later realization
auto f = N<FunctionStmt>(canonicalName, ret, args, suite, stmt->attributes);
ctx->cache->functions[canonicalName].module = ctx->getModule();
ctx->cache->functions[canonicalName].ast = f;
ctx->cache->functions[canonicalName].origAst =
std::static_pointer_cast<FunctionStmt>(stmt->clone());
ctx->cache->functions[canonicalName].isToplevel =
ctx->getModule().empty() && ctx->isGlobal();
ctx->cache->functions[canonicalName].rootName = rootName;
f->setDone();
// Construct the type // Construct the type
auto funcTyp = std::make_shared<types::FuncType>( auto funcTyp = std::make_shared<types::FuncType>(
baseType, ctx->cache->functions[stmt->name].ast.get(), explicits); baseType, ctx->cache->functions[canonicalName].ast.get(), explicits);
funcTyp->setSrcInfo(getSrcInfo()); funcTyp->setSrcInfo(getSrcInfo());
if (isClassMember && stmt->attributes.has(Attr::Method)) { if (isClassMember && stmt->attributes.has(Attr::Method)) {
funcTyp->funcParent = ctx->find(stmt->attributes.parentClass)->type; funcTyp->funcParent = ctx->find(stmt->attributes.parentClass)->type;
} }
funcTyp = std::static_pointer_cast<types::FuncType>( funcTyp = std::static_pointer_cast<types::FuncType>(
funcTyp->generalize(ctx->typecheckLevel)); funcTyp->generalize(ctx->typecheckLevel));
ctx->cache->functions[canonicalName].type = funcTyp;
ctx->addFunc(stmt->name, canonicalName, funcTyp);
if (isClassMember)
ctx->remove(stmt->name);
// Special method handling // Special method handling
if (isClassMember) { if (isClassMember) {
// Set the enclosing class name
stmt->attributes.parentClass = ctx->getBase()->name;
// Add the method to the class' method list
ctx->cache->classes[ctx->getBase()->name].methods[stmt->name] = rootName;
auto m = auto m =
ctx->cache->getMethod(ctx->find(stmt->attributes.parentClass)->type->getClass(), ctx->cache->getMethod(ctx->find(stmt->attributes.parentClass)->type->getClass(),
ctx->cache->rev(stmt->name)); ctx->cache->rev(canonicalName));
bool found = false; bool found = false;
for (auto &i : ctx->cache->overloads[m]) for (auto &i : ctx->cache->overloads[m])
if (i == stmt->name) { if (i == canonicalName) {
ctx->cache->functions[i].type = funcTyp; ctx->cache->functions[i].type = funcTyp;
found = true; found = true;
break; break;
} }
seqassert(found, "cannot find matching class method for {}", stmt->name); seqassert(found, "cannot find matching class method for {}", canonicalName);
} else { } else {
// Hack so that we can later use same helpers for class overloads // Hack so that we can later use same helpers for class overloads
ctx->cache->classes[".toplevel"].methods[stmt->name] = rootName; ctx->cache->classes[".toplevel"].methods[stmt->name] = rootName;
@ -410,22 +429,6 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
// args.push_back(kw); // args.push_back(kw);
// partialArgs.emplace_back("", N<EllipsisExpr>(EllipsisExpr::PARTIAL)); // partialArgs.emplace_back("", N<EllipsisExpr>(EllipsisExpr::PARTIAL));
// } // }
// Make function AST and cache it for later realization
auto f = N<FunctionStmt>(canonicalName, ret, args, suite, stmt->attributes);
ctx->cache->functions[canonicalName].ast = f;
ctx->cache->functions[canonicalName].origAst =
std::static_pointer_cast<FunctionStmt>(stmt->clone());
ctx->cache->functions[canonicalName].isToplevel =
ctx->getModule().empty() && ctx->isGlobal();
ctx->cache->functions[canonicalName].rootName = rootName;
// Update the visited table
// Functions should always be visible, so add them to the toplevel
auto val = std::make_shared<TypecheckItem>(TypecheckItem::Func, ctx->getBaseName(),
stmt->name, ctx->getModule());
val->type = funcTyp;
ctx->addToplevel(stmt->name, val);
ctx->cache->functions[stmt->name].type = funcTyp;
// Ensure that functions with @C, @force_realize, and @export attributes can be // Ensure that functions with @C, @force_realize, and @export attributes can be
// realized // realized
@ -436,13 +439,13 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
} }
// Debug information // Debug information
LOG_REALIZE("[stmt] added func {}: {}", stmt->name, funcTyp); LOG("[stmt] added func {}: {}", canonicalName, funcTyp->debugString(2));
// Expression to be used if function binding is modified by captures or decorators // Expression to be used if function binding is modified by captures or decorators
ExprPtr finalExpr = nullptr; ExprPtr finalExpr = nullptr;
// If there are captures, replace `fn` with `fn(cap1=cap1, cap2=cap2, ...)` // If there are captures, replace `fn` with `fn(cap1=cap1, cap2=cap2, ...)`
// if (!captures.empty()) { // if (!captures.empty()) {
// finalExpr = N<CallExpr>(N<IdExpr>(stmt->name), partialArgs); // finalExpr = N<CallExpr>(N<IdExpr>(canonicalName), partialArgs);
// // Add updated self reference in case function is recursive! // // Add updated self reference in case function is recursive!
// auto pa = partialArgs; // auto pa = partialArgs;
// for (auto &a : pa) { // for (auto &a : pa) {
@ -463,13 +466,13 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
E(Error::FN_NO_DECORATORS, stmt->decorators[i]); E(Error::FN_NO_DECORATORS, stmt->decorators[i]);
// Replace each decorator with `decorator(finalExpr)` in the reverse order // Replace each decorator with `decorator(finalExpr)` in the reverse order
finalExpr = N<CallExpr>(stmt->decorators[i], finalExpr = N<CallExpr>(stmt->decorators[i],
finalExpr ? finalExpr : N<IdExpr>(stmt->name)); finalExpr ? finalExpr : N<IdExpr>(canonicalName));
} }
} }
if (finalExpr) { if (finalExpr) {
resultStmt = resultStmt =
N<SuiteStmt>(f, transform(N<AssignStmt>(N<IdExpr>(stmt->name), finalExpr))); N<SuiteStmt>(f, transform(N<AssignStmt>(N<IdExpr>(canonicalName), finalExpr)));
} else { } else {
resultStmt = f; resultStmt = f;
} }
@ -603,11 +606,8 @@ std::pair<bool, std::string> TypecheckVisitor::getDecorator(const ExprPtr &e) {
if (id && id->getId()) { if (id && id->getId()) {
auto ci = ctx->find(id->getId()->value); auto ci = ctx->find(id->getId()->value);
if (ci && ci->isFunc()) { if (ci && ci->isFunc()) {
if (ctx->cache->overloads[ci->canonicalName].size() == 1) { return {ctx->cache->functions[ci->canonicalName].ast->attributes.isAttribute,
return {ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0]] ci->canonicalName};
.ast->attributes.isAttribute,
ci->canonicalName};
}
} }
} }
return {false, ""}; return {false, ""};

View File

@ -71,15 +71,12 @@ void TypecheckVisitor::visit(ImportStmt *stmt) {
if (!stmt->what) { if (!stmt->what) {
// Case: import foo // Case: import foo
auto name = stmt->as.empty() ? path : stmt->as; auto name = stmt->as.empty() ? path : stmt->as;
auto var = importVar + "_var"; // Construct `import_var = Import([path], [module])` (for printing imports etc.)
// Construct `import_var = Import([module], [path])` (for printing imports etc.)
resultStmt = N<SuiteStmt>( resultStmt = N<SuiteStmt>(
resultStmt, transform(N<AssignStmt>(N<IdExpr>(var), resultStmt,
N<CallExpr>(N<IdExpr>("Import"), transform(N<AssignStmt>(
N<StringExpr>(file->module), N<IdExpr>(name), N<CallExpr>(N<IdExpr>("Import"), N<StringExpr>(file->path),
N<StringExpr>(file->path)), N<StringExpr>(file->module)))));
N<IdExpr>("Import"))));
ctx->addVar(name, var, stmt->getSrcInfo())->importPath = file->path;
} else if (stmt->what->isId("*")) { } else if (stmt->what->isId("*")) {
// Case: from foo import * // Case: from foo import *
seqassert(stmt->as.empty(), "renamed star-import"); seqassert(stmt->as.empty(), "renamed star-import");
@ -91,7 +88,7 @@ void TypecheckVisitor::visit(ImportStmt *stmt) {
// `__` while the standard library is being loaded // `__` while the standard library is being loaded
auto c = i.second.front(); auto c = i.second.front();
if (c->isConditional() && i.first.find('.') == std::string::npos) { if (c->isConditional() && i.first.find('.') == std::string::npos) {
c = import.ctx->findDominatingBinding(i.first, this); c = findDominatingBinding(i.first, import.ctx.get());
} }
// Imports should ignore noShadow property // Imports should ignore noShadow property
ctx->Context<TypecheckItem>::add(i.first, c); ctx->Context<TypecheckItem>::add(i.first, c);
@ -106,14 +103,11 @@ void TypecheckVisitor::visit(ImportStmt *stmt) {
if (!c) if (!c)
E(Error::IMPORT_NO_NAME, i, i->value, file->module); E(Error::IMPORT_NO_NAME, i, i->value, file->module);
if (c->isConditional()) if (c->isConditional())
c = import.ctx->findDominatingBinding(i->value, this); c = findDominatingBinding(i->value, import.ctx.get());
// Imports should ignore noShadow property // Imports should ignore noShadow property
ctx->Context<TypecheckItem>::add(stmt->as.empty() ? i->value : stmt->as, c); ctx->Context<TypecheckItem>::add(stmt->as.empty() ? i->value : stmt->as, c);
} }
resultStmt = transform(!resultStmt ? N<SuiteStmt>() : resultStmt); // erase it
if (!resultStmt) {
resultStmt = N<SuiteStmt>(); // erase it
}
} }
/// Transform special `from C` and `from python` imports. /// Transform special `from C` and `from python` imports.
@ -204,9 +198,10 @@ StmtPtr TypecheckVisitor::transformCImport(const std::string &name,
StmtPtr TypecheckVisitor::transformCVarImport(const std::string &name, const Expr *type, StmtPtr TypecheckVisitor::transformCVarImport(const std::string &name, const Expr *type,
const std::string &altName) { const std::string &altName) {
auto canonical = ctx->generateCanonicalName(name); auto canonical = ctx->generateCanonicalName(name);
auto val = ctx->addVar(altName.empty() ? name : altName, canonical); auto typ = transformType(type->clone());
val->noShadow = true; auto val = ctx->addVar(altName.empty() ? name : altName, canonical, typ->type);
auto s = N<AssignStmt>(N<IdExpr>(canonical), nullptr, transformType(type->clone())); val->canShadow = false;
auto s = N<AssignStmt>(N<IdExpr>(canonical), nullptr, typ);
s->lhs->setAttr(ExprAttr::ExternVar); s->lhs->setAttr(ExprAttr::ExternVar);
return s; return s;
} }
@ -313,18 +308,17 @@ StmtPtr TypecheckVisitor::transformNewImport(const ImportFile &file) {
auto ictx = std::make_shared<TypeContext>(ctx->cache, file.path); auto ictx = std::make_shared<TypeContext>(ctx->cache, file.path);
ictx->isStdlibLoading = ctx->isStdlibLoading; ictx->isStdlibLoading = ctx->isStdlibLoading;
ictx->moduleName = file; ictx->moduleName = file;
auto import = ctx->cache->imports.insert({file.path, {file.path, ictx}}).first; auto import =
import->second.moduleName = file.module; ctx->cache->imports.insert({file.path, {file.module, file.path, ictx}}).first;
// __name__ = [import name] // __name__ = [import name]
StmtPtr n = StmtPtr n = N<AssignStmt>(N<IdExpr>("__name__"), N<StringExpr>(file.module));
N<AssignStmt>(N<IdExpr>("__name__"), N<StringExpr>(ictx->moduleName.module)); if (file.module == "internal.core") {
if (ictx->moduleName.module == "internal.core") {
// str is not defined when loading internal.core; __name__ is not needed anyway // str is not defined when loading internal.core; __name__ is not needed anyway
n = nullptr; n = nullptr;
} }
n = N<SuiteStmt>(n, parseFile(ctx->cache, file.path)); n = N<SuiteStmt>(n, parseFile(ctx->cache, file.path));
n = TypecheckVisitor(ictx).transform(n); n = TypecheckVisitor(ictx, preamble).transform(n);
if (!ctx->cache->errors.empty()) if (!ctx->cache->errors.empty())
throw exc::ParserException(); throw exc::ParserException();
// Add comment to the top of import for easier dump inspection // Add comment to the top of import for easier dump inspection
@ -341,8 +335,8 @@ StmtPtr TypecheckVisitor::transformNewImport(const ImportFile &file) {
std::string importDoneVar; std::string importDoneVar;
// `import_[I]_done = False` (set to True upon successful import) // `import_[I]_done = False` (set to True upon successful import)
ctx->cache->imports[MAIN_IMPORT].ctx->bases[0].preamble.push_back(N<AssignStmt>( preamble->push_back(N<AssignStmt>(N<IdExpr>(importDoneVar = importVar + "_done"),
N<IdExpr>(importDoneVar = importVar + "_done"), N<BoolExpr>(false))); N<BoolExpr>(false)));
ctx->cache->addGlobal(importDoneVar); ctx->cache->addGlobal(importDoneVar);
// Wrap all imported top-level statements into a function. // Wrap all imported top-level statements into a function.
@ -371,12 +365,11 @@ StmtPtr TypecheckVisitor::transformNewImport(const ImportFile &file) {
} }
// Create import function manually with ForceRealize // Create import function manually with ForceRealize
ctx->cache->functions[importVar + ":0"].ast = ctx->cache->functions[importVar].ast =
N<FunctionStmt>(importVar + ":0", nullptr, std::vector<Param>{}, N<FunctionStmt>(importVar, nullptr, std::vector<Param>{}, N<SuiteStmt>(stmts),
N<SuiteStmt>(stmts), Attr({Attr::ForceRealize})); Attr({Attr::ForceRealize}));
ctx->cache->imports[MAIN_IMPORT].ctx->bases[0].preamble.push_back( preamble->push_back(ctx->cache->functions[importVar].ast->clone());
ctx->cache->functions[importVar + ":0"].ast->clone()); ctx->cache->overloads[importVar].push_back(importVar);
ctx->cache->overloads[importVar].push_back(importVar + ":0");
} }
return nullptr; return nullptr;
} }

View File

@ -48,13 +48,11 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
if (!result) if (!result)
return nullptr; return nullptr;
for (ctx->getRealizationBase()->iteration = 1;; for (ctx->getBase()->iteration = 1;; ctx->getBase()->iteration++) {
ctx->getRealizationBase()->iteration++) { LOG_TYPECHECK("[iter] {} :: {}", ctx->getBase()->name, ctx->getBase()->iteration);
LOG_TYPECHECK("[iter] {} :: {}", ctx->getRealizationBase()->name, if (ctx->getBase()->iteration >= MAX_TYPECHECK_ITER)
ctx->getRealizationBase()->iteration);
if (ctx->getRealizationBase()->iteration >= MAX_TYPECHECK_ITER)
error(result, "cannot typecheck '{}' in reasonable time", error(result, "cannot typecheck '{}' in reasonable time",
ctx->cache->rev(ctx->getRealizationBase()->name)); ctx->cache->rev(ctx->getBase()->name));
// Keep iterating until: // Keep iterating until:
// (1) success: the statement is marked as done; or // (1) success: the statement is marked as done; or
@ -65,12 +63,12 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
ctx->changedNodes = 0; ctx->changedNodes = 0;
auto returnEarly = ctx->returnEarly; auto returnEarly = ctx->returnEarly;
ctx->returnEarly = false; ctx->returnEarly = false;
TypecheckVisitor(ctx).transform(result); TypecheckVisitor(ctx, preamble).transform(result);
std::swap(ctx->changedNodes, changedNodes); std::swap(ctx->changedNodes, changedNodes);
std::swap(ctx->returnEarly, returnEarly); std::swap(ctx->returnEarly, returnEarly);
ctx->typecheckLevel--; ctx->typecheckLevel--;
if (ctx->getRealizationBase()->iteration == 1 && isToplevel) { if (ctx->getBase()->iteration == 1 && isToplevel) {
// Realize all @force_realize functions // Realize all @force_realize functions
for (auto &f : ctx->cache->functions) { for (auto &f : ctx->cache->functions) {
auto &attr = f.second.ast->attributes; auto &attr = f.second.ast->attributes;
@ -94,8 +92,8 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
// their default values and then run another round to see if anything changed. // their default values and then run another round to see if anything changed.
bool anotherRound = false; bool anotherRound = false;
// Special case: return type might have default as well (e.g., Union) // Special case: return type might have default as well (e.g., Union)
if (ctx->getRealizationBase()->returnType) if (ctx->getBase()->returnType)
ctx->pendingDefaults.insert(ctx->getRealizationBase()->returnType); ctx->pendingDefaults.insert(ctx->getBase()->returnType);
for (auto &unbound : ctx->pendingDefaults) { for (auto &unbound : ctx->pendingDefaults) {
if (auto tu = unbound->getUnion()) { if (auto tu = unbound->getUnion()) {
// Seal all dynamic unions after the iteration is over // Seal all dynamic unions after the iteration is over
@ -225,10 +223,9 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
LOG_REALIZE("[realize] ty {} -> {}", realized->name, realized->realizedTypeName()); LOG_REALIZE("[realize] ty {} -> {}", realized->name, realized->realizedTypeName());
// Realizations should always be visible, so add them to the toplevel // Realizations should always be visible, so add them to the toplevel
auto val = std::make_shared<TypecheckItem>( auto val = std::make_shared<TypecheckItem>(realized->realizedTypeName(), "",
TypecheckItem::Type, "", realized->realizedTypeName(), ctx->getModule()); ctx->getModule(), realized);
val->type = realized; ctx->addAlwaysVisible(val);
ctx->addToplevel(realized->realizedTypeName(), val);
auto realization = auto realization =
ctx->cache->classes[realized->name].realizations[realized->realizedTypeName()] = ctx->cache->classes[realized->name].realizations[realized->realizedTypeName()] =
std::make_shared<Cache::Class::ClassRealization>(); std::make_shared<Cache::Class::ClassRealization>();
@ -250,10 +247,12 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
std::map<std::string, SrcInfo> memberInfo; // needed for IR std::map<std::string, SrcInfo> memberInfo; // needed for IR
for (auto &field : ctx->cache->classes[realized->name].fields) { for (auto &field : ctx->cache->classes[realized->name].fields) {
auto ftyp = ctx->instantiate(field.type, realized); auto ftyp = ctx->instantiate(field.type, realized);
if (!realize(ftyp)) if (!realize(ftyp)) {
E(Error::TYPE_CANNOT_REALIZE_ATTR, getSrcInfo(), field.name, realize(ftyp);
ftyp->prettyString()); E(Error::TYPE_CANNOT_REALIZE_ATTR, getSrcInfo(), ctx->cache->rev(field.name),
LOG_REALIZE("- member: {} -> {}: {}", field.name, field.type, ftyp); realized->prettyString());
}
// LOG_REALIZE("- member: {} -> {}: {}", field.name, field.type, ftyp);
realization->fields.emplace_back(field.name, ftyp); realization->fields.emplace_back(field.name, ftyp);
names.emplace_back(field.name); names.emplace_back(field.name);
typeArgs.emplace_back(makeIRType(ftyp->getClass().get())); typeArgs.emplace_back(makeIRType(ftyp->getClass().get()));
@ -272,10 +271,9 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
// Fix for partial types // Fix for partial types
if (auto p = type->getPartial()) { if (auto p = type->getPartial()) {
auto pt = std::make_shared<PartialType>(realized->getRecord(), p->func, p->known); auto pt = std::make_shared<PartialType>(realized->getRecord(), p->func, p->known);
auto val = std::make_shared<TypecheckItem>(TypecheckItem::Type, "", auto val =
pt->realizedName(), ctx->getModule()); std::make_shared<TypecheckItem>(pt->realizedName(), "", ctx->getModule(), pt);
val->type = pt; ctx->addAlwaysVisible(val);
ctx->addToplevel(pt->realizedName(), val);
ctx->cache->classes[pt->name].realizations[pt->realizedName()] = ctx->cache->classes[pt->name].realizations[pt->realizedName()] =
ctx->cache->classes[realized->name].realizations[realized->realizedTypeName()]; ctx->cache->classes[realized->name].realizations[realized->realizedTypeName()];
} }
@ -291,24 +289,31 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
} }
} }
seqassert(in(ctx->cache->imports, type->ast->attributes.module) != nullptr,
"bad module: '{}'", type->ast->attributes.module);
auto &imp = ctx->cache->imports[type->ast->attributes.module];
auto oldCtx = this->ctx;
this->ctx = imp.ctx;
// LOG("=> {}", ctx->moduleName.module, ctx->moduleName.path);
if (ctx->getRealizationDepth() > MAX_REALIZATION_DEPTH) { if (ctx->getRealizationDepth() > MAX_REALIZATION_DEPTH) {
E(Error::MAX_REALIZATION, getSrcInfo(), ctx->cache->rev(type->ast->name)); E(Error::MAX_REALIZATION, getSrcInfo(), ctx->cache->rev(type->ast->name));
} }
LOG_REALIZE("[realize] fn {} -> {} : base {} ; depth = {}", type->ast->name,
type->realizedName(), ctx->getRealizationStackName(),
ctx->getRealizationDepth());
getLogger().level++; getLogger().level++;
ctx->addBlock(); ctx->addBlock();
ctx->typecheckLevel++; ctx->typecheckLevel++;
// Find function parents // Find function parents
ctx->realizationBases.push_back( ctx->bases.push_back({type->ast->name, type->getFunc(), type->getRetType()});
{type->ast->name, type->getFunc(), type->getRetType()}); LOG("[realize] fn {} -> {} : base {} ; depth = {} ; ctx-base: {}", type->ast->name,
type->realizedName(), ctx->getRealizationStackName(), ctx->getRealizationDepth(),
ctx->getBaseName());
// Clone the generic AST that is to be realized // Clone the generic AST that is to be realized
auto ast = generateSpecialAst(type); auto ast = generateSpecialAst(type);
addFunctionGenerics(type); addFunctionGenerics(type);
ctx->getBase()->attributes = &(ast->attributes);
// Internal functions have no AST that can be realized // Internal functions have no AST that can be realized
bool hasAst = ast->suite && !ast->attributes.has(Attr::Internal); bool hasAst = ast->suite && !ast->attributes.has(Attr::Internal);
@ -317,8 +322,8 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
if (ast->args[i].status == Param::Normal) { if (ast->args[i].status == Param::Normal) {
std::string varName = ast->args[i].name; std::string varName = ast->args[i].name;
trimStars(varName); trimStars(varName);
ctx->addVar(varName, varName, getSrcInfo(), auto v = ctx->addVar(ctx->cache->rev(varName), varName,
std::make_shared<LinkType>(type->getArgTypes()[j++])); std::make_shared<LinkType>(type->getArgTypes()[j++]));
} }
// Populate realization table in advance to support recursive realizations // Populate realization table in advance to support recursive realizations
@ -332,9 +337,8 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
// Realizations should always be visible, so add them to the toplevel // Realizations should always be visible, so add them to the toplevel
auto val = auto val =
std::make_shared<TypecheckItem>(TypecheckItem::Func, "", key, ctx->getModule()); std::make_shared<TypecheckItem>(key, "", ctx->getModule(), type->getFunc());
val->type = type->getFunc(); ctx->addAlwaysVisible(val);
ctx->addToplevel(key, val);
if (hasAst) { if (hasAst) {
auto oldBlockLevel = ctx->blockLevel; auto oldBlockLevel = ctx->blockLevel;
@ -348,13 +352,15 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
// Lambda typecheck failures are "ignored" as they are treated as statements, // Lambda typecheck failures are "ignored" as they are treated as statements,
// not functions. // not functions.
// TODO: generalize this further. // TODO: generalize this further.
// LOG("{}", ast->suite->toString(2)); LOG("[error=>] {}", ast->suite->toString(2));
// inferTypes(ast->suite, ctx);
error("cannot typecheck the program"); error("cannot typecheck the program");
} }
ctx->realizationBases.pop_back(); ctx->bases.pop_back();
ctx->popBlock(); ctx->popBlock();
ctx->typecheckLevel--; ctx->typecheckLevel--;
getLogger().level--; getLogger().level--;
this->ctx = oldCtx;
return nullptr; // inference must be delayed return nullptr; // inference must be delayed
} }
@ -362,6 +368,7 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
// function has no return statement // function has no return statement
if (!ast->ret && type->getRetType()->getUnbound()) if (!ast->ret && type->getRetType()->getUnbound())
unify(type->getRetType(), ctx->forceFind("NoneType")->type); unify(type->getRetType(), ctx->forceFind("NoneType")->type);
// LOG("-> {} {}", key, ret->toString(2));
} }
// Realize the return type // Realize the return type
auto ret = realize(type->getRetType()); auto ret = realize(type->getRetType());
@ -387,14 +394,14 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
} }
if (force) if (force)
realizations[type->realizedName()]->ast = r->ast; realizations[type->realizedName()]->ast = r->ast;
val = std::make_shared<TypecheckItem>(TypecheckItem::Func, "", type->realizedName(), val = std::make_shared<TypecheckItem>(type->realizedName(), "", ctx->getModule(),
ctx->getModule()); type->getFunc());
val->type = type->getFunc(); ctx->addAlwaysVisible(val);
ctx->addToplevel(type->realizedName(), val); ctx->bases.pop_back();
ctx->realizationBases.pop_back();
ctx->popBlock(); ctx->popBlock();
ctx->typecheckLevel--; ctx->typecheckLevel--;
getLogger().level--; getLogger().level--;
this->ctx = oldCtx;
return type->getFunc(); return type->getFunc();
} }
@ -403,7 +410,9 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force)
/// Intended to be called once the typechecking is done. /// Intended to be called once the typechecking is done.
/// TODO: add JIT compatibility. /// TODO: add JIT compatibility.
StmtPtr TypecheckVisitor::prepareVTables() { StmtPtr TypecheckVisitor::prepareVTables() {
auto rep = "__internal__.class_populate_vtables:0"; // see internal.codon auto rep = "__internal__.class_populate_vtables"; // see internal.codon
if (!in(ctx->cache->functions, rep))
return nullptr;
auto &initFn = ctx->cache->functions[rep]; auto &initFn = ctx->cache->functions[rep];
auto suite = N<SuiteStmt>(); auto suite = N<SuiteStmt>();
for (auto &[_, cls] : ctx->cache->classes) { for (auto &[_, cls] : ctx->cache->classes) {
@ -417,7 +426,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
continue; continue;
// __internal__.class_set_rtti_vtable(real.ID, size, real.type) // __internal__.class_set_rtti_vtable(real.ID, size, real.type)
suite->stmts.push_back(N<ExprStmt>( suite->stmts.push_back(N<ExprStmt>(
N<CallExpr>(N<IdExpr>("__internal__.class_set_rtti_vtable:0"), N<CallExpr>(N<IdExpr>("__internal__.class_set_rtti_vtable"),
N<IntExpr>(real->id), N<IntExpr>(vtSz + 2), NT<IdExpr>(r)))); N<IntExpr>(real->id), N<IntExpr>(vtSz + 2), NT<IdExpr>(r))));
// LOG("[poly] {} -> {}", r, real->id); // LOG("[poly] {} -> {}", r, real->id);
vtSz = 0; vtSz = 0;
@ -431,7 +440,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
// p[real.ID].__setitem__(f.ID, Function[<TYPE_F>](f).__raw__()) // p[real.ID].__setitem__(f.ID, Function[<TYPE_F>](f).__raw__())
LOG_REALIZE("[poly] vtable[{}][{}] = {}", real->id, vtSz + id, fn); LOG_REALIZE("[poly] vtable[{}][{}] = {}", real->id, vtSz + id, fn);
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>( suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<IdExpr>("__internal__.class_set_rtti_vtable_fn:0"), N<IdExpr>("__internal__.class_set_rtti_vtable_fn"),
N<IntExpr>(real->id), N<IntExpr>(vtSz + id), N<IntExpr>(real->id), N<IntExpr>(vtSz + id),
N<CallExpr>(N<DotExpr>( N<CallExpr>(N<DotExpr>(
N<CallExpr>( N<CallExpr>(
@ -457,7 +466,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
typ->ast = initFn.ast.get(); typ->ast = initFn.ast.get();
realizeFunc(typ.get(), true); realizeFunc(typ.get(), true);
auto &initDist = ctx->cache->functions["__internal__.class_base_derived_dist:0"]; auto &initDist = ctx->cache->functions["__internal__.class_base_derived_dist"];
// def class_base_derived_dist(B, D): // def class_base_derived_dist(B, D):
// return Tuple[<types before B is reached in D>].__elemsize__ // return Tuple[<types before B is reached in D>].__elemsize__
auto oldAst = initDist.ast; auto oldAst = initDist.ast;
@ -583,7 +592,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
nullptr); nullptr);
std::vector<ExprPtr> callArgs; std::vector<ExprPtr> callArgs;
callArgs.emplace_back( callArgs.emplace_back(
N<CallExpr>(N<IdExpr>("__internal__.class_base_to_derived:0"), N<CallExpr>(N<IdExpr>("__internal__.class_base_to_derived"),
N<IdExpr>(fp->ast->args[0].name), N<IdExpr>(cp->realizedName()), N<IdExpr>(fp->ast->args[0].name), N<IdExpr>(cp->realizedName()),
N<IdExpr>(real->type->realizedName()))); N<IdExpr>(real->type->realizedName())));
for (size_t i = 1; i < args.size(); i++) for (size_t i = 1; i < args.size(); i++)
@ -782,11 +791,11 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
auto ast = std::dynamic_pointer_cast<FunctionStmt>( auto ast = std::dynamic_pointer_cast<FunctionStmt>(
clone(ctx->cache->functions[type->ast->name].ast)); clone(ctx->cache->functions[type->ast->name].ast));
if (ast->hasAttr("autogenerated") && endswith(ast->name, ".__iter__:0") && if (ast->hasAttr("autogenerated") && endswith(ast->name, ".__iter__") &&
type->getArgTypes()[0]->getHeterogenousTuple()) { type->getArgTypes()[0]->getHeterogenousTuple()) {
// Special case: do not realize auto-generated heterogenous __iter__ // Special case: do not realize auto-generated heterogenous __iter__
E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable"); E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable");
} else if (ast->hasAttr("autogenerated") && endswith(ast->name, ".__getitem__:0") && } else if (ast->hasAttr("autogenerated") && endswith(ast->name, ".__getitem__") &&
type->getArgTypes()[0]->getHeterogenousTuple()) { type->getArgTypes()[0]->getHeterogenousTuple()) {
// Special case: do not realize auto-generated heterogenous __getitem__ // Special case: do not realize auto-generated heterogenous __getitem__
E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable"); E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable");
@ -814,15 +823,15 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
ll.push_back(format("ret {{}} %{}", as.size())); ll.push_back(format("ret {{}} %{}", as.size()));
items[0] = N<ExprStmt>(N<StringExpr>(combine2(ll, "\n"))); items[0] = N<ExprStmt>(N<StringExpr>(combine2(ll, "\n")));
ast->suite = N<SuiteStmt>(items); ast->suite = N<SuiteStmt>(items);
} else if (startswith(ast->name, "Union.__new__:0")) { } else if (startswith(ast->name, "Union.__new__")) {
auto unionType = type->funcParent->getUnion(); auto unionType = type->funcParent->getUnion();
seqassert(unionType, "expected union, got {}", type->funcParent); seqassert(unionType, "expected union, got {}", type->funcParent);
StmtPtr suite = N<ReturnStmt>(N<CallExpr>( StmtPtr suite = N<ReturnStmt>(N<CallExpr>(
N<IdExpr>("__internal__.new_union:0"), N<IdExpr>(type->ast->args[0].name), N<IdExpr>("__internal__.new_union"), N<IdExpr>(type->ast->args[0].name),
N<IdExpr>(unionType->realizedTypeName()))); N<IdExpr>(unionType->realizedTypeName())));
ast->suite = suite; ast->suite = suite;
} else if (startswith(ast->name, "__internal__.new_union:0")) { } else if (startswith(ast->name, "__internal__.new_union")) {
// Special case: __internal__.new_union // Special case: __internal__.new_union
// def __internal__.new_union(value, U[T0, ..., TN]): // def __internal__.new_union(value, U[T0, ..., TN]):
// if isinstance(value, T0): // if isinstance(value, T0):
@ -842,7 +851,7 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
suite->stmts.push_back(N<IfStmt>( suite->stmts.push_back(N<IfStmt>(
N<CallExpr>(N<IdExpr>("isinstance"), N<IdExpr>(objVar), N<CallExpr>(N<IdExpr>("isinstance"), N<IdExpr>(objVar),
NT<IdExpr>(t->realizedName())), NT<IdExpr>(t->realizedName())),
N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.union_make:0"), N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.union_make"),
N<IntExpr>(tag), N<IdExpr>(objVar), N<IntExpr>(tag), N<IdExpr>(objVar),
N<IdExpr>(unionType->realizedTypeName()))))); N<IdExpr>(unionType->realizedTypeName())))));
// Check for Union[T] // Check for Union[T]
@ -852,8 +861,8 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
NT<InstantiateExpr>(NT<IdExpr>("Union"), NT<InstantiateExpr>(NT<IdExpr>("Union"),
std::vector<ExprPtr>{NT<IdExpr>(t->realizedName())})), std::vector<ExprPtr>{NT<IdExpr>(t->realizedName())})),
N<ReturnStmt>( N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("__internal__.union_make:0"), N<IntExpr>(tag), N<CallExpr>(N<IdExpr>("__internal__.union_make"), N<IntExpr>(tag),
N<CallExpr>(N<IdExpr>("__internal__.get_union:0"), N<CallExpr>(N<IdExpr>("__internal__.get_union"),
N<IdExpr>(objVar), NT<IdExpr>(t->realizedName())), N<IdExpr>(objVar), NT<IdExpr>(t->realizedName())),
N<IdExpr>(unionType->realizedTypeName()))))); N<IdExpr>(unionType->realizedTypeName())))));
tag++; tag++;
@ -861,7 +870,7 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
suite->stmts.push_back(N<ExprStmt>(N<CallExpr>( suite->stmts.push_back(N<ExprStmt>(N<CallExpr>(
N<IdExpr>("compile_error"), N<StringExpr>("invalid union constructor")))); N<IdExpr>("compile_error"), N<StringExpr>("invalid union constructor"))));
ast->suite = suite; ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union:0")) { } else if (startswith(ast->name, "__internal__.get_union")) {
// Special case: __internal__.get_union // Special case: __internal__.get_union
// def __internal__.new_union(union: Union[T0,...,TN], T): // def __internal__.new_union(union: Union[T0,...,TN], T):
// if __internal__.union_get_tag(union) == 0: // if __internal__.union_get_tag(union) == 0:
@ -878,10 +887,10 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
for (const auto &t : unionTypes) { for (const auto &t : unionTypes) {
if (t->realizedName() == targetType->realizedName()) { if (t->realizedName() == targetType->realizedName()) {
suite->stmts.push_back(N<IfStmt>( suite->stmts.push_back(N<IfStmt>(
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"), N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag"),
N<IdExpr>(selfVar)), N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag)), "==", N<IntExpr>(tag)),
N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"), N<ReturnStmt>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data"),
N<IdExpr>(selfVar), N<IdExpr>(selfVar),
NT<IdExpr>(t->realizedName()))))); NT<IdExpr>(t->realizedName())))));
} }
@ -891,7 +900,7 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"), N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
N<StringExpr>("invalid union getter")))); N<StringExpr>("invalid union getter"))));
ast->suite = suite; ast->suite = suite;
} else if (startswith(ast->name, "__internal__._get_union_method:0")) { } else if (startswith(ast->name, "__internal__._get_union_method")) {
// def __internal__._get_union_method(union: Union[T0,...,TN], method, *args, **kw): // def __internal__._get_union_method(union: Union[T0,...,TN], method, *args, **kw):
// if __internal__.union_get_tag(union) == 0: // if __internal__.union_get_tag(union) == 0:
// return __internal__.union_get_data(union, T0).method(*args, **kw) // return __internal__.union_get_data(union, T0).method(*args, **kw)
@ -907,7 +916,7 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
int tag = 0; int tag = 0;
for (auto &t : unionTypes) { for (auto &t : unionTypes) {
auto callee = auto callee =
N<DotExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"), N<DotExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_data"),
N<IdExpr>(selfVar), NT<IdExpr>(t->realizedName())), N<IdExpr>(selfVar), NT<IdExpr>(t->realizedName())),
fnName); fnName);
auto args = N<StarExpr>(N<IdExpr>(ast->args[2].name.substr(1))); auto args = N<StarExpr>(N<IdExpr>(ast->args[2].name.substr(1)));
@ -919,7 +928,7 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
suite->stmts.push_back(N<IfStmt>( suite->stmts.push_back(N<IfStmt>(
N<BinaryExpr>( N<BinaryExpr>(
check, "&&", check, "&&",
N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag:0"), N<BinaryExpr>(N<CallExpr>(N<IdExpr>("__internal__.union_get_tag"),
N<IdExpr>(selfVar)), N<IdExpr>(selfVar)),
"==", N<IntExpr>(tag))), "==", N<IntExpr>(tag))),
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(callee, args, kwargs))))); N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(callee, args, kwargs)))));
@ -931,7 +940,7 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
// suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>())); // suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>()));
unify(type->getRetType(), ctx->instantiate(ctx->forceFind("Union")->type)); unify(type->getRetType(), ctx->instantiate(ctx->forceFind("Union")->type));
ast->suite = suite; ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union_first:0")) { } else if (startswith(ast->name, "__internal__.get_union_first")) {
// def __internal__.get_union_first(union: Union[T0]): // def __internal__.get_union_first(union: Union[T0]):
// return __internal__.union_get_data(union, T0) // return __internal__.union_get_data(union, T0)
auto unionType = type->getArgTypes()[0]->getUnion(); auto unionType = type->getArgTypes()[0]->getUnion();
@ -939,7 +948,7 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
auto selfVar = ast->args[0].name; auto selfVar = ast->args[0].name;
auto suite = N<SuiteStmt>(N<ReturnStmt>( auto suite = N<SuiteStmt>(N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("__internal__.union_get_data:0"), N<IdExpr>(selfVar), N<CallExpr>(N<IdExpr>("__internal__.union_get_data"), N<IdExpr>(selfVar),
NT<IdExpr>(unionTypes[0]->realizedName())))); NT<IdExpr>(unionTypes[0]->realizedName()))));
ast->suite = suite; ast->suite = suite;
} }

View File

@ -87,7 +87,7 @@ void TypecheckVisitor::visit(WhileStmt *stmt) {
ctx->leaveConditionalBlock(); ctx->leaveConditionalBlock();
// Dominate loop variables // Dominate loop variables
for (auto &var : ctx->getBase()->getLoop()->seenVars) { for (auto &var : ctx->getBase()->getLoop()->seenVars) {
ctx->findDominatingBinding(var, this); findDominatingBinding(var, ctx.get());
} }
ctx->getBase()->loops.pop_back(); ctx->getBase()->loops.pop_back();
@ -140,15 +140,16 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
ctx->enterConditionalBlock(); ctx->enterConditionalBlock();
ctx->getBase()->loops.push_back({breakVar, ctx->scope.blocks, {}}); ctx->getBase()->loops.push_back({breakVar, ctx->scope.blocks, {}});
std::string varName; std::string varName;
TypeContext::Item val = nullptr;
if (auto i = stmt->var->getId()) { if (auto i = stmt->var->getId()) {
auto val = ctx->addVar(i->value, varName = ctx->generateCanonicalName(i->value), val = ctx->addVar(i->value, varName = ctx->generateCanonicalName(i->value),
stmt->var->getSrcInfo()); ctx->getUnbound());
val->avoidDomination = ctx->avoidDomination; val->avoidDomination = ctx->avoidDomination;
transform(stmt->var); transform(stmt->var);
stmt->suite = N<SuiteStmt>(stmt->suite); stmt->suite = N<SuiteStmt>(stmt->suite);
} else { } else {
varName = ctx->cache->getTemporaryVar("for"); varName = ctx->cache->getTemporaryVar("for");
auto val = ctx->addVar(varName, varName, stmt->var->getSrcInfo()); val = ctx->addVar(varName, varName, ctx->getUnbound());
auto var = N<IdExpr>(varName); auto var = N<IdExpr>(varName);
std::vector<StmtPtr> stmts; std::vector<StmtPtr> stmts;
// Add for_var = [for variables] // Add for_var = [for variables]
@ -162,9 +163,6 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
seqassert(var, "corrupt for variable: {}", stmt->var); seqassert(var, "corrupt for variable: {}", stmt->var);
// Unify iterator variable and the iterator type // Unify iterator variable and the iterator type
auto val = ctx->addVar(var->value, var->value, getSrcInfo(),
ctx->getUnbound(stmt->var->getSrcInfo()));
val->root = stmt;
if (iterType && iterType->name != "Generator") if (iterType && iterType->name != "Generator")
E(Error::EXPECTED_GENERATOR, stmt->iter); E(Error::EXPECTED_GENERATOR, stmt->iter);
unify(stmt->var->type, unify(stmt->var->type,
@ -181,13 +179,12 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
resultStmt = N<SuiteStmt>(assign, N<ForStmt>(*stmt), resultStmt = N<SuiteStmt>(assign, N<ForStmt>(*stmt),
N<IfStmt>(transform(N<IdExpr>(breakVar)), N<IfStmt>(transform(N<IdExpr>(breakVar)),
transformConditionalScope(stmt->elseSuite))); transformConditionalScope(stmt->elseSuite)));
val->root = resultStmt->getSuite()->stmts[1].get();
} }
ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts)); ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts));
// Dominate loop variables // Dominate loop variables
for (auto &var : ctx->getBase()->getLoop()->seenVars) for (auto &var : ctx->getBase()->getLoop()->seenVars)
ctx->findDominatingBinding(var, this); findDominatingBinding(var, ctx.get());
ctx->getBase()->loops.pop_back(); ctx->getBase()->loops.pop_back();
if (stmt->iter->isDone() && stmt->suite->isDone()) if (stmt->iter->isDone() && stmt->suite->isDone())
@ -341,7 +338,7 @@ TypecheckVisitor::transformStaticLoopCall(
auto stmt = N<AssignStmt>(N<IdExpr>(vars[0]), nullptr, nullptr); auto stmt = N<AssignStmt>(N<IdExpr>(vars[0]), nullptr, nullptr);
std::vector<std::shared_ptr<codon::SrcObject>> block; std::vector<std::shared_ptr<codon::SrcObject>> block;
if (startswith(fn->value, "statictuple:0")) { if (startswith(fn->value, "statictuple")) {
auto &args = iter->getCall()->args[0].value->getCall()->args; auto &args = iter->getCall()->args[0].value->getCall()->args;
if (vars.size() != 1) if (vars.size() != 1)
error("expected one item"); error("expected one item");
@ -356,7 +353,7 @@ TypecheckVisitor::transformStaticLoopCall(
} }
block.push_back(wrap(stmt->clone())); block.push_back(wrap(stmt->clone()));
} }
} else if (fn && startswith(fn->value, "std.internal.types.range.staticrange:0")) { } else if (fn && startswith(fn->value, "std.internal.types.range.staticrange")) {
if (vars.size() != 1) if (vars.size() != 1)
error("expected one item"); error("expected one item");
auto st = auto st =
@ -367,7 +364,7 @@ TypecheckVisitor::transformStaticLoopCall(
fn->type->getFunc()->funcGenerics[2].type->getStatic()->evaluate().getInt(); fn->type->getFunc()->funcGenerics[2].type->getStatic()->evaluate().getInt();
if (abs(st - ed) / abs(step) > MAX_STATIC_ITER) if (abs(st - ed) / abs(step) > MAX_STATIC_ITER)
E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, abs(st - ed) / abs(step)); E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, abs(st - ed) / abs(step));
for (int i = st; step > 0 ? i < ed : i > ed; i += step) { for (int64_t i = st; step > 0 ? i < ed : i > ed; i += step) {
stmt->rhs = N<IntExpr>(i); stmt->rhs = N<IntExpr>(i);
stmt->type = NT<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("int")); stmt->type = NT<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("int"));
block.push_back(wrap(stmt->clone())); block.push_back(wrap(stmt->clone()));
@ -379,7 +376,7 @@ TypecheckVisitor::transformStaticLoopCall(
fn->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt(); fn->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt();
if (ed > MAX_STATIC_ITER) if (ed > MAX_STATIC_ITER)
E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, ed); E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, ed);
for (int i = 0; i < ed; i++) { for (int64_t i = 0; i < ed; i++) {
stmt->rhs = N<IntExpr>(i); stmt->rhs = N<IntExpr>(i);
stmt->type = NT<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("int")); stmt->type = NT<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("int"));
block.push_back(wrap(stmt->clone())); block.push_back(wrap(stmt->clone()));
@ -402,8 +399,8 @@ TypecheckVisitor::transformStaticLoopCall(
if (typ->getHeterogenousTuple()) { if (typ->getHeterogenousTuple()) {
auto &ast = ctx->cache->functions[method].ast; auto &ast = ctx->cache->functions[method].ast;
if (ast->hasAttr("autogenerated") && if (ast->hasAttr("autogenerated") &&
(endswith(ast->name, ".__iter__:0") || (endswith(ast->name, ".__iter__") ||
endswith(ast->name, ".__getitem__:0"))) { endswith(ast->name, ".__getitem__"))) {
// ignore __getitem__ and other heterogenuous methods // ignore __getitem__ and other heterogenuous methods
continue; continue;
} }
@ -436,7 +433,7 @@ TypecheckVisitor::transformStaticLoopCall(
} else { } else {
error("bad call to staticenumerate"); error("bad call to staticenumerate");
} }
} else if (fn && startswith(fn->value, "std.internal.internal.vars:0")) { } else if (fn && startswith(fn->value, "std.internal.internal.vars")) {
if (auto fna = ctx->getFunctionArgs(fn->type)) { if (auto fna = ctx->getFunctionArgs(fn->type)) {
auto [generics, args] = *fna; auto [generics, args] = *fna;
@ -467,7 +464,7 @@ TypecheckVisitor::transformStaticLoopCall(
} else { } else {
error("bad call to vars"); error("bad call to vars");
} }
} else if (fn && startswith(fn->value, "std.internal.static.vars_types:0")) { } else if (fn && startswith(fn->value, "std.internal.static.vars_types")) {
if (auto fna = ctx->getFunctionArgs(fn->type)) { if (auto fna = ctx->getFunctionArgs(fn->type)) {
auto [generics, args] = *fna; auto [generics, args] = *fna;

View File

@ -369,6 +369,7 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
TypePtr t = nullptr; TypePtr t = nullptr;
if (expr->typeParams[i]->isStatic()) { if (expr->typeParams[i]->isStatic()) {
t = Type::makeStatic(ctx->cache, expr->typeParams[i]); t = Type::makeStatic(ctx->cache, expr->typeParams[i]);
t = ctx->instantiate(t);
} else { } else {
if (expr->typeParams[i]->getNone()) // `None` -> `NoneType` if (expr->typeParams[i]->getNone()) // `None` -> `NoneType`
transformType(expr->typeParams[i]); transformType(expr->typeParams[i]);
@ -458,13 +459,14 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
} }
/// Division and modulus implementations. /// Division and modulus implementations.
std::pair<int, int> divMod(const std::shared_ptr<TypeContext> &ctx, int a, int b) { std::pair<int64_t, int64_t> divMod(const std::shared_ptr<TypeContext> &ctx, int64_t a,
int64_t b) {
if (!b) if (!b)
E(Error::STATIC_DIV_ZERO, ctx->getSrcInfo()); E(Error::STATIC_DIV_ZERO, ctx->getSrcInfo());
if (ctx->cache->pythonCompat) { if (ctx->cache->pythonCompat) {
// Use Python implementation. // Use Python implementation.
int d = a / b; int64_t d = a / b;
int m = a - d * b; int64_t m = a - d * b;
if (m && ((b ^ m) < 0)) { if (m && ((b ^ m) < 0)) {
m += b; m += b;
d -= 1; d -= 1;
@ -800,7 +802,7 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
auto classItem = in(ctx->cache->classes, tuple->name); auto classItem = in(ctx->cache->classes, tuple->name);
seqassert(classItem, "cannot find class '{}'", tuple->name); seqassert(classItem, "cannot find class '{}'", tuple->name);
auto sz = classItem->fields.size(); auto sz = int64_t(classItem->fields.size());
int64_t start = 0, stop = sz, step = 1; int64_t start = 0, stop = sz, step = 1;
if (getInt(&start, index)) { if (getInt(&start, index)) {
// Case: `tuple[int]` // Case: `tuple[int]`

View File

@ -30,95 +30,43 @@ StmtPtr TypecheckVisitor::apply(
Cache *cache, const StmtPtr &node, const std::string &file, Cache *cache, const StmtPtr &node, const std::string &file,
const std::unordered_map<std::string, std::string> &defines, const std::unordered_map<std::string, std::string> &defines,
const std::unordered_map<std::string, std::string> &earlyDefines, bool barebones) { const std::unordered_map<std::string, std::string> &earlyDefines, bool barebones) {
auto preamble = std::vector<StmtPtr>(); auto preamble = std::make_shared<std::vector<StmtPtr>>();
seqassertn(cache->module, "cache's module is not set"); seqassertn(cache->module, "cache's module is not set");
#define N std::make_shared
// Load standard library if it has not been loaded // Load standard library if it has not been loaded
if (!in(cache->imports, STDLIB_IMPORT)) { if (!in(cache->imports, STDLIB_IMPORT))
// Load the internal.__init__ loadStdLibrary(cache, preamble, earlyDefines, barebones);
auto stdlib = std::make_shared<TypeContext>(cache, STDLIB_IMPORT);
auto stdlibPath =
getImportFile(cache->argv0, STDLIB_INTERNAL_MODULE, "", true, cache->module0);
const std::string initFile = "__init__.codon";
if (!stdlibPath || !endswith(stdlibPath->path, initFile))
E(Error::COMPILER_NO_STDLIB);
/// Use __init_test__ for faster testing (e.g., #%% name,barebones)
/// TODO: get rid of it one day...
if (barebones) {
stdlibPath->path =
stdlibPath->path.substr(0, stdlibPath->path.size() - initFile.size()) +
"__init_test__.codon";
}
stdlib->setFilename(stdlibPath->path);
cache->imports[STDLIB_IMPORT] = {stdlibPath->path, stdlib};
stdlib->isStdlibLoading = true;
stdlib->moduleName = {ImportFile::STDLIB, stdlibPath->path, "__init__"};
// Load the standard library
stdlib->setFilename(stdlibPath->path);
// Core definitions
auto core = TypecheckVisitor(stdlib).transform(
parseCode(stdlib->cache, stdlibPath->path, "from internal.core import *"));
preamble.insert(preamble.end(), stdlib->getBase()->preamble.begin(),
stdlib->getBase()->preamble.end());
stdlib->getBase()->preamble.clear();
preamble.push_back(core);
for (auto &d : earlyDefines) {
// Load early compile-time defines (for standard library)
auto def = TypecheckVisitor(stdlib).transform(
N<AssignStmt>(N<IdExpr>(d.first), N<IntExpr>(d.second),
N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("int"))));
preamble.insert(preamble.end(), stdlib->getBase()->preamble.begin(),
stdlib->getBase()->preamble.end());
stdlib->getBase()->preamble.clear();
preamble.push_back(def);
}
auto std =
TypecheckVisitor(stdlib).transform(parseFile(stdlib->cache, stdlibPath->path));
preamble.insert(preamble.end(), stdlib->getBase()->preamble.begin(),
stdlib->getBase()->preamble.end());
stdlib->getBase()->preamble.clear();
preamble.push_back(std);
stdlib->isStdlibLoading = false;
}
// Set up the context and the cache // Set up the context and the cache
auto ctx = std::make_shared<TypeContext>(cache, file); auto ctx = std::make_shared<TypeContext>(cache, file);
cache->imports[file].filename = file; cache->imports[file] = cache->imports[MAIN_IMPORT] = {MAIN_IMPORT, file, ctx};
cache->imports[file].ctx = ctx;
cache->imports[MAIN_IMPORT] = {file, ctx};
ctx->setFilename(file); ctx->setFilename(file);
ctx->moduleName = {ImportFile::PACKAGE, file, MODULE_MAIN}; ctx->moduleName = {ImportFile::PACKAGE, file, MODULE_MAIN};
if (!cache->typeCtx)
cache->typeCtx = std::make_shared<TypeContext>(cache);
// Prepare the code // Prepare the code
auto suite = N<SuiteStmt>(); auto tv = TypecheckVisitor(ctx, preamble);
suite->stmts.push_back(N<ClassStmt>(".toplevel", std::vector<Param>{}, nullptr, auto suite = tv.N<SuiteStmt>();
std::vector<ExprPtr>{N<IdExpr>(Attr::Internal)})); suite->stmts.push_back(
tv.N<ClassStmt>(".toplevel", std::vector<Param>{}, nullptr,
std::vector<ExprPtr>{tv.N<IdExpr>(Attr::Internal)}));
// Load compile-time defines (e.g., codon run -DFOO=1 ...)
for (auto &d : defines) { for (auto &d : defines) {
// Load compile-time defines (e.g., codon run -DFOO=1 ...)
suite->stmts.push_back( suite->stmts.push_back(
N<AssignStmt>(N<IdExpr>(d.first), N<IntExpr>(d.second), tv.N<AssignStmt>(tv.N<IdExpr>(d.first), tv.N<IntExpr>(d.second),
N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("int")))); tv.N<IndexExpr>(tv.N<IdExpr>("Static"), tv.N<IdExpr>("int"))));
} }
// Set up __name__ // Set up __name__
suite->stmts.push_back( suite->stmts.push_back(
N<AssignStmt>(N<IdExpr>("__name__"), N<StringExpr>(MODULE_MAIN))); tv.N<AssignStmt>(tv.N<IdExpr>("__name__"), tv.N<StringExpr>(MODULE_MAIN)));
suite->stmts.push_back(node); suite->stmts.push_back(node);
auto v = TypecheckVisitor(ctx); auto n = tv.inferTypes(suite, true);
auto n = v.inferTypes(suite, true);
if (!n) { if (!n) {
v.error("cannot typecheck the program"); tv.error("cannot typecheck the program");
} }
suite = N<SuiteStmt>(); suite = tv.N<SuiteStmt>();
suite->stmts.push_back(N<SuiteStmt>(preamble)); suite->stmts.push_back(tv.N<SuiteStmt>(*preamble));
suite->stmts.insert(suite->stmts.end(), ctx->getBase()->preamble.begin(),
ctx->getBase()->preamble.end());
ctx->getBase()->preamble.clear();
// Add dominated assignment declarations // Add dominated assignment declarations
if (in(ctx->scope.stmts, ctx->scope.blocks.back())) if (in(ctx->scope.stmts, ctx->scope.blocks.back()))
@ -127,8 +75,7 @@ StmtPtr TypecheckVisitor::apply(
ctx->scope.stmts[ctx->scope.blocks.back()].end()); ctx->scope.stmts[ctx->scope.blocks.back()].end());
suite->stmts.push_back(n); suite->stmts.push_back(n);
if (n->getSuite()) if (n->getSuite())
v.prepareVTables(); tv.prepareVTables();
#undef N
if (!ctx->cache->errors.empty()) if (!ctx->cache->errors.empty())
throw exc::ParserException(); throw exc::ParserException();
@ -136,23 +83,75 @@ StmtPtr TypecheckVisitor::apply(
return suite; return suite;
} }
void TypecheckVisitor::loadStdLibrary(
Cache *cache, const std::shared_ptr<std::vector<StmtPtr>> &preamble,
const std::unordered_map<std::string, std::string> &earlyDefines, bool barebones) {
// Load the internal.__init__
auto stdlib = std::make_shared<TypeContext>(cache, STDLIB_IMPORT);
auto stdlibPath =
getImportFile(cache->argv0, STDLIB_INTERNAL_MODULE, "", true, cache->module0);
const std::string initFile = "__init__.codon";
if (!stdlibPath || !endswith(stdlibPath->path, initFile))
E(Error::COMPILER_NO_STDLIB);
/// Use __init_test__ for faster testing (e.g., #%% name,barebones)
/// TODO: get rid of it one day...
if (barebones) {
stdlibPath->path =
stdlibPath->path.substr(0, stdlibPath->path.size() - initFile.size()) +
"__init_test__.codon";
}
stdlib->setFilename(stdlibPath->path);
cache->imports[stdlibPath->path] =
cache->imports[STDLIB_IMPORT] = {STDLIB_IMPORT, stdlibPath->path, stdlib};
// Load the standard library
stdlib->isStdlibLoading = true;
stdlib->moduleName = {ImportFile::STDLIB, stdlibPath->path, "__init__"};
stdlib->setFilename(stdlibPath->path);
// 1. Core definitions
auto core = TypecheckVisitor(stdlib, preamble)
.transform(parseCode(stdlib->cache, stdlibPath->path,
"from internal.core import *"));
preamble->push_back(core);
LOG("core done");
// 2. Load early compile-time defines (for standard library)
for (auto &d : earlyDefines) {
auto tv = TypecheckVisitor(stdlib, preamble);
auto def = tv.transform(
tv.N<AssignStmt>(tv.N<IdExpr>(d.first), tv.N<IntExpr>(d.second),
tv.N<IndexExpr>(tv.N<IdExpr>("Static"), tv.N<IdExpr>("int"))));
preamble->push_back(def);
}
LOG("defs done");
// 3. Load stdlib
auto std = TypecheckVisitor(stdlib, preamble)
.transform(parseFile(stdlib->cache, stdlibPath->path));
preamble->push_back(std);
stdlib->isStdlibLoading = false;
LOG("stdlib done");
}
/// Simplify an AST node. Assumes that the standard library is loaded. /// Simplify an AST node. Assumes that the standard library is loaded.
StmtPtr TypecheckVisitor::apply(const std::shared_ptr<TypeContext> &ctx, StmtPtr TypecheckVisitor::apply(const std::shared_ptr<TypeContext> &ctx,
const StmtPtr &node, const std::string &file) { const StmtPtr &node, const std::string &file) {
auto oldFilename = ctx->getFilename(); auto oldFilename = ctx->getFilename();
ctx->setFilename(file); ctx->setFilename(file);
auto v = TypecheckVisitor(ctx); auto preamble = std::make_shared<std::vector<StmtPtr>>();
auto n = v.inferTypes(node, true); auto tv = TypecheckVisitor(ctx, preamble);
auto n = tv.inferTypes(node, true);
ctx->setFilename(oldFilename); ctx->setFilename(oldFilename);
if (!n) { if (!n) {
v.error("cannot typecheck the program"); tv.error("cannot typecheck the program");
} }
if (!ctx->cache->errors.empty()) { if (!ctx->cache->errors.empty()) {
throw exc::ParserException(); throw exc::ParserException();
} }
auto suite = std::make_shared<SuiteStmt>(ctx->getBase()->preamble); auto suite = std::make_shared<SuiteStmt>(*preamble);
ctx->getBase()->preamble.clear();
suite->stmts.push_back(n); suite->stmts.push_back(n);
return suite; return suite;
} }
@ -160,14 +159,16 @@ StmtPtr TypecheckVisitor::apply(const std::shared_ptr<TypeContext> &ctx,
/**************************************************************************************/ /**************************************************************************************/
TypecheckVisitor::TypecheckVisitor(std::shared_ptr<TypeContext> ctx, TypecheckVisitor::TypecheckVisitor(std::shared_ptr<TypeContext> ctx,
const std::shared_ptr<std::vector<StmtPtr>> &pre,
const std::shared_ptr<std::vector<StmtPtr>> &stmts) const std::shared_ptr<std::vector<StmtPtr>> &stmts)
: ctx(std::move(ctx)) { : ctx(std::move(ctx)) {
preamble = pre ? pre : std::make_shared<std::vector<StmtPtr>>();
prependStmts = stmts ? stmts : std::make_shared<std::vector<StmtPtr>>(); prependStmts = stmts ? stmts : std::make_shared<std::vector<StmtPtr>>();
} }
/**************************************************************************************/ /**************************************************************************************/
ExprPtr TypecheckVisitor::transform(ExprPtr &expr) { return transform(expr); } ExprPtr TypecheckVisitor::transform(ExprPtr &expr) { return transform(expr, true); }
/// Transform an expression node. /// Transform an expression node.
ExprPtr TypecheckVisitor::transform(ExprPtr &expr, bool allowTypes) { ExprPtr TypecheckVisitor::transform(ExprPtr &expr, bool allowTypes) {
@ -178,7 +179,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr, bool allowTypes) {
unify(expr->type, ctx->getUnbound()); unify(expr->type, ctx->getUnbound());
auto typ = expr->type; auto typ = expr->type;
if (!expr->done) { if (!expr->done) {
TypecheckVisitor v(ctx, prependStmts); TypecheckVisitor v(ctx, preamble, prependStmts);
v.setSrcInfo(expr->getSrcInfo()); v.setSrcInfo(expr->getSrcInfo());
ctx->pushSrcInfo(expr->getSrcInfo()); ctx->pushSrcInfo(expr->getSrcInfo());
expr->accept(v); expr->accept(v);
@ -236,7 +237,7 @@ StmtPtr TypecheckVisitor::transform(StmtPtr &stmt) {
if (!stmt || stmt->done) if (!stmt || stmt->done)
return stmt; return stmt;
TypecheckVisitor v(ctx); TypecheckVisitor v(ctx, preamble);
v.setSrcInfo(stmt->getSrcInfo()); v.setSrcInfo(stmt->getSrcInfo());
ctx->pushSrcInfo(stmt->getSrcInfo()); ctx->pushSrcInfo(stmt->getSrcInfo());
stmt->accept(v); stmt->accept(v);
@ -254,8 +255,8 @@ StmtPtr TypecheckVisitor::transform(StmtPtr &stmt) {
} }
if (stmt->done) if (stmt->done)
ctx->changedNodes++; ctx->changedNodes++;
// LOG_TYPECHECK("[stmt] {}: {}{}", getSrcInfo(), stmt, stmt->isDone() ? "[done]" : // LOG("[stmt] {}: {} {}", getSrcInfo(), split(stmt->toString(1), '\n').front(),
// ""); // stmt->isDone() ? "[done]" : "");
return stmt; return stmt;
} }
@ -534,7 +535,7 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
expr = transform(N<CallExpr>(expr, N<EllipsisExpr>(EllipsisExpr::PARTIAL))); expr = transform(N<CallExpr>(expr, N<EllipsisExpr>(EllipsisExpr::PARTIAL)));
else else
expr = transform(N<CallExpr>( expr = transform(N<CallExpr>(
N<IdExpr>("__internal__.class_ctr:0"), N<IdExpr>("__internal__.class_ctr"),
std::vector<CallExpr::Arg>{{"T", expr}, std::vector<CallExpr::Arg>{{"T", expr},
{"", N<EllipsisExpr>(EllipsisExpr::PARTIAL)}})); {"", N<EllipsisExpr>(EllipsisExpr::PARTIAL)}}));
} }
@ -576,7 +577,7 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
!expectedClass->getUnion()) { !expectedClass->getUnion()) {
// Extract union types via __internal__.get_union // Extract union types via __internal__.get_union
if (auto t = realize(expectedClass)) { if (auto t = realize(expectedClass)) {
expr = transform(N<CallExpr>(N<IdExpr>("__internal__.get_union:0"), expr, expr = transform(N<CallExpr>(N<IdExpr>("__internal__.get_union"), expr,
N<IdExpr>(t->realizedName()))); N<IdExpr>(t->realizedName())));
} else { } else {
return false; return false;
@ -587,7 +588,7 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
expectedClass->getUnion()->addType(exprClass); expectedClass->getUnion()->addType(exprClass);
if (auto t = realize(expectedClass)) { if (auto t = realize(expectedClass)) {
if (expectedClass->unify(exprClass.get(), nullptr) == -1) if (expectedClass->unify(exprClass.get(), nullptr) == -1)
expr = transform(N<CallExpr>(N<IdExpr>("__internal__.new_union:0"), expr, expr = transform(N<CallExpr>(N<IdExpr>("__internal__.new_union"), expr,
NT<IdExpr>(t->realizedName()))); NT<IdExpr>(t->realizedName())));
} else { } else {
return false; return false;
@ -656,4 +657,36 @@ TypecheckVisitor::unpackTupleTypes(ExprPtr expr) {
return ret; return ret;
} }
TypePtr TypecheckVisitor::getClassGeneric(const types::ClassTypePtr &cls, int idx) {
seqassert(idx < cls->generics.size(), "bad generic");
return cls->generics[idx].type;
}
std::string TypecheckVisitor::getClassStaticStr(const types::ClassTypePtr &cls,
int idx) {
int i = 0;
for (auto &g : cls->generics) {
if (g.type->getStatic() &&
g.type->getStatic()->expr->staticValue.type == StaticValue::STRING) {
if (i++ == idx) {
return g.type->getStatic()->evaluate().getString();
}
}
}
seqassert(false, "bad string static generic");
return "";
}
int64_t TypecheckVisitor::getClassStaticInt(const types::ClassTypePtr &cls, int idx) {
int i = 0;
for (auto &g : cls->generics) {
if (g.type->getStatic() &&
g.type->getStatic()->expr->staticValue.type == StaticValue::INT) {
if (i++ == idx) {
return g.type->getStatic()->evaluate().getInt();
}
}
}
seqassert(false, "bad int static generic");
return -1;
}
} // namespace codon::ast } // namespace codon::ast

View File

@ -26,7 +26,8 @@ class TypecheckVisitor : public CallbackASTVisitor<ExprPtr, StmtPtr> {
/// Shared simplification context. /// Shared simplification context.
std::shared_ptr<TypeContext> ctx; std::shared_ptr<TypeContext> ctx;
/// Statements to prepend before the current statement. /// Statements to prepend before the current statement.
std::shared_ptr<std::vector<StmtPtr>> prependStmts; std::shared_ptr<std::vector<StmtPtr>> prependStmts = nullptr;
std::shared_ptr<std::vector<StmtPtr>> preamble = nullptr;
/// Each new expression is stored here (as @c visit does not return anything) and /// Each new expression is stored here (as @c visit does not return anything) and
/// later returned by a @c transform call. /// later returned by a @c transform call.
@ -45,9 +46,15 @@ public:
static StmtPtr apply(const std::shared_ptr<TypeContext> &cache, const StmtPtr &node, static StmtPtr apply(const std::shared_ptr<TypeContext> &cache, const StmtPtr &node,
const std::string &file = "<internal>"); const std::string &file = "<internal>");
private:
static void loadStdLibrary(Cache *, const std::shared_ptr<std::vector<StmtPtr>> &,
const std::unordered_map<std::string, std::string> &,
bool);
public: public:
explicit TypecheckVisitor( explicit TypecheckVisitor(
std::shared_ptr<TypeContext> ctx, std::shared_ptr<TypeContext> ctx,
const std::shared_ptr<std::vector<StmtPtr>> &preamble = nullptr,
const std::shared_ptr<std::vector<StmtPtr>> &stmts = nullptr); const std::shared_ptr<std::vector<StmtPtr>> &stmts = nullptr);
public: // Convenience transformators public: // Convenience transformators
@ -89,6 +96,7 @@ private: // Node typechecking rules
/* Identifier access expressions (access.cpp) */ /* Identifier access expressions (access.cpp) */
void visit(IdExpr *) override; void visit(IdExpr *) override;
TypeContext::Item findDominatingBinding(const std::string &, TypeContext *);
bool checkCapture(const TypeContext::Item &); bool checkCapture(const TypeContext::Item &);
void visit(DotExpr *) override; void visit(DotExpr *) override;
std::pair<size_t, TypeContext::Item> getImport(const std::vector<std::string> &); std::pair<size_t, TypeContext::Item> getImport(const std::vector<std::string> &);
@ -235,8 +243,8 @@ private: // Node typechecking rules
void visit(ClassStmt *) override; void visit(ClassStmt *) override;
std::vector<ClassStmt *> parseBaseClasses(std::vector<ExprPtr> &, std::vector<ClassStmt *> parseBaseClasses(std::vector<ExprPtr> &,
std::vector<Param> &, const Attr &, std::vector<Param> &, const Attr &,
const std::string &, const std::string &, const ExprPtr &,
const ExprPtr & = nullptr); types::ClassTypePtr &);
std::pair<StmtPtr, FunctionStmt *> autoDeduceMembers(ClassStmt *, std::pair<StmtPtr, FunctionStmt *> autoDeduceMembers(ClassStmt *,
std::vector<Param> &); std::vector<Param> &);
std::vector<StmtPtr> getClassMethods(const StmtPtr &s); std::vector<StmtPtr> getClassMethods(const StmtPtr &s);
@ -254,13 +262,15 @@ private: // Node typechecking rules
void visit(CommentStmt *stmt) override; void visit(CommentStmt *stmt) override;
void visit(CustomStmt *) override; void visit(CustomStmt *) override;
private: public:
/* Type inference (infer.cpp) */ /* Type inference (infer.cpp) */
types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b); types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b);
types::TypePtr unify(types::TypePtr &&a, const types::TypePtr &b) { types::TypePtr unify(types::TypePtr &&a, const types::TypePtr &b) {
auto x = a; auto x = a;
return unify(x, b); return unify(x, b);
} }
private:
StmtPtr inferTypes(StmtPtr, bool isToplevel = false); StmtPtr inferTypes(StmtPtr, bool isToplevel = false);
types::TypePtr realize(types::TypePtr); types::TypePtr realize(types::TypePtr);
types::TypePtr realizeFunc(types::FuncType *, bool = false); types::TypePtr realizeFunc(types::FuncType *, bool = false);
@ -271,6 +281,10 @@ private:
codon::ir::Func * codon::ir::Func *
makeIRFunction(const std::shared_ptr<Cache::Function::FunctionRealization> &); makeIRFunction(const std::shared_ptr<Cache::Function::FunctionRealization> &);
types::TypePtr getClassGeneric(const types::ClassTypePtr &, int = 0);
std::string getClassStaticStr(const types::ClassTypePtr &, int = 0);
int64_t getClassStaticInt(const types::ClassTypePtr &, int = 0);
private: private:
types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ, types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ,
const std::string &member, const std::string &member,
@ -293,6 +307,7 @@ private:
public: public:
bool isTuple(const std::string &s) const { return startswith(s, TYPE_TUPLE); } bool isTuple(const std::string &s) const { return startswith(s, TYPE_TUPLE); }
std::shared_ptr<TypeContext> getCtx() const { return ctx; }
friend class Cache; friend class Cache;
friend class TypeContext; friend class TypeContext;
@ -308,4 +323,21 @@ private: // Helpers
const std::function<std::shared_ptr<codon::SrcObject>(StmtPtr)> &); const std::function<std::shared_ptr<codon::SrcObject>(StmtPtr)> &);
}; };
class NameVisitor : public CallbackASTVisitor<ExprPtr, StmtPtr> {
TypecheckVisitor *tv;
ExprPtr resultExpr = nullptr;
StmtPtr resultStmt = nullptr;
public:
NameVisitor(TypecheckVisitor *tv) : tv(tv) {}
ExprPtr transform(const std::shared_ptr<Expr> &expr) override;
ExprPtr transform(std::shared_ptr<Expr> &expr) override;
StmtPtr transform(const std::shared_ptr<Stmt> &stmt) override;
StmtPtr transform(std::shared_ptr<Stmt> &stmt) override;
void visit(IdExpr *expr) override;
void visit(AssignStmt *stmt) override;
void visit(TryStmt *stmt) override;
void visit(ForStmt *stmt) override;
};
} // namespace codon::ast } // namespace codon::ast

View File

@ -4,6 +4,7 @@
from internal.attributes import * from internal.attributes import *
from internal.static import static_print as __static_print__ from internal.static import static_print as __static_print__
from internal.types.ptr import * from internal.types.ptr import *
from internal.types.str import * from internal.types.str import *
from internal.types.int import * from internal.types.int import *
@ -33,19 +34,20 @@ from internal.types.collections.tuple import *
import internal.c_stubs as _C import internal.c_stubs as _C
from internal.format import * from internal.format import *
from internal.builtin import * from internal.builtin import *
from internal.builtin import _jit_display from internal.builtin import _jit_display
from internal.str import * from internal.str import *
from internal.sort import sorted from internal.sort import sorted
from openmp import Ident as __OMPIdent, for_par # from openmp import Ident as __OMPIdent, for_par
from gpu import _gpu_loop_outline_template # from gpu import _gpu_loop_outline_template
from internal.file import File, gzFile, open, gzopen from internal.file import File, gzFile, open, gzopen
from pickle import pickle, unpickle from pickle import pickle, unpickle
from internal.dlopen import dlsym as _dlsym from internal.dlopen import dlsym as _dlsym
import internal.python # import internal.python
if __py_numerics__: # if __py_numerics__:
import internal.pynumerics # import internal.pynumerics
if __py_extension__: # if __py_extension__:
internal.python.ensure_initialized() # internal.python.ensure_initialized()

View File

@ -7,10 +7,6 @@ class __internal__:
class __magic__: class __magic__:
pass pass
@__internal__
class __magic__:
pass
@tuple @tuple
@__internal__ @__internal__
@__notuple__ @__notuple__
@ -163,6 +159,15 @@ class __array__:
def __new__(sz: Static[int]) -> Array[T]: def __new__(sz: Static[int]) -> Array[T]:
pass pass
@dataclass(init=True)
@tuple
@__internal__
class Import:
path: Static[str]
name: str
def __new__(P: Static[str], name: str) -> Import[P]:
return (name, )
def __ptr__(var): def __ptr__(var):
pass pass

View File

@ -260,11 +260,6 @@ class __internal__:
if msg: if msg:
raise OSError(prefix + msg) raise OSError(prefix + msg)
@pure
@llvm
def opt_tuple_new(T: type) -> Optional[T]:
ret { i1, {=T} } { i1 false, {=T} undef }
@pure @pure
@llvm @llvm
def opt_ref_new(T: type) -> Optional[T]: def opt_ref_new(T: type) -> Optional[T]:
@ -630,12 +625,8 @@ class __magic__:
return slf.__repr__() return slf.__repr__()
@dataclass(init=True) @extend
@tuple
class Import: class Import:
name: str
file: str
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<module '{self.name}' from '{self.file}'>" return f"<module '{self.name}' from '{self.file}'>"

View File

@ -1948,3 +1948,11 @@ class _PyWrap:
if obj.head.pytype != pytype: if obj.head.pytype != pytype:
_conversion_error(T.__name__) _conversion_error(T.__name__)
return obj.data return obj.data
class PyError(Static[Exception]):
pytype: pyobj
def __init__(self, message: str, pytype: pyobj = pyobj(cobj(), steal=True)):
super().__init__("PyError", message)
self.pytype = pytype

View File

@ -39,3 +39,10 @@ class Array:
return (e - s, self.ptr + s) return (e - s, self.ptr + s)
array = Array array = Array
# Forward declarations
@dataclass(init=False)
class List:
len: int
arr: Array[T]
T: type

View File

@ -287,6 +287,9 @@ class int:
@extend @extend
class float: class float:
def __complex__(self) -> complex:
return complex(self, 0.0)
def __suffix_j__(x: float) -> complex: def __suffix_j__(x: float) -> complex:
return complex(0, x) return complex(0, x)
@ -566,3 +569,8 @@ class complex64:
declare float @llvm.log.f32(float) declare float @llvm.log.f32(float)
%y = call float @llvm.log.f32(float %x) %y = call float @llvm.log.f32(float %x)
ret float %y ret float %y
@extend
class int:
def __complex__(self) -> complex:
return complex(float(self), 0.0)

View File

@ -85,13 +85,6 @@ class CError(Static[Exception]):
super().__init__("CError", message) super().__init__("CError", message)
self.python_type = self.__class__._pytype self.python_type = self.__class__._pytype
class PyError(Static[Exception]):
pytype: pyobj
def __init__(self, message: str, pytype: pyobj = pyobj(cobj(), steal=True)):
super().__init__("PyError", message)
self.pytype = pytype
class TypeError(Static[Exception]): class TypeError(Static[Exception]):
_pytype: ClassVar[cobj] = cobj() _pytype: ClassVar[cobj] = cobj()
def __init__(self, message: str = ""): def __init__(self, message: str = ""):

View File

@ -2,7 +2,6 @@
from internal.attributes import commutative from internal.attributes import commutative
from internal.gc import alloc_atomic, free from internal.gc import alloc_atomic, free
from internal.types.complex import complex
@extend @extend
class float: class float:
@ -41,9 +40,6 @@ class float:
%1 = zext i1 %0 to i8 %1 = zext i1 %0 to i8
ret i8 %1 ret i8 %1
def __complex__(self) -> complex:
return complex(self, 0.0)
def __pos__(self) -> float: def __pos__(self) -> float:
return self return self

View File

@ -1,7 +1,6 @@
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io> # Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
from internal.attributes import commutative, associative, distributive from internal.attributes import commutative, associative, distributive
from internal.types.complex import complex
@extend @extend
class int: class int:
@ -29,9 +28,6 @@ class int:
%tmp = sitofp i64 %self to double %tmp = sitofp i64 %self to double
ret double %tmp ret double %tmp
def __complex__(self) -> complex:
return complex(float(self), 0.0)
def __index__(self) -> int: def __index__(self) -> int:
return self return self

View File

@ -1,5 +1,13 @@
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io> # Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
@extend
class __internal__:
@pure
@llvm
def opt_tuple_new(T: type) -> Optional[T]:
ret { i1, {=T} } { i1 false, {=T} undef }
@extend @extend
class Optional: class Optional:
def __new__() -> Optional[T]: def __new__() -> Optional[T]:

View File

@ -190,12 +190,6 @@ class Ptr:
ptr = Ptr ptr = Ptr
Jar = Ptr[byte] Jar = Ptr[byte]
# Forward declarations
class List:
len: int
arr: Array[T]
T: type
@extend @extend
class NoneType: class NoneType:
def __new__() -> NoneType: def __new__() -> NoneType: