1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Add and fix IR attributes; Fix Counter test

This commit is contained in:
Ibrahim Numanagić 2022-02-19 12:39:24 -08:00
parent d59ec14cc6
commit 354e7f80fe
7 changed files with 142 additions and 35 deletions

View File

@ -18,7 +18,7 @@ namespace ast {
Expr::Expr() Expr::Expr()
: type(nullptr), isTypeExpr(false), staticValue(StaticValue::NOT_STATIC), : type(nullptr), isTypeExpr(false), staticValue(StaticValue::NOT_STATIC),
done(false) {} done(false), attributes(0) {}
types::TypePtr Expr::getType() const { return type; } types::TypePtr Expr::getType() const { return type; }
void Expr::setType(types::TypePtr t) { this->type = std::move(t); } void Expr::setType(types::TypePtr t) { this->type = std::move(t); }
bool Expr::isType() const { return isTypeExpr; } bool Expr::isType() const { return isTypeExpr; }
@ -31,6 +31,8 @@ std::string Expr::wrapType(const std::string &sexpr) const {
done ? "*" : ""); done ? "*" : "");
} }
bool Expr::isStatic() const { return staticValue.type != StaticValue::NOT_STATIC; } bool Expr::isStatic() const { return staticValue.type != StaticValue::NOT_STATIC; }
bool Expr::hasAttr(int attr) const { return (attributes & (1 << attr)); }
void Expr::setAttr(int attr) { attributes |= (1 << attr); }
StaticValue::StaticValue(StaticValue::Type t) : value(), type(t), evaluated(false) {} StaticValue::StaticValue(StaticValue::Type t) : value(), type(t), evaluated(false) {}
StaticValue::StaticValue(int64_t i) : value(i), type(INT), evaluated(true) {} StaticValue::StaticValue(int64_t i) : value(i), type(INT), evaluated(true) {}
@ -397,14 +399,11 @@ StmtExpr::StmtExpr(std::shared_ptr<Stmt> stmt, std::shared_ptr<Stmt> stmt2,
stmts.push_back(std::move(stmt2)); stmts.push_back(std::move(stmt2));
} }
StmtExpr::StmtExpr(const StmtExpr &expr) StmtExpr::StmtExpr(const StmtExpr &expr)
: Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)), : Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)) {}
attributes(expr.attributes) {}
std::string StmtExpr::toString() const { std::string StmtExpr::toString() const {
return wrapType(format("stmt-expr ({}) {}", combine(stmts, " "), expr->toString())); return wrapType(format("stmt-expr ({}) {}", combine(stmts, " "), expr->toString()));
} }
ACCEPT_IMPL(StmtExpr, ASTVisitor); ACCEPT_IMPL(StmtExpr, ASTVisitor);
bool StmtExpr::hasAttr(const std::string &attr) const { return in(attributes, attr); }
void StmtExpr::setAttr(const std::string &attr) { attributes.insert(attr); }
PtrExpr::PtrExpr(ExprPtr expr) : Expr(), expr(std::move(expr)) {} PtrExpr::PtrExpr(ExprPtr expr) : Expr(), expr(std::move(expr)) {}
PtrExpr::PtrExpr(const PtrExpr &expr) : Expr(expr), expr(ast::clone(expr.expr)) {} PtrExpr::PtrExpr(const PtrExpr &expr) : Expr(expr), expr(ast::clone(expr.expr)) {}

View File

