mirror of https://github.com/exaloop/codon.git
Fix missing generic issue
parent
bc951f29f9
commit
a65a1cb881
|
@ -480,6 +480,56 @@ std::string FunctionStmt::getDocstr() {
|
|||
return "";
|
||||
}
|
||||
|
||||
// Search expression tree for a identifier
|
||||
class IdSearchVisitor : public CallbackASTVisitor<bool, bool> {
|
||||
std::string what;
|
||||
bool result;
|
||||
|
||||
public:
|
||||
IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {}
|
||||
bool transform(const std::shared_ptr<Expr> &expr) override {
|
||||
if (result)
|
||||
return result;
|
||||
IdSearchVisitor v(what);
|
||||
if (expr)
|
||||
expr->accept(v);
|
||||
return result = v.result;
|
||||
}
|
||||
bool transform(const std::shared_ptr<Stmt> &stmt) override {
|
||||
if (result)
|
||||
return result;
|
||||
IdSearchVisitor v(what);
|
||||
if (stmt)
|
||||
stmt->accept(v);
|
||||
return result = v.result;
|
||||
}
|
||||
void visit(IdExpr *expr) override {
|
||||
if (expr->value == what)
|
||||
result = true;
|
||||
}
|
||||
};
|
||||
|
||||
/// Check if a function can be called with the given arguments.
|
||||
/// See @c reorderNamedArgs for details.
|
||||
std::unordered_set<std::string> FunctionStmt::getNonInferrableGenerics() {
|
||||
std::unordered_set<std::string> nonInferrableGenerics;
|
||||
for (auto &a : args) {
|
||||
if (a.status == Param::Generic && !a.defaultValue) {
|
||||
bool inferrable = false;
|
||||
for (auto &b : args)
|
||||
if (b.type && IdSearchVisitor(a.name).transform(b.type)) {
|
||||
inferrable = true;
|
||||
break;
|
||||
}
|
||||
if (ret && IdSearchVisitor(a.name).transform(ret))
|
||||
inferrable = true;
|
||||
if (!inferrable)
|
||||
nonInferrableGenerics.insert(a.name);
|
||||
}
|
||||
}
|
||||
return nonInferrableGenerics;
|
||||
}
|
||||
|
||||
ClassStmt::ClassStmt(std::string name, std::vector<Param> args, StmtPtr suite,
|
||||
std::vector<ExprPtr> decorators, std::vector<ExprPtr> baseClasses,
|
||||
std::vector<ExprPtr> staticBaseClasses)
|
||||
|
|
|
@ -482,6 +482,7 @@ struct FunctionStmt : public Stmt {
|
|||
|
||||
FunctionStmt *getFunction() override { return this; }
|
||||
std::string getDocstr();
|
||||
std::unordered_set<std::string> getNonInferrableGenerics();
|
||||
};
|
||||
|
||||
/// Class statement (@(attributes...) class name[generics...]: args... ; suite).
|
||||
|
|
|
@ -281,8 +281,10 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
|
|||
};
|
||||
|
||||
// Handle reordered arguments (see @c reorderNamedArgs for details)
|
||||
bool partial = false;
|
||||
auto reorderFn = [&](int starArgIndex, int kwstarArgIndex,
|
||||
const std::vector<std::vector<int>> &slots, bool partial) {
|
||||
const std::vector<std::vector<int>> &slots, bool _partial) {
|
||||
partial = _partial;
|
||||
ctx->addBlock(); // add function generics to typecheck default arguments
|
||||
addFunctionGenerics(calleeFn->getFunc().get());
|
||||
for (size_t si = 0, pi = 0; si < slots.size(); si++) {
|
||||
|
@ -410,18 +412,28 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
|
|||
(!expr->hasAttr(ExprAttr::OrderedCall) &&
|
||||
typeArgs.size() == calleeFn->funcGenerics.size()),
|
||||
"bad vector sizes");
|
||||
for (size_t si = 0;
|
||||
!expr->hasAttr(ExprAttr::OrderedCall) && si < calleeFn->funcGenerics.size();
|
||||
si++) {
|
||||
if (typeArgs[si]) {
|
||||
auto typ = typeArgs[si]->type;
|
||||
if (calleeFn->funcGenerics[si].type->isStaticType()) {
|
||||
if (!typeArgs[si]->isStatic()) {
|
||||
E(Error::EXPECTED_STATIC, typeArgs[si]);
|
||||
if (!calleeFn->funcGenerics.empty()) {
|
||||
auto niGenerics = calleeFn->ast->getNonInferrableGenerics();
|
||||
for (size_t si = 0;
|
||||
!expr->hasAttr(ExprAttr::OrderedCall) && si < calleeFn->funcGenerics.size();
|
||||
si++) {
|
||||
if (typeArgs[si]) {
|
||||
auto typ = typeArgs[si]->type;
|
||||
if (calleeFn->funcGenerics[si].type->isStaticType()) {
|
||||
if (!typeArgs[si]->isStatic()) {
|
||||
E(Error::EXPECTED_STATIC, typeArgs[si]);
|
||||
}
|
||||
typ = Type::makeStatic(ctx->cache, typeArgs[si]);
|
||||
}
|
||||
unify(typ, calleeFn->funcGenerics[si].type);
|
||||
} else {
|
||||
if (calleeFn->funcGenerics[si].type->getUnbound() &&
|
||||
!calleeFn->ast->args[si].defaultValue &&
|
||||
!partial &&
|
||||
in(niGenerics, calleeFn->funcGenerics[si].name)) {
|
||||
error("generic '{}' not provided", calleeFn->funcGenerics[si].niceName);
|
||||
}
|
||||
typ = Type::makeStatic(ctx->cache, typeArgs[si]);
|
||||
}
|
||||
unify(typ, calleeFn->funcGenerics[si].type);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -199,6 +199,7 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type,
|
|||
/// @example
|
||||
/// `(a1, ..., aN)` -> `Tuple.N.__new__(a1, ..., aN)`
|
||||
void TypecheckVisitor::visit(TupleExpr *expr) {
|
||||
expr->setType(ctx->getUnbound());
|
||||
for (int ai = 0; ai < expr->items.size(); ai++)
|
||||
if (auto star = expr->items[ai]->getStar()) {
|
||||
// Case: unpack star expressions (e.g., `*arg` -> `arg.item1, arg.item2, ...`)
|
||||
|
@ -226,6 +227,7 @@ void TypecheckVisitor::visit(TupleExpr *expr) {
|
|||
auto tupleName = generateTuple(expr->items.size());
|
||||
resultExpr =
|
||||
transform(N<CallExpr>(N<DotExpr>(tupleName, "__new__"), clone(expr->items)));
|
||||
unify(expr->type, resultExpr->type);
|
||||
}
|
||||
|
||||
/// Transform a tuple generator expression.
|
||||
|
|
|
@ -217,20 +217,55 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(
|
|||
return m.empty() ? nullptr : m[0];
|
||||
}
|
||||
|
||||
// Search expression tree for a identifier
|
||||
class IdSearchVisitor : public CallbackASTVisitor<bool, bool> {
|
||||
std::string what;
|
||||
bool result;
|
||||
|
||||
public:
|
||||
IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {}
|
||||
bool transform(const std::shared_ptr<Expr> &expr) override {
|
||||
if (result)
|
||||
return result;
|
||||
IdSearchVisitor v(what);
|
||||
if (expr)
|
||||
expr->accept(v);
|
||||
return v.result;
|
||||
}
|
||||
bool transform(const std::shared_ptr<Stmt> &stmt) override {
|
||||
if (result)
|
||||
return result;
|
||||
IdSearchVisitor v(what);
|
||||
if (stmt)
|
||||
stmt->accept(v);
|
||||
return v.result;
|
||||
}
|
||||
void visit(IdExpr *expr) override {
|
||||
if (expr->value == what)
|
||||
result = true;
|
||||
}
|
||||
};
|
||||
|
||||
/// Check if a function can be called with the given arguments.
|
||||
/// See @c reorderNamedArgs for details.
|
||||
int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
|
||||
const std::vector<CallExpr::Arg> &args) {
|
||||
std::vector<std::pair<types::TypePtr, size_t>> reordered;
|
||||
auto niGenerics = fn->ast->getNonInferrableGenerics();
|
||||
auto score = ctx->reorderNamedArgs(
|
||||
fn.get(), args,
|
||||
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
|
||||
for (int si = 0; si < slots.size(); si++) {
|
||||
if (fn->ast->args[si].status == Param::Generic) {
|
||||
if (slots[si].empty())
|
||||
if (slots[si].empty()) {
|
||||
// is this "real" type?
|
||||
if (in(niGenerics, fn->ast->args[si].name) &&
|
||||
!fn->ast->args[si].defaultValue)
|
||||
return -1;
|
||||
reordered.push_back({nullptr, 0});
|
||||
else
|
||||
} else {
|
||||
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
|
||||
}
|
||||
} else if (si == s || si == k || slots[si].size() != 1) {
|
||||
// Ignore *args, *kwargs and default arguments
|
||||
reordered.push_back({nullptr, 0});
|
||||
|
|
|
@ -590,7 +590,7 @@ def f(x):
|
|||
return g(x)
|
||||
print f(5) #: 6
|
||||
|
||||
##% nested_generic_static,barebones
|
||||
#%% nested_generic_static,barebones
|
||||
def foo():
|
||||
N: Static[int] = 5
|
||||
Z: Static[int] = 15
|
||||
|
|
|
@ -191,7 +191,7 @@ foo(int, float) #! foo() takes 1 arguments (2 given)
|
|||
#%% instantiate_err_2,barebones
|
||||
def foo[N, T]():
|
||||
return N()
|
||||
foo(int) #! cannot typecheck the program
|
||||
foo(int) #! generic 'T' not provided
|
||||
|
||||
#%% instantiate_err_3,barebones
|
||||
Ptr[int, float]() #! Ptr takes 1 generics (2 given)
|
||||
|
|
|
@ -1856,3 +1856,14 @@ a.test2(2)
|
|||
#: test:B 1
|
||||
#: test2:B 2
|
||||
|
||||
|
||||
#%% no_generic,barebones
|
||||
def foo(a, b: Static[int]):
|
||||
pass
|
||||
foo(5) #! generic 'b' not provided
|
||||
|
||||
|
||||
#%% no_generic_2,barebones
|
||||
def f(a, b, T: type):
|
||||
print(a, b)
|
||||
f(1, 2) #! generic 'T' not provided
|
||||
|
|
Loading…
Reference in New Issue