mirror of https://github.com/exaloop/codon.git
Fix partial handling in IR
parent
8dd7d2e0ea
commit
4b2dfaf28f
|
@ -9,6 +9,9 @@
|
|||
#include "codon/parser/visitors/simplify/simplify.h"
|
||||
#include "codon/parser/visitors/translate/translate.h"
|
||||
#include "codon/parser/visitors/typecheck/typecheck.h"
|
||||
#include "codon/sir/util/operator.h"
|
||||
#include "codon/sir/util/irtools.h"
|
||||
|
||||
|
||||
namespace codon {
|
||||
|
||||
|
@ -110,9 +113,22 @@ Compiler::parseCode(const std::string &file, const std::string &code, int startL
|
|||
const std::unordered_map<std::string, std::string> &defines) {
|
||||
return parse(/*isCode=*/true, file, code, startLine, testFlags, defines);
|
||||
}
|
||||
|
||||
struct DummyOp : public codon::ir::util::Operator {
|
||||
void handle(codon::ir::CallInstr *x) override {
|
||||
auto *M = x->getModule();
|
||||
auto *func = codon::ir::util::getFunc(x->getCallee());
|
||||
if (!func || func->getUnmangledName() != "foo")
|
||||
return;
|
||||
auto fn = M->getOrRealizeFunc("bar", {x->front()->getType()}, {});
|
||||
seqassert(fn, "did not succeed");
|
||||
auto result = codon::ir::util::call(fn, {x->front()});
|
||||
x->replaceAll(result);
|
||||
}
|
||||
};
|
||||
llvm::Error Compiler::compile() {
|
||||
pm->run(module.get());
|
||||
auto d = DummyOp();
|
||||
module->accept(d);
|
||||
llvisitor->visit(module.get());
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
|
|
@ -397,11 +397,14 @@ StmtExpr::StmtExpr(std::shared_ptr<Stmt> stmt, std::shared_ptr<Stmt> stmt2,
|
|||
stmts.push_back(std::move(stmt2));
|
||||
}
|
||||
StmtExpr::StmtExpr(const StmtExpr &expr)
|
||||
: Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)) {}
|
||||
: Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)),
|
||||
attributes(expr.attributes) {}
|
||||
std::string StmtExpr::toString() const {
|
||||
return wrapType(format("stmt-expr ({}) {}", combine(stmts, " "), expr->toString()));
|
||||
}
|
||||
ACCEPT_IMPL(StmtExpr, ASTVisitor);
|
||||
bool StmtExpr::hasAttr(const std::string &attr) const { return in(attributes, attr); }
|
||||
void StmtExpr::setAttr(const std::string &attr) { attributes.insert(attr); }
|
||||
|
||||
PtrExpr::PtrExpr(ExprPtr expr) : Expr(), expr(std::move(expr)) {}
|
||||
PtrExpr::PtrExpr(const PtrExpr &expr) : Expr(expr), expr(ast::clone(expr.expr)) {}
|
||||
|
|
|
@ -600,6 +600,8 @@ struct RangeExpr : public Expr {
|
|||
struct StmtExpr : public Expr {
|
||||
std::vector<std::shared_ptr<Stmt>> stmts;
|
||||
ExprPtr expr;
|
||||
/// Set of attributes.
|
||||
std::set<std::string> attributes;
|
||||
|
||||
StmtExpr(std::vector<std::shared_ptr<Stmt>> stmts, ExprPtr expr);
|
||||
StmtExpr(std::shared_ptr<Stmt> stmt, ExprPtr expr);
|
||||
|
@ -610,6 +612,10 @@ struct StmtExpr : public Expr {
|
|||
ACCEPT(ASTVisitor);
|
||||
|
||||
const StmtExpr *getStmtExpr() const override { return this; }
|
||||
|
||||
/// Attribute helpers
|
||||
bool hasAttr(const std::string &attr) const;
|
||||
void setAttr(const std::string &attr);
|
||||
};
|
||||
|
||||
/// Pointer expression (__ptr__(expr)).
|
||||
|
|
|
@ -42,6 +42,7 @@ std::string SuiteStmt::toString(int indent) const {
|
|||
}
|
||||
ACCEPT_IMPL(SuiteStmt, ASTVisitor);
|
||||
void SuiteStmt::flatten(StmtPtr s, std::vector<StmtPtr> &stmts) {
|
||||
// WARNING: does not preserve attributes!
|
||||
if (!s)
|
||||
return;
|
||||
auto suite = const_cast<SuiteStmt *>(s->getSuite());
|
||||
|
|
|
@ -188,6 +188,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
|
|||
std::shared_ptr<TranslateContext> codegenCtx;
|
||||
/// Set of function realizations that are to be translated to IR.
|
||||
std::set<std::pair<std::string, std::string>> pendingRealizations;
|
||||
/// Mapping of partial record names to function pointers and corresponding masks.
|
||||
std::unordered_map<std::string, std::pair<types::FuncTypePtr, std::vector<char>>>
|
||||
partials;
|
||||
|
||||
/// Custom operators
|
||||
std::unordered_map<std::string,
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "codon/parser/common.h"
|
||||
#include "codon/parser/peg/peg.h"
|
||||
#include "codon/parser/visitors/simplify/simplify.h"
|
||||
#include "codon/sir/attribute.h"
|
||||
|
||||
using fmt::format;
|
||||
|
||||
|
@ -187,7 +188,9 @@ void SimplifyVisitor::visit(ListExpr *expr) {
|
|||
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), clone(it)))));
|
||||
}
|
||||
}
|
||||
resultExpr = N<StmtExpr>(stmts, transform(var));
|
||||
auto e = N<StmtExpr>(stmts, transform(var));
|
||||
e->setAttr(ir::ListLiteralAttribute::AttributeName);
|
||||
resultExpr = e;
|
||||
ctx->popBlock();
|
||||
}
|
||||
|
||||
|
@ -207,7 +210,9 @@ void SimplifyVisitor::visit(SetExpr *expr) {
|
|||
stmts.push_back(transform(
|
||||
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "add"), clone(it)))));
|
||||
}
|
||||
resultExpr = N<StmtExpr>(stmts, transform(var));
|
||||
auto e = N<StmtExpr>(stmts, transform(var));
|
||||
e->setAttr(ir::SetLiteralAttribute::AttributeName);
|
||||
resultExpr = e;
|
||||
ctx->popBlock();
|
||||
}
|
||||
|
||||
|
@ -229,7 +234,9 @@ void SimplifyVisitor::visit(DictExpr *expr) {
|
|||
stmts.push_back(transform(N<ExprStmt>(N<CallExpr>(
|
||||
N<DotExpr>(clone(var), "__setitem__"), clone(it.key), clone(it.value)))));
|
||||
}
|
||||
resultExpr = N<StmtExpr>(stmts, transform(var));
|
||||
auto e = N<StmtExpr>(stmts, transform(var));
|
||||
e->setAttr(ir::DictLiteralAttribute::AttributeName);
|
||||
resultExpr = e;
|
||||
ctx->popBlock();
|
||||
}
|
||||
|
||||
|
@ -688,7 +695,9 @@ void SimplifyVisitor::visit(StmtExpr *expr) {
|
|||
for (auto &s : expr->stmts)
|
||||
stmts.emplace_back(transform(s));
|
||||
auto e = transform(expr->expr);
|
||||
resultExpr = N<StmtExpr>(stmts, e);
|
||||
auto s = N<StmtExpr>(stmts, e);
|
||||
s->attributes = expr->attributes;
|
||||
resultExpr = s;
|
||||
}
|
||||
|
||||
/**************************************************************************************/
|
||||
|
|
|
@ -209,6 +209,10 @@ void TranslateVisitor::visit(StmtExpr *expr) {
|
|||
transform(s);
|
||||
ctx->popSeries();
|
||||
result = make<ir::FlowInstr>(expr, bodySeries, transform(expr->expr));
|
||||
for (auto &a: expr->attributes) {
|
||||
// if (a == ir::ListLiteralAttribute::AttributeName)
|
||||
// result->setAttribute(ir::ListLiteralAttribute);
|
||||
}
|
||||
}
|
||||
|
||||
/************************************************************************************/
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "codon/parser/common.h"
|
||||
#include "codon/parser/visitors/simplify/simplify.h"
|
||||
#include "codon/parser/visitors/typecheck/typecheck.h"
|
||||
#include "codon/sir/attribute.h"
|
||||
|
||||
using fmt::format;
|
||||
|
||||
|
@ -739,14 +740,7 @@ ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &e
|
|||
if (!tuple->getRecord())
|
||||
return nullptr;
|
||||
if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_PARTIAL))
|
||||
// in(std::set<std::string>{"Ptr", "pyobj", "str", "Array"}, tuple->name))
|
||||
// Ptr, pyobj and str are internal types and have only one overloaded __getitem__
|
||||
return nullptr;
|
||||
// if (in(ctx->cache->classes[tuple->name].methods, "__getitem__")) {
|
||||
// ctx->cache->overloads[ctx->cache->classes[tuple->name].methods["__getitem__"]]
|
||||
// .size() != 1)
|
||||
// return nullptr;
|
||||
// }
|
||||
|
||||
// Extract a static integer value from a compatible expression.
|
||||
auto getInt = [&](int64_t *o, const ExprPtr &e) {
|
||||
|
@ -1121,30 +1115,32 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
|
|||
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__init__"), expr->args))),
|
||||
clone(var)));
|
||||
}
|
||||
} else if (auto pc = callee->getPartial()) {
|
||||
ExprPtr var = N<IdExpr>(partialVar = ctx->cache->getTemporaryVar("pt"));
|
||||
expr->expr = transform(N<StmtExpr>(N<AssignStmt>(clone(var), expr->expr),
|
||||
N<IdExpr>(pc->func->ast->name)));
|
||||
calleeFn = expr->expr->type->getFunc();
|
||||
for (int i = 0, j = 0; i < pc->known.size(); i++)
|
||||
if (pc->func->ast->args[i].generic) {
|
||||
if (pc->known[i])
|
||||
unify(calleeFn->funcGenerics[j].type, pc->func->funcGenerics[j].type);
|
||||
j++;
|
||||
}
|
||||
known = pc->known;
|
||||
seqassert(calleeFn, "not a function: {}", expr->expr->type->toString());
|
||||
} else if (!callee->getFunc()) {
|
||||
// Case 3: callee is not a named function. Route it through a __call__ method.
|
||||
ExprPtr newCall = N<CallExpr>(N<DotExpr>(expr->expr, "__call__"), expr->args);
|
||||
return transform(newCall, false, allowVoidExpr);
|
||||
} else {
|
||||
auto pc = callee->getPartial();
|
||||
if (pc) {
|
||||
ExprPtr var = N<IdExpr>(partialVar = ctx->cache->getTemporaryVar("pt"));
|
||||
expr->expr = transform(N<StmtExpr>(N<AssignStmt>(clone(var), expr->expr),
|
||||
N<IdExpr>(pc->func->ast->name)));
|
||||
calleeFn = expr->expr->type->getFunc();
|
||||
for (int i = 0, j = 0; i < pc->known.size(); i++)
|
||||
if (pc->func->ast->args[i].generic) {
|
||||
if (pc->known[i])
|
||||
unify(calleeFn->funcGenerics[j].type, pc->func->funcGenerics[j].type);
|
||||
j++;
|
||||
}
|
||||
known = pc->known;
|
||||
seqassert(calleeFn, "not a function: {}", expr->expr->type->toString());
|
||||
} else if (!callee->getFunc()) {
|
||||
// Case 3: callee is not a named function. Route it through a __call__ method.
|
||||
ExprPtr newCall = N<CallExpr>(N<DotExpr>(expr->expr, "__call__"), expr->args);
|
||||
return transform(newCall, false, allowVoidExpr);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle named and default arguments
|
||||
std::vector<CallExpr::Arg> args;
|
||||
std::vector<ExprPtr> typeArgs;
|
||||
int typeArgCount = 0;
|
||||
// bool isPartial = false;
|
||||
int ellipsisStage = -1;
|
||||
auto newMask = std::vector<char>(calleeFn->ast->args.size(), 1);
|
||||
auto getPartialArg = [&](int pi) {
|
||||
|
@ -1383,12 +1379,14 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
|
|||
N<CallExpr>(N<IdExpr>(partialTypeName), newArgs)),
|
||||
N<IdExpr>(var));
|
||||
}
|
||||
const_cast<StmtExpr *>(call->getStmtExpr())
|
||||
->setAttr(ir::PartialFunctionAttribute::AttributeName);
|
||||
call = transform(call, false, allowVoidExpr);
|
||||
seqassert(call->type->getRecord() &&
|
||||
startswith(call->type->getRecord()->name, partialTypeName) &&
|
||||
!call->type->getPartial(),
|
||||
"bad partial transformation");
|
||||
call->type = N<PartialType>(call->type->getRecord(), calleeFn, newMask);
|
||||
// seqassert(call->type->getRecord() &&
|
||||
// startswith(call->type->getRecord()->name, partialTypeName) &&
|
||||
// !call->type->getPartial(),
|
||||
// "bad partial transformation");
|
||||
// call->type = N<PartialType>(call->type->getRecord(), calleeFn, newMask);
|
||||
return call;
|
||||
} else {
|
||||
// Case 2. Normal function call.
|
||||
|
@ -1636,9 +1634,12 @@ std::string TypecheckVisitor::generatePartialStub(const std::vector<char> &mask,
|
|||
else if (!fn->ast->args[i].generic)
|
||||
tupleSize++;
|
||||
auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name);
|
||||
if (!ctx->find(typeName))
|
||||
if (!ctx->find(typeName)) {
|
||||
ctx->cache->partials[typeName] = {
|
||||
std::static_pointer_cast<types::FuncType>(fn->shared_from_this()), mask};
|
||||
// 2 for .starArgs and .kwstarArgs (empty tuples if fn does not have them)
|
||||
generateTupleStub(tupleSize + 2, typeName, {}, false);
|
||||
}
|
||||
return typeName;
|
||||
}
|
||||
|
||||
|
@ -1710,12 +1711,14 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
|
|||
N<CallExpr>(N<IdExpr>(partialTypeName), N<TupleExpr>(),
|
||||
N<CallExpr>(N<IdExpr>(kwName)))),
|
||||
N<IdExpr>(var));
|
||||
const_cast<StmtExpr *>(call->getStmtExpr())
|
||||
->setAttr(ir::PartialFunctionAttribute::AttributeName);
|
||||
call = transform(call, false, allowVoidExpr);
|
||||
seqassert(call->type->getRecord() &&
|
||||
startswith(call->type->getRecord()->name, partialTypeName) &&
|
||||
!call->type->getPartial(),
|
||||
"bad partial transformation");
|
||||
call->type = N<PartialType>(call->type->getRecord(), fn, mask);
|
||||
// seqassert(call->type->getRecord() &&
|
||||
// startswith(call->type->getRecord()->name, partialTypeName) &&
|
||||
// !call->type->getPartial(),
|
||||
// "bad partial transformation");
|
||||
// call->type = N<PartialType>(call->type->getRecord(), fn, mask);
|
||||
return call;
|
||||
}
|
||||
|
||||
|
|
|
@ -508,6 +508,13 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
|
|||
else
|
||||
typ = std::make_shared<ClassType>(
|
||||
stmt->name, ctx->cache->reverseIdentifierLookup[stmt->name]);
|
||||
if (stmt->isRecord() && startswith(stmt->name, TYPE_PARTIAL)) {
|
||||
seqassert(in(ctx->cache->partials, stmt->name),
|
||||
"invalid partial initialization: {}", stmt->name);
|
||||
typ = std::make_shared<PartialType>(typ->getRecord(),
|
||||
ctx->cache->partials[stmt->name].first,
|
||||
ctx->cache->partials[stmt->name].second);
|
||||
}
|
||||
typ->setSrcInfo(stmt->getSrcInfo());
|
||||
ctx->add(TypecheckItem::Type, stmt->name, typ);
|
||||
ctx->bases[0].visitedAsts[stmt->name] = {TypecheckItem::Type, typ};
|
||||
|
|
Loading…
Reference in New Issue