Support for overloaded functions [wip; base logic done]

pull/10/head
Ibrahim Numanagić 2021-12-28 20:58:20 -08:00
parent 58664374c7
commit fa7278e616
16 changed files with 241 additions and 165 deletions

View File

@ -50,6 +50,9 @@ types::ClassTypePtr Cache::findClass(const std::string &name) const {
types::FuncTypePtr Cache::findFunction(const std::string &name) const {
auto f = typeCtx->find(name);
if (f && f->type && f->kind == TypecheckItem::Func)
return f->type->getFunc();
f = typeCtx->find(name + ":0");
if (f && f->type && f->kind == TypecheckItem::Func)
return f->type->getFunc();
return nullptr;

View File

@ -106,19 +106,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
/// Non-simplified AST. Used for base class instantiation.
std::shared_ptr<ClassStmt> originalAst;
/// A class function method.
struct ClassMethod {
/// Canonical name of a method (e.g. __init__.1).
std::string name;
/// A corresponding generic function type.
types::FuncTypePtr type;
/// Method age (how many class extension were seen before a method definition).
/// Used to prevent the usage of a method before it was defined in the code.
int age;
};
/// Class method lookup table. Each name points to a list of ClassMethod instances
/// that share the same method name (a list because methods can be overloaded).
std::unordered_map<std::string, std::vector<ClassMethod>> methods;
/// Class method lookup table. Each non-canonical name points
/// to a root function name of a corresponding method.
std::unordered_map<std::string, std::string> methods;
/// A class field (member).
struct ClassField {
@ -177,6 +167,20 @@ struct Cache : public std::enable_shared_from_this<Cache> {
/// corresponding Function instance.
std::unordered_map<std::string, Function> functions;
struct Overload {
/// Canonical name of an overload (e.g. Foo.__init__.1).
std::string name;
/// Overload age (how many class extension were seen before a method definition).
/// Used to prevent the usage of an overload before it was defined in the code.
/// TODO: I have no recollection of how this was supposed to work. Most likely
/// it does not work at all...
int age;
};
/// Maps a "root" name of each function to the list of names of the function
/// overloads.
std::unordered_map<std::string, std::vector<Overload>> overloads;
/// Pointer to the later contexts needed for IR API access.
std::shared_ptr<TypeContext> typeCtx;
std::shared_ptr<TranslateContext> codegenCtx;

View File

@ -52,8 +52,6 @@ std::shared_ptr<peg::Grammar> initParser() {
template <typename T>
T parseCode(Cache *cache, const std::string &file, std::string code, int line_offset,
int col_offset, const std::string &rule) {
TIME("peg");
// Initialize
if (!grammar)
grammar = initParser();

View File

@ -14,8 +14,9 @@ namespace codon {
namespace ast {
SimplifyItem::SimplifyItem(Kind k, std::string base, std::string canonicalName,
bool global)
: kind(k), base(move(base)), canonicalName(move(canonicalName)), global(global) {}
bool global, std::string moduleName)
: kind(k), base(move(base)), canonicalName(move(canonicalName)), global(global),
moduleName(move(moduleName)) {}
SimplifyContext::SimplifyContext(std::string filename, Cache *cache)
: Context<SimplifyItem>(move(filename)), cache(move(cache)),
@ -31,6 +32,7 @@ std::shared_ptr<SimplifyItem> SimplifyContext::add(SimplifyItem::Kind kind,
bool global) {
seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name);
auto t = std::make_shared<SimplifyItem>(kind, getBase(), canonicalName, global);
t->moduleName = getModule();
Context<SimplifyItem>::add(name, t);
Context<SimplifyItem>::add(canonicalName, t);
return t;
@ -60,6 +62,14 @@ std::string SimplifyContext::getBase() const {
return bases.back().name;
}
std::string SimplifyContext::getModule() const {
std::string base = moduleName.status == ImportFile::STDLIB ? "std." : "";
base += moduleName.module;
if (startswith(base, "__main__"))
base = base.substr(8);
return base;
}
std::string SimplifyContext::generateCanonicalName(const std::string &name,
bool includeBase,
bool zeroId) const {
@ -67,12 +77,8 @@ std::string SimplifyContext::generateCanonicalName(const std::string &name,
bool alreadyGenerated = name.find('.') != std::string::npos;
if (includeBase && !alreadyGenerated) {
std::string base = getBase();
if (base.empty()) {
base = moduleName.status == ImportFile::STDLIB ? "std." : "";
base += moduleName.module;
if (startswith(base, "__main__"))
base = base.substr(8);
}
if (base.empty())
base = getModule();
newName = (base.empty() ? "" : (base + ".")) + newName;
}
auto num = cache->identifierCount[newName]++;

View File

@ -32,13 +32,16 @@ struct SimplifyItem {
bool global;
/// Non-empty string if a variable is import variable
std::string importPath;
/// Full module name
std::string moduleName;
public:
SimplifyItem(Kind k, std::string base, std::string canonicalName,
bool global = false);
SimplifyItem(Kind k, std::string base, std::string canonicalName, bool global = false,
std::string moduleName = "");
/// Convenience getters.
std::string getBase() const { return base; }
std::string getModule() const { return moduleName; }
bool isGlobal() const { return global; }
bool isVar() const { return kind == Var; }
bool isFunc() const { return kind == Func; }
@ -107,6 +110,8 @@ public:
/// Return a canonical name of the top-most base, or an empty string if this is a
/// top-level base.
std::string getBase() const;
/// Return the current module.
std::string getModule() const;
/// Return the current base nesting level (note: bases, not blocks).
int getLevel() const { return bases.size(); }
/// Pretty-print the current context state.

View File

@ -437,25 +437,27 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
if (stmt->decorators.size() != 1)
error("__attribute__ cannot be mixed with other decorators");
attr.isAttribute = true;
} else if (d->isId(Attr::LLVM))
} else if (d->isId(Attr::LLVM)) {
attr.set(Attr::LLVM);
else if (d->isId(Attr::Python))
} else if (d->isId(Attr::Python)) {
attr.set(Attr::Python);
else if (d->isId(Attr::Internal))
} else if (d->isId(Attr::Internal)) {
attr.set(Attr::Internal);
else if (d->isId(Attr::Atomic))
} else if (d->isId(Attr::Atomic)) {
attr.set(Attr::Atomic);
else if (d->isId(Attr::Property))
} else if (d->isId(Attr::Property)) {
attr.set(Attr::Property);
else if (d->isId(Attr::ForceRealize))
} else if (d->isId(Attr::ForceRealize)) {
attr.set(Attr::ForceRealize);
else {
} else {
// Let's check if this is a attribute
auto dt = transform(clone(d));
if (dt && dt->getId()) {
auto ci = ctx->find(dt->getId()->value);
if (ci && ci->kind == SimplifyItem::Func) {
if (ctx->cache->functions[ci->canonicalName].ast->attributes.isAttribute) {
if (ctx->cache->overloads[ci->canonicalName].size() == 1)
if (ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name]
.ast->attributes.isAttribute) {
attr.set(ci->canonicalName);
continue;
}
@ -473,19 +475,22 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
}
bool isClassMember = ctx->inClass();
if (isClassMember && !endswith(stmt->name, ".dispatch") &&
ctx->cache->classes[ctx->bases.back().name].methods[stmt->name].empty()) {
generateDispatch(stmt->name);
}
auto func_name = stmt->name;
if (endswith(stmt->name, ".dispatch"))
func_name = func_name.substr(0, func_name.size() - 9);
auto canonicalName = ctx->generateCanonicalName(
func_name, true, isClassMember && !endswith(stmt->name, ".dispatch"));
if (endswith(stmt->name, ".dispatch")) {
canonicalName += ".dispatch";
ctx->cache->reverseIdentifierLookup[canonicalName] = func_name;
std::string rootName;
if (isClassMember) {
auto &m = ctx->cache->classes[ctx->bases.back().name].methods;
auto i = m.find(stmt->name);
if (i != m.end())
rootName = i->second;
} else if (auto c = ctx->find(stmt->name)) {
if (c->isFunc() && c->getModule() == ctx->getModule() &&
c->getBase() == ctx->getBase())
rootName = c->canonicalName;
}
if (rootName.empty())
rootName = ctx->generateCanonicalName(stmt->name, true);
auto canonicalName =
format("{}:{}", rootName, ctx->cache->overloads[rootName].size());
ctx->cache->reverseIdentifierLookup[canonicalName] = stmt->name;
bool isEnclosedFunc = ctx->inFunction();
if (attr.has(Attr::ForceRealize) && (ctx->getLevel() || isClassMember))
@ -495,7 +500,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
ctx->bases = std::vector<SimplifyContext::Base>();
if (!isClassMember)
// Class members are added to class' method table
ctx->add(SimplifyItem::Func, func_name, canonicalName, ctx->isToplevel());
ctx->add(SimplifyItem::Func, stmt->name, rootName, ctx->isToplevel());
if (isClassMember)
ctx->bases.push_back(oldBases[0]);
ctx->bases.emplace_back(SimplifyContext::Base{canonicalName}); // Add new base...
@ -614,8 +619,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
// ... set the enclosing class name...
attr.parentClass = ctx->bases.back().name;
// ... add the method to class' method list ...
ctx->cache->classes[ctx->bases.back().name].methods[func_name].push_back(
{canonicalName, nullptr, ctx->cache->age});
ctx->cache->classes[ctx->bases.back().name].methods[stmt->name] = rootName;
// ... and if the function references outer class variable (by definition a
// generic), mark it as not static as it needs fully instantiated class to be
// realized. For example, in class A[T]: def foo(): pass, A.foo() can be realized
@ -624,6 +628,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
if (isMethod)
attr.set(Attr::Method);
}
ctx->cache->overloads[rootName].push_back({canonicalName, ctx->cache->age});
std::vector<CallExpr::Arg> partialArgs;
if (!captures.empty()) {
@ -649,21 +654,21 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
ExprPtr finalExpr;
if (!captures.empty())
finalExpr = N<CallExpr>(N<IdExpr>(func_name), partialArgs);
finalExpr = N<CallExpr>(N<IdExpr>(stmt->name), partialArgs);
if (isClassMember && decorators.size())
error("decorators cannot be applied to class methods");
for (int j = int(decorators.size()) - 1; j >= 0; j--) {
if (auto c = const_cast<CallExpr *>(decorators[j]->getCall())) {
c->args.emplace(c->args.begin(),
CallExpr::Arg{"", finalExpr ? finalExpr : N<IdExpr>(func_name)});
CallExpr::Arg{"", finalExpr ? finalExpr : N<IdExpr>(stmt->name)});
finalExpr = N<CallExpr>(c->expr, c->args);
} else {
finalExpr =
N<CallExpr>(decorators[j], finalExpr ? finalExpr : N<IdExpr>(func_name));
N<CallExpr>(decorators[j], finalExpr ? finalExpr : N<IdExpr>(stmt->name));
}
}
if (finalExpr)
resultStmt = transform(N<AssignStmt>(N<IdExpr>(func_name), finalExpr));
resultStmt = transform(N<AssignStmt>(N<IdExpr>(stmt->name), finalExpr));
}
void SimplifyVisitor::visit(ClassStmt *stmt) {
@ -744,6 +749,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
ClassStmt *originalAST = nullptr;
auto classItem =
std::make_shared<SimplifyItem>(SimplifyItem::Type, "", "", ctx->isToplevel());
classItem->moduleName = ctx->getModule();
if (!extension) {
classItem->canonicalName = canonicalName =
ctx->generateCanonicalName(name, !attr.has(Attr::Internal));
@ -949,22 +955,29 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
for (int ai = 0; ai < baseASTs.size(); ai++) {
// FUNCS
for (auto &mm : ctx->cache->classes[baseASTs[ai]->name].methods)
for (auto &mf : mm.second) {
for (auto &mf : ctx->cache->overloads[mm.second]) {
auto f = ctx->cache->functions[mf.name].ast;
if (f->attributes.has("autogenerated"))
continue;
auto subs = substitutions[ai];
if (ctx->cache->classes[ctx->bases.back().name]
.methods[ctx->cache->reverseIdentifierLookup[f->name]]
.empty())
generateDispatch(ctx->cache->reverseIdentifierLookup[f->name]);
auto newName = ctx->generateCanonicalName(
std::string rootName;
auto &mts = ctx->cache->classes[ctx->bases.back().name].methods;
auto it = mts.find(ctx->cache->reverseIdentifierLookup[f->name]);
if (it != mts.end())
rootName = it->second;
else
rootName = ctx->generateCanonicalName(
ctx->cache->reverseIdentifierLookup[f->name], true);
auto newCanonicalName =
format("{}:{}", rootName, ctx->cache->overloads[rootName].size());
ctx->cache->reverseIdentifierLookup[newCanonicalName] =
ctx->cache->reverseIdentifierLookup[f->name];
auto nf = std::dynamic_pointer_cast<FunctionStmt>(
replace(std::static_pointer_cast<Stmt>(f), subs));
subs[nf->name] = N<IdExpr>(newName);
nf->name = newName;
subs[nf->name] = N<IdExpr>(newCanonicalName);
nf->name = newCanonicalName;
suite->stmts.push_back(nf);
nf->attributes.parentClass = ctx->bases.back().name;
@ -972,10 +985,10 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
if (nf->attributes.has(".changedSelf")) // replace self type with new class
nf->args[0].type = transformType(ctx->bases.back().ast);
preamble->functions.push_back(clone(nf));
ctx->cache->functions[newName].ast = nf;
ctx->cache->overloads[rootName].push_back({newCanonicalName, ctx->cache->age});
ctx->cache->functions[newCanonicalName].ast = nf;
ctx->cache->classes[ctx->bases.back().name]
.methods[ctx->cache->reverseIdentifierLookup[f->name]]
.push_back({newName, nullptr, ctx->cache->age});
.methods[ctx->cache->reverseIdentifierLookup[f->name]] = rootName;
}
}
for (auto sp : getClassMethods(stmt->suite))
@ -1248,8 +1261,10 @@ StmtPtr SimplifyVisitor::transformCImport(const std::string &name,
auto f = N<FunctionStmt>(name, ret ? ret->clone() : N<IdExpr>("void"), fnArgs,
nullptr, attr);
StmtPtr tf = transform(f); // Already in the preamble
if (!altName.empty())
if (!altName.empty()) {
ctx->add(altName, ctx->find(name));
ctx->remove(name);
}
return tf;
}
@ -1392,10 +1407,11 @@ void SimplifyVisitor::transformNewImport(const ImportFile &file) {
stmts[0] = N<SuiteStmt>();
// Add a def import(): ... manually to the cache and to the preamble (it won't be
// transformed here!).
ctx->cache->functions[importVar].ast =
N<FunctionStmt>(importVar, nullptr, std::vector<Param>{}, N<SuiteStmt>(stmts),
Attr({Attr::ForceRealize}));
preamble->functions.push_back(ctx->cache->functions[importVar].ast->clone());
ctx->cache->overloads[importVar].push_back({importVar, ctx->cache->age});
ctx->cache->functions[importVar + ":0"].ast =
N<FunctionStmt>(importVar + ":0", nullptr, std::vector<Param>{},
N<SuiteStmt>(stmts), Attr({Attr::ForceRealize}));
preamble->functions.push_back(ctx->cache->functions[importVar + ":0"].ast->clone());
;
}
}
@ -1768,15 +1784,5 @@ std::vector<StmtPtr> SimplifyVisitor::getClassMethods(const StmtPtr &s) {
return v;
}
void SimplifyVisitor::generateDispatch(const std::string &name) {
transform(N<FunctionStmt>(
name + ".dispatch", nullptr,
std::vector<Param>{Param("*args"), Param("**kwargs")},
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
N<DotExpr>(N<IdExpr>(ctx->bases.back().name), name),
N<StarExpr>(N<IdExpr>("args")), N<KeywordStarExpr>(N<IdExpr>("kwargs"))))),
Attr({"autogenerated"})));
}
} // namespace ast
} // namespace codon

View File

@ -296,7 +296,8 @@ void TranslateVisitor::visit(ForStmt *stmt) {
auto c = stmt->decorator->getCall();
seqassert(c, "for par is not a call: {}", stmt->decorator->toString());
auto fc = c->expr->getType()->getFunc();
seqassert(fc && fc->ast->name == "std.openmp.for_par", "for par is not a function");
seqassert(fc && fc->ast->name == "std.openmp.for_par:0",
"for par is not a function");
auto schedule =
fc->funcGenerics[0].type->getStatic()->expr->staticValue.getString();
bool ordered = fc->funcGenerics[1].type->getStatic()->expr->staticValue.getInt();

View File

@ -313,6 +313,8 @@ private:
int64_t translateIndex(int64_t idx, int64_t len, bool clamp = false);
int64_t sliceAdjustIndices(int64_t length, int64_t *start, int64_t *stop,
int64_t step);
types::FuncTypePtr findDispatch(const std::string &fn);
std::string getRootName(const std::string &name);
friend struct Cache;
};

View File

@ -142,22 +142,23 @@ std::vector<types::FuncTypePtr> TypeContext::findMethod(const std::string &typeN
auto m = cache->classes.find(typeName);
if (m != cache->classes.end()) {
auto t = m->second.methods.find(method);
if (t != m->second.methods.end() && !t->second.empty()) {
seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"),
"first method '{}' is not dispatch", t->second[0].name);
if (t != m->second.methods.end()) {
auto mt = cache->overloads[t->second];
std::unordered_set<std::string> signatureLoci;
std::vector<types::FuncTypePtr> vv;
for (int mti = int(t->second.size()) - 1; mti > 0; mti--) {
auto &mt = t->second[mti];
if (mt.age <= age) {
for (int mti = int(mt.size()) - 1; mti >= 0; mti--) {
auto &m = mt[mti];
if (endswith(m.name, ":dispatch"))
continue;
if (m.age <= age) {
if (hideShadowed) {
auto sig = cache->functions[mt.name].ast->signature();
auto sig = cache->functions[m.name].ast->signature();
if (!in(signatureLoci, sig)) {
signatureLoci.insert(sig);
vv.emplace_back(mt.type);
vv.emplace_back(cache->functions[m.name].type);
}
} else {
vv.emplace_back(mt.type);
vv.emplace_back(cache->functions[m.name].type);
}
}
}

View File

@ -105,6 +105,9 @@ void TypecheckVisitor::visit(IdExpr *expr) {
return;
}
auto val = ctx->find(expr->value);
if (!val) {
val = ctx->find(expr->value + ":0"); // is it function?!
}
seqassert(val, "cannot find IdExpr '{}' ({})", expr->value, expr->getSrcInfo());
auto t = ctx->instantiate(expr, val->type);
@ -725,14 +728,17 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic,
ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &expr,
ExprPtr &index) {
if (!tuple->getRecord() ||
in(std::set<std::string>{"Ptr", "pyobj", "str", "Array"}, tuple->name))
if (!tuple->getRecord())
return nullptr;
if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_PARTIAL))
// in(std::set<std::string>{"Ptr", "pyobj", "str", "Array"}, tuple->name))
// Ptr, pyobj and str are internal types and have only one overloaded __getitem__
return nullptr;
if (ctx->cache->classes[tuple->name].methods["__getitem__"].size() != 2)
// n.b.: there is dispatch as well
// TODO: be smarter! there might be a compatible getitem?
return nullptr;
// if (in(ctx->cache->classes[tuple->name].methods, "__getitem__")) {
// ctx->cache->overloads[ctx->cache->classes[tuple->name].methods["__getitem__"]]
// .size() != 1)
// return nullptr;
// }
// Extract a static integer value from a compatible expression.
auto getInt = [&](int64_t *o, const ExprPtr &e) {
@ -919,9 +925,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
} else if (methods.size() > 1) {
auto m = ctx->cache->classes.find(typ->name);
auto t = m->second.methods.find(expr->member);
seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"),
"first method is not dispatch");
bestMethod = t->second[0].type;
bestMethod = findDispatch(t->second);
} else {
bestMethod = methods[0];
}
@ -1036,7 +1040,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
ctx->bases.back().supers, expr->args);
if (m.empty())
error("no matching super methods are available");
// LOG("found {} <- {}", ctx->bases.back().type->getFunc()->toString(), m[0]->toString());
// LOG("found {} <- {}", ctx->bases.back().type->getFunc()->toString(),
// m[0]->toString());
ExprPtr e = N<CallExpr>(N<IdExpr>(m[0]->ast->name), expr->args);
return transform(e, false, true);
}
@ -1700,7 +1705,7 @@ TypecheckVisitor::findBestMethod(const Expr *expr, const std::string &member,
std::vector<types::FuncTypePtr>
TypecheckVisitor::findSuperMethods(const types::FuncTypePtr &func) {
if (func->ast->attributes.parentClass.empty() ||
endswith(func->ast->name, ".dispatch"))
endswith(func->ast->name, ":dispatch"))
return {};
auto p = ctx->find(func->ast->attributes.parentClass)->type;
if (!p || !p->getClass())
@ -1711,14 +1716,13 @@ TypecheckVisitor::findSuperMethods(const types::FuncTypePtr &func) {
std::vector<types::FuncTypePtr> result;
if (m != ctx->cache->classes.end()) {
auto t = m->second.methods.find(methodName);
if (t != m->second.methods.end() && !t->second.empty()) {
seqassert(!t->second.empty() && endswith(t->second[0].name, ".dispatch"),
"first method '{}' is not dispatch", t->second[0].name);
for (int mti = 1; mti < t->second.size(); mti++) {
auto &mt = t->second[mti];
if (mt.type->ast->name == func->ast->name)
if (t != m->second.methods.end()) {
for (auto &m : ctx->cache->overloads[t->second]) {
if (endswith(m.name, ":dispatch"))
continue;
if (m.name == func->ast->name)
break;
result.emplace_back(mt.type);
result.emplace_back(ctx->cache->functions[m.name].type);
}
}
}
@ -1860,5 +1864,45 @@ int64_t TypecheckVisitor::sliceAdjustIndices(int64_t length, int64_t *start,
return 0;
}
types::FuncTypePtr TypecheckVisitor::findDispatch(const std::string &fn) {
for (auto &m : ctx->cache->overloads[fn])
if (endswith(ctx->cache->functions[m.name].ast->name, ":dispatch"))
return ctx->cache->functions[m.name].type;
// Generate dispatch and return it!
auto name = fn + ":dispatch";
ExprPtr root;
auto a = ctx->cache->functions[ctx->cache->overloads[fn][0].name].ast;
if (!a->attributes.parentClass.empty())
root = N<DotExpr>(N<IdExpr>(a->attributes.parentClass),
ctx->cache->reverseIdentifierLookup[fn]);
else
root = N<IdExpr>(fn);
root = N<CallExpr>(root, N<StarExpr>(N<IdExpr>("args")),
N<KeywordStarExpr>(N<IdExpr>("kwargs")));
auto ast = N<FunctionStmt>(
name, nullptr, std::vector<Param>{Param("*args"), Param("**kwargs")},
N<SuiteStmt>(N<IfStmt>(
N<CallExpr>(N<IdExpr>("isinstance"), root->clone(), N<IdExpr>("void")),
N<ExprStmt>(root->clone()), N<ReturnStmt>(root))),
Attr({"autogenerated"}));
ctx->cache->reverseIdentifierLookup[name] = ctx->cache->reverseIdentifierLookup[fn];
auto baseType =
ctx->instantiate(N<IdExpr>(name).get(), ctx->find(generateFunctionStub(2))->type,
nullptr, false)
->getRecord();
auto typ = std::make_shared<FuncType>(baseType, ast.get());
typ = std::static_pointer_cast<FuncType>(typ->generalize(ctx->typecheckLevel));
ctx->add(TypecheckItem::Func, name, typ);
ctx->cache->overloads[fn].insert(ctx->cache->overloads[fn].begin(), {name, 0});
ctx->cache->functions[name].ast = ast;
ctx->cache->functions[name].type = typ;
prependStmts->push_back(ast);
return typ;
}
} // namespace ast
} // namespace codon

View File

@ -474,12 +474,12 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
typ = std::static_pointer_cast<FuncType>(typ->generalize(ctx->typecheckLevel));
// Check if this is a class method; if so, update the class method lookup table.
if (isClassMember) {
auto &methods = ctx->cache->classes[attr.parentClass]
auto m = ctx->cache->classes[attr.parentClass]
.methods[ctx->cache->reverseIdentifierLookup[stmt->name]];
bool found = false;
for (auto &i : methods)
for (auto &i : ctx->cache->overloads[m])
if (i.name == stmt->name) {
i.type = typ;
ctx->cache->functions[i.name].type = typ;
found = true;
break;
}
@ -570,5 +570,11 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
stmt->done = true;
}
std::string TypecheckVisitor::getRootName(const std::string &name) {
auto p = name.rfind(':');
seqassert(p != std::string::npos, ": not found in {}", name);
return name.substr(0, p);
}
} // namespace ast
} // namespace codon

View File

@ -44,6 +44,7 @@ class str:
if c == '\n': d = "\\n"
elif c == '\r': d = "\\r"
elif c == '\t': d = "\\t"
elif c == '\a': d = "\\a"
elif c == '\\': d = "\\\\"
elif c == q: d = qe
else:

View File

@ -538,7 +538,7 @@ def foo() -> int:
a{=a}
foo()
#! not a type or static expression
#! while realizing foo (arguments foo)
#! while realizing foo:0 (arguments foo:0)
#%% function_llvm_err_4,barebones
a = 5

View File

@ -374,7 +374,7 @@ def foo(i, j, k):
return i + j + k
print foo(1.1, 2.2, 3.3) #: 6.6
p = foo(6, ...)
print p.__class__ #: foo[int,...,...]
print p.__class__ #: foo:0[int,...,...]
print p(2, 1) #: 9
print p(k=3, j=6) #: 15
q = p(k=1, ...)
@ -390,11 +390,11 @@ print 42 |> add_two #: 44
def moo(a, b, c=3):
print a, b, c
m = moo(b=2, ...)
print m.__class__ #: moo[...,int,...]
print m.__class__ #: moo:0[...,int,...]
m('s', 1.1) #: s 2 1.1
# #
n = m(c=2.2, ...)
print n.__class__ #: moo[...,int,float]
print n.__class__ #: moo:0[...,int,float]
n('x') #: x 2 2.2
print n('y').__class__ #: void
@ -403,11 +403,11 @@ def ff(a, b, c):
print ff(1.1, 2, True).__class__ #: Tuple[float,int,bool]
print ff(1.1, ...)(2, True).__class__ #: Tuple[float,int,bool]
y = ff(1.1, ...)(c=True, ...)
print y.__class__ #: ff[float,...,bool]
print y.__class__ #: ff:0[float,...,bool]
print ff(1.1, ...)(2, ...)(True).__class__ #: Tuple[float,int,bool]
print y('hei').__class__ #: Tuple[float,str,bool]
z = ff(1.1, ...)(c='s', ...)
print z.__class__ #: ff[float,...,str]
print z.__class__ #: ff:0[float,...,str]
#%% call_arguments_partial,barebones
def doo[R, T](a: Callable[[T], R], b: Generator[T], c: Optional[T], d: T):
@ -432,7 +432,7 @@ l = [1]
def adder(a, b): return a+b
doo(b=l, d=Optional(5), c=l[0], a=adder(b=4, ...))
#: int int
#: adder[.. Generator[int]
#: adder:0[ Generator[int]
#: 5
#: 1 Optional[int]
#: 5 int

View File

@ -252,7 +252,7 @@ try:
except MyError:
print "my"
except OSError as o:
print "os", o._hdr[0], len(o._hdr[1]), o._hdr[3][-20:], o._hdr[4]
print "os", o._hdr.typename, len(o._hdr.msg), o._hdr.file[-20:], o._hdr.line
#: os std.internal.types.error.OSError 9 typecheck_stmt.codon 249
finally:
print "whoa" #: whoa
@ -263,7 +263,7 @@ def foo():
try:
foo()
except MyError as e:
print e._hdr[0], e._hdr[1] #: MyError foo!
print e._hdr.typename, e._hdr.msg #: MyError foo!
#%% throw_error,barebones
raise 'hello' #! cannot throw non-exception (first object member must be of type ExcHeader)
@ -291,13 +291,13 @@ def foo(x):
print len(x)
foo(5) #: 4
def foo(x):
def foo2(x):
if isinstance(x, int):
print x+1
return
print len(x)
foo(1) #: 2
foo('s') #: 1
foo2(1) #: 2
foo2('s') #: 1
#%% super,barebones
class Foo:
@ -341,4 +341,4 @@ class Foo:
print 'foo-1', a
Foo.foo(1)
#! no matching super methods are available
#! while realizing Foo.foo.2
#! while realizing Foo.foo:0

View File

@ -199,10 +199,10 @@ def f[T](x: T) -> T:
print f(1.2).__class__ #: float
print f('s').__class__ #: str
def f[T](x: T):
return f(x - 1, T) if x else 1
print f(1) #: 1
print f(1.1).__class__ #: int
def f2[T](x: T):
return f2(x - 1, T) if x else 1
print f2(1) #: 1
print f2(1.1).__class__ #: int
#%% recursive_error,barebones
@ -215,7 +215,7 @@ def rec3(x, y): #- ('a, 'b) -> 'b
return y
rec3(1, 's')
#! cannot unify str and int
#! while realizing rec3 (arguments rec3[int,str])
#! while realizing rec3:0 (arguments rec3:0[int,str])
#%% instantiate_function_2,barebones
def fx[T](x: T) -> T:
@ -447,13 +447,13 @@ def f(x):
return g(x)
print f(5), f('s') #: 5 s
def f[U](x: U, y):
def f2[U](x: U, y):
def g[T, U](x: T, y: U):
return (x, y)
return g(y, x)
x, y = 1, 'haha'
print f(x, y).__class__ #: Tuple[str,int]
print f('aa', 1.1, U=str).__class__ #: Tuple[float,str]
print f2(x, y).__class__ #: Tuple[str,int]
print f2('aa', 1.1, U=str).__class__ #: Tuple[float,str]
#%% nested_fn_generic_error,barebones
def f[U](x: U, y): # ('u, 'a) -> tuple['a, 'u]
@ -464,7 +464,7 @@ print f(1.1, 1, int).__class__ #! cannot unify float and int
#%% fn_realization,barebones
def ff[T](x: T, y: tuple[T]):
print ff(T=str,...).__class__ #: ff[str,Tuple[str],str]
print ff(T=str,...).__class__ #: ff:0[str,Tuple[str],str]
return x
x = ff(1, (1,))
print x, x.__class__ #: 1 int
@ -474,7 +474,7 @@ def fg[T](x:T):
def g[T](y):
z = T()
return z
print fg(T=str,...).__class__ #: fg[str,str]
print fg(T=str,...).__class__ #: fg:0[str,str]
print g(1, T).__class__ #: int
fg(1)
print fg(1).__class__ #: void
@ -515,7 +515,7 @@ class A[T]:
def foo[W](t: V, u: V, v: V, w: W):
return (t, u, v, w)
print A.B.C[bool].foo(W=str, ...).__class__ #: A.B.C.foo.2[bool,bool,bool,str,str]
print A.B.C[bool].foo(W=str, ...).__class__ #: A.B.C.foo:0[bool,bool,bool,str,str]
print A.B.C.foo(1,1,1,True) #: (1, 1, 1, True)
print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x')
print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x')
@ -734,10 +734,10 @@ def test(name, sort, key):
def foo(l, f):
return [f(i) for i in l]
test('hi', foo, lambda x: x+1) #: hi [2, 3, 4, 5]
# TODO
# def foof(l: List[int], x, f: Callable[[int], int]):
# return [f(i)+x for i in l]
# test('qsort', foof(..., 3, ...))
def foof(l: List[int], x, f: Callable[[int], int]):
return [f(i)+x for i in l]
test('qsort', foof(x=3, ...), lambda x: x+1) #: qsort [5, 6, 7, 8]
#%% class_fn_access,barebones
class X[T]:
@ -745,8 +745,7 @@ class X[T]:
return (x+x, y+y)
y = X[X[int]]()
print y.__class__ #: X[X[int]]
print X[float].foo(U=int, ...).__class__ #: X.foo.2[X[float],float,int,int]
# print y.foo.1[float].__class__
print X[float].foo(U=int, ...).__class__ #: X.foo:0[X[float],float,int,int]
print X[int]().foo(1, 's') #: (2, 'ss')
#%% class_partial_access,barebones
@ -754,7 +753,7 @@ class X[T]:
def foo[U](self, x, y: U):
return (x+x, y+y)
y = X[X[int]]()
print y.foo(U=float,...).__class__ #: X.foo.2[X[X[int]],...,...]
print y.foo(U=float,...).__class__ #: X.foo:0[X[X[int]],...,...]
print y.foo(1, 2.2, float) #: (2, 4.4)
#%% forward,barebones
@ -765,10 +764,10 @@ def bar[T](x):
print x, T.__class__
foo(bar, 1)
#: 1 int
#: bar[...]
#: bar:0[...]
foo(bar(...), 's')
#: s str
#: bar[...]
#: bar:0[...]
z = bar
z('s', int)
#: s int
@ -786,8 +785,8 @@ def foo(f, x):
def bar[T](x):
print x, T.__class__
foo(bar(T=int,...), 1)
#! too many arguments for bar[T1,int] (expected maximum 2, got 2)
#! while realizing foo (arguments foo[bar[...],int])
#! too many arguments for bar:0[T1,int] (expected maximum 2, got 2)
#! while realizing foo:0 (arguments foo:0[bar:0[...],int])
#%% sort_partial
def foo(x, y):
@ -806,16 +805,16 @@ def frec(x, y):
return grec(x, y) if bl(y) else 2
print frec(1, 2).__class__, frec('s', 1).__class__
#! expression with void type
#! while realizing frec (arguments frec[int,int])
#! while realizing frec:0 (arguments frec:0[int,int])
#%% return_fn,barebones
def retfn(a):
def inner(b, *args, **kwargs):
print a, b, args, kwargs
print inner.__class__ #: retfn.inner[...,...,int,...]
print inner.__class__ #: retfn:0.inner:0[...,...,int,...]
return inner(15, ...)
f = retfn(1)
print f.__class__ #: retfn.inner[int,...,int,...]
print f.__class__ #: retfn:0.inner:0[int,...,int,...]
f(2,3,foo='bar') #: 1 15 (2, 3) (foo: 'bar')
#%% decorator_manual,barebones
@ -823,7 +822,7 @@ def foo(x, *args, **kwargs):
print x, args, kwargs
return 1
def dec(fn, a):
print 'decorating', fn.__class__ #: decorating foo[...,...,...]
print 'decorating', fn.__class__ #: decorating foo:0[...,...,...]
def inner(*args, **kwargs):
print 'decorator', args, kwargs #: decorator (5.5, 's') (z: True)
return fn(a, *args, **kwargs)
@ -846,7 +845,7 @@ def dec(fn, a):
return inner
ff = dec(foo, 10)
print ff(5.5, 's', z=True)
#: decorating foo[...,...,...]
#: decorating foo:0[...,...,...]
#: decorator (5.5, 's') (z: True)
#: 10 (5.5, 's') (z: True)
#: 1
@ -856,7 +855,7 @@ def zoo(e, b, *args):
return f'zoo: {e}, {b}, {args}'
print zoo(2, 3)
print zoo('s', 3)
#: decorating zoo[...,...,...]
#: decorating zoo:0[...,...,...]
#: decorator (2, 3) ()
#: zoo: 5, 2, (3)
#: decorator ('s', 3) ()
@ -869,9 +868,9 @@ def mydecorator(func):
print("after")
return inner
@mydecorator
def foo():
def foo2():
print("foo")
foo()
foo2()
#: before
#: foo
#: after
@ -891,7 +890,7 @@ def factorial(num):
return n
factorial(10)
#: 3628800
#: time needed for factorial[...] is 3628799
#: time needed for factorial:0[...] is 3628799
def dx1(func):
def inner():
@ -921,9 +920,9 @@ def dy2(func):
return inner
@dy1
@dy2
def num(a, b):
def num2(a, b):
return a+b
print(num(10, 20)) #: 3600
print(num2(10, 20)) #: 3600
#%% hetero_iter,barebones
e = (1, 2, 3, 'foo', 5, 'bar', 6)
@ -970,14 +969,14 @@ def tee(iterable, n=2):
return list(gen(d) for d in deques)
it = [1,2,3,4]
a, b = tee(it) #! cannot typecheck the program
#! while realizing tee (arguments tee[List[int],int])
#! while realizing tee:0 (arguments tee:0[List[int],int])
#%% new_syntax,barebones
def foo[T,U](x: type, y, z: Static[int] = 10):
print T.__class__, U.__class__, x.__class__, y.__class__, Int[z+1].__class__
return List[x]()
print foo(T=int,U=str,...).__class__ #: foo[T1,x,z,int,str]
print foo(T=int,U=str,z=5,x=bool,...).__class__ #: foo[T1,bool,5,int,str]
print foo(T=int,U=str,...).__class__ #: foo:0[T1,x,z,int,str]
print foo(T=int,U=str,z=5,x=bool,...).__class__ #: foo:0[T1,bool,5,int,str]
print foo(float,3,T=int,U=str,z=5).__class__ #: List[float]
foo(float,1,10,str,int) #: str int float int Int[11]
@ -993,11 +992,11 @@ print Foo[5,int,float,6].__class__ #: Foo[5,int,float,6]
print Foo(1.1, 10i32, [False], 10u66).__class__ #: Foo[66,bool,float,32]
def foo[N: Static[int]]():
def foo2[N: Static[int]]():
print Int[N].__class__, N
x: Static[int] = 5
y: Static[int] = 105 - x * 2
foo(y-x) #: Int[90] 90
foo2(y-x) #: Int[90] 90
if 1.1+2.2 > 0:
x: Static[int] = 88