@ -78,6 +78,9 @@ struct Expr : public codon::SrcObject {
/// type-checking procedure was successful). /// type-checking procedure was successful).
bool done; bool done;
/// Set of attributes.
int attributes;
public: public:
Expr(); Expr();
Expr(const Expr &expr) = default; Expr(const Expr &expr) = default;
@ -124,6 +127,10 @@ public:
virtual const TupleExpr *getTuple() const { return nullptr; } virtual const TupleExpr *getTuple() const { return nullptr; }
virtual const UnaryExpr *getUnary() const { return nullptr; } virtual const UnaryExpr *getUnary() const { return nullptr; }
/// Attribute helpers
bool hasAttr(int attr) const;
void setAttr(int attr);
protected: protected:
/// Add a type to S-expression string. /// Add a type to S-expression string.
std::string wrapType(const std::string &sexpr) const; std::string wrapType(const std::string &sexpr) const;
@ -600,8 +607,6 @@ struct RangeExpr : public Expr {
struct StmtExpr : public Expr { struct StmtExpr : public Expr {
std::vector<std::shared_ptr<Stmt>> stmts; std::vector<std::shared_ptr<Stmt>> stmts;
ExprPtr expr; ExprPtr expr;
/// Set of attributes.
std::set<std::string> attributes;
StmtExpr(std::vector<std::shared_ptr<Stmt>> stmts, ExprPtr expr); StmtExpr(std::vector<std::shared_ptr<Stmt>> stmts, ExprPtr expr);
StmtExpr(std::shared_ptr<Stmt> stmt, ExprPtr expr); StmtExpr(std::shared_ptr<Stmt> stmt, ExprPtr expr);
@ -612,10 +617,6 @@ struct StmtExpr : public Expr {
ACCEPT(ASTVisitor); ACCEPT(ASTVisitor);
const StmtExpr *getStmtExpr() const override { return this; } const StmtExpr *getStmtExpr() const override { return this; }
/// Attribute helpers
bool hasAttr(const std::string &attr) const;
void setAttr(const std::string &attr);
}; };
/// Pointer expression (__ptr__(expr)). /// Pointer expression (__ptr__(expr)).
@ -672,5 +673,15 @@ struct StackAllocExpr : Expr {
#undef ACCEPT #undef ACCEPT
enum ExprAttr {
SequenceItem,
StarSequenceItem,
List,
Set,
Dict,
Partial,
__LAST__
};
} // namespace ast } // namespace ast
} // namespace codon } // namespace codon

View File

@ -9,7 +9,6 @@
#include "codon/parser/common.h" #include "codon/parser/common.h"
#include "codon/parser/peg/peg.h" #include "codon/parser/peg/peg.h"
#include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/simplify/simplify.h"
#include "codon/sir/attribute.h"
using fmt::format; using fmt::format;
@ -32,6 +31,8 @@ ExprPtr SimplifyVisitor::transform(const ExprPtr &expr, bool allowTypes,
ctx->canAssign = oldAssign; ctx->canAssign = oldAssign;
if (!allowTypes && v.resultExpr && v.resultExpr->isType()) if (!allowTypes && v.resultExpr && v.resultExpr->isType())
error("unexpected type expression"); error("unexpected type expression");
if (v.resultExpr)
v.resultExpr->attributes |= expr->attributes;
return v.resultExpr; return v.resultExpr;
} }
@ -180,16 +181,20 @@ void SimplifyVisitor::visit(ListExpr *expr) {
for (const auto &it : expr->items) { for (const auto &it : expr->items) {
if (auto star = it->getStar()) { if (auto star = it->getStar()) {
ExprPtr forVar = N<IdExpr>(ctx->cache->getTemporaryVar("it")); ExprPtr forVar = N<IdExpr>(ctx->cache->getTemporaryVar("it"));
auto st = star->what->clone();
st->setAttr(ExprAttr::StarSequenceItem);
stmts.push_back(transform(N<ForStmt>( stmts.push_back(transform(N<ForStmt>(
clone(forVar), star->what->clone(), clone(forVar), st,
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), clone(forVar)))))); N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), clone(forVar))))));
} else { } else {
stmts.push_back(transform( auto st = clone(it);
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), clone(it))))); st->setAttr(ExprAttr::SequenceItem);
stmts.push_back(
transform(N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), st))));
} }
} }
auto e = N<StmtExpr>(stmts, transform(var)); auto e = N<StmtExpr>(stmts, transform(var));
e->setAttr(ir::ListLiteralAttribute::AttributeName); e->setAttr(ExprAttr::List);
resultExpr = e; resultExpr = e;
ctx->popBlock(); ctx->popBlock();
} }
@ -203,15 +208,19 @@ void SimplifyVisitor::visit(SetExpr *expr) {
for (auto &it : expr->items) for (auto &it : expr->items)
if (auto star = it->getStar()) { if (auto star = it->getStar()) {
ExprPtr forVar = N<IdExpr>(ctx->cache->getTemporaryVar("it")); ExprPtr forVar = N<IdExpr>(ctx->cache->getTemporaryVar("it"));
auto st = star->what->clone();
st->setAttr(ExprAttr::StarSequenceItem);
stmts.push_back(transform(N<ForStmt>( stmts.push_back(transform(N<ForStmt>(
clone(forVar), star->what->clone(), clone(forVar), st,
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "add"), clone(forVar)))))); N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "add"), clone(forVar))))));
} else { } else {
stmts.push_back(transform( auto st = clone(it);
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "add"), clone(it))))); st->setAttr(ExprAttr::SequenceItem);
stmts.push_back(
transform(N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "add"), st))));
} }
auto e = N<StmtExpr>(stmts, transform(var)); auto e = N<StmtExpr>(stmts, transform(var));
e->setAttr(ir::SetLiteralAttribute::AttributeName); e->setAttr(ExprAttr::Set);
resultExpr = e; resultExpr = e;
ctx->popBlock(); ctx->popBlock();
} }
@ -225,17 +234,23 @@ void SimplifyVisitor::visit(DictExpr *expr) {
for (auto &it : expr->items) for (auto &it : expr->items)
if (auto star = CAST(it.value, KeywordStarExpr)) { if (auto star = CAST(it.value, KeywordStarExpr)) {
ExprPtr forVar = N<IdExpr>(ctx->cache->getTemporaryVar("it")); ExprPtr forVar = N<IdExpr>(ctx->cache->getTemporaryVar("it"));
auto st = star->what->clone();
st->setAttr(ExprAttr::StarSequenceItem);
stmts.push_back(transform(N<ForStmt>( stmts.push_back(transform(N<ForStmt>(
clone(forVar), N<CallExpr>(N<DotExpr>(star->what->clone(), "items")), clone(forVar), N<CallExpr>(N<DotExpr>(st, "items")),
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__setitem__"), N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__setitem__"),
N<IndexExpr>(clone(forVar), N<IntExpr>(0)), N<IndexExpr>(clone(forVar), N<IntExpr>(0)),
N<IndexExpr>(clone(forVar), N<IntExpr>(1))))))); N<IndexExpr>(clone(forVar), N<IntExpr>(1)))))));
} else { } else {
stmts.push_back(transform(N<ExprStmt>(N<CallExpr>( auto k = clone(it.key);
N<DotExpr>(clone(var), "__setitem__"), clone(it.key), clone(it.value))))); k->setAttr(ExprAttr::SequenceItem);
auto v = clone(it.value);
v->setAttr(ExprAttr::SequenceItem);
stmts.push_back(transform(
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__setitem__"), k, v))));
} }
auto e = N<StmtExpr>(stmts, transform(var)); auto e = N<StmtExpr>(stmts, transform(var));
e->setAttr(ir::DictLiteralAttribute::AttributeName); e->setAttr(ExprAttr::Dict);
resultExpr = e; resultExpr = e;
ctx->popBlock(); ctx->popBlock();
} }

View File

@ -57,8 +57,81 @@ ir::Func *TranslateVisitor::apply(Cache *cache, StmtPtr stmts) {
ir::Value *TranslateVisitor::transform(const ExprPtr &expr) { ir::Value *TranslateVisitor::transform(const ExprPtr &expr) {
TranslateVisitor v(ctx); TranslateVisitor v(ctx);
v.setSrcInfo(expr->getSrcInfo()); v.setSrcInfo(expr->getSrcInfo());
types::PartialType *p = nullptr;
if (expr->attributes) {
if (expr->hasAttr(ExprAttr::List) || expr->hasAttr(ExprAttr::Set) ||
expr->hasAttr(ExprAttr::Dict) || expr->hasAttr(ExprAttr::Partial)) {
ctx->seqItems.push_back(std::vector<std::pair<ExprAttr, ir::Value *>>());
}
if (expr->hasAttr(ExprAttr::Partial))
p = expr->type->getPartial().get();
// LOG("{} {}: {}", std::string(ctx->seqItems.size(), ' '), expr->attributes, expr->toString());
}
expr->accept(v); expr->accept(v);
return v.result; ir::Value *ir = v.result;
if (expr->attributes) {
if (expr->hasAttr(ExprAttr::List) || expr->hasAttr(ExprAttr::Set)) {
std::vector<ir::LiteralElement> v;
for (auto &p : ctx->seqItems.back()) {
seqassert(p.first <= ExprAttr::StarSequenceItem, "invalid list/set element");
v.push_back(
ir::LiteralElement{p.second, p.first == ExprAttr::StarSequenceItem});
}
if (expr->hasAttr(ExprAttr::List))
ir->setAttribute(std::make_unique<ir::ListLiteralAttribute>(v));
else
ir->setAttribute(std::make_unique<ir::SetLiteralAttribute>(v));
ctx->seqItems.pop_back();
}
if (expr->hasAttr(ExprAttr::Dict)) {
std::vector<ir::DictLiteralAttribute::KeyValuePair> v;
LOG("{} {}", expr->toString(), expr->getSrcInfo());
for (int pi = 0; pi < ctx->seqItems.back().size(); pi++) {
auto &p = ctx->seqItems.back()[pi];
if (p.first == ExprAttr::StarSequenceItem) {
v.push_back({p.second, nullptr});
} else {
seqassert(p.first == ExprAttr::SequenceItem &&
pi + 1 < ctx->seqItems.back().size() &&
ctx->seqItems.back()[pi + 1].first == ExprAttr::SequenceItem,
"invalid dict element");
v.push_back({p.second, ctx->seqItems.back()[pi + 1].second});
pi++;
}
}
ir->setAttribute(std::make_unique<ir::DictLiteralAttribute>(v));
ctx->seqItems.pop_back();
}
if (expr->hasAttr(ExprAttr::Partial)) {
std::vector<ir::Value *> v;
seqassert(p, "invalid partial element");
int j = 0;
for (int i = 0; i < p->known.size(); i++) {
if (p->known[i] && !p->func->ast->args[i].generic) {
seqassert(j < ctx->seqItems.back().size() &&
ctx->seqItems.back()[j].first == ExprAttr::SequenceItem,
"invalid partial element");
v.push_back(ctx->seqItems.back()[j++].second);
} else if (!p->func->ast->args[i].generic) {
v.push_back({nullptr});
}
}
// seqassert(j == ctx->seqItems.back().size(), "invalid partial element");
ir->setAttribute(std::make_unique<ir::PartialFunctionAttribute>(nullptr, v));
ctx->seqItems.pop_back();
}
if (expr->hasAttr(ExprAttr::SequenceItem)) {
ctx->seqItems.back().push_back({ExprAttr::SequenceItem, ir});
}
if (expr->hasAttr(ExprAttr::StarSequenceItem)) {
ctx->seqItems.back().push_back({ExprAttr::StarSequenceItem, ir});
}
}
return ir;
} }
void TranslateVisitor::defaultVisit(Expr *n) { void TranslateVisitor::defaultVisit(Expr *n) {
@ -211,10 +284,6 @@ void TranslateVisitor::visit(StmtExpr *expr) {
transform(s); transform(s);
ctx->popSeries(); ctx->popSeries();
result = make<ir::FlowInstr>(expr, bodySeries, transform(expr->expr)); result = make<ir::FlowInstr>(expr, bodySeries, transform(expr->expr));
for (auto &a : expr->attributes) {
// if (a == ir::ListLiteralAttribute::AttributeName)
// result->setAttribute(ir::ListLiteralAttribute);
}
} }
/************************************************************************************/ /************************************************************************************/

View File

@ -50,6 +50,8 @@ struct TranslateContext : public Context<TranslateItem> {
std::vector<codon::ir::BodiedFunc *> bases; std::vector<codon::ir::BodiedFunc *> bases;
/// Stack of IR series (blocks). /// Stack of IR series (blocks).
std::vector<codon::ir::SeriesFlow *> series; std::vector<codon::ir::SeriesFlow *> series;
/// Stack of sequence items for attribute initialization.
std::vector<std::vector<std::pair<ExprAttr, ir::Value*>>> seqItems;
public: public:
TranslateContext(Cache *cache); TranslateContext(Cache *cache);

View File

@ -36,8 +36,10 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr, bool allowTypes, bool allowVo
ctx->allowActivation = false; ctx->allowActivation = false;
v.setSrcInfo(expr->getSrcInfo()); v.setSrcInfo(expr->getSrcInfo());
expr->accept(v); expr->accept(v);
if (v.resultExpr) if (v.resultExpr) {
v.resultExpr->attributes |= expr->attributes;
expr = v.resultExpr; expr = v.resultExpr;
}
seqassert(expr->type, "type not set for {}", expr->toString()); seqassert(expr->type, "type not set for {}", expr->toString());
unify(typ, expr->type); unify(typ, expr->type);
if (disableActivation) if (disableActivation)
@ -1246,6 +1248,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
return -1; return -1;
}, },
known); known);
bool hasPartialArgs = partialStarArgs != nullptr,
hasPartialKwargs = partialKwstarArgs != nullptr;
if (isPartial) { if (isPartial) {
deactivateUnbounds(expr->args.back().value->getType().get()); deactivateUnbounds(expr->args.back().value->getType().get());
expr->args.pop_back(); expr->args.pop_back();
@ -1361,10 +1365,16 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
deactivateUnbounds(calleeFn.get()); deactivateUnbounds(calleeFn.get());
std::vector<ExprPtr> newArgs; std::vector<ExprPtr> newArgs;
for (auto &r : args) for (auto &r : args)
if (!r.value->getEllipsis()) if (!r.value->getEllipsis()) {
newArgs.push_back(r.value); newArgs.push_back(r.value);
newArgs.back()->setAttr(ExprAttr::SequenceItem);
}
newArgs.push_back(partialStarArgs); newArgs.push_back(partialStarArgs);
if (hasPartialArgs)
newArgs.back()->setAttr(ExprAttr::SequenceItem);
newArgs.push_back(partialKwstarArgs); newArgs.push_back(partialKwstarArgs);
if (hasPartialKwargs)
newArgs.back()->setAttr(ExprAttr::SequenceItem);
std::string var = ctx->cache->getTemporaryVar("partial"); std::string var = ctx->cache->getTemporaryVar("partial");
ExprPtr call = nullptr; ExprPtr call = nullptr;
@ -1379,8 +1389,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
N<CallExpr>(N<IdExpr>(partialTypeName), newArgs)), N<CallExpr>(N<IdExpr>(partialTypeName), newArgs)),
N<IdExpr>(var)); N<IdExpr>(var));
} }
const_cast<StmtExpr *>(call->getStmtExpr()) call->setAttr(ExprAttr::Partial);
->setAttr(ir::PartialFunctionAttribute::AttributeName);
call = transform(call, false, allowVoidExpr); call = transform(call, false, allowVoidExpr);
seqassert(call->type->getPartial(), "expected partial type"); seqassert(call->type->getPartial(), "expected partial type");
return call; return call;
@ -1722,8 +1731,7 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
N<CallExpr>(N<IdExpr>(partialTypeName), N<TupleExpr>(), N<CallExpr>(N<IdExpr>(partialTypeName), N<TupleExpr>(),
N<CallExpr>(N<IdExpr>(kwName)))), N<CallExpr>(N<IdExpr>(kwName)))),
N<IdExpr>(var)); N<IdExpr>(var));
const_cast<StmtExpr *>(call->getStmtExpr()) call->setAttr(ExprAttr::Partial);
->setAttr(ir::PartialFunctionAttribute::AttributeName);
call = transform(call, false, allowVoidExpr); call = transform(call, false, allowVoidExpr);
seqassert(call->type->getPartial(), "expected partial type"); seqassert(call->type->getPartial(), "expected partial type");
return call; return call;

View File

@ -339,6 +339,9 @@ class Counter[T](Dict[T,int]):
result |= other result |= other
return result return result
def __dict_do_op_throws__[F, Z](self, key: T, other: Z, op: F):
self.__dict_do_op__(key, other, 0, op)
@extend @extend
class Dict: class Dict: