Fix partial handling in IR

pull/12/head
Ibrahim Numanagić 2022-01-25 18:55:00 -08:00
parent 8dd7d2e0ea
commit 4b2dfaf28f
9 changed files with 94 additions and 42 deletions

View File

@ -9,6 +9,9 @@
#include "codon/parser/visitors/simplify/simplify.h"
#include "codon/parser/visitors/translate/translate.h"
#include "codon/parser/visitors/typecheck/typecheck.h"
#include "codon/sir/util/operator.h"
#include "codon/sir/util/irtools.h"
namespace codon {
@ -110,9 +113,22 @@ Compiler::parseCode(const std::string &file, const std::string &code, int startL
const std::unordered_map<std::string, std::string> &defines) {
return parse(/*isCode=*/true, file, code, startLine, testFlags, defines);
}
struct DummyOp : public codon::ir::util::Operator {
void handle(codon::ir::CallInstr *x) override {
auto *M = x->getModule();
auto *func = codon::ir::util::getFunc(x->getCallee());
if (!func || func->getUnmangledName() != "foo")
return;
auto fn = M->getOrRealizeFunc("bar", {x->front()->getType()}, {});
seqassert(fn, "did not succeed");
auto result = codon::ir::util::call(fn, {x->front()});
x->replaceAll(result);
}
};
llvm::Error Compiler::compile() {
pm->run(module.get());
auto d = DummyOp();
module->accept(d);
llvisitor->visit(module.get());
return llvm::Error::success();
}

View File

@ -397,11 +397,14 @@ StmtExpr::StmtExpr(std::shared_ptr<Stmt> stmt, std::shared_ptr<Stmt> stmt2,
stmts.push_back(std::move(stmt2));
}
StmtExpr::StmtExpr(const StmtExpr &expr)
: Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)) {}
: Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)),
attributes(expr.attributes) {}
std::string StmtExpr::toString() const {
return wrapType(format("stmt-expr ({}) {}", combine(stmts, " "), expr->toString()));
}
ACCEPT_IMPL(StmtExpr, ASTVisitor);
bool StmtExpr::hasAttr(const std::string &attr) const { return in(attributes, attr); }
void StmtExpr::setAttr(const std::string &attr) { attributes.insert(attr); }
PtrExpr::PtrExpr(ExprPtr expr) : Expr(), expr(std::move(expr)) {}
PtrExpr::PtrExpr(const PtrExpr &expr) : Expr(expr), expr(ast::clone(expr.expr)) {}

View File

@ -600,6 +600,8 @@ struct RangeExpr : public Expr {
struct StmtExpr : public Expr {
std::vector<std::shared_ptr<Stmt>> stmts;
ExprPtr expr;
/// Set of attributes.
std::set<std::string> attributes;
StmtExpr(std::vector<std::shared_ptr<Stmt>> stmts, ExprPtr expr);
StmtExpr(std::shared_ptr<Stmt> stmt, ExprPtr expr);
@ -610,6 +612,10 @@ struct StmtExpr : public Expr {
ACCEPT(ASTVisitor);
const StmtExpr *getStmtExpr() const override { return this; }
/// Attribute helpers
bool hasAttr(const std::string &attr) const;
void setAttr(const std::string &attr);
};
/// Pointer expression (__ptr__(expr)).

View File

@ -42,6 +42,7 @@ std::string SuiteStmt::toString(int indent) const {
}
ACCEPT_IMPL(SuiteStmt, ASTVisitor);
void SuiteStmt::flatten(StmtPtr s, std::vector<StmtPtr> &stmts) {
// WARNING: does not preserve attributes!
if (!s)
return;
auto suite = const_cast<SuiteStmt *>(s->getSuite());

View File

@ -188,6 +188,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
std::shared_ptr<TranslateContext> codegenCtx;
/// Set of function realizations that are to be translated to IR.
std::set<std::pair<std::string, std::string>> pendingRealizations;
/// Mapping of partial record names to function pointers and corresponding masks.
std::unordered_map<std::string, std::pair<types::FuncTypePtr, std::vector<char>>>
partials;
/// Custom operators
std::unordered_map<std::string,

View File

@ -9,6 +9,7 @@
#include "codon/parser/common.h"
#include "codon/parser/peg/peg.h"
#include "codon/parser/visitors/simplify/simplify.h"
#include "codon/sir/attribute.h"
using fmt::format;
@ -187,7 +188,9 @@ void SimplifyVisitor::visit(ListExpr *expr) {
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), clone(it)))));
}
}
resultExpr = N<StmtExpr>(stmts, transform(var));
auto e = N<StmtExpr>(stmts, transform(var));
e->setAttr(ir::ListLiteralAttribute::AttributeName);
resultExpr = e;
ctx->popBlock();
}
@ -207,7 +210,9 @@ void SimplifyVisitor::visit(SetExpr *expr) {
stmts.push_back(transform(
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "add"), clone(it)))));
}
resultExpr = N<StmtExpr>(stmts, transform(var));
auto e = N<StmtExpr>(stmts, transform(var));
e->setAttr(ir::SetLiteralAttribute::AttributeName);
resultExpr = e;
ctx->popBlock();
}
@ -229,7 +234,9 @@ void SimplifyVisitor::visit(DictExpr *expr) {
stmts.push_back(transform(N<ExprStmt>(N<CallExpr>(
N<DotExpr>(clone(var), "__setitem__"), clone(it.key), clone(it.value)))));
}
resultExpr = N<StmtExpr>(stmts, transform(var));
auto e = N<StmtExpr>(stmts, transform(var));
e->setAttr(ir::DictLiteralAttribute::AttributeName);
resultExpr = e;
ctx->popBlock();
}
@ -688,7 +695,9 @@ void SimplifyVisitor::visit(StmtExpr *expr) {
for (auto &s : expr->stmts)
stmts.emplace_back(transform(s));
auto e = transform(expr->expr);
resultExpr = N<StmtExpr>(stmts, e);
auto s = N<StmtExpr>(stmts, e);
s->attributes = expr->attributes;
resultExpr = s;
}
/**************************************************************************************/

