Fix missing generic issue

pull/335/head
Ibrahim Numanagić 2023-03-19 22:00:12 -07:00
parent bc951f29f9
commit a65a1cb881
8 changed files with 126 additions and 15 deletions

View File

@ -480,6 +480,56 @@ std::string FunctionStmt::getDocstr() {
return "";
}
// Search expression tree for a identifier
class IdSearchVisitor : public CallbackASTVisitor<bool, bool> {
std::string what;
bool result;
public:
IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {}
bool transform(const std::shared_ptr<Expr> &expr) override {
if (result)
return result;
IdSearchVisitor v(what);
if (expr)
expr->accept(v);
return result = v.result;
}
bool transform(const std::shared_ptr<Stmt> &stmt) override {
if (result)
return result;
IdSearchVisitor v(what);
if (stmt)
stmt->accept(v);
return result = v.result;
}
void visit(IdExpr *expr) override {
if (expr->value == what)
result = true;
}
};
/// Check if a function can be called with the given arguments.
/// See @c reorderNamedArgs for details.
std::unordered_set<std::string> FunctionStmt::getNonInferrableGenerics() {
std::unordered_set<std::string> nonInferrableGenerics;
for (auto &a : args) {
if (a.status == Param::Generic && !a.defaultValue) {
bool inferrable = false;
for (auto &b : args)
if (b.type && IdSearchVisitor(a.name).transform(b.type)) {
inferrable = true;
break;
}
if (ret && IdSearchVisitor(a.name).transform(ret))
inferrable = true;
if (!inferrable)
nonInferrableGenerics.insert(a.name);
}
}
return nonInferrableGenerics;
}
ClassStmt::ClassStmt(std::string name, std::vector<Param> args, StmtPtr suite,
std::vector<ExprPtr> decorators, std::vector<ExprPtr> baseClasses,
std::vector<ExprPtr> staticBaseClasses)

View File

@ -482,6 +482,7 @@ struct FunctionStmt : public Stmt {
FunctionStmt *getFunction() override { return this; }
std::string getDocstr();
std::unordered_set<std::string> getNonInferrableGenerics();
};
/// Class statement (@(attributes...) class name[generics...]: args... ; suite).

View File

