mirror of https://github.com/exaloop/codon.git
New Tuple machinery (#462)
* Refactor Tuple class * Add Tuple[N,...] support; Fix inline operators with list brackets * Fix Tuple[N,...] support * Merge Sequre SIMD changes * Fix repeat-tuple unification * Fix staticlen in OpenMP * Fix isinstance unification * Fix delayed unification with static realization * Fix CI * Cleanup * Use "fcmp une" in float.__ne__() --------- Co-authored-by: A. R. Shajii <ars@ars.me>pull/484/head
parent
9933954e30
commit
ce459c5667
|
@ -1745,7 +1745,7 @@ void LLVMVisitor::visit(const InternalFunc *x) {
|
|||
}
|
||||
|
||||
else if (internalFuncMatchesIgnoreArgs<RecordType>("__new__", x)) {
|
||||
auto *recordType = cast<RecordType>(parentType);
|
||||
auto *recordType = cast<RecordType>(cast<FuncType>(x->getType())->getReturnType());
|
||||
seqassertn(args.size() == std::distance(recordType->begin(), recordType->end()),
|
||||
"args size does not match");
|
||||
result = llvm::UndefValue::get(getLLVMType(recordType));
|
||||
|
|
|
@ -672,8 +672,6 @@ void ClassStmt::parseDecorators() {
|
|||
E(Error::CLASS_BAD_DECORATOR, d);
|
||||
}
|
||||
}
|
||||
if (startswith(name, TYPE_TUPLE))
|
||||
tupleMagics["contains"] = true;
|
||||
if (attributes.has("deduce"))
|
||||
tupleMagics["new"] = false;
|
||||
if (!attributes.has(Attr::Tuple)) {
|
||||
|
@ -681,12 +679,7 @@ void ClassStmt::parseDecorators() {
|
|||
tupleMagics["new"] = tupleMagics["raw"] = true;
|
||||
tupleMagics["len"] = false;
|
||||
}
|
||||
if (startswith(name, TYPE_TUPLE)) {
|
||||
tupleMagics["add"] = true;
|
||||
tupleMagics["mul"] = true;
|
||||
} else {
|
||||
tupleMagics["dict"] = true;
|
||||
}
|
||||
tupleMagics["dict"] = true;
|
||||
// Internal classes do not get any auto-generated members.
|
||||
attributes.magics.clear();
|
||||
if (!attributes.has(Attr::Internal)) {
|
||||
|
|
|
@ -104,8 +104,6 @@ std::string ClassType::debugString(char mode) const {
|
|||
}
|
||||
// Special formatting for Functions and Tuples
|
||||
auto n = mode == 0 ? niceName : name;
|
||||
if (startswith(n, TYPE_TUPLE))
|
||||
n = "Tuple";
|
||||
return fmt::format("{}{}", n, gs.empty() ? "" : fmt::format("[{}]", join(gs, ",")));
|
||||
}
|
||||
|
||||
|
@ -130,13 +128,13 @@ std::string ClassType::realizedTypeName() const {
|
|||
|
||||
RecordType::RecordType(Cache *cache, std::string name, std::string niceName,
|
||||
std::vector<Generic> generics, std::vector<TypePtr> args,
|
||||
bool noTuple)
|
||||
bool noTuple, const std::shared_ptr<StaticType> &repeats)
|
||||
: ClassType(cache, std::move(name), std::move(niceName), std::move(generics)),
|
||||
args(std::move(args)), noTuple(false) {}
|
||||
args(std::move(args)), noTuple(false), repeats(repeats) {}
|
||||
|
||||
RecordType::RecordType(const ClassTypePtr &base, std::vector<TypePtr> args,
|
||||
bool noTuple)
|
||||
: ClassType(base), args(std::move(args)), noTuple(noTuple) {}
|
||||
bool noTuple, const std::shared_ptr<StaticType> &repeats)
|
||||
: ClassType(base), args(std::move(args)), noTuple(noTuple), repeats(repeats) {}
|
||||
|
||||
int RecordType::unify(Type *typ, Unification *us) {
|
||||
if (auto tr = typ->getRecord()) {
|
||||
|
@ -148,6 +146,27 @@ int RecordType::unify(Type *typ, Unification *us) {
|
|||
return generics[0].type->unify(t64.get(), us);
|
||||
}
|
||||
|
||||
// TODO: we now support very limited unification strategy where repetitions must
|
||||
// match. We should expand this later on...
|
||||
if (repeats || tr->repeats) {
|
||||
if (!repeats && tr->repeats) {
|
||||
auto n = std::make_shared<StaticType>(cache, args.size());
|
||||
if (tr->repeats->unify(n.get(), us) == -1)
|
||||
return -1;
|
||||
} else if (!tr->repeats) {
|
||||
auto n = std::make_shared<StaticType>(cache, tr->args.size());
|
||||
if (repeats->unify(n.get(), us) == -1)
|
||||
return -1;
|
||||
} else {
|
||||
if (repeats->unify(tr->repeats.get(), us) == -1)
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
if (getRepeats() != -1)
|
||||
flatten();
|
||||
if (tr->getRepeats() != -1)
|
||||
tr->flatten();
|
||||
|
||||
int s1 = 2, s = 0;
|
||||
if (args.size() != tr->args.size())
|
||||
return -1;
|
||||
|
@ -158,7 +177,7 @@ int RecordType::unify(Type *typ, Unification *us) {
|
|||
return -1;
|
||||
}
|
||||
// Handle Tuple<->@tuple: when unifying tuples, only record members matter.
|
||||
if (startswith(name, TYPE_TUPLE) || startswith(tr->name, TYPE_TUPLE)) {
|
||||
if (name == TYPE_TUPLE || tr->name == TYPE_TUPLE) {
|
||||
if (!args.empty() || (!noTuple && !tr->noTuple)) // prevent POD<->() unification
|
||||
return s1 + int(name == tr->name);
|
||||
else
|
||||
|
@ -177,7 +196,8 @@ TypePtr RecordType::generalize(int atLevel) {
|
|||
auto a = args;
|
||||
for (auto &t : a)
|
||||
t = t->generalize(atLevel);
|
||||
return std::make_shared<RecordType>(c, a, noTuple);
|
||||
auto r = repeats ? repeats->generalize(atLevel)->getStatic() : nullptr;
|
||||
return std::make_shared<RecordType>(c, a, noTuple, r);
|
||||
}
|
||||
|
||||
TypePtr RecordType::instantiate(int atLevel, int *unboundCount,
|
||||
|
@ -187,11 +207,17 @@ TypePtr RecordType::instantiate(int atLevel, int *unboundCount,
|
|||
auto a = args;
|
||||
for (auto &t : a)
|
||||
t = t->instantiate(atLevel, unboundCount, cache);
|
||||
return std::make_shared<RecordType>(c, a, noTuple);
|
||||
auto r = repeats ? repeats->instantiate(atLevel, unboundCount, cache)->getStatic()
|
||||
: nullptr;
|
||||
return std::make_shared<RecordType>(c, a, noTuple, r);
|
||||
}
|
||||
|
||||
std::vector<TypePtr> RecordType::getUnbounds() const {
|
||||
std::vector<TypePtr> u;
|
||||
if (repeats) {
|
||||
auto tu = repeats->getUnbounds();
|
||||
u.insert(u.begin(), tu.begin(), tu.end());
|
||||
}
|
||||
for (auto &a : args) {
|
||||
auto tu = a->getUnbounds();
|
||||
u.insert(u.begin(), tu.begin(), tu.end());
|
||||
|
@ -202,21 +228,58 @@ std::vector<TypePtr> RecordType::getUnbounds() const {
|
|||
}
|
||||
|
||||
bool RecordType::canRealize() const {
|
||||
return std::all_of(args.begin(), args.end(),
|
||||
return getRepeats() >= 0 &&
|
||||
std::all_of(args.begin(), args.end(),
|
||||
[](auto &a) { return a->canRealize(); }) &&
|
||||
this->ClassType::canRealize();
|
||||
}
|
||||
|
||||
bool RecordType::isInstantiated() const {
|
||||
return std::all_of(args.begin(), args.end(),
|
||||
return (!repeats || repeats->isInstantiated()) &&
|
||||
std::all_of(args.begin(), args.end(),
|
||||
[](auto &a) { return a->isInstantiated(); }) &&
|
||||
this->ClassType::isInstantiated();
|
||||
}
|
||||
|
||||
std::string RecordType::debugString(char mode) const {
|
||||
return fmt::format("{}", this->ClassType::debugString(mode));
|
||||
std::string RecordType::realizedName() const {
|
||||
if (!_rn.empty())
|
||||
return _rn;
|
||||
if (name == TYPE_TUPLE) {
|
||||
std::vector<std::string> gs;
|
||||
auto n = getRepeats();
|
||||
if (n == -1)
|
||||
gs.push_back(repeats->realizedName());
|
||||
for (int i = 0; i < std::max(n, int64_t(0)); i++)
|
||||
for (auto &a : args)
|
||||
gs.push_back(a->realizedName());
|
||||
std::string s = join(gs, ",");
|
||||
if (canRealize())
|
||||
const_cast<RecordType *>(this)->_rn =
|
||||
fmt::format("{}{}", name, s.empty() ? "" : fmt::format("[{}]", s));
|
||||
return _rn;
|
||||
}
|
||||
return ClassType::realizedName();
|
||||
}
|
||||
|
||||
std::string RecordType::debugString(char mode) const {
|
||||
if (name == TYPE_TUPLE) {
|
||||
std::vector<std::string> gs;
|
||||
auto n = getRepeats();
|
||||
if (n == -1)
|
||||
gs.push_back(repeats->debugString(mode));
|
||||
for (int i = 0; i < std::max(n, int64_t(0)); i++)
|
||||
for (auto &a : args)
|
||||
gs.push_back(a->debugString(mode));
|
||||
return fmt::format("{}{}", name,
|
||||
gs.empty() ? "" : fmt::format("[{}]", join(gs, ",")));
|
||||
} else {
|
||||
return fmt::format("{}{}", repeats ? repeats->debugString(mode) + "," : "",
|
||||
this->ClassType::debugString(mode));
|
||||
}
|
||||
}
|
||||
|
||||
std::string RecordType::realizedTypeName() const { return realizedName(); }
|
||||
|
||||
std::shared_ptr<RecordType> RecordType::getHeterogenousTuple() {
|
||||
seqassert(canRealize(), "{} not realizable", toString());
|
||||
if (args.size() > 1) {
|
||||
|
@ -228,4 +291,25 @@ std::shared_ptr<RecordType> RecordType::getHeterogenousTuple() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
/// Returns -1 if the type cannot be realized yet
|
||||
int64_t RecordType::getRepeats() const {
|
||||
if (!repeats)
|
||||
return 1;
|
||||
if (repeats->canRealize())
|
||||
return std::max(repeats->evaluate().getInt(), int64_t(0));
|
||||
return -1;
|
||||
}
|
||||
|
||||
void RecordType::flatten() {
|
||||
auto n = getRepeats();
|
||||
seqassert(n >= 0, "bad call to flatten");
|
||||
|
||||
auto a = args;
|
||||
args.clear();
|
||||
for (int64_t i = 0; i < n; i++)
|
||||
args.insert(args.end(), a.begin(), a.end());
|
||||
|
||||
repeats = nullptr;
|
||||
}
|
||||
|
||||
} // namespace codon::ast::types
|
||||
|
|
|
@ -78,12 +78,15 @@ struct RecordType : public ClassType {
|
|||
/// List of tuple arguments.
|
||||
std::vector<TypePtr> args;
|
||||
bool noTuple;
|
||||
std::shared_ptr<StaticType> repeats = nullptr;
|
||||
|
||||
explicit RecordType(
|
||||
Cache *cache, std::string name, std::string niceName,
|
||||
std::vector<ClassType::Generic> generics = std::vector<ClassType::Generic>(),
|
||||
std::vector<TypePtr> args = std::vector<TypePtr>(), bool noTuple = false);
|
||||
RecordType(const ClassTypePtr &base, std::vector<TypePtr> args, bool noTuple = false);
|
||||
std::vector<TypePtr> args = std::vector<TypePtr>(), bool noTuple = false,
|
||||
const std::shared_ptr<StaticType> &repeats = nullptr);
|
||||
RecordType(const ClassTypePtr &base, std::vector<TypePtr> args, bool noTuple = false,
|
||||
const std::shared_ptr<StaticType> &repeats = nullptr);
|
||||
|
||||
public:
|
||||
int unify(Type *typ, Unification *undo) override;
|
||||
|
@ -96,11 +99,16 @@ public:
|
|||
bool canRealize() const override;
|
||||
bool isInstantiated() const override;
|
||||
std::string debugString(char mode) const override;
|
||||
std::string realizedName() const override;
|
||||
std::string realizedTypeName() const override;
|
||||
|
||||
std::shared_ptr<RecordType> getRecord() override {
|
||||
return std::static_pointer_cast<RecordType>(shared_from_this());
|
||||
}
|
||||
std::shared_ptr<RecordType> getHeterogenousTuple() override;
|
||||
|
||||
int64_t getRepeats() const;
|
||||
void flatten();
|
||||
};
|
||||
|
||||
} // namespace codon::ast::types
|
||||
|
|
|
@ -99,6 +99,10 @@ bool FuncType::canRealize() const {
|
|||
return generics;
|
||||
}
|
||||
|
||||
std::string FuncType::realizedTypeName() const {
|
||||
return this->ClassType::realizedName();
|
||||
}
|
||||
|
||||
bool FuncType::isInstantiated() const {
|
||||
TypePtr removed = nullptr;
|
||||
auto retType = getRetType();
|
||||
|
|
|
@ -48,6 +48,7 @@ public:
|
|||
bool isInstantiated() const override;
|
||||
std::string debugString(char mode) const override;
|
||||
std::string realizedName() const override;
|
||||
std::string realizedTypeName() const override;
|
||||
|
||||
std::shared_ptr<FuncType> getFunc() override {
|
||||
return std::static_pointer_cast<FuncType>(shared_from_this());
|
||||
|
|
|
@ -140,7 +140,9 @@ StaticValue StaticType::evaluate() const {
|
|||
cache->typeCtx->addBlock();
|
||||
for (auto &g : generics)
|
||||
cache->typeCtx->add(TypecheckItem::Type, g.name, g.type);
|
||||
auto oldChangedNodes = cache->typeCtx->changedNodes;
|
||||
auto en = TypecheckVisitor(cache->typeCtx).transform(expr->clone());
|
||||
cache->typeCtx->changedNodes = oldChangedNodes;
|
||||
seqassert(en->isStatic() && en->staticValue.evaluated, "{} cannot be evaluated", en);
|
||||
cache->typeCtx->popBlock();
|
||||
return en->staticValue;
|
||||
|
|
|
@ -94,25 +94,23 @@ int CallableTrait::unify(Type *typ, Unification *us) {
|
|||
starArgTypes.insert(starArgTypes.end(), inArgs.begin() + i, inArgs.end());
|
||||
|
||||
auto tv = TypecheckVisitor(cache->typeCtx);
|
||||
auto name = tv.generateTuple(starArgTypes.size());
|
||||
auto t = cache->typeCtx->forceFind(name)->type;
|
||||
t = cache->typeCtx->instantiateGeneric(t, starArgTypes)->getClass();
|
||||
auto t = cache->typeCtx->instantiateTuple(starArgTypes)->getClass();
|
||||
if (t->unify(trInArgs[star].get(), us) == -1)
|
||||
return -1;
|
||||
}
|
||||
if (kwStar < trInArgs.size()) {
|
||||
auto tv = TypecheckVisitor(cache->typeCtx);
|
||||
std::vector<std::string> names;
|
||||
std::vector<TypePtr> starArgTypes;
|
||||
if (auto tp = tr->getPartial()) {
|
||||
auto ts = tp->args.back()->getRecord();
|
||||
seqassert(ts, "bad partial *args/**kwargs");
|
||||
auto &ff = cache->classes[ts->name].fields;
|
||||
auto ff = tv.getClassFields(ts.get());
|
||||
for (size_t i = 0; i < ts->args.size(); i++) {
|
||||
names.emplace_back(ff[i].name);
|
||||
starArgTypes.emplace_back(ts->args[i]);
|
||||
}
|
||||
}
|
||||
auto tv = TypecheckVisitor(cache->typeCtx);
|
||||
auto name = tv.generateTuple(starArgTypes.size(), TYPE_KWTUPLE, names);
|
||||
auto t = cache->typeCtx->forceFind(name)->type;
|
||||
t = cache->typeCtx->instantiateGeneric(t, starArgTypes)->getClass();
|
||||
|
|
|
@ -46,4 +46,16 @@ public:
|
|||
std::string debugString(char mode) const override;
|
||||
};
|
||||
|
||||
struct VariableTupleTrait : public Trait {
|
||||
TypePtr size;
|
||||
|
||||
public:
|
||||
explicit VariableTupleTrait(TypePtr size);
|
||||
int unify(Type *typ, Unification *undo) override;
|
||||
TypePtr generalize(int atLevel) override;
|
||||
TypePtr instantiate(int atLevel, int *unboundCount,
|
||||
std::unordered_map<int, TypePtr> *cache) override;
|
||||
std::string debugString(char mode) const override;
|
||||
};
|
||||
|
||||
} // namespace codon::ast::types
|
||||
|
|
|
@ -150,9 +150,7 @@ void UnionType::seal() {
|
|||
pendingTypes[i]->getLink()->kind == LinkType::Unbound)
|
||||
break;
|
||||
std::vector<TypePtr> typeSet(pendingTypes.begin(), pendingTypes.begin() + i);
|
||||
auto name = tv.generateTuple(typeSet.size());
|
||||
auto t = cache->typeCtx->instantiateGeneric(
|
||||
cache->typeCtx->forceFind(name)->type->getClass(), typeSet);
|
||||
auto t = cache->typeCtx->instantiateTuple(typeSet);
|
||||
Unification us;
|
||||
generics[0].type->unify(t.get(), &us);
|
||||
}
|
||||
|
|
|
@ -152,20 +152,16 @@ ir::Func *Cache::realizeFunction(types::FuncTypePtr type,
|
|||
|
||||
ir::types::Type *Cache::makeTuple(const std::vector<types::TypePtr> &types) {
|
||||
auto tv = TypecheckVisitor(typeCtx);
|
||||
auto name = tv.generateTuple(types.size());
|
||||
auto t = typeCtx->find(name);
|
||||
seqassertn(t && t->type, "cannot find {}", name);
|
||||
return realizeType(t->type->getClass(), types);
|
||||
auto t = typeCtx->instantiateTuple(types);
|
||||
return realizeType(t, types);
|
||||
}
|
||||
|
||||
ir::types::Type *Cache::makeFunction(const std::vector<types::TypePtr> &types) {
|
||||
auto tv = TypecheckVisitor(typeCtx);
|
||||
seqassertn(!types.empty(), "types must have at least one argument");
|
||||
|
||||
auto tup = tv.generateTuple(types.size() - 1);
|
||||
const auto &ret = types[0];
|
||||
auto argType = typeCtx->instantiateGeneric(
|
||||
typeCtx->find(tup)->type,
|
||||
auto argType = typeCtx->instantiateTuple(
|
||||
std::vector<types::TypePtr>(types.begin() + 1, types.end()));
|
||||
auto t = typeCtx->find("Function");
|
||||
seqassertn(t && t->type, "cannot find 'Function'");
|
||||
|
@ -175,8 +171,7 @@ ir::types::Type *Cache::makeFunction(const std::vector<types::TypePtr> &types) {
|
|||
ir::types::Type *Cache::makeUnion(const std::vector<types::TypePtr> &types) {
|
||||
auto tv = TypecheckVisitor(typeCtx);
|
||||
|
||||
auto tup = tv.generateTuple(types.size());
|
||||
auto argType = typeCtx->instantiateGeneric(typeCtx->find(tup)->type, types);
|
||||
auto argType = typeCtx->instantiateTuple(types);
|
||||
auto t = typeCtx->find("Union");
|
||||
seqassertn(t && t->type, "cannot find 'Union'");
|
||||
return realizeType(t->type->getClass(), {argType});
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#define STDLIB_IMPORT ":stdlib:"
|
||||
#define STDLIB_INTERNAL_MODULE "internal"
|
||||
|
||||
#define TYPE_TUPLE "Tuple.N"
|
||||
#define TYPE_TUPLE "Tuple"
|
||||
#define TYPE_KWTUPLE "KwTuple.N"
|
||||
#define TYPE_TYPEVAR "TypeVar"
|
||||
#define TYPE_CALLABLE "Callable"
|
||||
|
|
|
@ -14,10 +14,6 @@ using namespace codon::error;
|
|||
namespace codon::ast {
|
||||
|
||||
void SimplifyVisitor::visit(IdExpr *expr) {
|
||||
if (startswith(expr->value, TYPE_TUPLE)) {
|
||||
expr->markType();
|
||||
return;
|
||||
}
|
||||
auto val = ctx->findDominatingBinding(expr->value);
|
||||
|
||||
if (!val && ctx->getBase()->pyCaptures) {
|
||||
|
|
|
@ -95,6 +95,18 @@ StmtPtr SimplifyVisitor::transformAssignment(ExprPtr lhs, ExprPtr rhs, ExprPtr t
|
|||
if (auto idx = lhs->getIndex()) {
|
||||
// Case: a[x] = b
|
||||
seqassert(!type, "unexpected type annotation");
|
||||
if (auto b = rhs->getBinary()) {
|
||||
if (mustExist && b->inPlace && !b->rexpr->getId()) {
|
||||
auto var = ctx->cache->getTemporaryVar("assign");
|
||||
seqassert(rhs->getBinary(), "not a bin");
|
||||
return transform(N<SuiteStmt>(
|
||||
N<AssignStmt>(N<IdExpr>(var), idx->index),
|
||||
N<ExprStmt>(N<CallExpr>(
|
||||
N<DotExpr>(idx->expr, "__setitem__"), N<IdExpr>(var),
|
||||
N<BinaryExpr>(N<IndexExpr>(idx->expr->clone(), N<IdExpr>(var)), b->op,
|
||||
b->rexpr, true)))));
|
||||
}
|
||||
}
|
||||
return transform(N<ExprStmt>(
|
||||
N<CallExpr>(N<DotExpr>(idx->expr, "__setitem__"), idx->index, rhs)));
|
||||
}
|
||||
|
|
|
@ -528,7 +528,23 @@ StmtPtr SimplifyVisitor::codegenMagic(const std::string &op, const ExprPtr &typE
|
|||
// Classes: def __new__() -> T
|
||||
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), typExpr->clone())));
|
||||
}
|
||||
} else if (op == "init") {
|
||||
}
|
||||
// else if (startswith(op, "new.")) {
|
||||
// // special handle for tuple[t1, t2, ...]
|
||||
// int sz = atoi(op.substr(4).c_str());
|
||||
// std::vector<ExprPtr> ts;
|
||||
// for (int i = 0; i < sz; i++) {
|
||||
// fargs.emplace_back(format("a{}", i + 1), I(format("T{}", i + 1)));
|
||||
// ts.emplace_back(I(format("T{}", i + 1)));
|
||||
// }
|
||||
// for (int i = 0; i < sz; i++) {
|
||||
// fargs.emplace_back(format("T{}", i + 1), I("type"));
|
||||
// }
|
||||
// ret = N<InstantiateExpr>(I(TYPE_TUPLE), ts);
|
||||
// ret->markType();
|
||||
// attr.set(Attr::Internal);
|
||||
// }
|
||||
else if (op == "init") {
|
||||
// Classes: def __init__(self: T, a1: T1, ..., aN: TN) -> None:
|
||||
// self.aI = aI ...
|
||||
ret = I("NoneType");
|
||||
|
|
|
@ -54,13 +54,12 @@ void SimplifyVisitor::visit(ChainBinaryExpr *expr) {
|
|||
}
|
||||
|
||||
/// Transform index into an instantiation @c InstantiateExpr if possible.
|
||||
/// Generate tuple class `Tuple.N` for `Tuple[T1, ... TN]` (and `tuple[...]`).
|
||||
/// Generate tuple class `Tuple` for `Tuple[T1, ... TN]` (and `tuple[...]`).
|
||||
/// The rest is handled during the type checking.
|
||||
void SimplifyVisitor::visit(IndexExpr *expr) {
|
||||
if (expr->expr->isId("tuple") || expr->expr->isId("Tuple")) {
|
||||
// Special case: tuples. Change to Tuple.N
|
||||
if (expr->expr->isId("tuple") || expr->expr->isId(TYPE_TUPLE)) {
|
||||
auto t = expr->index->getTuple();
|
||||
expr->expr = NT<IdExpr>(format(TYPE_TUPLE "{}", t ? t->items.size() : 1));
|
||||
expr->expr = NT<IdExpr>(TYPE_TUPLE);
|
||||
} else if (expr->expr->isId("Static")) {
|
||||
// Special case: static types. Ensure that static is supported
|
||||
if (!expr->index->isId("int") && !expr->index->isId("str"))
|
||||
|
@ -84,7 +83,7 @@ void SimplifyVisitor::visit(IndexExpr *expr) {
|
|||
if (i->getList() && expr->expr->isType()) {
|
||||
// Special case: `A[[A, B], C]` -> `A[Tuple[A, B], C]` (e.g., in
|
||||
// `Function[...]`)
|
||||
i = N<IndexExpr>(N<IdExpr>("Tuple"), N<TupleExpr>(i->getList()->items));
|
||||
i = N<IndexExpr>(N<IdExpr>(TYPE_TUPLE), N<TupleExpr>(i->getList()->items));
|
||||
}
|
||||
transform(i, true);
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace codon::ast {
|
|||
* - All imports are flattened resulting in a single self-containing
|
||||
* (and fairly large) AST
|
||||
* - All identifiers are normalized (no two distinct objects share the same name)
|
||||
* - Variadic classes (e.g., Tuple.N) are generated
|
||||
* - Variadic classes (e.g., Tuple) are generated
|
||||
* - Any AST node that can be trivially expressed as a set of "simpler" nodes
|
||||
* type is simplified. If a transformation requires a type information,
|
||||
* it is done during the type checking.
|
||||
|
|
|
@ -55,7 +55,12 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) {
|
|||
cache->codegenCtx->add(TranslateItem::Var, g.first, g.second);
|
||||
}
|
||||
|
||||
TranslateVisitor(cache->codegenCtx).transform(stmts);
|
||||
auto tv = TranslateVisitor(cache->codegenCtx);
|
||||
tv.transform(stmts);
|
||||
for (auto &[fn, f] : cache->functions)
|
||||
if (startswith(fn, TYPE_TUPLE)) {
|
||||
tv.transformFunctionRealizations(fn, f.ast->attributes.has(Attr::LLVM));
|
||||
}
|
||||
cache->populatePythonModule();
|
||||
return main;
|
||||
}
|
||||
|
@ -223,9 +228,10 @@ void TranslateVisitor::visit(CallExpr *expr) {
|
|||
seqassert(!expr->args[i].value->getEllipsis(), "ellipsis not elided");
|
||||
if (i + 1 == expr->args.size() && isVariadic) {
|
||||
auto call = expr->args[i].value->getCall();
|
||||
seqassert(call && call->expr->getId() &&
|
||||
startswith(call->expr->getId()->value, TYPE_TUPLE),
|
||||
"expected *args tuple");
|
||||
seqassert(
|
||||
call && call->expr->getId() &&
|
||||
startswith(call->expr->getId()->value, std::string(TYPE_TUPLE) + "["),
|
||||
"expected *args tuple: '{}'", call->toString());
|
||||
for (auto &arg : call->args)
|
||||
items.emplace_back(transform(arg.value));
|
||||
} else {
|
||||
|
@ -524,20 +530,7 @@ void TranslateVisitor::visit(ThrowStmt *stmt) {
|
|||
|
||||
void TranslateVisitor::visit(FunctionStmt *stmt) {
|
||||
// Process all realizations.
|
||||
for (auto &real : ctx->cache->functions[stmt->name].realizations) {
|
||||
if (!in(ctx->cache->pendingRealizations, make_pair(stmt->name, real.first)))
|
||||
continue;
|
||||
ctx->cache->pendingRealizations.erase(make_pair(stmt->name, real.first));
|
||||
|
||||
LOG_TYPECHECK("[translate] generating fn {}", real.first);
|
||||
real.second->ir->setSrcInfo(getSrcInfo());
|
||||
const auto &ast = real.second->ast;
|
||||
seqassert(ast, "AST not set for {}", real.first);
|
||||
if (!stmt->attributes.has(Attr::LLVM))
|
||||
transformFunction(real.second->type.get(), ast.get(), real.second->ir);
|
||||
else
|
||||
transformLLVMFunction(real.second->type.get(), ast.get(), real.second->ir);
|
||||
}
|
||||
transformFunctionRealizations(stmt->name, stmt->attributes.has(Attr::LLVM));
|
||||
}
|
||||
|
||||
void TranslateVisitor::visit(ClassStmt *stmt) {
|
||||
|
@ -555,6 +548,24 @@ codon::ir::types::Type *TranslateVisitor::getType(const types::TypePtr &t) {
|
|||
return i->getType();
|
||||
}
|
||||
|
||||
void TranslateVisitor::transformFunctionRealizations(const std::string &name,
|
||||
bool isLLVM) {
|
||||
for (auto &real : ctx->cache->functions[name].realizations) {
|
||||
if (!in(ctx->cache->pendingRealizations, make_pair(name, real.first)))
|
||||
continue;
|
||||
ctx->cache->pendingRealizations.erase(make_pair(name, real.first));
|
||||
|
||||
LOG_TYPECHECK("[translate] generating fn {}", real.first);
|
||||
real.second->ir->setSrcInfo(getSrcInfo());
|
||||
const auto &ast = real.second->ast;
|
||||
seqassert(ast, "AST not set for {}", real.first);
|
||||
if (!isLLVM)
|
||||
transformFunction(real.second->type.get(), ast.get(), real.second->ir);
|
||||
else
|
||||
transformLLVMFunction(real.second->type.get(), ast.get(), real.second->ir);
|
||||
}
|
||||
}
|
||||
|
||||
void TranslateVisitor::transformFunction(types::FuncType *type, FunctionStmt *ast,
|
||||
ir::Func *func) {
|
||||
std::vector<std::string> names;
|
||||
|
|
|
@ -66,6 +66,7 @@ public:
|
|||
private:
|
||||
ir::types::Type *getType(const types::TypePtr &t);
|
||||
|
||||
void transformFunctionRealizations(const std::string &name, bool isLLVM);
|
||||
void transformFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func);
|
||||
void transformLLVMFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func);
|
||||
|
||||
|
|
|
@ -20,13 +20,7 @@ using namespace types;
|
|||
/// replace it with its value (e.g., a @c IntExpr ). Also ensure that the identifier of
|
||||
/// a generic function or a type is fully qualified (e.g., replace `Ptr` with
|
||||
/// `Ptr[byte]`).
|
||||
/// For tuple identifiers, generate appropriate class. See @c generateTuple for
|
||||
/// details.
|
||||
void TypecheckVisitor::visit(IdExpr *expr) {
|
||||
// Generate tuple stubs if needed
|
||||
if (isTuple(expr->value))
|
||||
generateTuple(std::stoi(expr->value.substr(sizeof(TYPE_TUPLE) - 1)));
|
||||
|
||||
// Replace identifiers that have been superseded by domination analysis during the
|
||||
// simplification
|
||||
while (auto s = in(ctx->cache->replacements, expr->value))
|
||||
|
@ -187,7 +181,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
|
|||
return nullptr;
|
||||
|
||||
// Check if this is a method or member access
|
||||
if (ctx->findMethod(typ->name, expr->member).empty())
|
||||
if (ctx->findMethod(typ.get(), expr->member).empty())
|
||||
return getClassMember(expr, args);
|
||||
auto bestMethod = getBestOverload(expr, args);
|
||||
|
||||
|
@ -210,10 +204,9 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
|
|||
std::vector<ExprPtr> ids;
|
||||
for (auto &t : fn->getArgTypes())
|
||||
ids.push_back(NT<IdExpr>(t->realizedName()));
|
||||
auto name = generateTuple(ids.size());
|
||||
auto fnType = NT<InstantiateExpr>(
|
||||
NT<IdExpr>("Function"),
|
||||
std::vector<ExprPtr>{NT<InstantiateExpr>(NT<IdExpr>(name), ids),
|
||||
std::vector<ExprPtr>{NT<InstantiateExpr>(NT<IdExpr>(TYPE_TUPLE), ids),
|
||||
NT<IdExpr>(fn->getRetType()->realizedName())});
|
||||
// Function[Tuple[TArg1, TArg2, ...],TRet](
|
||||
// __internal__.class_get_rtti_vtable(expr)[T[VIRTUAL_ID]]
|
||||
|
@ -264,7 +257,7 @@ ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr,
|
|||
|
||||
// Case: object member access (`obj.member`)
|
||||
if (!expr->expr->isType()) {
|
||||
if (auto member = ctx->findMember(typ->name, expr->member)) {
|
||||
if (auto member = ctx->findMember(typ, expr->member)) {
|
||||
unify(expr->type, ctx->instantiate(member, typ));
|
||||
if (expr->expr->isDone() && realize(expr->type))
|
||||
expr->setDone();
|
||||
|
@ -340,7 +333,8 @@ ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr,
|
|||
{"", N<EllipsisExpr>(EllipsisExpr::PARTIAL)}}));
|
||||
}
|
||||
|
||||
// For debugging purposes: ctx->findMethod(typ->name, expr->member);
|
||||
// For debugging purposes:
|
||||
// ctx->findMethod(typ.get(), expr->member);
|
||||
E(Error::DOT_NO_ATTR, expr, typ->prettyString(), expr->member);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -372,7 +366,7 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr,
|
|||
bool addSelf = true;
|
||||
if (auto dot = expr->getDot()) {
|
||||
auto methods =
|
||||
ctx->findMethod(dot->expr->type->getClass()->name, dot->member, false);
|
||||
ctx->findMethod(dot->expr->type->getClass().get(), dot->member, false);
|
||||
if (!methods.empty() && methods.front()->ast->attributes.has(Attr::StaticMethod))
|
||||
addSelf = false;
|
||||
}
|
||||
|
@ -410,7 +404,7 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr,
|
|||
if (auto dot = expr->getDot()) {
|
||||
// Case: method overloads (DotExpr)
|
||||
auto methods =
|
||||
ctx->findMethod(dot->expr->type->getClass()->name, dot->member, false);
|
||||
ctx->findMethod(dot->expr->type->getClass().get(), dot->member, false);
|
||||
auto m = findMatchingMethods(dot->expr->type->getClass(), methods, *methodArgs);
|
||||
bestMethod = m.empty() ? nullptr : m[0];
|
||||
} else if (auto id = expr->getId()) {
|
||||
|
|
|
@ -161,7 +161,7 @@ void TypecheckVisitor::visit(AssignMemberStmt *stmt) {
|
|||
transform(stmt->lhs);
|
||||
|
||||
if (auto lhsClass = stmt->lhs->getType()->getClass()) {
|
||||
auto member = ctx->findMember(lhsClass->name, stmt->member);
|
||||
auto member = ctx->findMember(lhsClass, stmt->member);
|
||||
|
||||
if (!member && stmt->lhs->isType()) {
|
||||
// Case: class variables
|
||||
|
|
|
@ -160,7 +160,7 @@ bool TypecheckVisitor::transformCallArgs(std::vector<CallExpr::Arg> &args) {
|
|||
return false;
|
||||
if (!typ->getRecord())
|
||||
E(Error::CALL_BAD_UNPACK, args[ai], typ->prettyString());
|
||||
auto &fields = ctx->cache->classes[typ->name].fields;
|
||||
auto fields = getClassFields(typ.get());
|
||||
for (size_t i = 0; i < typ->getRecord()->args.size(); i++, ai++) {
|
||||
args.insert(args.begin() + ai,
|
||||
{"", transform(N<DotExpr>(clone(star->what), fields[i].name))});
|
||||
|
@ -176,9 +176,9 @@ bool TypecheckVisitor::transformCallArgs(std::vector<CallExpr::Arg> &args) {
|
|||
}
|
||||
if (!typ)
|
||||
return false;
|
||||
if (!typ->getRecord() || startswith(typ->name, TYPE_TUPLE))
|
||||
if (!typ->getRecord() || typ->name == TYPE_TUPLE)
|
||||
E(Error::CALL_BAD_KWUNPACK, args[ai], typ->prettyString());
|
||||
auto &fields = ctx->cache->classes[typ->name].fields;
|
||||
auto fields = getClassFields(typ.get());
|
||||
for (size_t i = 0; i < typ->getRecord()->args.size(); i++, ai++) {
|
||||
args.insert(args.begin() + ai,
|
||||
{fields[i].name,
|
||||
|
@ -338,7 +338,7 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
|
|||
auto e = getPartialArg(-1);
|
||||
auto t = e->getType()->getRecord();
|
||||
seqassert(t && startswith(t->name, TYPE_KWTUPLE), "{} not a kwtuple", e);
|
||||
auto &ff = ctx->cache->classes[t->name].fields;
|
||||
auto ff = getClassFields(t.get());
|
||||
for (int i = 0; i < t->getRecord()->args.size(); i++) {
|
||||
names.emplace_back(ff[i].name);
|
||||
values.emplace_back(
|
||||
|
@ -663,10 +663,9 @@ ExprPtr TypecheckVisitor::transformSuper() {
|
|||
if (typ->getRecord()) {
|
||||
// Case: tuple types. Return `tuple(obj.args...)`
|
||||
std::vector<ExprPtr> members;
|
||||
for (auto &field : ctx->cache->classes[name].fields)
|
||||
for (auto &field : getClassFields(superTyp.get()))
|
||||
members.push_back(N<DotExpr>(N<IdExpr>(funcTyp->ast->args[0].name), field.name));
|
||||
ExprPtr e = transform(
|
||||
N<CallExpr>(N<IdExpr>(format(TYPE_TUPLE "{}", members.size())), members));
|
||||
ExprPtr e = transform(N<TupleExpr>(members));
|
||||
e->type = unify(superTyp, e->type); // see super_tuple test for this line
|
||||
return e;
|
||||
} else {
|
||||
|
@ -732,8 +731,8 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
|
|||
}
|
||||
|
||||
expr->staticValue.type = StaticValue::INT;
|
||||
if (typExpr->isId("Tuple") || typExpr->isId("tuple")) {
|
||||
return transform(N<BoolExpr>(startswith(typ->name, TYPE_TUPLE)));
|
||||
if (typExpr->isId(TYPE_TUPLE) || typExpr->isId("tuple")) {
|
||||
return transform(N<BoolExpr>(typ->name == TYPE_TUPLE));
|
||||
} else if (typExpr->isId("ByVal")) {
|
||||
return transform(N<BoolExpr>(typ->getRecord() != nullptr));
|
||||
} else if (typExpr->isId("ByRef")) {
|
||||
|
@ -766,7 +765,10 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
|
|||
|
||||
// Check super types (i.e., statically inherited) as well
|
||||
for (auto &tx : getSuperTypes(typ->getClass())) {
|
||||
if (tx->unify(typExpr->type.get(), nullptr) >= 0)
|
||||
types::Type::Unification us;
|
||||
auto s = tx->unify(typExpr->type.get(), &us);
|
||||
us.undo();
|
||||
if (s >= 0)
|
||||
return transform(N<BoolExpr>(true));
|
||||
}
|
||||
return transform(N<BoolExpr>(false));
|
||||
|
@ -841,8 +843,8 @@ ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) {
|
|||
}
|
||||
}
|
||||
|
||||
bool exists = !ctx->findMethod(typ->getClass()->name, member).empty() ||
|
||||
ctx->findMember(typ->getClass()->name, member);
|
||||
bool exists = !ctx->findMethod(typ->getClass().get(), member).empty() ||
|
||||
ctx->findMember(typ->getClass(), member);
|
||||
if (exists && args.size() > 1)
|
||||
exists &= findBestMethod(typ, member, args) != nullptr;
|
||||
return transform(N<BoolExpr>(exists));
|
||||
|
@ -890,26 +892,23 @@ ExprPtr TypecheckVisitor::transformTupleFn(CallExpr *expr) {
|
|||
return expr->clone();
|
||||
|
||||
std::vector<ExprPtr> items;
|
||||
auto tn = generateTuple(ctx->cache->classes[cls->name].fields.size());
|
||||
for (auto &ft : ctx->cache->classes[cls->name].fields) {
|
||||
for (auto &ft : getClassFields(cls.get())) {
|
||||
auto t = ctx->instantiate(ft.type, cls);
|
||||
auto rt = realize(t);
|
||||
seqassert(rt, "cannot realize '{}' in {}", t, ft.name);
|
||||
items.push_back(NT<IdExpr>(t->realizedName()));
|
||||
}
|
||||
auto e = transform(NT<InstantiateExpr>(N<IdExpr>(tn), items));
|
||||
auto e = transform(NT<InstantiateExpr>(N<IdExpr>(TYPE_TUPLE), items));
|
||||
return e;
|
||||
}
|
||||
|
||||
std::vector<ExprPtr> args;
|
||||
args.reserve(ctx->cache->classes[cls->name].fields.size());
|
||||
std::string var = ctx->cache->getTemporaryVar("tup");
|
||||
for (auto &field : ctx->cache->classes[cls->name].fields)
|
||||
for (auto &field : getClassFields(cls.get()))
|
||||
args.emplace_back(N<DotExpr>(N<IdExpr>(var), field.name));
|
||||
|
||||
return transform(N<StmtExpr>(
|
||||
N<AssignStmt>(N<IdExpr>(var), expr->args.front().value),
|
||||
N<CallExpr>(N<IdExpr>(format("{}{}", TYPE_TUPLE, args.size())), args)));
|
||||
return transform(N<StmtExpr>(N<AssignStmt>(N<IdExpr>(var), expr->args.front().value),
|
||||
N<TupleExpr>(args)));
|
||||
}
|
||||
|
||||
/// Transform type function to a type IdExpr identifier.
|
||||
|
@ -1092,7 +1091,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
|
|||
return {true, nullptr};
|
||||
|
||||
size_t idx = 0;
|
||||
for (auto &f : ctx->cache->classes[typ->name].fields) {
|
||||
for (auto &f : getClassFields(typ.get())) {
|
||||
auto k = N<StringExpr>(f.name);
|
||||
auto v = N<DotExpr>(expr->args[0].value, f.name);
|
||||
if (withIdx) {
|
||||
|
@ -1119,10 +1118,10 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
|
|||
error("invalid index");
|
||||
typ = t->getRecord()->args[n];
|
||||
} else {
|
||||
if (n < 0 || n >= ctx->cache->classes[t->getClass()->name].fields.size())
|
||||
auto f = getClassFields(t->getClass().get());
|
||||
if (n < 0 || n >= f.size())
|
||||
error("invalid index");
|
||||
typ = ctx->instantiate(ctx->cache->classes[t->getClass()->name].fields[n].type,
|
||||
t->getClass());
|
||||
typ = ctx->instantiate(f[n].type, t->getClass());
|
||||
}
|
||||
typ = realize(typ);
|
||||
return {true, transform(NT<IdExpr>(typ->realizedName()))};
|
||||
|
@ -1141,8 +1140,8 @@ std::vector<ClassTypePtr> TypecheckVisitor::getSuperTypes(const ClassTypePtr &cl
|
|||
result.push_back(cls);
|
||||
for (auto &name : ctx->cache->classes[cls->name].staticParentClasses) {
|
||||
auto parentTyp = ctx->instantiate(ctx->forceFind(name)->type)->getClass();
|
||||
for (auto &field : ctx->cache->classes[cls->name].fields) {
|
||||
for (auto &parentField : ctx->cache->classes[name].fields)
|
||||
for (auto &field : getClassFields(cls.get())) {
|
||||
for (auto &parentField : getClassFields(parentTyp.get()))
|
||||
if (field.name == parentField.name) {
|
||||
unify(ctx->instantiate(field.type, cls),
|
||||
ctx->instantiate(parentField.type, parentTyp));
|
||||
|
|
|
@ -117,7 +117,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
|
|||
LOG_REALIZE(" - member: {}: {}", m.name, m.type);
|
||||
}
|
||||
|
||||
/// Generate a tuple class `Tuple.N[T1,...,TN]`.
|
||||
/// Generate a tuple class `Tuple[T1,...,TN]`.
|
||||
/// @param len Tuple length (`N`)
|
||||
/// @param name Tuple name. `Tuple` by default.
|
||||
/// Can be something else (e.g., `KwTuple`)
|
||||
|
|
|
@ -126,8 +126,7 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type,
|
|||
seqassert(collectionTyp->getRecord() &&
|
||||
collectionTyp->getRecord()->args.size() == 2,
|
||||
"bad dict");
|
||||
auto tname = generateTuple(2);
|
||||
auto tt = unify(typ, ctx->instantiate(ctx->getType(tname)))->getRecord();
|
||||
auto tt = unify(typ, ctx->instantiateTuple(2))->getRecord();
|
||||
auto nt = collectionTyp->getRecord()->args;
|
||||
for (int di = 0; di < 2; di++) {
|
||||
if (!nt[di]->getClass())
|
||||
|
@ -135,7 +134,7 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type,
|
|||
else if (auto dt = superTyp(nt[di]->getClass(), tt->args[di]->getClass()))
|
||||
nt[di] = dt;
|
||||
}
|
||||
collectionTyp = ctx->instantiateGeneric(ctx->getType(tname), nt);
|
||||
collectionTyp = ctx->instantiateTuple(nt);
|
||||
}
|
||||
}
|
||||
if (!done)
|
||||
|
@ -195,9 +194,9 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type,
|
|||
}
|
||||
|
||||
/// Transform tuples.
|
||||
/// Generate tuple classes (e.g., `Tuple.N`) if not available.
|
||||
/// Generate tuple classes (e.g., `Tuple`) if not available.
|
||||
/// @example
|
||||
/// `(a1, ..., aN)` -> `Tuple.N.__new__(a1, ..., aN)`
|
||||
/// `(a1, ..., aN)` -> `Tuple.__new__(a1, ..., aN)`
|
||||
void TypecheckVisitor::visit(TupleExpr *expr) {
|
||||
expr->setType(ctx->getUnbound());
|
||||
for (int ai = 0; ai < expr->items.size(); ai++)
|
||||
|
@ -213,7 +212,7 @@ void TypecheckVisitor::visit(TupleExpr *expr) {
|
|||
return; // continue later when the type becomes known
|
||||
if (!typ->getRecord())
|
||||
E(Error::CALL_BAD_UNPACK, star, typ->prettyString());
|
||||
auto &ff = ctx->cache->classes[typ->name].fields;
|
||||
auto ff = getClassFields(typ.get());
|
||||
for (int i = 0; i < typ->getRecord()->args.size(); i++, ai++) {
|
||||
expr->items.insert(expr->items.begin() + ai,
|
||||
transform(N<DotExpr>(clone(star->what), ff[i].name)));
|
||||
|
@ -224,15 +223,14 @@ void TypecheckVisitor::visit(TupleExpr *expr) {
|
|||
} else {
|
||||
expr->items[ai] = transform(expr->items[ai]);
|
||||
}
|
||||
auto tupleName = generateTuple(expr->items.size());
|
||||
resultExpr =
|
||||
transform(N<CallExpr>(N<DotExpr>(tupleName, "__new__"), clone(expr->items)));
|
||||
auto s = ctx->generateTuple(expr->items.size());
|
||||
resultExpr = transform(N<CallExpr>(N<IdExpr>(s), clone(expr->items)));
|
||||
unify(expr->type, resultExpr->type);
|
||||
}
|
||||
|
||||
/// Transform a tuple generator expression.
|
||||
/// @example
|
||||
/// `tuple(expr for i in tuple_generator)` -> `Tuple.N.__new__(expr...)`
|
||||
/// `tuple(expr for i in tuple_generator)` -> `Tuple.__new__(expr...)`
|
||||
void TypecheckVisitor::visit(GeneratorExpr *expr) {
|
||||
seqassert(expr->kind == GeneratorExpr::Generator && expr->loops.size() == 1 &&
|
||||
expr->loops[0].conds.empty(),
|
||||
|
@ -282,8 +280,7 @@ void TypecheckVisitor::visit(GeneratorExpr *expr) {
|
|||
}
|
||||
|
||||
auto tuple = gen->type->getRecord();
|
||||
if (!tuple ||
|
||||
!(startswith(tuple->name, TYPE_TUPLE) || startswith(tuple->name, TYPE_KWTUPLE)))
|
||||
if (!tuple || !(tuple->name == TYPE_TUPLE || startswith(tuple->name, TYPE_KWTUPLE)))
|
||||
E(Error::CALL_BAD_ITER, gen, gen->type->prettyString());
|
||||
|
||||
// `a := tuple[i]; expr...` for each i
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#include "codon/parser/ast.h"
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/parser/visitors/format/format.h"
|
||||
#include "codon/parser/visitors/simplify/ctx.h"
|
||||
#include "codon/parser/visitors/typecheck/typecheck.h"
|
||||
|
||||
using fmt::format;
|
||||
using namespace codon::error;
|
||||
|
@ -97,14 +99,18 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
|
|||
for (auto &i : genericCache) {
|
||||
if (auto l = i.second->getLink()) {
|
||||
i.second->setSrcInfo(srcInfo);
|
||||
if (l->defaultType)
|
||||
if (l->defaultType) {
|
||||
pendingDefaults.insert(i.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (t->getUnion() && !t->getUnion()->isSealed()) {
|
||||
t->setSrcInfo(srcInfo);
|
||||
pendingDefaults.insert(t);
|
||||
}
|
||||
if (auto r = t->getRecord())
|
||||
if (r->repeats && r->repeats->canRealize())
|
||||
r->flatten();
|
||||
return t;
|
||||
}
|
||||
|
||||
|
@ -126,9 +132,112 @@ TypeContext::instantiateGeneric(const SrcInfo &srcInfo, const types::TypePtr &ro
|
|||
return instantiate(srcInfo, root, g);
|
||||
}
|
||||
|
||||
std::vector<types::FuncTypePtr> TypeContext::findMethod(const std::string &typeName,
|
||||
std::shared_ptr<types::RecordType>
|
||||
TypeContext::instantiateTuple(const SrcInfo &srcInfo,
|
||||
const std::vector<types::TypePtr> &generics) {
|
||||
auto key = generateTuple(generics.size());
|
||||
auto root = forceFind(key)->type->getRecord();
|
||||
return instantiateGeneric(srcInfo, root, generics)->getRecord();
|
||||
}
|
||||
|
||||
std::string TypeContext::generateTuple(size_t n) {
|
||||
auto key = format("_{}:{}", TYPE_TUPLE, n);
|
||||
if (!in(cache->classes, key)) {
|
||||
cache->classes[key].fields.clear();
|
||||
cache->classes[key].ast =
|
||||
std::static_pointer_cast<ClassStmt>(clone(cache->classes[TYPE_TUPLE].ast));
|
||||
auto root = std::make_shared<types::RecordType>(cache, TYPE_TUPLE, TYPE_TUPLE);
|
||||
for (size_t i = 0; i < n; i++) { // generate unique ID
|
||||
auto g = getUnbound()->getLink();
|
||||
g->kind = types::LinkType::Generic;
|
||||
g->genericName = format("T{}", i + 1);
|
||||
auto gn = cache->imports[MAIN_IMPORT].ctx->generateCanonicalName(g->genericName);
|
||||
root->generics.emplace_back(gn, g->genericName, g, g->id);
|
||||
root->args.emplace_back(g);
|
||||
cache->classes[key].ast->args.emplace_back(
|
||||
g->genericName, std::make_shared<IdExpr>("type"), nullptr, Param::Generic);
|
||||
cache->classes[key].fields.push_back(
|
||||
Cache::Class::ClassField{format("item{}", i + 1), g, ""});
|
||||
}
|
||||
std::vector<ExprPtr> eTypeArgs;
|
||||
for (size_t i = 0; i < n; i++)
|
||||
eTypeArgs.push_back(std::make_shared<IdExpr>(format("T{}", i + 1)));
|
||||
auto eType = std::make_shared<InstantiateExpr>(std::make_shared<IdExpr>(TYPE_TUPLE),
|
||||
eTypeArgs);
|
||||
eType->type = root;
|
||||
cache->classes[key].mro = {eType};
|
||||
addToplevel(key, std::make_shared<TypecheckItem>(TypecheckItem::Type, root));
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
std::shared_ptr<types::RecordType> TypeContext::instantiateTuple(size_t n) {
|
||||
std::vector<types::TypePtr> t(n);
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
auto g = getUnbound()->getLink();
|
||||
g->genericName = format("T{}", i + 1);
|
||||
t[i] = g;
|
||||
}
|
||||
return instantiateTuple(getSrcInfo(), t);
|
||||
}
|
||||
|
||||
std::vector<types::FuncTypePtr> TypeContext::findMethod(types::ClassType *type,
|
||||
const std::string &method,
|
||||
bool hideShadowed) const {
|
||||
bool hideShadowed) {
|
||||
auto typeName = type->name;
|
||||
if (type->is(TYPE_TUPLE)) {
|
||||
auto sz = type->getRecord()->getRepeats();
|
||||
if (sz != -1)
|
||||
type->getRecord()->flatten();
|
||||
sz = int64_t(type->getRecord()->args.size());
|
||||
typeName = format("_{}:{}", TYPE_TUPLE, sz);
|
||||
if (in(cache->classes[TYPE_TUPLE].methods, method) &&
|
||||
!in(cache->classes[typeName].methods, method)) {
|
||||
auto type = forceFind(typeName)->type;
|
||||
|
||||
cache->classes[typeName].methods[method] =
|
||||
cache->classes[TYPE_TUPLE].methods[method];
|
||||
auto &o = cache->overloads[cache->classes[typeName].methods[method]];
|
||||
auto f = cache->functions[o[0].name];
|
||||
f.realizations.clear();
|
||||
|
||||
seqassert(f.type, "tuple fn type not yet set");
|
||||
f.ast->attributes.parentClass = typeName;
|
||||
f.ast = std::static_pointer_cast<FunctionStmt>(clone(f.ast));
|
||||
f.ast->name = format("{}{}", f.ast->name.substr(0, f.ast->name.size() - 1), sz);
|
||||
f.ast->attributes.set(Attr::Method);
|
||||
|
||||
auto eType = clone(cache->classes[typeName].mro[0]);
|
||||
eType->type = nullptr;
|
||||
for (auto &a : f.ast->args)
|
||||
if (a.type && a.type->isId(TYPE_TUPLE)) {
|
||||
a.type = eType;
|
||||
}
|
||||
if (f.ast->ret && f.ast->ret->isId(TYPE_TUPLE))
|
||||
f.ast->ret = eType;
|
||||
// TODO: resurrect Tuple[N].__new__(defaults...)
|
||||
if (method == "__new__") {
|
||||
for (size_t i = 0; i < sz; i++) {
|
||||
auto n = format("item{}", i + 1);
|
||||
f.ast->args.emplace_back(
|
||||
cache->imports[MAIN_IMPORT].ctx->generateCanonicalName(n),
|
||||
std::make_shared<IdExpr>(format("T{}", i + 1))
|
||||
// std::make_shared<CallExpr>(
|
||||
// std::make_shared<IdExpr>(format("T{}", i + 1)))
|
||||
);
|
||||
}
|
||||
}
|
||||
cache->reverseIdentifierLookup[f.ast->name] = method;
|
||||
cache->functions[f.ast->name] = f;
|
||||
cache->functions[f.ast->name].type =
|
||||
TypecheckVisitor(cache->typeCtx).makeFunctionType(f.ast.get());
|
||||
addToplevel(f.ast->name,
|
||||
std::make_shared<TypecheckItem>(TypecheckItem::Func,
|
||||
cache->functions[f.ast->name].type));
|
||||
o.push_back(Cache::Overload{f.ast->name, 0});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<types::FuncTypePtr> vv;
|
||||
std::unordered_set<std::string> signatureLoci;
|
||||
|
||||
|
@ -166,9 +275,26 @@ std::vector<types::FuncTypePtr> TypeContext::findMethod(const std::string &typeN
|
|||
return vv;
|
||||
}
|
||||
|
||||
types::TypePtr TypeContext::findMember(const std::string &typeName,
|
||||
types::TypePtr TypeContext::findMember(const types::ClassTypePtr &type,
|
||||
const std::string &member) const {
|
||||
if (auto cls = in(cache->classes, typeName)) {
|
||||
if (type->is(TYPE_TUPLE)) {
|
||||
if (!startswith(member, "item") || member.size() < 5)
|
||||
return nullptr;
|
||||
int id = 0;
|
||||
for (int i = 4; i < member.size(); i++) {
|
||||
if (member[i] >= '0' + (i == 4) && member[i] <= '9')
|
||||
id = id * 10 + member[i] - '0';
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
auto sz = type->getRecord()->getRepeats();
|
||||
if (sz != -1)
|
||||
type->getRecord()->flatten();
|
||||
if (id < 1 || id > type->getRecord()->args.size())
|
||||
return nullptr;
|
||||
return type->getRecord()->args[id - 1];
|
||||
}
|
||||
if (auto cls = in(cache->classes, type->name)) {
|
||||
for (auto &pt : cls->mro) {
|
||||
if (auto pc = pt->type->getClass()) {
|
||||
auto mc = in(cache->classes, pc->name);
|
||||
|
|
|
@ -132,14 +132,22 @@ public:
|
|||
return instantiateGeneric(getSrcInfo(), std::move(root), generics);
|
||||
}
|
||||
|
||||
std::shared_ptr<types::RecordType>
|
||||
instantiateTuple(const SrcInfo &info, const std::vector<types::TypePtr> &generics);
|
||||
std::shared_ptr<types::RecordType>
|
||||
instantiateTuple(const std::vector<types::TypePtr> &generics) {
|
||||
return instantiateTuple(getSrcInfo(), generics);
|
||||
}
|
||||
std::shared_ptr<types::RecordType> instantiateTuple(size_t n);
|
||||
std::string generateTuple(size_t n);
|
||||
|
||||
/// Returns the list of generic methods that correspond to typeName.method.
|
||||
std::vector<types::FuncTypePtr> findMethod(const std::string &typeName,
|
||||
std::vector<types::FuncTypePtr> findMethod(types::ClassType *type,
|
||||
const std::string &method,
|
||||
bool hideShadowed = true) const;
|
||||
bool hideShadowed = true);
|
||||
/// Returns the generic type of typeName.member, if it exists (nullptr otherwise).
|
||||
/// Special cases: __elemsize__ and __atomic__.
|
||||
types::TypePtr findMember(const std::string &typeName,
|
||||
const std::string &member) const;
|
||||
types::TypePtr findMember(const types::ClassTypePtr &, const std::string &) const;
|
||||
|
||||
using ReorderDoneFn =
|
||||
std::function<int(int, int, const std::vector<std::vector<int>> &, bool)>;
|
||||
|
|
|
@ -75,6 +75,42 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
|
|||
// Function should be constructed only once
|
||||
stmt->setDone();
|
||||
|
||||
auto funcTyp = makeFunctionType(stmt);
|
||||
// If this is a class method, update the method lookup table
|
||||
bool isClassMember = !stmt->attributes.parentClass.empty();
|
||||
if (isClassMember) {
|
||||
auto m =
|
||||
ctx->cache->getMethod(ctx->find(stmt->attributes.parentClass)->type->getClass(),
|
||||
ctx->cache->rev(stmt->name));
|
||||
bool found = false;
|
||||
for (auto &i : ctx->cache->overloads[m])
|
||||
if (i.name == stmt->name) {
|
||||
ctx->cache->functions[i.name].type = funcTyp;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
seqassert(found, "cannot find matching class method for {}", stmt->name);
|
||||
}
|
||||
|
||||
// Update the visited table
|
||||
// Functions should always be visible, so add them to the toplevel
|
||||
ctx->addToplevel(stmt->name,
|
||||
std::make_shared<TypecheckItem>(TypecheckItem::Func, funcTyp));
|
||||
ctx->cache->functions[stmt->name].type = funcTyp;
|
||||
|
||||
// Ensure that functions with @C, @force_realize, and @export attributes can be
|
||||
// realized
|
||||
if (stmt->attributes.has(Attr::ForceRealize) || stmt->attributes.has(Attr::Export) ||
|
||||
(stmt->attributes.has(Attr::C) && !stmt->attributes.has(Attr::CVarArg))) {
|
||||
if (!funcTyp->canRealize())
|
||||
E(Error::FN_REALIZE_BUILTIN, stmt);
|
||||
}
|
||||
|
||||
// Debug information
|
||||
LOG_REALIZE("[stmt] added func {}: {}", stmt->name, funcTyp);
|
||||
}
|
||||
|
||||
types::FuncTypePtr TypecheckVisitor::makeFunctionType(FunctionStmt *stmt) {
|
||||
// Handle generics
|
||||
bool isClassMember = !stmt->attributes.parentClass.empty();
|
||||
auto explicits = std::vector<ClassType::Generic>();
|
||||
|
@ -169,38 +205,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
|
|||
}
|
||||
funcTyp =
|
||||
std::static_pointer_cast<FuncType>(funcTyp->generalize(ctx->typecheckLevel));
|
||||
|
||||
// If this is a class method, update the method lookup table
|
||||
if (isClassMember) {
|
||||
auto m =
|
||||
ctx->cache->getMethod(ctx->find(stmt->attributes.parentClass)->type->getClass(),
|
||||
ctx->cache->rev(stmt->name));
|
||||
bool found = false;
|
||||
for (auto &i : ctx->cache->overloads[m])
|
||||
if (i.name == stmt->name) {
|
||||
ctx->cache->functions[i.name].type = funcTyp;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
seqassert(found, "cannot find matching class method for {}", stmt->name);
|
||||
}
|
||||
|
||||
// Update the visited table
|
||||
// Functions should always be visible, so add them to the toplevel
|
||||
ctx->addToplevel(stmt->name,
|
||||
std::make_shared<TypecheckItem>(TypecheckItem::Func, funcTyp));
|
||||
ctx->cache->functions[stmt->name].type = funcTyp;
|
||||
|
||||
// Ensure that functions with @C, @force_realize, and @export attributes can be
|
||||
// realized
|
||||
if (stmt->attributes.has(Attr::ForceRealize) || stmt->attributes.has(Attr::Export) ||
|
||||
(stmt->attributes.has(Attr::C) && !stmt->attributes.has(Attr::CVarArg))) {
|
||||
if (!funcTyp->canRealize())
|
||||
E(Error::FN_REALIZE_BUILTIN, stmt);
|
||||
}
|
||||
|
||||
// Debug information
|
||||
LOG_REALIZE("[stmt] added func {}: {}", stmt->name, funcTyp);
|
||||
return funcTyp;
|
||||
}
|
||||
|
||||
/// Make an empty partial call `fn(...)` for a given function.
|
||||
|
@ -232,11 +237,10 @@ ExprPtr TypecheckVisitor::partializeFunction(const types::FuncTypePtr &fn) {
|
|||
return call;
|
||||
}
|
||||
|
||||
/// Generate and return `Function[Tuple.N[args...], ret]` type
|
||||
/// Generate and return `Function[Tuple[args...], ret]` type
|
||||
std::shared_ptr<RecordType> TypecheckVisitor::getFuncTypeBase(size_t nargs) {
|
||||
auto baseType = ctx->instantiate(ctx->forceFind("Function")->type)->getRecord();
|
||||
unify(baseType->generics[0].type,
|
||||
ctx->instantiate(ctx->forceFind(generateTuple(nargs))->type)->getRecord());
|
||||
unify(baseType->generics[0].type, ctx->instantiateTuple(nargs)->getRecord());
|
||||
return baseType;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,9 +53,12 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
|
|||
ctx->getRealizationBase()->iteration++) {
|
||||
LOG_TYPECHECK("[iter] {} :: {}", ctx->getRealizationBase()->name,
|
||||
ctx->getRealizationBase()->iteration);
|
||||
if (ctx->getRealizationBase()->iteration >= MAX_TYPECHECK_ITER)
|
||||
if (ctx->getRealizationBase()->iteration >= MAX_TYPECHECK_ITER) {
|
||||
error(result, "cannot typecheck '{}' in reasonable time",
|
||||
ctx->cache->rev(ctx->getRealizationBase()->name));
|
||||
ctx->getRealizationBase()->name.empty()
|
||||
? "toplevel"
|
||||
: ctx->cache->rev(ctx->getRealizationBase()->name));
|
||||
}
|
||||
|
||||
// Keep iterating until:
|
||||
// (1) success: the statement is marked as done; or
|
||||
|
@ -183,7 +186,6 @@ types::TypePtr TypecheckVisitor::realize(types::TypePtr typ) {
|
|||
}
|
||||
e.trackRealize(fmt::format("{}{}", name, name_args), getSrcInfo());
|
||||
}
|
||||
|
||||
} else {
|
||||
e.trackRealize(typ->prettyString(), getSrcInfo());
|
||||
}
|
||||
|
@ -198,9 +200,12 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
|
|||
if (!type || !type->canRealize())
|
||||
return nullptr;
|
||||
|
||||
if (auto tr = type->getRecord())
|
||||
tr->flatten();
|
||||
|
||||
// Check if the type fields are all initialized
|
||||
// (sometimes that's not the case: e.g., `class X: x: List[X]`)
|
||||
for (auto &field : ctx->cache->classes[type->name].fields) {
|
||||
for (auto field : getClassFields(type)) {
|
||||
if (!field.type)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -247,16 +252,24 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
|
|||
std::vector<ir::types::Type *> typeArgs; // needed for IR
|
||||
std::vector<std::string> names; // needed for IR
|
||||
std::map<std::string, SrcInfo> memberInfo; // needed for IR
|
||||
for (auto &field : ctx->cache->classes[realized->name].fields) {
|
||||
if (realized->is(TYPE_TUPLE))
|
||||
realized->getRecord()->flatten();
|
||||
int i = 0;
|
||||
for (auto &field : getClassFields(realized.get())) {
|
||||
auto ftyp = ctx->instantiate(field.type, realized);
|
||||
// HACK: repeated tuples have no generics so this is needed to fix the instantiation
|
||||
// above
|
||||
if (realized->is(TYPE_TUPLE))
|
||||
unify(ftyp, realized->getRecord()->args[i]);
|
||||
|
||||
if (!realize(ftyp))
|
||||
E(Error::TYPE_CANNOT_REALIZE_ATTR, getSrcInfo(), field.name,
|
||||
ftyp->prettyString());
|
||||
LOG_REALIZE("- member: {} -> {}: {}", field.name, field.type, ftyp);
|
||||
realization->fields.emplace_back(field.name, ftyp);
|
||||
names.emplace_back(field.name);
|
||||
typeArgs.emplace_back(makeIRType(ftyp->getClass().get()));
|
||||
memberInfo[field.name] = field.type->getSrcInfo();
|
||||
i++;
|
||||
}
|
||||
|
||||
// Set IR attributes
|
||||
|
@ -431,9 +444,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
|
|||
NT<InstantiateExpr>(
|
||||
NT<IdExpr>("Function"),
|
||||
std::vector<ExprPtr>{
|
||||
NT<InstantiateExpr>(
|
||||
NT<IdExpr>(format("{}{}", TYPE_TUPLE, ids.size())),
|
||||
ids),
|
||||
NT<InstantiateExpr>(NT<IdExpr>(TYPE_TUPLE), ids),
|
||||
NT<IdExpr>(fn->getRetType()->realizedName())}),
|
||||
N<IdExpr>(fn->realizedName())),
|
||||
"__raw__")),
|
||||
|
@ -459,7 +470,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
|
|||
auto baseTyp = t->funcGenerics[0].type->getClass();
|
||||
auto derivedTyp = t->funcGenerics[1].type->getClass();
|
||||
|
||||
const auto &fields = ctx->cache->classes[derivedTyp->name].fields;
|
||||
const auto &fields = getClassFields(derivedTyp.get());
|
||||
auto types = std::vector<ExprPtr>{};
|
||||
auto found = false;
|
||||
for (auto &f : fields) {
|
||||
|
@ -471,13 +482,11 @@ StmtPtr TypecheckVisitor::prepareVTables() {
|
|||
types.push_back(NT<IdExpr>(ft->realizedName()));
|
||||
}
|
||||
}
|
||||
seqassert(found || ctx->cache->classes[baseTyp->name].fields.empty(),
|
||||
seqassert(found || getClassFields(baseTyp.get()).empty(),
|
||||
"cannot find distance between {} and {}", derivedTyp->name,
|
||||
baseTyp->name);
|
||||
StmtPtr suite = N<ReturnStmt>(
|
||||
N<DotExpr>(NT<InstantiateExpr>(
|
||||
NT<IdExpr>(format("{}{}", TYPE_TUPLE, types.size())), types),
|
||||
"__elemsize__"));
|
||||
N<DotExpr>(NT<InstantiateExpr>(NT<IdExpr>(TYPE_TUPLE), types), "__elemsize__"));
|
||||
LOG_REALIZE("[poly] {} : {}", t, *suite);
|
||||
initDist.ast->suite = suite;
|
||||
t->ast = initDist.ast.get();
|
||||
|
@ -518,7 +527,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
|
|||
->vtables[cp->realizedName()];
|
||||
|
||||
// Add or extract thunk ID
|
||||
size_t vid;
|
||||
size_t vid = 0;
|
||||
if (auto i = in(vt.table, key)) {
|
||||
vid = i->second;
|
||||
} else {
|
||||
|
@ -681,14 +690,19 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
|
|||
seqassert(types.size() == 1 && statics.size() == 1, "bad generics/statics");
|
||||
handle = module->unsafeGetVectorType(statics[0]->getInt(), types[0]);
|
||||
} else if (auto tr = t->getRecord()) {
|
||||
seqassert(tr->getRepeats() >= 0, "repeats not resolved: '{}'", tr->debugString(2));
|
||||
tr->flatten();
|
||||
std::vector<ir::types::Type *> typeArgs;
|
||||
std::vector<std::string> names;
|
||||
std::map<std::string, SrcInfo> memberInfo;
|
||||
for (int ai = 0; ai < tr->args.size(); ai++) {
|
||||
names.emplace_back(ctx->cache->classes[t->name].fields[ai].name);
|
||||
auto n = t->name == TYPE_TUPLE ? format("item{}", ai + 1)
|
||||
: ctx->cache->classes[t->name].fields[ai].name;
|
||||
names.emplace_back(n);
|
||||
typeArgs.emplace_back(forceFindIRType(tr->args[ai]));
|
||||
memberInfo[ctx->cache->classes[t->name].fields[ai].name] =
|
||||
ctx->cache->classes[t->name].fields[ai].type->getSrcInfo();
|
||||
memberInfo[n] = t->name == TYPE_TUPLE
|
||||
? tr->getSrcInfo()
|
||||
: ctx->cache->classes[t->name].fields[ai].type->getSrcInfo();
|
||||
}
|
||||
auto record =
|
||||
ir::cast<ir::types::RecordType>(module->unsafeGetMemberedType(realizedName));
|
||||
|
|
|
@ -60,8 +60,8 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
|
|||
if ((resultStmt = transformStaticForLoop(stmt)))
|
||||
return;
|
||||
|
||||
bool maybeHeterogenous = startswith(iterType->name, TYPE_TUPLE) ||
|
||||
startswith(iterType->name, TYPE_KWTUPLE);
|
||||
bool maybeHeterogenous =
|
||||
iterType->name == TYPE_TUPLE || startswith(iterType->name, TYPE_KWTUPLE);
|
||||
if (maybeHeterogenous && !iterType->canRealize()) {
|
||||
return; // wait until the tuple is fully realizable
|
||||
} else if (maybeHeterogenous && iterType->getHeterogenousTuple()) {
|
||||
|
@ -336,7 +336,7 @@ TypecheckVisitor::transformStaticLoopCall(
|
|||
error("expected three items");
|
||||
auto typ = args[0]->getClass();
|
||||
size_t idx = 0;
|
||||
for (auto &f : ctx->cache->classes[typ->name].fields) {
|
||||
for (auto &f : getClassFields(typ.get())) {
|
||||
std::vector<StmtPtr> stmts;
|
||||
if (withIdx) {
|
||||
stmts.push_back(
|
||||
|
@ -370,7 +370,7 @@ TypecheckVisitor::transformStaticLoopCall(
|
|||
seqassert(typ, "vars_types expects a realizable type, got '{}' instead",
|
||||
generics[0]);
|
||||
size_t idx = 0;
|
||||
for (auto &f : ctx->cache->classes[typ->getClass()->name].fields) {
|
||||
for (auto &f : getClassFields(typ->getClass().get())) {
|
||||
auto ta = realize(ctx->instantiate(f.type, typ->getClass()));
|
||||
seqassert(ta, "cannot realize '{}'", f.type->debugString(1));
|
||||
std::vector<StmtPtr> stmts;
|
||||
|
|
|
@ -271,15 +271,29 @@ void TypecheckVisitor::visit(IndexExpr *expr) {
|
|||
/// Instantiate(foo, [bar]) -> Id("foo[bar]")
|
||||
void TypecheckVisitor::visit(InstantiateExpr *expr) {
|
||||
transformType(expr->typeExpr);
|
||||
TypePtr typ =
|
||||
ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType());
|
||||
|
||||
std::shared_ptr<types::StaticType> repeats = nullptr;
|
||||
if (expr->typeExpr->isId(TYPE_TUPLE) && !expr->typeParams.empty()) {
|
||||
transform(expr->typeParams[0]);
|
||||
if (expr->typeParams[0]->staticValue.type == StaticValue::INT) {
|
||||
repeats = Type::makeStatic(ctx->cache, expr->typeParams[0]);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr typ = nullptr;
|
||||
size_t typeParamsSize = expr->typeParams.size() - (repeats != nullptr);
|
||||
if (expr->typeExpr->isId(TYPE_TUPLE)) {
|
||||
typ = ctx->instantiateTuple(typeParamsSize);
|
||||
} else {
|
||||
typ = ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType());
|
||||
}
|
||||
seqassert(typ->getClass(), "unknown type: {}", expr->typeExpr);
|
||||
|
||||
auto &generics = typ->getClass()->generics;
|
||||
bool isUnion = typ->getUnion() != nullptr;
|
||||
if (!isUnion && expr->typeParams.size() != generics.size())
|
||||
if (!isUnion && typeParamsSize != generics.size())
|
||||
E(Error::GENERICS_MISMATCH, expr, ctx->cache->rev(typ->getClass()->name),
|
||||
generics.size(), expr->typeParams.size());
|
||||
generics.size(), typeParamsSize);
|
||||
|
||||
if (expr->typeExpr->isId(TYPE_CALLABLE)) {
|
||||
// Case: Callable[...] trait instantiation
|
||||
|
@ -303,7 +317,7 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
|
|||
typ->getLink()->trait = std::make_shared<TypeTrait>(expr->typeParams[0]->type);
|
||||
unify(expr->type, typ);
|
||||
} else {
|
||||
for (size_t i = 0; i < expr->typeParams.size(); i++) {
|
||||
for (size_t i = (repeats != nullptr); i < expr->typeParams.size(); i++) {
|
||||
transform(expr->typeParams[i]);
|
||||
TypePtr t = nullptr;
|
||||
if (expr->typeParams[i]->isStatic()) {
|
||||
|
@ -319,7 +333,10 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
|
|||
if (isUnion)
|
||||
typ->getUnion()->addType(t);
|
||||
else
|
||||
unify(t, generics[i].type);
|
||||
unify(t, generics[i - (repeats != nullptr)].type);
|
||||
}
|
||||
if (repeats) {
|
||||
typ->getRecord()->repeats = repeats;
|
||||
}
|
||||
if (isUnion) {
|
||||
typ->getUnion()->seal();
|
||||
|
@ -714,7 +731,7 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
|
|||
const ExprPtr &expr, const ExprPtr &index) {
|
||||
if (!tuple->getRecord())
|
||||
return {false, nullptr};
|
||||
if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_KWTUPLE) &&
|
||||
if (tuple->name != TYPE_TUPLE && !startswith(tuple->name, TYPE_KWTUPLE) &&
|
||||
!startswith(tuple->name, TYPE_PARTIAL)) {
|
||||
if (tuple->is(TYPE_OPTIONAL)) {
|
||||
if (auto newTuple = tuple->generics[0].type->getClass()) {
|
||||
|
@ -743,16 +760,15 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
|
|||
return false;
|
||||
};
|
||||
|
||||
auto classItem = in(ctx->cache->classes, tuple->name);
|
||||
seqassert(classItem, "cannot find class '{}'", tuple->name);
|
||||
auto sz = classItem->fields.size();
|
||||
auto classFields = getClassFields(tuple.get());
|
||||
auto sz = int64_t(tuple->getRecord()->args.size());
|
||||
int64_t start = 0, stop = sz, step = 1;
|
||||
if (getInt(&start, index)) {
|
||||
// Case: `tuple[int]`
|
||||
auto i = translateIndex(start, stop);
|
||||
if (i < 0 || i >= stop)
|
||||
E(Error::TUPLE_RANGE_BOUNDS, index, stop - 1, i);
|
||||
return {true, transform(N<DotExpr>(expr, classItem->fields[i].name))};
|
||||
return {true, transform(N<DotExpr>(expr, classFields[i].name))};
|
||||
} else if (auto slice = CAST(index, SliceExpr)) {
|
||||
// Case: `tuple[int:int:int]`
|
||||
if (!getInt(&start, slice->start) || !getInt(&stop, slice->stop) ||
|
||||
|
@ -773,11 +789,11 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
|
|||
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) {
|
||||
if (i < 0 || i >= sz)
|
||||
E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i);
|
||||
te.push_back(N<DotExpr>(clone(var), classItem->fields[i].name));
|
||||
te.push_back(N<DotExpr>(clone(var), classFields[i].name));
|
||||
}
|
||||
ExprPtr e = transform(N<StmtExpr>(
|
||||
std::vector<StmtPtr>{ass},
|
||||
N<CallExpr>(N<DotExpr>(format(TYPE_TUPLE "{}", te.size()), "__new__"), te)));
|
||||
ExprPtr e = transform(
|
||||
N<StmtExpr>(std::vector<StmtPtr>{ass},
|
||||
N<CallExpr>(N<DotExpr>(N<IdExpr>(TYPE_TUPLE), "__new__"), te)));
|
||||
return {true, e};
|
||||
}
|
||||
|
||||
|
|
|
@ -61,8 +61,9 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
|
|||
}
|
||||
seqassert(expr->type, "type not set for {}", expr);
|
||||
unify(typ, expr->type);
|
||||
if (expr->done)
|
||||
if (expr->done) {
|
||||
ctx->changedNodes++;
|
||||
}
|
||||
}
|
||||
realize(typ);
|
||||
LOG_TYPECHECK("[expr] {}: {}{}", getSrcInfo(), expr, expr->isDone() ? "[done]" : "");
|
||||
|
@ -182,7 +183,7 @@ TypecheckVisitor::findBestMethod(const ClassTypePtr &typ, const std::string &mem
|
|||
callArgs.push_back({"", std::make_shared<NoneExpr>()}); // dummy expression
|
||||
callArgs.back().value->setType(a);
|
||||
}
|
||||
auto methods = ctx->findMethod(typ->name, member, false);
|
||||
auto methods = ctx->findMethod(typ.get(), member, false);
|
||||
auto m = findMatchingMethods(typ, methods, callArgs);
|
||||
return m.empty() ? nullptr : m[0];
|
||||
}
|
||||
|
@ -195,7 +196,7 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(const ClassTypePtr &typ,
|
|||
std::vector<CallExpr::Arg> callArgs;
|
||||
for (auto &a : args)
|
||||
callArgs.push_back({"", a});
|
||||
auto methods = ctx->findMethod(typ->name, member, false);
|
||||
auto methods = ctx->findMethod(typ.get(), member, false);
|
||||
auto m = findMatchingMethods(typ, methods, callArgs);
|
||||
return m.empty() ? nullptr : m[0];
|
||||
}
|
||||
|
@ -210,7 +211,7 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(
|
|||
callArgs.push_back({n, std::make_shared<NoneExpr>()}); // dummy expression
|
||||
callArgs.back().value->setType(a);
|
||||
}
|
||||
auto methods = ctx->findMethod(typ->name, member, false);
|
||||
auto methods = ctx->findMethod(typ.get(), member, false);
|
||||
auto m = findMatchingMethods(typ, methods, callArgs);
|
||||
return m.empty() ? nullptr : m[0];
|
||||
}
|
||||
|
@ -451,8 +452,8 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
|
|||
ExprPtr TypecheckVisitor::castToSuperClass(ExprPtr expr, ClassTypePtr superTyp,
|
||||
bool isVirtual) {
|
||||
ClassTypePtr typ = expr->type->getClass();
|
||||
for (auto &field : ctx->cache->classes[typ->name].fields) {
|
||||
for (auto &parentField : ctx->cache->classes[superTyp->name].fields)
|
||||
for (auto &field : getClassFields(typ.get())) {
|
||||
for (auto &parentField : getClassFields(superTyp.get()))
|
||||
if (field.name == parentField.name) {
|
||||
unify(ctx->instantiate(field.type, typ),
|
||||
ctx->instantiate(parentField.type, superTyp));
|
||||
|
@ -493,4 +494,16 @@ TypecheckVisitor::unpackTupleTypes(ExprPtr expr) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::vector<Cache::Class::ClassField> &
|
||||
TypecheckVisitor::getClassFields(types::ClassType *t) {
|
||||
seqassert(t && in(ctx->cache->classes, t->name), "cannot find '{}'",
|
||||
t ? t->name : "<null>");
|
||||
if (t->is(TYPE_TUPLE) && !t->getRecord()->args.empty()) {
|
||||
auto key = ctx->generateTuple(t->getRecord()->args.size());
|
||||
return ctx->cache->classes[key].fields;
|
||||
} else {
|
||||
return ctx->cache->classes[t->name].fields;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace codon::ast
|
||||
|
|
|
@ -177,6 +177,10 @@ private: // Node typechecking rules
|
|||
ExprPtr partializeFunction(const types::FuncTypePtr &);
|
||||
std::shared_ptr<types::RecordType> getFuncTypeBase(size_t);
|
||||
|
||||
public:
|
||||
types::FuncTypePtr makeFunctionType(FunctionStmt *);
|
||||
|
||||
private:
|
||||
/* Classes (class.cpp) */
|
||||
void visit(ClassStmt *) override;
|
||||
void parseBaseClasses(ClassStmt *);
|
||||
|
@ -227,7 +231,8 @@ private:
|
|||
StmtPtr prepareVTables();
|
||||
|
||||
public:
|
||||
bool isTuple(const std::string &s) const { return startswith(s, TYPE_TUPLE); }
|
||||
bool isTuple(const std::string &s) const { return s == TYPE_TUPLE; }
|
||||
std::vector<Cache::Class::ClassField> &getClassFields(types::ClassType *);
|
||||
|
||||
friend class Cache;
|
||||
friend class types::CallableTrait;
|
||||
|
|
|
@ -19,7 +19,8 @@
|
|||
|
||||
#define DBG(c, ...) \
|
||||
fmt::print(codon::getLogger().log, "{}" c "\n", \
|
||||
std::string(2 * codon::getLogger().level, ' '), ##__VA_ARGS__)
|
||||
std::string(size_t(2) * size_t(codon::getLogger().level), ' '), \
|
||||
##__VA_ARGS__)
|
||||
#define LOG(c, ...) DBG(c, ##__VA_ARGS__)
|
||||
#define LOG_TIME(c, ...) \
|
||||
{ \
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
|
||||
@tuple(container=False) # disallow default __getitem__
|
||||
class Vec[T, N: Static[int]]:
|
||||
ZERO_16x8i = Vec[u8,16](u8(0))
|
||||
|
@ -7,18 +8,50 @@ class Vec[T, N: Static[int]]:
|
|||
ZERO_32x8i = Vec[u8,32](u8(0))
|
||||
FF_32x8i = Vec[u8,32](u8(0xff))
|
||||
|
||||
ZERO_4x64f = Vec[f64,4](0.0)
|
||||
|
||||
@llvm
|
||||
def _mm_set1_epi8(val: u8) -> Vec[u8, 16]:
|
||||
%0 = insertelement <16 x i8> undef, i8 %val, i32 0
|
||||
%1 = shufflevector <16 x i8> %0, <16 x i8> undef, <16 x i32> zeroinitializer
|
||||
ret <16 x i8> %1
|
||||
|
||||
@llvm
|
||||
def _mm_set1_epi32(val: u32) -> Vec[u32, 4]:
|
||||
%0 = insertelement <4 x i32> undef, i32 %val, i32 0
|
||||
%1 = shufflevector <4 x i32> %0, <4 x i32> undef, <4 x i32> zeroinitializer
|
||||
ret <4 x i32> %1
|
||||
|
||||
@llvm
|
||||
def _mm256_set1_epi8(val: u8) -> Vec[u8, 32]:
|
||||
%0 = insertelement <32 x i8> undef, i8 %val, i32 0
|
||||
%1 = shufflevector <32 x i8> %0, <32 x i8> undef, <32 x i32> zeroinitializer
|
||||
ret <32 x i8> %1
|
||||
|
||||
@llvm
|
||||
def _mm256_set1_epi32(val: u32) -> Vec[u32, 8]:
|
||||
%0 = insertelement <8 x i32> undef, i32 %val, i32 0
|
||||
%1 = shufflevector <8 x i32> %0, <8 x i32> undef, <8 x i32> zeroinitializer
|
||||
ret <8 x i32> %1
|
||||
|
||||
@llvm
|
||||
def _mm256_set1_epi64x(val: u64) -> Vec[u64, 4]:
|
||||
%0 = insertelement <4 x i64> undef, i64 %val, i32 0
|
||||
%1 = shufflevector <4 x i64> %0, <4 x i64> undef, <4 x i32> zeroinitializer
|
||||
ret <4 x i64> %1
|
||||
|
||||
@llvm
|
||||
def _mm512_set1_epi64(val: u64) -> Vec[u64, 8]:
|
||||
%0 = insertelement <8 x i64> undef, i64 %val, i32 0
|
||||
%1 = shufflevector <8 x i64> %0, <8 x i64> undef, <8 x i32> zeroinitializer
|
||||
ret <8 x i64> %1
|
||||
|
||||
@llvm
|
||||
def _mm_load_epi32(data) -> Vec[u32, 4]:
|
||||
%0 = bitcast i32* %data to <4 x i32>*
|
||||
%1 = load <4 x i32>, <4 x i32>* %0, align 1
|
||||
ret <4 x i32> %1
|
||||
|
||||
@llvm
|
||||
def _mm_loadu_si128(data) -> Vec[u8, 16]:
|
||||
%0 = bitcast i8* %data to <16 x i8>*
|
||||
|
@ -31,12 +64,42 @@ class Vec[T, N: Static[int]]:
|
|||
%1 = load <32 x i8>, <32 x i8>* %0, align 1
|
||||
ret <32 x i8> %1
|
||||
|
||||
@llvm
|
||||
def _mm256_load_epi32(data) -> Vec[u32, 8]:
|
||||
%0 = bitcast i32* %data to <8 x i32>*
|
||||
%1 = load <8 x i32>, <8 x i32>* %0
|
||||
ret <8 x i32> %1
|
||||
|
||||
@llvm
|
||||
def _mm256_load_epi64(data) -> Vec[u64, 4]:
|
||||
%0 = bitcast i64* %data to <4 x i64>*
|
||||
%1 = load <4 x i64>, <4 x i64>* %0
|
||||
ret <4 x i64> %1
|
||||
|
||||
@llvm
|
||||
def _mm512_load_epi64(data) -> Vec[u64, 8]:
|
||||
%0 = bitcast i64* %data to <8 x i64>*
|
||||
%1 = load <8 x i64>, <8 x i64>* %0
|
||||
ret <8 x i64> %1
|
||||
|
||||
@llvm
|
||||
def _mm256_set1_ps(val: f32) -> Vec[f32, 8]:
|
||||
%0 = insertelement <8 x float> undef, float %val, i32 0
|
||||
%1 = shufflevector <8 x float> %0, <8 x float> undef, <8 x i32> zeroinitializer
|
||||
ret <8 x float> %1
|
||||
|
||||
@llvm
|
||||
def _mm256_set1_pd(val: f64) -> Vec[f64, 4]:
|
||||
%0 = insertelement <4 x double> undef, double %val, i64 0
|
||||
%1 = shufflevector <4 x double> %0, <4 x double> undef, <4 x i32> zeroinitializer
|
||||
ret <4 x double> %1
|
||||
|
||||
@llvm
|
||||
def _mm512_set1_pd(val: f64) -> Vec[f64, 8]:
|
||||
%0 = insertelement <8 x double> undef, double %val, i64 0
|
||||
%1 = shufflevector <8 x double> %0, <8 x double> undef, <8 x i32> zeroinitializer
|
||||
ret <8 x double> %1
|
||||
|
||||
@llvm
|
||||
def _mm512_set1_ps(val: f32) -> Vec[f32, 16]:
|
||||
%0 = insertelement <16 x float> undef, float %val, i32 0
|
||||
|
@ -49,6 +112,18 @@ class Vec[T, N: Static[int]]:
|
|||
%1 = load <8 x float>, <8 x float>* %0
|
||||
ret <8 x float> %1
|
||||
|
||||
@llvm
|
||||
def _mm256_loadu_pd(data: Ptr[f64]) -> Vec[f64, 4]:
|
||||
%0 = bitcast double* %data to <4 x double>*
|
||||
%1 = load <4 x double>, <4 x double>* %0
|
||||
ret <4 x double> %1
|
||||
|
||||
@llvm
|
||||
def _mm512_loadu_pd(data: Ptr[f64]) -> Vec[f64, 8]:
|
||||
%0 = bitcast double* %data to <8 x double>*
|
||||
%1 = load <8 x double>, <8 x double>* %0
|
||||
ret <8 x double> %1
|
||||
|
||||
@llvm
|
||||
def _mm512_loadu_ps(data: Ptr[f32]) -> Vec[f32, 16]:
|
||||
%0 = bitcast float* %data to <16 x float>*
|
||||
|
@ -72,6 +147,11 @@ class Vec[T, N: Static[int]]:
|
|||
%0 = bitcast <8 x i32> %vec to <8 x float>
|
||||
ret <8 x float> %0
|
||||
|
||||
@llvm
|
||||
def _mm256_castsi256_pd(vec: Vec[u64, 4]) -> Vec[f64, 4]:
|
||||
%0 = bitcast <4 x i64> %vec to <4 x double>
|
||||
ret <4 x double> %0
|
||||
|
||||
@llvm
|
||||
def _mm512_castsi512_ps(vec: Vec[u32, 16]) -> Vec[f32, 16]:
|
||||
%0 = bitcast <16 x i32> %vec to <16 x float>
|
||||
|
@ -92,10 +172,58 @@ class Vec[T, N: Static[int]]:
|
|||
return Vec._mm256_loadu_si256(x)
|
||||
if isinstance(x, str):
|
||||
return Vec._mm256_loadu_si256(x.ptr)
|
||||
if isinstance(T, u32) and N == 4:
|
||||
if isinstance(x, int):
|
||||
assert x >= 0, "SIMD: No support for negative int vectors added yet."
|
||||
return Vec._mm_set1_epi32(u32(x))
|
||||
if isinstance(x, u32):
|
||||
return Vec._mm_set1_epi32(x)
|
||||
if isinstance(x, Ptr[u32]):
|
||||
return Vec._mm_load_epi32(x)
|
||||
if isinstance(x, List[u32]):
|
||||
return Vec._mm_load_epi32(x.arr.ptr)
|
||||
if isinstance(x, str):
|
||||
return Vec._mm_load_epi32(x.ptr)
|
||||
if isinstance(T, u32) and N == 8:
|
||||
if isinstance(x, int):
|
||||
assert x >= 0, "SIMD: No support for negative int vectors added yet."
|
||||
return Vec._mm256_set1_epi32(u32(x))
|
||||
if isinstance(x, u64):
|
||||
return Vec._mm256_set1_epi32(x)
|
||||
if isinstance(x, Ptr[u64]):
|
||||
return Vec._mm256_load_epi32(x)
|
||||
if isinstance(x, List[u64]):
|
||||
return Vec._mm256_load_epi32(x.arr.ptr)
|
||||
if isinstance(x, str):
|
||||
return Vec._mm256_load_epi32(x.ptr)
|
||||
if isinstance(T, u64) and N == 4:
|
||||
if isinstance(x, int):
|
||||
assert x >= 0, "SIMD: No support for negative int vectors added yet."
|
||||
return Vec._mm256_set1_epi64x(u64(x))
|
||||
if isinstance(x, u64):
|
||||
return Vec._mm256_set1_epi64x(x)
|
||||
if isinstance(x, Ptr[u64]):
|
||||
return Vec._mm256_load_epi64(x)
|
||||
if isinstance(x, List[u64]):
|
||||
return Vec._mm256_load_epi64(x.arr.ptr)
|
||||
if isinstance(x, str):
|
||||
return Vec._mm256_load_epi64(x.ptr)
|
||||
if isinstance(T, u64) and N == 8:
|
||||
if isinstance(x, int):
|
||||
assert x >= 0, "SIMD: No support for negative int vectors added yet."
|
||||
return Vec._mm512_set1_epi64(u64(x))
|
||||
if isinstance(x, u64):
|
||||
return Vec._mm512_set1_epi64(x)
|
||||
if isinstance(x, Ptr[u64]):
|
||||
return Vec._mm512_load_epi64(x)
|
||||
if isinstance(x, List[u64]):
|
||||
return Vec._mm512_load_epi64(x.arr.ptr)
|
||||
if isinstance(x, str):
|
||||
return Vec._mm512_load_epi64(x.ptr)
|
||||
if isinstance(T, f32) and N == 8:
|
||||
if isinstance(x, f32):
|
||||
return Vec._mm256_set1_ps(x)
|
||||
if isinstance(x, Ptr[f32]): # TODO: multi-elif does NOT work with statics [why?!]
|
||||
if isinstance(x, Ptr[f32]):
|
||||
return Vec._mm256_loadu_ps(x)
|
||||
if isinstance(x, List[f32]):
|
||||
return Vec._mm256_loadu_ps(x.arr.ptr)
|
||||
|
@ -104,12 +232,26 @@ class Vec[T, N: Static[int]]:
|
|||
if isinstance(T, f32) and N == 16:
|
||||
if isinstance(x, f32):
|
||||
return Vec._mm512_set1_ps(x)
|
||||
if isinstance(x, Ptr[f32]): # TODO: multi-elif does NOT work with statics [why?!]
|
||||
if isinstance(x, Ptr[f32]):
|
||||
return Vec._mm512_loadu_ps(x)
|
||||
if isinstance(x, List[f32]):
|
||||
return Vec._mm512_loadu_ps(x.arr.ptr)
|
||||
if isinstance(x, Vec[u8, 32]):
|
||||
return Vec._mm512_castsi512_ps(Vec._mm512_cvtepi8_epi64(x))
|
||||
if isinstance(T, f64) and N == 4:
|
||||
if isinstance(x, f64):
|
||||
return Vec._mm256_set1_pd(x)
|
||||
if isinstance(x, Ptr[f64]):
|
||||
return Vec._mm256_loadu_pd(x)
|
||||
if isinstance(x, List[f64]):
|
||||
return Vec._mm256_loadu_pd(x.arr.ptr)
|
||||
if isinstance(T, f64) and N == 8:
|
||||
if isinstance(x, f64):
|
||||
return Vec._mm512_set1_pd(x)
|
||||
if isinstance(x, Ptr[f64]):
|
||||
return Vec._mm512_loadu_pd(x)
|
||||
if isinstance(x, List[f64]):
|
||||
return Vec._mm512_loadu_pd(x.arr.ptr)
|
||||
compile_error("invalid SIMD vector constructor")
|
||||
|
||||
def __new__(x: str, offset: int = 0) -> Vec[u8, N]:
|
||||
|
@ -185,6 +327,14 @@ class Vec[T, N: Static[int]]:
|
|||
def __and__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]:
|
||||
return Vec._mm_and_si256(self, other)
|
||||
|
||||
@llvm
|
||||
def _mm256_and_si256(x: Vec[u64, 4], y: Vec[u64, 4]) -> Vec[u64, 4]:
|
||||
%0 = and <4 x i64> %x, %y
|
||||
ret <4 x i64> %0
|
||||
|
||||
def __and__(self: Vec[u64, 4], other: Vec[u64, 4]) -> Vec[u64, 4]:
|
||||
return Vec._mm256_and_si256(self, other)
|
||||
|
||||
@llvm
|
||||
def _mm256_and_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]:
|
||||
%0 = bitcast <8 x float> %x to <8 x i32>
|
||||
|
@ -255,9 +405,6 @@ class Vec[T, N: Static[int]]:
|
|||
%0 = fadd <8 x float> %x, %y
|
||||
ret <8 x float> %0
|
||||
|
||||
def __add__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]:
|
||||
return Vec._mm256_add_ps(self, other)
|
||||
|
||||
def __rshift__(self: Vec[u8, 16], shift: Static[int]) -> Vec[u8, 16]:
|
||||
if shift == 0:
|
||||
return self
|
||||
|
@ -294,18 +441,428 @@ class Vec[T, N: Static[int]]:
|
|||
def sum(self: Vec[f32, 8], x: f32 = f32(0.0)) -> f32:
|
||||
return x + self[0] + self[1] + self[2] + self[3] + self[4] + self[5] + self[6] + self[7]
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{','.join(self.scatter())}>"
|
||||
|
||||
# Methods below added ad-hoc. TODO: Add Intel intrinsics equivalents for them and integrate them neatly above.
|
||||
|
||||
# Constructors
|
||||
@llvm
|
||||
def __getitem__(self, n: Static[int]) -> T:
|
||||
%0 = extractelement <{=N} x {=T}> %self, i32 {=n}
|
||||
def generic_init[SLS: Static[int]](val: T) -> Vec[T, SLS]:
|
||||
%0 = insertelement <{=SLS} x {=T}> undef, {=T} %val, i32 0
|
||||
%1 = shufflevector <{=SLS} x {=T}> %0, <{=SLS} x {=T}> undef, <{=SLS} x i32> zeroinitializer
|
||||
ret <{=SLS} x {=T}> %1
|
||||
|
||||
@llvm
|
||||
def generic_load[SLS: Static[int]](data: Ptr[T]) -> Vec[T, SLS]:
|
||||
%0 = bitcast T* %data to <{=SLS} x {=T}>*
|
||||
%1 = load <{=SLS} x {=T}>, <{=SLS} x {=T}>* %0
|
||||
ret <{=SLS} x {=T}> %1
|
||||
|
||||
# Bitwise intrinsics
|
||||
@llvm
|
||||
def __and__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = and <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def __and__(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = and <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
@llvm
|
||||
def __or__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = or <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def __or__(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = or <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
@llvm
|
||||
def __xor__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = xor <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def __xor__(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = xor <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
@llvm
|
||||
def __lshift__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = shl <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def __lshift__(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = shl <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
@llvm
|
||||
def __rshift__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = lshr <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def __rshift__(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = lshr <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
@llvm
|
||||
def bit_flip(self: Vec[u1, N]) -> Vec[u1, N]:
|
||||
%0 = insertelement <{=N} x i1> undef, i1 1, i32 0
|
||||
%1 = shufflevector <{=N} x i1> %0, <{=N} x i1> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = xor <{=N} x i1> %self, %1
|
||||
ret <{=N} x i1> %2
|
||||
|
||||
@llvm
|
||||
def shift_half(self: Vec[u128, N]) -> Vec[u128, N]:
|
||||
%0 = insertelement <{=N} x i128> undef, i128 64, i32 0
|
||||
%1 = shufflevector <{=N} x i128> %0, <{=N} x i128> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = lshr <{=N} x i128> %self, %1
|
||||
ret <{=N} x i128> %2
|
||||
|
||||
# Comparisons
|
||||
@llvm
|
||||
def __ge__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u1, N]:
|
||||
%0 = icmp uge <{=N} x i64> %self, %other
|
||||
ret <{=N} x i1> %0
|
||||
|
||||
@llvm
|
||||
def __ge__(self: Vec[u64, N], other: u64) -> Vec[u1, N]:
|
||||
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
|
||||
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = icmp uge <{=N} x i64> %self, %1
|
||||
ret <{=N} x i1> %2
|
||||
|
||||
@llvm
|
||||
def __ge__(self: Vec[f64, N], other: Vec[f64, N]) -> Vec[u1, N]:
|
||||
%0 = fcmp oge <{=N} x double> %self, %other
|
||||
ret <{=N} x i1> %0
|
||||
|
||||
@llvm
|
||||
def __ge__(self: Vec[f64, N], other: f64) -> Vec[u1, N]:
|
||||
%0 = insertelement <{=N} x double> undef, double %other, i32 0
|
||||
%1 = shufflevector <{=N} x double> %0, <{=N} x double> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = fcmp oge <{=N} x double> %self, %1
|
||||
ret <{=N} x i1> %2
|
||||
|
||||
@llvm
|
||||
def __le__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u1, N]:
|
||||
%0 = icmp ule <{=N} x i64> %self, %other
|
||||
ret <{=N} x i1> %0
|
||||
|
||||
@llvm
|
||||
def __le__(self: Vec[u64, N], other: u64) -> Vec[u1, N]:
|
||||
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
|
||||
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = icmp ule <{=N} x i64> %self, %1
|
||||
ret <{=N} x i1> %2
|
||||
|
||||
# Arithmetic intrinsics
|
||||
@llvm
|
||||
def __neg__(self: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = sub <{=N} x {=T}> zeroinitializer, %self
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def __mod__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u64, N]:
|
||||
%0 = urem <{=N} x i64> %self, %other
|
||||
ret <{=N} x i64> %0
|
||||
|
||||
@llvm
|
||||
def __mod__(self: Vec[u64, N], other: u64) -> Vec[u64, N]:
|
||||
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
|
||||
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = urem <{=N} x i64> %self, %1
|
||||
ret <{=N} x i64> %2
|
||||
|
||||
@llvm
|
||||
def add(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = add <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def add(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = add <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
@llvm
|
||||
def fadd(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = fadd <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def fadd(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = fadd <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
def __add__(self: Vec[T, N], other) -> Vec[T, N]:
|
||||
if isinstance(T, u8) or isinstance(T, u64) or isinstance(T, u128):
|
||||
return self.add(other)
|
||||
if isinstance(T, f32) or isinstance(T, f64):
|
||||
return self.fadd(other)
|
||||
compile_error("invalid SIMD vector addition")
|
||||
|
||||
@llvm
|
||||
def sub(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = sub <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def sub(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = sub <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
@llvm
|
||||
def fsub(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = fsub <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def fsub(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = fsub <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
def __sub__(self: Vec[T, N], other) -> Vec[T, N]:
|
||||
if isinstance(T, u8) or isinstance(T, u64):
|
||||
return self.sub(other)
|
||||
if isinstance(T, f32) or isinstance(T, f64):
|
||||
return self.fsub(other)
|
||||
compile_error("invalid SIMD vector subtraction")
|
||||
|
||||
@llvm
|
||||
def mul(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = mul <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def fmul(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = fmul <{=N} x {=T}> %self, %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
@llvm
|
||||
def mul(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = mul <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
@llvm
|
||||
def fmul(self: Vec[T, N], other: T) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = fmul <{=N} x {=T}> %self, %1
|
||||
ret <{=N} x {=T}> %2
|
||||
|
||||
def __mul__(self: Vec[T, N], other) -> Vec[T, N]:
|
||||
if isinstance(T, u8) or isinstance(T, u64) or isinstance(T, u128):
|
||||
return self.mul(other)
|
||||
if isinstance(T, f32) or isinstance(T, f64):
|
||||
return self.fmul(other)
|
||||
compile_error("invalid SIMD vector multiplication")
|
||||
|
||||
@llvm
|
||||
def __truediv__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[f64, N]:
|
||||
%0 = uitofp <{=N} x i64> %self to <{=N} x double>
|
||||
%1 = uitofp <{=N} x i64> %other to <{=N} x double>
|
||||
%2 = fdiv <{=N} x double> %0, %1
|
||||
ret <{=N} x double> %2
|
||||
|
||||
@llvm
|
||||
def __truediv__(self: Vec[u64, 4], other: Vec[u64, 4]) -> Vec[f64, 4]:
|
||||
%0 = uitofp <4 x i64> %self to <4 x double>
|
||||
%1 = uitofp <4 x i64> %other to <4 x double>
|
||||
%2 = fdiv <4 x double> %0, %1
|
||||
ret <4 x double> %2
|
||||
|
||||
@llvm
|
||||
def __truediv__(self: Vec[u64, 8], other: Vec[u64, 8]) -> Vec[f64, 8]:
|
||||
%0 = uitofp <8 x i64> %self to <8 x double>
|
||||
%1 = uitofp <8 x i64> %other to <8 x double>
|
||||
%2 = fdiv <8 x double> %0, %1
|
||||
ret <8 x double> %2
|
||||
|
||||
@llvm
|
||||
def __truediv__(self: Vec[u64, 8], other: u64) -> Vec[f64, 8]:
|
||||
%0 = uitofp <8 x i64> %self to <8 x double>
|
||||
%1 = uitofp i64 %other to double
|
||||
%2 = insertelement <8 x double> undef, double %1, i32 0
|
||||
%3 = shufflevector <8 x double> %0, <8 x double> undef, <8 x i32> zeroinitializer
|
||||
%4 = fdiv <8 x double> %0, %3
|
||||
ret <8 x double> %4
|
||||
|
||||
@llvm
|
||||
def add_overflow(self: Vec[u64, N], other: Vec[u64, N]) -> Tuple[Vec[u64, N], Vec[u1, N]]:
|
||||
declare {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
|
||||
%0 = call {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %other)
|
||||
ret {<{=N} x i64>, <{=N} x i1>} %0
|
||||
|
||||
@llvm
|
||||
def add_overflow(self: Vec[u64, N], other: u64) -> Tuple[Vec[u64, N], Vec[u1, N]]:
|
||||
declare {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
|
||||
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
|
||||
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = call {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %1)
|
||||
ret {<{=N} x i64>, <{=N} x i1>} %2
|
||||
|
||||
@llvm
|
||||
def sub_overflow(self: Vec[u64, N], other: Vec[u64, N]) -> Tuple[Vec[u64, N], Vec[u1, N]]:
|
||||
declare {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
|
||||
%0 = call {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %other)
|
||||
ret {<{=N} x i64>, <{=N} x i1>} %0
|
||||
|
||||
def sub_overflow_commutative(self: Vec[u64, N], other: u64) -> Tuple[Vec[u64, N], Vec[u1, N]]:
|
||||
return Vec[u64, N](other).sub_overflow(self)
|
||||
|
||||
@llvm
|
||||
def sub_overflow(self: Vec[u64, N], other: u64) -> Tuple[Vec[u64, N], Vec[u1, N]]:
|
||||
declare {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
|
||||
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
|
||||
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = call {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %1)
|
||||
ret {<{=N} x i64>, <{=N} x i1>} %2
|
||||
|
||||
@llvm
|
||||
def zext_mul(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u128, N]:
|
||||
%0 = zext <{=N} x i64> %self to <{=N} x i128>
|
||||
%1 = zext <{=N} x i64> %other to <{=N} x i128>
|
||||
%2 = mul nuw <{=N} x i128> %0, %1
|
||||
ret <{=N} x i128> %2
|
||||
|
||||
@llvm
|
||||
def zext_mul(self: Vec[u64, N], other: u64) -> Vec[u128, N]:
|
||||
%0 = zext <{=N} x i64> %self to <{=N} x i128>
|
||||
%1 = insertelement <{=N} x i64> undef, i64 %other, i32 0
|
||||
%2 = shufflevector <{=N} x i64> %1, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
|
||||
%3 = zext <{=N} x i64> %2 to <{=N} x i128>
|
||||
%4 = mul nuw <{=N} x i128> %0, %3
|
||||
ret <{=N} x i128> %4
|
||||
|
||||
@nocapture
|
||||
@llvm
|
||||
def mulx(self: Vec[u64, N], other: Vec[u64, N], hi: Ptr[Vec[u64, N]]) -> Vec[u64, N]:
|
||||
%0 = zext <{=N} x i64> %self to <{=N} x i128>
|
||||
%1 = zext <{=N} x i64> %other to <{=N} x i128>
|
||||
%2 = mul nuw <{=N} x i128> %0, %1
|
||||
%3 = lshr <{=N} x i128> %2, <i128 64, i128 64, i128 64, i128 64, i128 64, i128 64, i128 64, i128 64>
|
||||
%4 = trunc <{=N} x i128> %3 to <{=N} x i64>
|
||||
store <{=N} x i64> %4, <{=N} x i64>* %hi, align 8
|
||||
%5 = trunc <{=N} x i128> %2 to <{=N} x i64>
|
||||
ret <{=N} x i64> %5
|
||||
|
||||
def mulhi(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u64, N]:
|
||||
hi = Ptr[Vec[u64, N]](1)
|
||||
self.mulx(other, hi)
|
||||
return hi[0]
|
||||
|
||||
@llvm
|
||||
def sqrt(self: Vec[f64, N]) -> Vec[f64, N]:
|
||||
declare <{=N} x double> @llvm.sqrt.v{=N}f64(<{=N} x double>)
|
||||
%0 = call <{=N} x double> @llvm.sqrt.v{=N}f64(<{=N} x double> %self)
|
||||
ret <{=N} x double> %0
|
||||
|
||||
@llvm
|
||||
def log(self: Vec[f64, N]) -> Vec[f64, N]:
|
||||
declare <{=N} x double> @llvm.log.v{=N}f64(<{=N} x double>)
|
||||
%0 = call <{=N} x double> @llvm.log.v{=N}f64(<{=N} x double> %self)
|
||||
ret <{=N} x double> %0
|
||||
|
||||
@llvm
|
||||
def cos(self: Vec[f64, N]) -> Vec[f64, N]:
|
||||
declare <{=N} x double> @llvm.cos.v{=N}f64(<{=N} x double>)
|
||||
%0 = call <{=N} x double> @llvm.cos.v{=N}f64(<{=N} x double> %self)
|
||||
ret <{=N} x double> %0
|
||||
|
||||
@llvm
|
||||
def fabs(self: Vec[f64, N]) -> Vec[f64, N]:
|
||||
declare <{=N} x double> @llvm.fabs.v{=N}f64(<{=N} x double>)
|
||||
%0 = call <{=N} x double> @llvm.fabs.v{=N}f64(<{=N} x double> %self)
|
||||
ret <{=N} x double> %0
|
||||
|
||||
# Conversion intrinsics
|
||||
@llvm
|
||||
def zext_double(self: Vec[u64, N]) -> Vec[u128, N]:
|
||||
%0 = zext <{=N} x i64> %self to <{=N} x i128>
|
||||
ret <{=N} x i128> %0
|
||||
|
||||
@llvm
|
||||
def trunc_half(self: Vec[u128, N]) -> Vec[u64, N]:
|
||||
%0 = trunc <{=N} x i128> %self to <{=N} x i64>
|
||||
ret <{=N} x i64> %0
|
||||
|
||||
@llvm
|
||||
def shift_trunc_half(self: Vec[u128, N]) -> Vec[u64, N]:
|
||||
%0 = insertelement <{=N} x i128> undef, i128 64, i32 0
|
||||
%1 = shufflevector <{=N} x i128> %0, <{=N} x i128> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = lshr <{=N} x i128> %self, %1
|
||||
%3 = trunc <{=N} x i128> %2 to <{=N} x i64>
|
||||
ret <{=N} x i64> %3
|
||||
|
||||
@llvm
|
||||
def to_u64(self: Vec[f64, N]) -> Vec[u64, N]:
|
||||
%0 = fptoui <{=N} x double> %self to <{=N} x i64>
|
||||
ret <{=N} x i64> %0
|
||||
|
||||
@llvm
|
||||
def to_u64(self: Vec[u1, N]) -> Vec[u64, N]:
|
||||
%0 = zext <{=N} x i1> %self to <{=N} x i64>
|
||||
ret <{=N} x i64> %0
|
||||
|
||||
@llvm
|
||||
def to_float(self: Vec[T, N]) -> Vec[f64, N]:
|
||||
%0 = uitofp <{=N} x {=T}> %self to <{=N} x double>
|
||||
ret <{=N} x double> %0
|
||||
|
||||
# Predication intrinsics
|
||||
@llvm
|
||||
def sub_if(self: Vec[T, N], other: Vec[T, N], mask: Vec[u1, N]) -> Vec[T, N]:
|
||||
%0 = sub <{=N} x {=T}> %self, %other
|
||||
%1 = select <{=N} x i1> %mask, <{=N} x {=T}> %0, <{=N} x {=T}> %self
|
||||
ret <{=N} x {=T}> %1
|
||||
|
||||
@llvm
|
||||
def sub_if(self: Vec[T, N], other: T, mask: Vec[u1, N]) -> Vec[T, N]:
|
||||
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
|
||||
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
|
||||
%2 = sub <{=N} x {=T}> %self, %1
|
||||
%3 = select <{=N} x i1> %mask, <{=N} x {=T}> %2, <{=N} x {=T}> %self
|
||||
ret <{=N} x {=T}> %3
|
||||
|
||||
# Gather-scatter
|
||||
@llvm
|
||||
def __getitem__(self: Vec[T, N], idx) -> T:
|
||||
%0 = extractelement <{=N} x {=T}> %self, i64 %idx
|
||||
ret {=T} %0
|
||||
|
||||
def __repr__(self):
|
||||
if N == 8:
|
||||
return f"<{self[0]}, {self[1]}, {self[2]}, {self[3]}, {self[4]}, {self[5]}, {self[6]}, {self[7]}>"
|
||||
elif N == 16:
|
||||
return f"<{self[0]}, {self[1]}, {self[2]}, {self[3]}, {self[4]}, {self[5]}, {self[6]}, {self[7]}, {self[8]}, {self[9]}, {self[10]}, {self[11]}, {self[12]}, {self[13]}, {self[14]}, {self[15]}>"
|
||||
else:
|
||||
return "?"
|
||||
# Misc
|
||||
def copy(self: Vec[T, N]) -> Vec[T, N]:
|
||||
return self
|
||||
|
||||
@llvm
|
||||
def mask(self: Vec[T, N], mask: Vec[u1, N], other: Vec[T, N]) -> Vec[T, N]:
|
||||
%0 = select <{=N} x i1> %mask, <{=N} x {=T}> %self, <{=N} x {=T}> %other
|
||||
ret <{=N} x {=T}> %0
|
||||
|
||||
def scatter(self: Vec[T, N]) -> List[T]:
|
||||
return [self[i] for i in staticrange(N)]
|
||||
|
@ -314,3 +871,15 @@ class Vec[T, N: Static[int]]:
|
|||
u8x16 = Vec[u8, 16]
|
||||
u8x32 = Vec[u8, 32]
|
||||
f32x8 = Vec[f32, 8]
|
||||
|
||||
|
||||
@llvm
|
||||
def bitcast_scatter[N: Static[int]](ptr_in: Ptr[Vec[u64, N]]) -> Ptr[u64]:
|
||||
%0 = bitcast <{=N} x i64>* %ptr_in to i64*
|
||||
ret i64* %0
|
||||
|
||||
|
||||
@llvm
|
||||
def bitcast_vectorize[N: Static[int]](ptr_in: Ptr[u64]) -> Ptr[Vec[u64, N]]:
|
||||
%0 = bitcast i64* %ptr_in to <{=N} x i64>*
|
||||
ret <{=N} x i64>* %0
|
||||
|
|
|
@ -102,6 +102,59 @@ class str:
|
|||
ptr: Ptr[byte]
|
||||
len: int
|
||||
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
class Tuple:
|
||||
@__internal__
|
||||
def __new__() -> Tuple:
|
||||
pass
|
||||
def __add__(self, obj):
|
||||
return __magic__.add(self, obj)
|
||||
def __mul__(self, n: Static[int]):
|
||||
return __magic__.mul(self, n)
|
||||
def __contains__(self, obj) -> bool:
|
||||
return __magic__.contains(self, obj)
|
||||
def __getitem__(self, idx: int):
|
||||
return __magic__.getitem(self, idx)
|
||||
def __iter__(self):
|
||||
yield from __magic__.iter(self)
|
||||
def __hash__(self) -> int:
|
||||
return __magic__.hash(self)
|
||||
def __repr__(self) -> str:
|
||||
return __magic__.repr(self)
|
||||
def __len__(self) -> int:
|
||||
return __magic__.len(self)
|
||||
def __eq__(self, obj: Tuple) -> bool:
|
||||
return __magic__.eq(self, obj)
|
||||
def __ne__(self, obj: Tuple) -> bool:
|
||||
return __magic__.ne(self, obj)
|
||||
def __gt__(self, obj: Tuple) -> bool:
|
||||
return __magic__.gt(self, obj)
|
||||
def __ge__(self, obj: Tuple) -> bool:
|
||||
return __magic__.ge(self, obj)
|
||||
def __lt__(self, obj: Tuple) -> bool:
|
||||
return __magic__.lt(self, obj)
|
||||
def __le__(self, obj: Tuple) -> bool:
|
||||
return __magic__.le(self, obj)
|
||||
def __pickle__(self, dest: Ptr[byte]):
|
||||
return __magic__.pickle(self, dest)
|
||||
def __unpickle__(src: Ptr[byte]) -> Tuple:
|
||||
return __magic__.unpickle(src)
|
||||
def __to_py__(self) -> Ptr[byte]:
|
||||
return __magic__.to_py(self)
|
||||
def __from_py__(src: Ptr[byte]) -> Tuple:
|
||||
return __magic__.from_py(src)
|
||||
def __to_gpu__(self, cache) -> Tuple:
|
||||
return __magic__.to_gpu(self, cache)
|
||||
def __from_gpu__(self, other: Tuple):
|
||||
return __magic__.from_gpu(self, other)
|
||||
def __from_gpu_new__(other: Tuple) -> Tuple:
|
||||
return __magic__.from_gpu_new(other)
|
||||
def __tuplesize__(self) -> int:
|
||||
return __magic__.tuplesize(self)
|
||||
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
class Array:
|
||||
|
@ -134,8 +187,6 @@ class TypeVar[T]: pass
|
|||
class ByVal: pass
|
||||
@__internal__
|
||||
class ByRef: pass
|
||||
@__internal__
|
||||
class Tuple: pass
|
||||
|
||||
@__internal__
|
||||
class ClassVar[T]:
|
||||
|
@ -147,10 +198,7 @@ class RTTI:
|
|||
|
||||
@__internal__
|
||||
@tuple
|
||||
class ellipsis:
|
||||
def __new__() -> ellipsis:
|
||||
return ()
|
||||
Ellipsis = ellipsis()
|
||||
class ellipsis: pass
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
|
|
|
@ -6,14 +6,18 @@ from internal.gc import (
|
|||
)
|
||||
from internal.static import vars_types, tuple_type, vars as _vars
|
||||
|
||||
|
||||
def vars(obj, with_index: Static[int] = 0):
|
||||
return _vars(obj, with_index)
|
||||
|
||||
|
||||
__vtables__ = Ptr[Ptr[cobj]]()
|
||||
__vtable_size__ = 0
|
||||
|
||||
@extend
|
||||
class ellipsis:
|
||||
def __new__() -> ellipsis:
|
||||
return ()
|
||||
|
||||
Ellipsis = ellipsis()
|
||||
|
||||
@extend
|
||||
class __internal__:
|
||||
|
@ -441,7 +445,6 @@ class __internal__:
|
|||
else:
|
||||
return default
|
||||
|
||||
|
||||
@extend
|
||||
class __magic__:
|
||||
# always present
|
||||
|
@ -635,7 +638,6 @@ class __magic__:
|
|||
def str(slf) -> str:
|
||||
return slf.__repr__()
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
@tuple
|
||||
class Import:
|
||||
|
@ -670,19 +672,16 @@ class Function:
|
|||
def __call__(self, *args) -> TR:
|
||||
return Function.__call_internal__(self, args)
|
||||
|
||||
|
||||
@tuple
|
||||
class PyObject:
|
||||
refcnt: int
|
||||
pytype: Ptr[byte]
|
||||
|
||||
|
||||
@tuple
|
||||
class PyWrapper[T]:
|
||||
head: PyObject
|
||||
data: T
|
||||
|
||||
|
||||
@extend
|
||||
class RTTI:
|
||||
def __new__() -> RTTI:
|
||||
|
|
|
@ -37,7 +37,7 @@ class float:
|
|||
@pure
|
||||
@llvm
|
||||
def __bool__(self) -> bool:
|
||||
%0 = fcmp one double %self, 0.000000e+00
|
||||
%0 = fcmp une double %self, 0.000000e+00
|
||||
%1 = zext i1 %0 to i8
|
||||
ret i8 %1
|
||||
|
||||
|
@ -118,10 +118,9 @@ class float:
|
|||
@pure
|
||||
@llvm
|
||||
def __ne__(a: float, b: float) -> bool:
|
||||
entry:
|
||||
%tmp = fcmp one double %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
%tmp = fcmp une double %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
|
@ -443,7 +442,7 @@ class float32:
|
|||
@pure
|
||||
@llvm
|
||||
def __bool__(self) -> bool:
|
||||
%0 = fcmp one float %self, 0.000000e+00
|
||||
%0 = fcmp une float %self, 0.000000e+00
|
||||
%1 = zext i1 %0 to i8
|
||||
ret i8 %1
|
||||
|
||||
|
@ -521,10 +520,9 @@ class float32:
|
|||
@pure
|
||||
@llvm
|
||||
def __ne__(a: float32, b: float32) -> bool:
|
||||
entry:
|
||||
%tmp = fcmp one float %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
%tmp = fcmp une float %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
|
@ -760,3 +758,4 @@ class float:
|
|||
return float32.__new__(double)
|
||||
|
||||
f32 = float32
|
||||
f64 = float
|
||||
|
|
|
@ -507,11 +507,16 @@ class UInt:
|
|||
def len() -> int:
|
||||
return N
|
||||
|
||||
i1 = Int[1]
|
||||
i8 = Int[8]
|
||||
i16 = Int[16]
|
||||
i32 = Int[32]
|
||||
i64 = Int[64]
|
||||
i128 = Int[128]
|
||||
|
||||
u1 = UInt[1]
|
||||
u8 = UInt[8]
|
||||
u16 = UInt[16]
|
||||
u32 = UInt[32]
|
||||
u64 = UInt[64]
|
||||
u128 = UInt[128]
|
||||
|
|
|
@ -538,7 +538,7 @@ def _task_loop_outline_template(gtid_ptr: Ptr[i32], btid_ptr: Ptr[i32], args):
|
|||
task = Ptr[TaskWithPrivates[P]](data)[0]
|
||||
priv = task.data
|
||||
gtid64 = int(gtid)
|
||||
if staticlen(S()) != 0:
|
||||
if staticlen(S) != 0:
|
||||
shared = Ptr[S](task.task.shareds)[0]
|
||||
_task_loop_body_stub(gtid64, priv, shared)
|
||||
else:
|
||||
|
|
|
@ -67,6 +67,16 @@ print f.a #: 20
|
|||
f[1] = -8
|
||||
print f.a #: 12
|
||||
|
||||
|
||||
def foo():
|
||||
print('foo')
|
||||
return 0
|
||||
v = [0]
|
||||
v[foo()] += 1
|
||||
#: foo
|
||||
print(v)
|
||||
#: [1]
|
||||
|
||||
#%% assign_err_1,barebones
|
||||
a, *b, c, *d = 1,2,3,4,5 #! multiple starred expressions in assignment
|
||||
|
||||
|
@ -163,12 +173,12 @@ assert True, "blah"
|
|||
try:
|
||||
assert False
|
||||
except AssertionError as e:
|
||||
print e.message[:15], e.message[-24:] #: Assert failed ( simplify_stmt.codon:164)
|
||||
print e.message[:15], e.message[-24:] #: Assert failed ( simplify_stmt.codon:174)
|
||||
|
||||
try:
|
||||
assert False, f"hehe {1}"
|
||||
except AssertionError as e:
|
||||
print e.message[:23], e.message[-24:] #: Assert failed: hehe 1 ( simplify_stmt.codon:169)
|
||||
print e.message[:23], e.message[-24:] #: Assert failed: hehe 1 ( simplify_stmt.codon:179)
|
||||
|
||||
#%% print,barebones
|
||||
print 1,
|
||||
|
|
|
@ -773,7 +773,7 @@ z = tuple(c)
|
|||
print z, z.__class__.__name__ #: (3, 'heh') Tuple[int,str]
|
||||
|
||||
#%% static_unify,barebones
|
||||
def foo(x: Callable[[1,2], 3]): pass #! '1' does not match expected type 'T1'
|
||||
def foo(x: Callable[[1,2], 3]): pass #! '2' does not match expected type 'T1'
|
||||
|
||||
#%% static_unify_2,barebones
|
||||
def foo(x: List[1]): pass #! '1' does not match expected type 'T'
|
||||
|
|
|
@ -1971,3 +1971,30 @@ foo(5) #! generic 'b' not provided
|
|||
def f(a, b, T: type):
|
||||
print(a, b)
|
||||
f(1, 2) #! generic 'T' not provided
|
||||
|
||||
#%% variardic_tuples,barebones
|
||||
|
||||
class Foo[N: Static[int]]:
|
||||
x: Tuple[N, str]
|
||||
|
||||
def __init__(self):
|
||||
self.x = ("hi", ) * N
|
||||
|
||||
f = Foo[5]()
|
||||
print(f.__class__.__name__)
|
||||
#: Foo[5]
|
||||
print(f.x.__class__.__name__)
|
||||
#: Tuple[str,str,str,str,str]
|
||||
print(f.x)
|
||||
#: ('hi', 'hi', 'hi', 'hi', 'hi')
|
||||
|
||||
print(Tuple[int, int].__class__.__name__)
|
||||
#: Tuple[int,int]
|
||||
print(Tuple[3, int].__class__.__name__)
|
||||
#: Tuple[int,int,int]
|
||||
print(Tuple[0].__class__.__name__)
|
||||
#: Tuple
|
||||
print(Tuple[-5, int].__class__.__name__)
|
||||
#: Tuple
|
||||
print(Tuple[5, int, str].__class__.__name__)
|
||||
#: Tuple[int,str,int,str,int,str,int,str,int,str]
|
||||
|
|
Loading…
Reference in New Issue