mirror of https://github.com/exaloop/codon.git
Fix missing generic issue
parent
bc951f29f9
commit
a65a1cb881
|
@ -480,6 +480,56 @@ std::string FunctionStmt::getDocstr() {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Search expression tree for a identifier
|
||||||
|
class IdSearchVisitor : public CallbackASTVisitor<bool, bool> {
|
||||||
|
std::string what;
|
||||||
|
bool result;
|
||||||
|
|
||||||
|
public:
|
||||||
|
IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {}
|
||||||
|
bool transform(const std::shared_ptr<Expr> &expr) override {
|
||||||
|
if (result)
|
||||||
|
return result;
|
||||||
|
IdSearchVisitor v(what);
|
||||||
|
if (expr)
|
||||||
|
expr->accept(v);
|
||||||
|
return result = v.result;
|
||||||
|
}
|
||||||
|
bool transform(const std::shared_ptr<Stmt> &stmt) override {
|
||||||
|
if (result)
|
||||||
|
return result;
|
||||||
|
IdSearchVisitor v(what);
|
||||||
|
if (stmt)
|
||||||
|
stmt->accept(v);
|
||||||
|
return result = v.result;
|
||||||
|
}
|
||||||
|
void visit(IdExpr *expr) override {
|
||||||
|
if (expr->value == what)
|
||||||
|
result = true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Check if a function can be called with the given arguments.
|
||||||
|
/// See @c reorderNamedArgs for details.
|
||||||
|
std::unordered_set<std::string> FunctionStmt::getNonInferrableGenerics() {
|
||||||
|
std::unordered_set<std::string> nonInferrableGenerics;
|
||||||
|
for (auto &a : args) {
|
||||||
|
if (a.status == Param::Generic && !a.defaultValue) {
|
||||||
|
bool inferrable = false;
|
||||||
|
for (auto &b : args)
|
||||||
|
if (b.type && IdSearchVisitor(a.name).transform(b.type)) {
|
||||||
|
inferrable = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (ret && IdSearchVisitor(a.name).transform(ret))
|
||||||
|
inferrable = true;
|
||||||
|
if (!inferrable)
|
||||||
|
nonInferrableGenerics.insert(a.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nonInferrableGenerics;
|
||||||
|
}
|
||||||
|
|
||||||
ClassStmt::ClassStmt(std::string name, std::vector<Param> args, StmtPtr suite,
|
ClassStmt::ClassStmt(std::string name, std::vector<Param> args, StmtPtr suite,
|
||||||
std::vector<ExprPtr> decorators, std::vector<ExprPtr> baseClasses,
|
std::vector<ExprPtr> decorators, std::vector<ExprPtr> baseClasses,
|
||||||
std::vector<ExprPtr> staticBaseClasses)
|
std::vector<ExprPtr> staticBaseClasses)
|
||||||
|
|
|
@ -482,6 +482,7 @@ struct FunctionStmt : public Stmt {
|
||||||
|
|
||||||
FunctionStmt *getFunction() override { return this; }
|
FunctionStmt *getFunction() override { return this; }
|
||||||
std::string getDocstr();
|
std::string getDocstr();
|
||||||
|
std::unordered_set<std::string> getNonInferrableGenerics();
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Class statement (@(attributes...) class name[generics...]: args... ; suite).
|
/// Class statement (@(attributes...) class name[generics...]: args... ; suite).
|
||||||
|
|
|
@ -281,8 +281,10 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
|
||||||
};
|
};
|
||||||
|
|
||||||
// Handle reordered arguments (see @c reorderNamedArgs for details)
|
// Handle reordered arguments (see @c reorderNamedArgs for details)
|
||||||
|
bool partial = false;
|
||||||
auto reorderFn = [&](int starArgIndex, int kwstarArgIndex,
|
auto reorderFn = [&](int starArgIndex, int kwstarArgIndex,
|
||||||
const std::vector<std::vector<int>> &slots, bool partial) {
|
const std::vector<std::vector<int>> &slots, bool _partial) {
|
||||||
|
partial = _partial;
|
||||||
ctx->addBlock(); // add function generics to typecheck default arguments
|
ctx->addBlock(); // add function generics to typecheck default arguments
|
||||||
addFunctionGenerics(calleeFn->getFunc().get());
|
addFunctionGenerics(calleeFn->getFunc().get());
|
||||||
for (size_t si = 0, pi = 0; si < slots.size(); si++) {
|
for (size_t si = 0, pi = 0; si < slots.size(); si++) {
|
||||||
|
@ -410,18 +412,28 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
|
||||||
(!expr->hasAttr(ExprAttr::OrderedCall) &&
|
(!expr->hasAttr(ExprAttr::OrderedCall) &&
|
||||||
typeArgs.size() == calleeFn->funcGenerics.size()),
|
typeArgs.size() == calleeFn->funcGenerics.size()),
|
||||||
"bad vector sizes");
|
"bad vector sizes");
|
||||||
for (size_t si = 0;
|
if (!calleeFn->funcGenerics.empty()) {
|
||||||
!expr->hasAttr(ExprAttr::OrderedCall) && si < calleeFn->funcGenerics.size();
|
auto niGenerics = calleeFn->ast->getNonInferrableGenerics();
|
||||||
si++) {
|
for (size_t si = 0;
|
||||||
if (typeArgs[si]) {
|
!expr->hasAttr(ExprAttr::OrderedCall) && si < calleeFn->funcGenerics.size();
|
||||||
auto typ = typeArgs[si]->type;
|
si++) {
|
||||||
if (calleeFn->funcGenerics[si].type->isStaticType()) {
|
if (typeArgs[si]) {
|
||||||
if (!typeArgs[si]->isStatic()) {
|
auto typ = typeArgs[si]->type;
|
||||||
E(Error::EXPECTED_STATIC, typeArgs[si]);
|
if (calleeFn->funcGenerics[si].type->isStaticType()) {
|
||||||
|
if (!typeArgs[si]->isStatic()) {
|
||||||
|
E(Error::EXPECTED_STATIC, typeArgs[si]);
|
||||||
|
}
|
||||||
|
typ = Type::makeStatic(ctx->cache, typeArgs[si]);
|
||||||
|
}
|
||||||
|
unify(typ, calleeFn->funcGenerics[si].type);
|
||||||
|
} else {
|
||||||
|
if (calleeFn->funcGenerics[si].type->getUnbound() &&
|
||||||
|
!calleeFn->ast->args[si].defaultValue &&
|
||||||
|
!partial &&
|
||||||
|
in(niGenerics, calleeFn->funcGenerics[si].name)) {
|
||||||
|
error("generic '{}' not provided", calleeFn->funcGenerics[si].niceName);
|
||||||
}
|
}
|
||||||
typ = Type::makeStatic(ctx->cache, typeArgs[si]);
|
|
||||||
}
|
}
|
||||||
unify(typ, calleeFn->funcGenerics[si].type);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -199,6 +199,7 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type,
|
||||||
/// @example
|
/// @example
|
||||||
/// `(a1, ..., aN)` -> `Tuple.N.__new__(a1, ..., aN)`
|
/// `(a1, ..., aN)` -> `Tuple.N.__new__(a1, ..., aN)`
|
||||||
void TypecheckVisitor::visit(TupleExpr *expr) {
|
void TypecheckVisitor::visit(TupleExpr *expr) {
|
||||||
|
expr->setType(ctx->getUnbound());
|
||||||
for (int ai = 0; ai < expr->items.size(); ai++)
|
for (int ai = 0; ai < expr->items.size(); ai++)
|
||||||
if (auto star = expr->items[ai]->getStar()) {
|
if (auto star = expr->items[ai]->getStar()) {
|
||||||
// Case: unpack star expressions (e.g., `*arg` -> `arg.item1, arg.item2, ...`)
|
// Case: unpack star expressions (e.g., `*arg` -> `arg.item1, arg.item2, ...`)
|
||||||
|
@ -226,6 +227,7 @@ void TypecheckVisitor::visit(TupleExpr *expr) {
|
||||||
auto tupleName = generateTuple(expr->items.size());
|
auto tupleName = generateTuple(expr->items.size());
|
||||||
resultExpr =
|
resultExpr =
|
||||||
transform(N<CallExpr>(N<DotExpr>(tupleName, "__new__"), clone(expr->items)));
|
transform(N<CallExpr>(N<DotExpr>(tupleName, "__new__"), clone(expr->items)));
|
||||||
|
unify(expr->type, resultExpr->type);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Transform a tuple generator expression.
|
/// Transform a tuple generator expression.
|
||||||
|
|
|
@ -217,20 +217,55 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(
|
||||||
return m.empty() ? nullptr : m[0];
|
return m.empty() ? nullptr : m[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Search expression tree for a identifier
|
||||||
|
class IdSearchVisitor : public CallbackASTVisitor<bool, bool> {
|
||||||
|
std::string what;
|
||||||
|
bool result;
|
||||||
|
|
||||||
|
public:
|
||||||
|
IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {}
|
||||||
|
bool transform(const std::shared_ptr<Expr> &expr) override {
|
||||||
|
if (result)
|
||||||
|
return result;
|
||||||
|
IdSearchVisitor v(what);
|
||||||
|
if (expr)
|
||||||
|
expr->accept(v);
|
||||||
|
return v.result;
|
||||||
|
}
|
||||||
|
bool transform(const std::shared_ptr<Stmt> &stmt) override {
|
||||||
|
if (result)
|
||||||
|
return result;
|
||||||
|
IdSearchVisitor v(what);
|
||||||
|
if (stmt)
|
||||||
|
stmt->accept(v);
|
||||||
|
return v.result;
|
||||||
|
}
|
||||||
|
void visit(IdExpr *expr) override {
|
||||||
|
if (expr->value == what)
|
||||||
|
result = true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Check if a function can be called with the given arguments.
|
/// Check if a function can be called with the given arguments.
|
||||||
/// See @c reorderNamedArgs for details.
|
/// See @c reorderNamedArgs for details.
|
||||||
int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
|
int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
|
||||||
const std::vector<CallExpr::Arg> &args) {
|
const std::vector<CallExpr::Arg> &args) {
|
||||||
std::vector<std::pair<types::TypePtr, size_t>> reordered;
|
std::vector<std::pair<types::TypePtr, size_t>> reordered;
|
||||||
|
auto niGenerics = fn->ast->getNonInferrableGenerics();
|
||||||
auto score = ctx->reorderNamedArgs(
|
auto score = ctx->reorderNamedArgs(
|
||||||
fn.get(), args,
|
fn.get(), args,
|
||||||
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
|
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
|
||||||
for (int si = 0; si < slots.size(); si++) {
|
for (int si = 0; si < slots.size(); si++) {
|
||||||
if (fn->ast->args[si].status == Param::Generic) {
|
if (fn->ast->args[si].status == Param::Generic) {
|
||||||
if (slots[si].empty())
|
if (slots[si].empty()) {
|
||||||
|
// is this "real" type?
|
||||||
|
if (in(niGenerics, fn->ast->args[si].name) &&
|
||||||
|
!fn->ast->args[si].defaultValue)
|
||||||
|
return -1;
|
||||||
reordered.push_back({nullptr, 0});
|
reordered.push_back({nullptr, 0});
|
||||||
else
|
} else {
|
||||||
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
|
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
|
||||||
|
}
|
||||||
} else if (si == s || si == k || slots[si].size() != 1) {
|
} else if (si == s || si == k || slots[si].size() != 1) {
|
||||||
// Ignore *args, *kwargs and default arguments
|
// Ignore *args, *kwargs and default arguments
|
||||||
reordered.push_back({nullptr, 0});
|
reordered.push_back({nullptr, 0});
|
||||||
|
|
|
@ -590,7 +590,7 @@ def f(x):
|
||||||
return g(x)
|
return g(x)
|
||||||
print f(5) #: 6
|
print f(5) #: 6
|
||||||
|
|
||||||
##% nested_generic_static,barebones
|
#%% nested_generic_static,barebones
|
||||||
def foo():
|
def foo():
|
||||||
N: Static[int] = 5
|
N: Static[int] = 5
|
||||||
Z: Static[int] = 15
|
Z: Static[int] = 15
|
||||||
|
|
|
@ -191,7 +191,7 @@ foo(int, float) #! foo() takes 1 arguments (2 given)
|
||||||
#%% instantiate_err_2,barebones
|
#%% instantiate_err_2,barebones
|
||||||
def foo[N, T]():
|
def foo[N, T]():
|
||||||
return N()
|
return N()
|
||||||
foo(int) #! cannot typecheck the program
|
foo(int) #! generic 'T' not provided
|
||||||
|
|
||||||
#%% instantiate_err_3,barebones
|
#%% instantiate_err_3,barebones
|
||||||
Ptr[int, float]() #! Ptr takes 1 generics (2 given)
|
Ptr[int, float]() #! Ptr takes 1 generics (2 given)
|
||||||
|
|
|
@ -1856,3 +1856,14 @@ a.test2(2)
|
||||||
#: test:B 1
|
#: test:B 1
|
||||||
#: test2:B 2
|
#: test2:B 2
|
||||||
|
|
||||||
|
|
||||||
|
#%% no_generic,barebones
|
||||||
|
def foo(a, b: Static[int]):
|
||||||
|
pass
|
||||||
|
foo(5) #! generic 'b' not provided
|
||||||
|
|
||||||
|
|
||||||
|
#%% no_generic_2,barebones
|
||||||
|
def f(a, b, T: type):
|
||||||
|
print(a, b)
|
||||||
|
f(1, 2) #! generic 'T' not provided
|
||||||
|
|
Loading…
Reference in New Issue