View File

@ -209,6 +209,10 @@ void TranslateVisitor::visit(StmtExpr *expr) {
transform(s);
ctx->popSeries();
result = make<ir::FlowInstr>(expr, bodySeries, transform(expr->expr));
for (auto &a: expr->attributes) {
// if (a == ir::ListLiteralAttribute::AttributeName)
// result->setAttribute(ir::ListLiteralAttribute);
}
}
/************************************************************************************/

View File

@ -10,6 +10,7 @@
#include "codon/parser/common.h"
#include "codon/parser/visitors/simplify/simplify.h"
#include "codon/parser/visitors/typecheck/typecheck.h"
#include "codon/sir/attribute.h"
using fmt::format;
@ -739,14 +740,7 @@ ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &e
if (!tuple->getRecord())
return nullptr;
if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_PARTIAL))
// in(std::set<std::string>{"Ptr", "pyobj", "str", "Array"}, tuple->name))
// Ptr, pyobj and str are internal types and have only one overloaded __getitem__
return nullptr;
// if (in(ctx->cache->classes[tuple->name].methods, "__getitem__")) {
// ctx->cache->overloads[ctx->cache->classes[tuple->name].methods["__getitem__"]]
// .size() != 1)
// return nullptr;
// }
// Extract a static integer value from a compatible expression.
auto getInt = [&](int64_t *o, const ExprPtr &e) {
@ -1121,30 +1115,32 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__init__"), expr->args))),
clone(var)));
}
} else if (auto pc = callee->getPartial()) {
ExprPtr var = N<IdExpr>(partialVar = ctx->cache->getTemporaryVar("pt"));
expr->expr = transform(N<StmtExpr>(N<AssignStmt>(clone(var), expr->expr),
N<IdExpr>(pc->func->ast->name)));
calleeFn = expr->expr->type->getFunc();
for (int i = 0, j = 0; i < pc->known.size(); i++)
if (pc->func->ast->args[i].generic) {
if (pc->known[i])
unify(calleeFn->funcGenerics[j].type, pc->func->funcGenerics[j].type);
j++;
}
known = pc->known;
seqassert(calleeFn, "not a function: {}", expr->expr->type->toString());
} else if (!callee->getFunc()) {
// Case 3: callee is not a named function. Route it through a __call__ method.
ExprPtr newCall = N<CallExpr>(N<DotExpr>(expr->expr, "__call__"), expr->args);
return transform(newCall, false, allowVoidExpr);
} else {
auto pc = callee->getPartial();
if (pc) {
ExprPtr var = N<IdExpr>(partialVar = ctx->cache->getTemporaryVar("pt"));
expr->expr = transform(N<StmtExpr>(N<AssignStmt>(clone(var), expr->expr),
N<IdExpr>(pc->func->ast->name)));
calleeFn = expr->expr->type->getFunc();
for (int i = 0, j = 0; i < pc->known.size(); i++)
if (pc->func->ast->args[i].generic) {
if (pc->known[i])
unify(calleeFn->funcGenerics[j].type, pc->func->funcGenerics[j].type);
j++;
}
known = pc->known;
seqassert(calleeFn, "not a function: {}", expr->expr->type->toString());
} else if (!callee->getFunc()) {
// Case 3: callee is not a named function. Route it through a __call__ method.
ExprPtr newCall = N<CallExpr>(N<DotExpr>(expr->expr, "__call__"), expr->args);
return transform(newCall, false, allowVoidExpr);
}
}
// Handle named and default arguments
std::vector<CallExpr::Arg> args;
std::vector<ExprPtr> typeArgs;
int typeArgCount = 0;
// bool isPartial = false;
int ellipsisStage = -1;
auto newMask = std::vector<char>(calleeFn->ast->args.size(), 1);
auto getPartialArg = [&](int pi) {
@ -1383,12 +1379,14 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
N<CallExpr>(N<IdExpr>(partialTypeName), newArgs)),
N<IdExpr>(var));
}
const_cast<StmtExpr *>(call->getStmtExpr())
->setAttr(ir::PartialFunctionAttribute::AttributeName);
call = transform(call, false, allowVoidExpr);
seqassert(call->type->getRecord() &&
startswith(call->type->getRecord()->name, partialTypeName) &&
!call->type->getPartial(),
"bad partial transformation");
call->type = N<PartialType>(call->type->getRecord(), calleeFn, newMask);
// seqassert(call->type->getRecord() &&
// startswith(call->type->getRecord()->name, partialTypeName) &&
// !call->type->getPartial(),
// "bad partial transformation");
// call->type = N<PartialType>(call->type->getRecord(), calleeFn, newMask);
return call;
} else {
// Case 2. Normal function call.
@ -1636,9 +1634,12 @@ std::string TypecheckVisitor::generatePartialStub(const std::vector<char> &mask,
else if (!fn->ast->args[i].generic)
tupleSize++;
auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name);
if (!ctx->find(typeName))
if (!ctx->find(typeName)) {
ctx->cache->partials[typeName] = {
std::static_pointer_cast<types::FuncType>(fn->shared_from_this()), mask};
// 2 for .starArgs and .kwstarArgs (empty tuples if fn does not have them)
generateTupleStub(tupleSize + 2, typeName, {}, false);
}
return typeName;
}
@ -1710,12 +1711,14 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
N<CallExpr>(N<IdExpr>(partialTypeName), N<TupleExpr>(),
N<CallExpr>(N<IdExpr>(kwName)))),
N<IdExpr>(var));
const_cast<StmtExpr *>(call->getStmtExpr())
->setAttr(ir::PartialFunctionAttribute::AttributeName);
call = transform(call, false, allowVoidExpr);
seqassert(call->type->getRecord() &&
startswith(call->type->getRecord()->name, partialTypeName) &&
!call->type->getPartial(),
"bad partial transformation");
call->type = N<PartialType>(call->type->getRecord(), fn, mask);
// seqassert(call->type->getRecord() &&
// startswith(call->type->getRecord()->name, partialTypeName) &&
// !call->type->getPartial(),
// "bad partial transformation");
// call->type = N<PartialType>(call->type->getRecord(), fn, mask);
return call;
}

View File

@ -508,6 +508,13 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
else
typ = std::make_shared<ClassType>(
stmt->name, ctx->cache->reverseIdentifierLookup[stmt->name]);
if (stmt->isRecord() && startswith(stmt->name, TYPE_PARTIAL)) {
seqassert(in(ctx->cache->partials, stmt->name),
"invalid partial initialization: {}", stmt->name);
typ = std::make_shared<PartialType>(typ->getRecord(),
ctx->cache->partials[stmt->name].first,
ctx->cache->partials[stmt->name].second);
}
typ->setSrcInfo(stmt->getSrcInfo());
ctx->add(TypecheckItem::Type, stmt->name, typ);
ctx->bases[0].visitedAsts[stmt->name] = {TypecheckItem::Type, typ};