@ -281,8 +281,10 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
};
// Handle reordered arguments (see @c reorderNamedArgs for details)
bool partial = false;
auto reorderFn = [&](int starArgIndex, int kwstarArgIndex,
const std::vector<std::vector<int>> &slots, bool partial) {
const std::vector<std::vector<int>> &slots, bool _partial) {
partial = _partial;
ctx->addBlock(); // add function generics to typecheck default arguments
addFunctionGenerics(calleeFn->getFunc().get());
for (size_t si = 0, pi = 0; si < slots.size(); si++) {
@ -410,18 +412,28 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
(!expr->hasAttr(ExprAttr::OrderedCall) &&
typeArgs.size() == calleeFn->funcGenerics.size()),
"bad vector sizes");
for (size_t si = 0;
!expr->hasAttr(ExprAttr::OrderedCall) && si < calleeFn->funcGenerics.size();
si++) {
if (typeArgs[si]) {
auto typ = typeArgs[si]->type;
if (calleeFn->funcGenerics[si].type->isStaticType()) {
if (!typeArgs[si]->isStatic()) {
E(Error::EXPECTED_STATIC, typeArgs[si]);
if (!calleeFn->funcGenerics.empty()) {
auto niGenerics = calleeFn->ast->getNonInferrableGenerics();
for (size_t si = 0;
!expr->hasAttr(ExprAttr::OrderedCall) && si < calleeFn->funcGenerics.size();
si++) {
if (typeArgs[si]) {
auto typ = typeArgs[si]->type;
if (calleeFn->funcGenerics[si].type->isStaticType()) {
if (!typeArgs[si]->isStatic()) {
E(Error::EXPECTED_STATIC, typeArgs[si]);
}
typ = Type::makeStatic(ctx->cache, typeArgs[si]);
}
unify(typ, calleeFn->funcGenerics[si].type);
} else {
if (calleeFn->funcGenerics[si].type->getUnbound() &&
!calleeFn->ast->args[si].defaultValue &&
!partial &&
in(niGenerics, calleeFn->funcGenerics[si].name)) {
error("generic '{}' not provided", calleeFn->funcGenerics[si].niceName);
}
typ = Type::makeStatic(ctx->cache, typeArgs[si]);
}
unify(typ, calleeFn->funcGenerics[si].type);
}
}

View File

@ -199,6 +199,7 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type,
/// @example
/// `(a1, ..., aN)` -> `Tuple.N.__new__(a1, ..., aN)`
void TypecheckVisitor::visit(TupleExpr *expr) {
expr->setType(ctx->getUnbound());
for (int ai = 0; ai < expr->items.size(); ai++)
if (auto star = expr->items[ai]->getStar()) {
// Case: unpack star expressions (e.g., `*arg` -> `arg.item1, arg.item2, ...`)
@ -226,6 +227,7 @@ void TypecheckVisitor::visit(TupleExpr *expr) {
auto tupleName = generateTuple(expr->items.size());
resultExpr =
transform(N<CallExpr>(N<DotExpr>(tupleName, "__new__"), clone(expr->items)));
unify(expr->type, resultExpr->type);
}
/// Transform a tuple generator expression.

View File

@ -217,20 +217,55 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(
return m.empty() ? nullptr : m[0];
}
// Search expression tree for a identifier
class IdSearchVisitor : public CallbackASTVisitor<bool, bool> {
std::string what;
bool result;
public:
IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {}
bool transform(const std::shared_ptr<Expr> &expr) override {
if (result)
return result;
IdSearchVisitor v(what);
if (expr)
expr->accept(v);
return v.result;
}
bool transform(const std::shared_ptr<Stmt> &stmt) override {
if (result)
return result;
IdSearchVisitor v(what);
if (stmt)
stmt->accept(v);
return v.result;
}
void visit(IdExpr *expr) override {
if (expr->value == what)
result = true;
}
};
/// Check if a function can be called with the given arguments.
/// See @c reorderNamedArgs for details.
int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
const std::vector<CallExpr::Arg> &args) {
std::vector<std::pair<types::TypePtr, size_t>> reordered;
auto niGenerics = fn->ast->getNonInferrableGenerics();
auto score = ctx->reorderNamedArgs(
fn.get(), args,
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
for (int si = 0; si < slots.size(); si++) {
if (fn->ast->args[si].status == Param::Generic) {
if (slots[si].empty())
if (slots[si].empty()) {
// is this "real" type?
if (in(niGenerics, fn->ast->args[si].name) &&
!fn->ast->args[si].defaultValue)
return -1;
reordered.push_back({nullptr, 0});
else
} else {
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
}
} else if (si == s || si == k || slots[si].size() != 1) {
// Ignore *args, *kwargs and default arguments
reordered.push_back({nullptr, 0});

View File

@ -590,7 +590,7 @@ def f(x):
return g(x)
print f(5) #: 6
##% nested_generic_static,barebones
#%% nested_generic_static,barebones
def foo():
N: Static[int] = 5
Z: Static[int] = 15

View File

@ -191,7 +191,7 @@ foo(int, float) #! foo() takes 1 arguments (2 given)
#%% instantiate_err_2,barebones
def foo[N, T]():
return N()
foo(int) #! cannot typecheck the program
foo(int) #! generic 'T' not provided
#%% instantiate_err_3,barebones
Ptr[int, float]() #! Ptr takes 1 generics (2 given)

View File

@ -1856,3 +1856,14 @@ a.test2(2)
#: test:B 1
#: test2:B 2
#%% no_generic,barebones
def foo(a, b: Static[int]):
pass
foo(5) #! generic 'b' not provided
#%% no_generic_2,barebones
def f(a, b, T: type):
print(a, b)
f(1, 2) #! generic 'T' not provided