Fix auto-deduce classes

typecheck-v2
Ibrahim Numanagić 2024-09-02 18:16:07 -07:00
parent 487ea9af41
commit 5a484ea76e
4 changed files with 41 additions and 37 deletions

View File

@ -84,9 +84,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
// Parse and add class generics
std::vector<Param> args;
if (stmt->hasAttribute("deduce") && args.empty()) {
E(Error::CUSTOM, stmt, "not yet implemented");
} else if (stmt->hasAttribute(Attr::Extend)) {
if (stmt->hasAttribute(Attr::Extend)) {
for (auto &a : argsToParse) {
if (!a.isGeneric())
continue;
@ -96,6 +94,11 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
args.emplace_back(val->canonicalName, nullptr, nullptr, a.status);
}
} else {
if (stmt->hasAttribute(Attr::ClassDeduce) && args.empty()) {
autoDeduceMembers(stmt, argsToParse);
stmt->eraseAttribute(Attr::ClassDeduce);
}
// Add all generics before parent classes, fields and methods
for (auto &a : argsToParse) {
if (!a.isGeneric())
@ -492,36 +495,23 @@ std::vector<TypePtr> TypecheckVisitor::parseBaseClasses(
/// x: T1
/// y: T2```
/// @return the transformed init and the pointer to the original function.
std::pair<Stmt *, FunctionStmt *>
TypecheckVisitor::autoDeduceMembers(ClassStmt *stmt, std::vector<Param> &args) {
// std::pair<Stmt *, FunctionStmt *> init{nullptr, nullptr};
// for (const auto &sp : getClassMethods(stmt->suite))
// if (auto f = cast<FunctionStmt>(sp)) {
// // todo)) do this
// if (f->name == "__init__" && !f->args.empty() && f->args[0].name == "self") {
// // Set up deducedMembers that will be populated during AssignStmt evaluation
// ctx->getBase()->deducedMembers =
// std::make_shared<std::vector<std::string>>(); auto transformed =
// transform(sp);
// transformed->getFunction()->attributes.set(Attr::RealizeWithoutSelf);
// ctx->cache->functions[transformed->getFunction()->name].ast->attributes.set(
// Attr::RealizeWithoutSelf);
// int i = 0;
// // Once done, add arguments
// for (auto &m : *(ctx->getBase()->deducedMembers)) {
// auto varName = ctx->generateCanonicalName(format("T{}", ++i));
// auto memberName = ctx->cache->rev(varName);
// ctx->addType(memberName, varName, stmt->getSrcInfo())->generic = true;
// args.emplace_back(varName, N<IdExpr>(TYPE_TYPE), nullptr, Param::Generic);
// args.emplace_back(m, N<IdExpr>(varName));
// ctx->cache->classes[canonicalName].fields.push_back(
// Cache::Class::ClassField{m, nullptr, canonicalName});
// }
// ctx->getBase()->deducedMembers = nullptr;
// return {transformed, f};
// }
// }
return {nullptr, nullptr};
void TypecheckVisitor::autoDeduceMembers(ClassStmt *stmt, std::vector<Param> &args) {
std::set<std::string> members;
for (const auto &sp : getClassMethods(stmt->suite))
if (auto f = cast<FunctionStmt>(sp)) {
if (f->name == "__init__")
if (auto b = f->getAttribute<ir::StringListAttribute>(Attr::ClassDeduce)) {
f->setAttribute(Attr::RealizeWithoutSelf);
for (auto m : b->values)
members.insert(m);
}
}
for (auto m : members) {
auto genericName = fmt::format("T_{}", m);
args.emplace_back(genericName, N<IdExpr>(TYPE_TYPE), N<IdExpr>("NoneType"),
Param::Generic);
args.emplace_back(m, N<IdExpr>(genericName));
}
}
/// Return a list of all statements within a given class suite.

View File

@ -410,6 +410,15 @@ types::Type *TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) {
}
// Realize the return type
auto ret = realize(type->getRetType());
if (ast->hasAttribute(Attr::RealizeWithoutSelf) &&
!extractFuncArgType(type)->canRealize()) { // For RealizeWithoutSelf
realizations.erase(key);
ctx->bases.pop_back();
ctx->popBlock();
ctx->typecheckLevel--;
getLogger().level--;
return nullptr;
}
seqassert(ret, "cannot realize return type '{}'", *(type->getRetType()));
// LOG("[realize] F {} -> {} => {}", type->getFuncName(), type->debugString(2),

View File

@ -199,8 +199,7 @@ private:
std::vector<Param> &, Stmt *,
const std::string &, Expr *,
types::ClassType *);
std::pair<Stmt *, FunctionStmt *> autoDeduceMembers(ClassStmt *,
std::vector<Param> &);
void autoDeduceMembers(ClassStmt *, std::vector<Param> &);
std::vector<Stmt *> getClassMethods(Stmt *s);
void transformNestedClasses(ClassStmt *, std::vector<Stmt *> &, std::vector<Stmt *> &,
std::vector<Stmt *> &);

View File

@ -378,11 +378,17 @@ print(f.x, f.y, f.__class__.__name__) #: ['s'] (1, 's') Foo[List[str],Tuple[int,
@deduce
class Bar:
def __init__(self, y):
def __init__(self, y: float):
self.y = Foo(y)
def __init__(self, y: str):
self.x = Foo(y)
b = Bar(3.1)
print(b.y.x, b.__class__.__name__) #: [3.1] Bar[Foo[List[float],Tuple[int,float]]]
print(b.x.__class__.__name__, b.y.__class__.__name__, b.y.x, b.__class__.__name__)
#: NoneType Foo[List[float],Tuple[int,float]] [3.1] Bar[NoneType,Foo[List[float],Tuple[int,float]]]
b = Bar('3.1')
print(b.x.__class__.__name__, b.y.__class__.__name__, b.x.x, b.__class__.__name__)
#: Foo[List[str],Tuple[int,str]] NoneType ['3.1'] Bar[Foo[List[str],Tuple[int,str]],NoneType]
#%% class_var,barebones
class Foo: