Add and fix IR attributes; Fix Counter test

pull/12/head
Ibrahim Numanagić 2022-02-19 12:39:24 -08:00
parent d59ec14cc6
commit 354e7f80fe
7 changed files with 142 additions and 35 deletions

View File

@ -18,7 +18,7 @@ namespace ast {
Expr::Expr()
: type(nullptr), isTypeExpr(false), staticValue(StaticValue::NOT_STATIC),
done(false) {}
done(false), attributes(0) {}
types::TypePtr Expr::getType() const { return type; }
void Expr::setType(types::TypePtr t) { this->type = std::move(t); }
bool Expr::isType() const { return isTypeExpr; }
@ -31,6 +31,8 @@ std::string Expr::wrapType(const std::string &sexpr) const {
done ? "*" : "");
}
bool Expr::isStatic() const { return staticValue.type != StaticValue::NOT_STATIC; }
bool Expr::hasAttr(int attr) const { return (attributes & (1 << attr)); }
void Expr::setAttr(int attr) { attributes |= (1 << attr); }
StaticValue::StaticValue(StaticValue::Type t) : value(), type(t), evaluated(false) {}
StaticValue::StaticValue(int64_t i) : value(i), type(INT), evaluated(true) {}
@ -397,14 +399,11 @@ StmtExpr::StmtExpr(std::shared_ptr<Stmt> stmt, std::shared_ptr<Stmt> stmt2,
stmts.push_back(std::move(stmt2));
}
StmtExpr::StmtExpr(const StmtExpr &expr)
: Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)),
attributes(expr.attributes) {}
: Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)) {}
std::string StmtExpr::toString() const {
return wrapType(format("stmt-expr ({}) {}", combine(stmts, " "), expr->toString()));
}
ACCEPT_IMPL(StmtExpr, ASTVisitor);
bool StmtExpr::hasAttr(const std::string &attr) const { return in(attributes, attr); }
void StmtExpr::setAttr(const std::string &attr) { attributes.insert(attr); }
PtrExpr::PtrExpr(ExprPtr expr) : Expr(), expr(std::move(expr)) {}
PtrExpr::PtrExpr(const PtrExpr &expr) : Expr(expr), expr(ast::clone(expr.expr)) {}

View File

@ -78,6 +78,9 @@ struct Expr : public codon::SrcObject {
/// type-checking procedure was successful).
bool done;
/// Set of attributes.
int attributes;
public:
Expr();
Expr(const Expr &expr) = default;
@ -124,6 +127,10 @@ public:
virtual const TupleExpr *getTuple() const { return nullptr; }
virtual const UnaryExpr *getUnary() const { return nullptr; }
/// Attribute helpers
bool hasAttr(int attr) const;
void setAttr(int attr);
protected:
/// Add a type to S-expression string.
std::string wrapType(const std::string &sexpr) const;
@ -600,8 +607,6 @@ struct RangeExpr : public Expr {
struct StmtExpr : public Expr {
std::vector<std::shared_ptr<Stmt>> stmts;
ExprPtr expr;
/// Set of attributes.
std::set<std::string> attributes;
StmtExpr(std::vector<std::shared_ptr<Stmt>> stmts, ExprPtr expr);
StmtExpr(std::shared_ptr<Stmt> stmt, ExprPtr expr);
@ -612,10 +617,6 @@ struct StmtExpr : public Expr {
ACCEPT(ASTVisitor);
const StmtExpr *getStmtExpr() const override { return this; }
/// Attribute helpers
bool hasAttr(const std::string &attr) const;
void setAttr(const std::string &attr);
};
/// Pointer expression (__ptr__(expr)).
@ -672,5 +673,15 @@ struct StackAllocExpr : Expr {
#undef ACCEPT
enum ExprAttr {
SequenceItem,
StarSequenceItem,
List,
Set,
Dict,
Partial,
__LAST__
};
} // namespace ast
} // namespace codon

View File

@ -9,7 +9,6 @@
#include "codon/parser/common.h"
#include "codon/parser/peg/peg.h"
#include "codon/parser/visitors/simplify/simplify.h"
#include "codon/sir/attribute.h"
using fmt::format;
@ -32,6 +31,8 @@ ExprPtr SimplifyVisitor::transform(const ExprPtr &expr, bool allowTypes,
ctx->canAssign = oldAssign;
if (!allowTypes && v.resultExpr && v.resultExpr->isType())
error("unexpected type expression");
if (v.resultExpr)
v.resultExpr->attributes |= expr->attributes;
return v.resultExpr;
}
@ -180,16 +181,20 @@ void SimplifyVisitor::visit(ListExpr *expr) {
for (const auto &it : expr->items) {
if (auto star = it->getStar()) {
ExprPtr forVar = N<IdExpr>(ctx->cache->getTemporaryVar("it"));
auto st = star->what->clone();
st->setAttr(ExprAttr::StarSequenceItem);
stmts.push_back(transform(N<ForStmt>(
clone(forVar), star->what->clone(),
clone(forVar), st,
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), clone(forVar))))));
} else {
stmts.push_back(transform(
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), clone(it)))));
auto st = clone(it);
st->setAttr(ExprAttr::SequenceItem);
stmts.push_back(
transform(N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), st))));
}
}
auto e = N<StmtExpr>(stmts, transform(var));
e->setAttr(ir::ListLiteralAttribute::AttributeName);
e->setAttr(ExprAttr::List);
resultExpr = e;
ctx->popBlock();
}
@ -203,15 +208,19 @@ void SimplifyVisitor::visit(SetExpr *expr) {
for (auto &it : expr->items)
if (auto star = it->getStar()) {
ExprPtr forVar = N<IdExpr>(ctx->cache->getTemporaryVar("it"));
auto st = star->what->clone();
st->setAttr(ExprAttr::StarSequenceItem);
stmts.push_back(transform(N<ForStmt>(
clone(forVar), star->what->clone(),
clone(forVar), st,
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "add"), clone(forVar))))));
} else {
stmts.push_back(transform(
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "add"), clone(it)))));
auto st = clone(it);
st->setAttr(ExprAttr::SequenceItem);
stmts.push_back(
transform(N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "add"), st))));
}
auto e = N<StmtExpr>(stmts, transform(var));
e->setAttr(ir::SetLiteralAttribute::AttributeName);
e->setAttr(ExprAttr::Set);
resultExpr = e;
ctx->popBlock();
}
@ -225,17 +234,23 @@ void SimplifyVisitor::visit(DictExpr *expr) {
for (auto &it : expr->items)
if (auto star = CAST(it.value, KeywordStarExpr)) {
ExprPtr forVar = N<IdExpr>(ctx->cache->getTemporaryVar("it"));
auto st = star->what->clone();
st->setAttr(ExprAttr::StarSequenceItem);
stmts.push_back(transform(N<ForStmt>(
clone(forVar), N<CallExpr>(N<DotExpr>(star->what->clone(), "items")),
clone(forVar), N<CallExpr>(N<DotExpr>(st, "items")),
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__setitem__"),
N<IndexExpr>(clone(forVar), N<IntExpr>(0)),
N<IndexExpr>(clone(forVar), N<IntExpr>(1)))))));
} else {
stmts.push_back(transform(N<ExprStmt>(N<CallExpr>(
N<DotExpr>(clone(var), "__setitem__"), clone(it.key), clone(it.value)))));
auto k = clone(it.key);
k->setAttr(ExprAttr::SequenceItem);
auto v = clone(it.value);
v->setAttr(ExprAttr::SequenceItem);
stmts.push_back(transform(
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__setitem__"), k, v))));
}
auto e = N<StmtExpr>(stmts, transform(var));
e->setAttr(ir::DictLiteralAttribute::AttributeName);
e->setAttr(ExprAttr::Dict);
resultExpr = e;
ctx->popBlock();
}

View File

@ -57,8 +57,81 @@ ir::Func *TranslateVisitor::apply(Cache *cache, StmtPtr stmts) {
ir::Value *TranslateVisitor::transform(const ExprPtr &expr) {
TranslateVisitor v(ctx);
v.setSrcInfo(expr->getSrcInfo());
types::PartialType *p = nullptr;
if (expr->attributes) {
if (expr->hasAttr(ExprAttr::List) || expr->hasAttr(ExprAttr::Set) ||
expr->hasAttr(ExprAttr::Dict) || expr->hasAttr(ExprAttr::Partial)) {
ctx->seqItems.push_back(std::vector<std::pair<ExprAttr, ir::Value *>>());
}
if (expr->hasAttr(ExprAttr::Partial))
p = expr->type->getPartial().get();
// LOG("{} {}: {}", std::string(ctx->seqItems.size(), ' '), expr->attributes, expr->toString());
}
expr->accept(v);
return v.result;
ir::Value *ir = v.result;
if (expr->attributes) {
if (expr->hasAttr(ExprAttr::List) || expr->hasAttr(ExprAttr::Set)) {
std::vector<ir::LiteralElement> v;
for (auto &p : ctx->seqItems.back()) {
seqassert(p.first <= ExprAttr::StarSequenceItem, "invalid list/set element");
v.push_back(
ir::LiteralElement{p.second, p.first == ExprAttr::StarSequenceItem});
}
if (expr->hasAttr(ExprAttr::List))
ir->setAttribute(std::make_unique<ir::ListLiteralAttribute>(v));
else
ir->setAttribute(std::make_unique<ir::SetLiteralAttribute>(v));
ctx->seqItems.pop_back();
}
if (expr->hasAttr(ExprAttr::Dict)) {
std::vector<ir::DictLiteralAttribute::KeyValuePair> v;
LOG("{} {}", expr->toString(), expr->getSrcInfo());
for (int pi = 0; pi < ctx->seqItems.back().size(); pi++) {
auto &p = ctx->seqItems.back()[pi];
if (p.first == ExprAttr::StarSequenceItem) {
v.push_back({p.second, nullptr});
} else {
seqassert(p.first == ExprAttr::SequenceItem &&
pi + 1 < ctx->seqItems.back().size() &&
ctx->seqItems.back()[pi + 1].first == ExprAttr::SequenceItem,
"invalid dict element");
v.push_back({p.second, ctx->seqItems.back()[pi + 1].second});
pi++;
}
}
ir->setAttribute(std::make_unique<ir::DictLiteralAttribute>(v));
ctx->seqItems.pop_back();
}
if (expr->hasAttr(ExprAttr::Partial)) {
std::vector<ir::Value *> v;
seqassert(p, "invalid partial element");
int j = 0;
for (int i = 0; i < p->known.size(); i++) {
if (p->known[i] && !p->func->ast->args[i].generic) {
seqassert(j < ctx->seqItems.back().size() &&
ctx->seqItems.back()[j].first == ExprAttr::SequenceItem,
"invalid partial element");
v.push_back(ctx->seqItems.back()[j++].second);
} else if (!p->func->ast->args[i].generic) {
v.push_back({nullptr});
}
}
// seqassert(j == ctx->seqItems.back().size(), "invalid partial element");
ir->setAttribute(std::make_unique<ir::PartialFunctionAttribute>(nullptr, v));
ctx->seqItems.pop_back();
}
if (expr->hasAttr(ExprAttr::SequenceItem)) {
ctx->seqItems.back().push_back({ExprAttr::SequenceItem, ir});
}
if (expr->hasAttr(ExprAttr::StarSequenceItem)) {
ctx->seqItems.back().push_back({ExprAttr::StarSequenceItem, ir});
}
}
return ir;
}
void TranslateVisitor::defaultVisit(Expr *n) {
@ -211,10 +284,6 @@ void TranslateVisitor::visit(StmtExpr *expr) {
transform(s);
ctx->popSeries();
result = make<ir::FlowInstr>(expr, bodySeries, transform(expr->expr));
for (auto &a : expr->attributes) {
// if (a == ir::ListLiteralAttribute::AttributeName)
// result->setAttribute(ir::ListLiteralAttribute);
}
}
/************************************************************************************/

View File

@ -50,6 +50,8 @@ struct TranslateContext : public Context<TranslateItem> {
std::vector<codon::ir::BodiedFunc *> bases;
/// Stack of IR series (blocks).
std::vector<codon::ir::SeriesFlow *> series;
/// Stack of sequence items for attribute initialization.
std::vector<std::vector<std::pair<ExprAttr, ir::Value*>>> seqItems;
public:
TranslateContext(Cache *cache);

View File

@ -36,8 +36,10 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr, bool allowTypes, bool allowVo
ctx->allowActivation = false;
v.setSrcInfo(expr->getSrcInfo());
expr->accept(v);
if (v.resultExpr)
if (v.resultExpr) {
v.resultExpr->attributes |= expr->attributes;
expr = v.resultExpr;
}
seqassert(expr->type, "type not set for {}", expr->toString());
unify(typ, expr->type);
if (disableActivation)
@ -1246,6 +1248,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
return -1;
},
known);
bool hasPartialArgs = partialStarArgs != nullptr,
hasPartialKwargs = partialKwstarArgs != nullptr;
if (isPartial) {
deactivateUnbounds(expr->args.back().value->getType().get());
expr->args.pop_back();
@ -1361,10 +1365,16 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
deactivateUnbounds(calleeFn.get());
std::vector<ExprPtr> newArgs;
for (auto &r : args)
if (!r.value->getEllipsis())
if (!r.value->getEllipsis()) {
newArgs.push_back(r.value);
newArgs.back()->setAttr(ExprAttr::SequenceItem);
}
newArgs.push_back(partialStarArgs);
if (hasPartialArgs)
newArgs.back()->setAttr(ExprAttr::SequenceItem);
newArgs.push_back(partialKwstarArgs);
if (hasPartialKwargs)
newArgs.back()->setAttr(ExprAttr::SequenceItem);
std::string var = ctx->cache->getTemporaryVar("partial");
ExprPtr call = nullptr;
@ -1379,8 +1389,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
N<CallExpr>(N<IdExpr>(partialTypeName), newArgs)),
N<IdExpr>(var));
}
const_cast<StmtExpr *>(call->getStmtExpr())
->setAttr(ir::PartialFunctionAttribute::AttributeName);
call->setAttr(ExprAttr::Partial);
call = transform(call, false, allowVoidExpr);
seqassert(call->type->getPartial(), "expected partial type");
return call;
@ -1722,8 +1731,7 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
N<CallExpr>(N<IdExpr>(partialTypeName), N<TupleExpr>(),
N<CallExpr>(N<IdExpr>(kwName)))),
N<IdExpr>(var));
const_cast<StmtExpr *>(call->getStmtExpr())
->setAttr(ir::PartialFunctionAttribute::AttributeName);
call->setAttr(ExprAttr::Partial);
call = transform(call, false, allowVoidExpr);
seqassert(call->type->getPartial(), "expected partial type");
return call;

View File

@ -339,6 +339,9 @@ class Counter[T](Dict[T,int]):
result |= other
return result
def __dict_do_op_throws__[F, Z](self, key: T, other: Z, op: F):
self.__dict_do_op__(key, other, 0, op)
@extend
class Dict: